In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import config
import os

In [None]:
raw_path = config.RAW_DATA_PATH

path = os.path.join(raw_path, 'wandb_export_2025-12-06T16_47_29.458+01_00.csv')
df = pd.read_csv(path)

other_paths = ['wandb_export_2025-12-06T18_23_35.335+01_00.csv', 'wandb_export_2025-12-06T18_24_22.357+01_00.csv']
print(df.shape)

dfs = [df]
for other_path in other_paths:
    df_other = pd.read_csv(os.path.join(raw_path, other_path))
    dfs.append(df_other)

df = pd.concat(dfs, ignore_index=True, sort = False)
print(df.shape)

In [None]:
df = df[df['mse']>0.02]
df = df[(df['silhouette_score']>0.2) & (df['mse']>0.028)]

In [None]:
mse_col = "mse"  # update if name differs

metrics = [
    "silhouette_score",
    "calinski_harabasz_score_adjusted",
    "noise_count",
    "auc_mid_kl",
    "auc_sev_kl",
    "spearman_correlation_klscore",
    "f1_score_mri_cart_yn",
    "f1_score_mri_osteo_yn"
]

# Set a clean style
sns.set(style="whitegrid")

# Create subplots
n_cols = 3
n_rows = int((len(metrics) + n_cols - 1) / n_cols)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 10))
axes = axes.flatten()

for i, metric in enumerate(metrics):
    ax = axes[i]
    
    # Scatter + regression line
    sns.regplot(
        data=df,
        x=metric,
        y=mse_col,
        ax=ax,
        scatter_kws={"alpha": 0.6},
        line_kws={"linewidth": 2}
    )
    
    ax.set_title(f"MSE vs {metric}", fontsize=12)
    ax.set_xlabel(metric)
    ax.set_ylabel("MSE")

# Hide empty axes if metrics < grid size
for j in range(i+1, len(axes)):
    axes[j].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
for m in metrics:
    sns.jointplot(
        data=df,
        x=m,
        y="mse",
        kind="reg",      # or "hex", or "kde"
        height=5
    )

In [None]:
metrics = [
    "silhouette_score",
    "calinski_harabasz_score_adjusted",
    "noise_count",
    "auc_mid_kl",
    "auc_sev_kl",
    "spearman_correlation_klscore",
]

metric_names = [
    'Silhouette Score',
    'Calinski Harabasz Score (Adjusted)',
    'Noise Count',
    'AUC_KL',
    'AUC_KL>3',
    'SRC_KL',
]

fig, axes = plt.subplots(3, 2, figsize=(16, 18))
axes = axes.flatten()

for i, metric in enumerate(metrics):
    ax = axes[i]

    sns.regplot(
        data=df, x=metric, y="mse",
        lowess=True,
        scatter_kws={"alpha": 0.5},
        line_kws={"color": "red"},
        ax=ax
    )
    
    ax.set_title(f"MSE vs {metric_names[i]}")
    ax.set_xlabel(metric_names[i])
    ax.set_ylabel("MSE")

plt.tight_layout()
plt.show()
