In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from XrayTo3DShape import MODEL_NAMES, filter_wandb_run, get_run_from_model_name

In [None]:
generalized_metrics_template = "../../runs/2d-3d-benchmark/{run_id}/evaluation/metric-log.csv"
clinical_parameters_template = "../../runs/2d-3d-benchmark/{run_id}/evaluation/{clinical_log_name}"

In [None]:
hip_COLUMNS = ["ASIS_L", "ASIS_R", "PT_L", "PT_R", "IS_L", "IS_R", "PSIS_L", "PSIS_R"]
vertebra_COLUMNS = ["spl", "spa", "avbh", "pvbh", "svbl", "ivbl", "vcl"]
femur_COLUMNS = [
    "FHR",
    "FHC",
    "NSA",
    "FNA_x",
    "FNA_y",
    "FNA_z",
    "FDA_x",
    "FDA_y",
    "FDA_z",
]
anatomy_wise_details = {
    "hip": {
        "columns": hip_COLUMNS,
        "clinical_log_name": "hip_landmark_error.csv",
        "subject_id_post_fix": "_hip_msk",
    },
    "vertebra": {
        "columns": vertebra_COLUMNS,
        "clinical_log_name": "vertebra_morphometry_error.csv",
        "subject_id_post_fix": "-seg-vert-msk",
    },
    "femur": {
        "columns": femur_COLUMNS,
        "clinical_log_name": "femur_morphometry_error.csv",
        "subject_id_post_fix": "s0174_femur_righ",  # this is a prefix
    },
}

In [None]:
clinical_metrics_dict = {}
clinical_metrics_dict["hip"] = [
    "ASIS_L",
    "ASIS_R",
    "PT_L",
    "PT_R",
    "IS_L",
    "IS_R",
    "PSIS_L",
    "PSIS_R",
]

In [None]:
for ANATOMY in [
    "vertebra",
    "femur",
    "hip",
]:
    runs = filter_wandb_run(anatomy=ANATOMY, tags=["model-compare", "dropout"])
    merged_csv_dict = {}
    model_run_dict = {}
    # save each dataframe into dict
    for model in MODEL_NAMES:
        run = get_run_from_model_name(model, runs)
        print(run.id, model)

        try:
            clinical_csv = pd.read_csv(
                clinical_parameters_template.format(
                    run_id=run.id,
                    clinical_log_name=anatomy_wise_details[ANATOMY]["clinical_log_name"],
                )
            )
            generalized_metric_csv = pd.read_csv(
                generalized_metrics_template.format(run_id=run.id)
            )

            post_fix = anatomy_wise_details[ANATOMY]["subject_id_post_fix"]
            if ANATOMY == "femur":
                post_fix_length = len(post_fix)  # this is actually a prefix so no negative sign
            else:
                post_fix_length = -len(post_fix)
            generalized_metric_csv["id"] = generalized_metric_csv["subject-id"].str[
                :post_fix_length
            ]
            # generalized_metric_csv['id'] = generalized_metric_csv['subject-id'].str[:5]
            merged_csv = pd.merge(clinical_csv, generalized_metric_csv, on="id", how="left")

            merged_csv_dict[model] = merged_csv
            model_run_dict[model] = run.id
        except FileNotFoundError as e:
            print(e)

    for clinical_metric in anatomy_wise_details[ANATOMY]["columns"]:
        # generate subject list and remove outliers
        threshold = merged_csv[clinical_metric].quantile(0.90)
        merged_filtered_csv = merged_csv[merged_csv[clinical_metric] < threshold]

        subjects_list = merged_filtered_csv["id"]

        NUM_SUBJECTS = len(subjects_list)
        for subject in subjects_list[:NUM_SUBJECTS]:
            DSC = []
            CM = []
            fig = plt.figure()

            for model in MODEL_NAMES:

                try:
                    df = merged_csv_dict[model]
                    df_filtered = df[df[clinical_metric] < threshold]
                    model_dsc = df_filtered[df_filtered["id"] == subject]["DSC"].values[0]
                    DSC.append(model_dsc)
                    model_cm = df_filtered[df_filtered["id"] == subject][clinical_metric].values[0]
                    CM.append(model_cm)
                    plt.annotate(f"{model}({model_run_dict[model]})", (model_dsc, model_cm))
                except (IndexError, KeyError) as e:
                    print(e)

            ind = np.argsort(DSC)
            DSC_sorted = np.take(DSC, ind)
            CM_sorted = np.take(CM, ind)
            plt.plot(DSC_sorted, CM_sorted)
            plt.xlabel("DSC")
            plt.ylabel(clinical_metric)
            plt.title(f"{subject}")
            plt.tight_layout()
            fig_save_path = Path(f"podium_plot/{ANATOMY}/{clinical_metric}/{subject}.png")
            fig_save_path.parent.mkdir(exist_ok=True, parents=True)  # create required subdirs
            plt.savefig(str(fig_save_path))
            plt.close()