In [None]:
WANDB_PROJECT = "temple/lung-registration"

In [None]:
import sys

sys.path.append('../')

import wandb
import numpy as np
import pandas as pd

In [None]:
def to_string(x):
    x = np.array(x)
    x = x.flatten()
    s = "-".join(x)
    return s


def is_any_element_in_list(list1, list2):
    for element in list1:
        if f"test_{element}" in list2:
            return True
    return False

In [None]:
def main_fetch_values(tags=None):
    api = wandb.Api()
    runs_wandb = api.runs(WANDB_PROJECT)
    runs_dict = {}
    for run in runs_wandb:
        if tags is None or is_any_element_in_list(tags, run.tags) or "identity" in run.tags:
            run_dict = {"name": run.id}

            exclude_patterns = ["weights", "gradients", "step", "_wandb", "_timestamp", "graph", "val_loss_epoch",
                                "train_loss_epoch"]

            # Update the dictionary while excluding keys with specific patterns
            def should_include_key(key):
                for pattern in exclude_patterns:
                    if pattern in key:
                        return False
                return True

            # .summary contains the output keys/values for metrics like accuracy.
            #  We call ._json_dict to omit large files 
            run_dict.update({f"m_{k}": v for k, v in run.summary._json_dict.items() if should_include_key(k)})

            # .config contains the hyperparameters.
            #  We remove special values that start with _.
            run_dict.update(
                {f"h_{k}": v for k, v in run.config.items()
                 if not k.startswith("_")})

            # .name is the human-readable name of the run.
            runs_dict.update({run.name: run_dict})

    runs_df = pd.DataFrame(runs_dict)
    runs_df = runs_df.transpose()

    runs_df["h_criteria_warped"] = runs_df["h_criteria_warped"].apply(to_string)
    runs_df["h_criteria_flow"] = runs_df["h_criteria_flow"].apply(to_string)
    runs_df["m_duration_hours"] = runs_df["m__runtime"] / 3600
    runs_df.drop(columns=["m__runtime"], inplace=True)
    runs_df = runs_df.reindex(sorted(runs_df.columns), axis=1)
    runs_df = runs_df.transpose()
    return runs_df

In [None]:
runs = main_fetch_values([
    # "criteria_warped",
    # "criteria_warped_mul",
    # "criteria_flow",  # make weight adjustment later :) hyperparameter optimisation lmao
    # "registration_depth",
    # "registration_sampling",
    # "registration_target",
    # "registration_stride",
    "identity_loss",
    "temporal_dependence"
])
runs = runs.transpose()
id_run = runs.loc["transmorph-identity"]

runs["m_score"] = (
        (
            (runs['m_mse_mean_epoch'] - id_run['m_mse_mean_epoch']) / (0.0 - id_run['m_mse_mean_epoch'])
        ) * 0.45 +
        (
            (runs['m_ssim_mean_epoch'] - id_run['m_ssim_mean_epoch']) / (1.0 - id_run['m_ssim_mean_epoch'])
        ) * 0.45 +
        (
            1.0 - (runs['m_perc_neg_jac_det_mean_epoch'] / runs['m_perc_neg_jac_det_mean_epoch'].max())
        ) * 0.05 +
        (
            1.0 - (runs['m_duration_hours'] / runs['m_duration_hours'].max())
        ) * 0.05
)
runs = runs.sort_values("m_score", ascending=False)
runs
