In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import scienceplots
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from tueplots import fontsizes

plt.style.use(["science", "no-latex"])
from XrayTo3DShape import MODEL_NAMES, filter_wandb_run, get_run_from_model_name

In [None]:
fontsizes.neurips2022()

In [None]:
COLUMNS = ["ASIS_L", "ASIS_R", "PT_L", "PT_R", "IS_L", "IS_R", "PSIS_L", "PSIS_R"]
ROWS = [
    "DSC",
    "NSD",
    "ASD",
    "HD95",
]

In [None]:
generalized_metrics_template = "../runs/2d-3d-benchmark/{run_id}/evaluation/metric-log.csv"
clinical_parameters_template = "../runs/2d-3d-benchmark/{run_id}/evaluation/hip_landmark_error.csv"

In [None]:
ANATOMY = "hip"
runs = filter_wandb_run(anatomy=ANATOMY, tags=["model-compare", "dropout"])
model = MODEL_NAMES[1]
run = get_run_from_model_name(model, runs)
print(run.id, model)

In [None]:
clinical_csv = pd.read_csv(clinical_parameters_template.format(run_id=run.id))
generalized_metric_csv = pd.read_csv(generalized_metrics_template.format(run_id=run.id))

In [None]:
generalized_metric_csv["id"] = generalized_metric_csv["subject-id"].str[:5]

In [None]:
merged_csv = pd.merge(clinical_csv, generalized_metric_csv, on="id", how="left")

In [None]:
subplot_sz = 5
rows = 1
cols = len(COLUMNS)
fig, ax = plt.subplots(rows, cols, figsize=(cols * subplot_sz, rows * subplot_sz))
rw = "DSC"
for clm_idx, clm in enumerate(COLUMNS):

    threshold = merged_csv[clm].quantile(0.90)
    merged_filtered_csv = merged_csv[merged_csv[clm] < threshold]
    # merged_filtered_csv = merged_csv

    dsc = merged_filtered_csv[rw]
    y = merged_filtered_csv[clm]

    regressor = LinearRegression().fit(dsc.values.reshape(-1, 1), y.values)
    y_pred = regressor.predict(dsc.values.reshape(-1, 1))
    r2 = r2_score(y, y_pred)

    ax[clm_idx].plot(dsc, y_pred)
    ax[clm_idx].scatter(dsc, y, s=subplot_sz * 5)

    ax[clm_idx].set_title(r"$R^2={:.2f}$".format(r2), fontsize=25)
    # ax[idx].xaxis.set_ticklabels([])
    # ax[idx].yaxis.set_ticklabels([])
    ax[clm_idx].set_xlabel(rw, fontsize=25)
    ax[clm_idx].set_ylabel(clm, fontsize=25)
    ax[clm_idx].xaxis.set_tick_params(labelsize=25)
    ax[clm_idx].yaxis.set_tick_params(labelsize=25)
plt.tight_layout()
plt.savefig("hip_dice_clinical_relationship_dsc.pdf")

In [None]:
subplot_sz = 5
rows = len(ROWS)
cols = len(COLUMNS)
fig, ax = plt.subplots(rows, cols, figsize=(cols * subplot_sz, rows * subplot_sz))
for rw_idx, rw in enumerate(ROWS):
    for clm_idx, clm in enumerate(COLUMNS):

        threshold = merged_csv[clm].quantile(0.90)
        merged_filtered_csv = merged_csv[merged_csv[clm] < threshold]
        # merged_filtered_csv = merged_csv

        dsc = merged_filtered_csv[rw]
        y = merged_filtered_csv[clm]

        regressor = LinearRegression().fit(dsc.values.reshape(-1, 1), y.values)
        y_pred = regressor.predict(dsc.values.reshape(-1, 1))
        r2 = r2_score(y, y_pred)

        ax[rw_idx][clm_idx].plot(dsc, y_pred)
        ax[rw_idx][clm_idx].scatter(dsc, y, s=subplot_sz * 5)

        ax[rw_idx][clm_idx].set_title(r"$R^2={:.2f}$".format(r2), fontsize=20)
        # ax[idx].xaxis.set_ticklabels([])
        # ax[idx].yaxis.set_ticklabels([])
        # ax[rw_idx][clm_idx].set_xlabel(rw,fontsize=20)
        # ax[rw_idx][clm_idx].set_ylabel(clm,fontsize=20)
plt.savefig("hip_dice_clinical_relationship.pdf")

In [None]:
clinical_csv.head(), generalized_metric_csv.head()