In [None]:
import config
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats
from IPython.display import display
import json
import datetime
import math

from utils.data_exploration_utils import plot_hist, scatterplot, barplots, check_img_resp_cluster_klscore

In [None]:
# data/processed/2025-08-11_hdbscan/run150

today = datetime.date.today()

proc_dir = config.PROC_DATA_PATH

folder = "2025-08-23_hdbscan"
folder_date = folder.split('_')[0]
run = "run2"

img_path = config.SCHULTHESS_DATAPATH
# img_path = os.path.join(img_path, "600x600_imgs")

filepath = os.path.join(proc_dir, folder, "questionnaire", run)

df = pd.read_csv(os.path.join(filepath, f'questionnaire_{run}_umap_hdbscan_scaled_wKL.csv'))

display(df.head())

In [None]:
ids = pd.read_csv(os.path.join(filepath, "x_umap_ids.csv"))
print(ids.shape)

# Some Exploration

In [None]:
df['KL-Score'].value_counts()

In [None]:
df['cluster_label'].value_counts().reset_index().sort_values('cluster_label')

In [None]:
values = df['cluster_label'].value_counts().reset_index().sort_values(by='cluster_label')

plt.bar(values['cluster_label'], values['count'], color = 'skyblue')
plt.xlabel('Cluster Label')
plt.ylabel('Count')
plt.show()

# Plot

In [None]:
from utils.hdbscan_utils import plot_hdbscan

In [None]:
folder2 = "2025-08-11_data_exploration"
df_filename = "inmodi_data_questionnaire_kl_woSC.csv"

df2 = pd.read_csv(os.path.join(proc_dir, folder2, df_filename))

df2 = ids.merge(df2, right_on='name', left_on='id', how='inner')

print(df2.shape)

In [None]:
embeddings_path = os.path.join(filepath, "X_umap_embeddings.npy")

X_umap = np.load(embeddings_path)

print(X_umap.shape)

In [None]:
df2.columns

In [None]:
df.columns

In [None]:
df2 = df2.merge(df[['name', 'cluster_label']], on = "name", how='left')
df2.drop(columns=['Unnamed: 0'], inplace=True)
print(df2.shape)

In [None]:
# replace NaN values with -1
df2['cluster_label'].fillna(-1, inplace=True)

# add index to df2 from 1 to x
df2.reset_index(inplace=True)
display(df2.head())

df2.drop(columns=['index'], inplace=True)
df2.to_csv(os.path.join(filepath, f"questionnaire_{run}_umap_hdbscan_scaled_wKL_v2.csv"), index=True, index_label = 'index')

In [None]:
display(df2['cluster_label'].value_counts().reset_index().sort_values('cluster_label'))

In [None]:
# plot_hdbscan(X=X_umap,
#              labels = df2['cluster_label'].to_numpy(),
# )

In [None]:
# def plot_hdbscan_wy(
#     X,
#     labels,
#     probabilities=None,
#     parameters=None,
#     ground_truth=False,
#     ax=None,
#     save_path=None,
#     size_min=8,
#     size_max=80,
#     use_first_three_dims=True,
# ):
#     """
#     Auto-plots 2D or 3D depending on X shape. If X has >=3 features, uses 3D.
#     - sizes scale with `probabilities` (in [0,1]); noise gets fixed small size.
#     - noise label is -1 (black X markers).
#     """
#     X = np.asarray(X)
#     labels = np.asarray(labels)
#     n, d = X.shape

#     # choose 2D vs 3D
#     is_3d = d >= 3
#     if is_3d and use_first_three_dims:
#         Xp = X[:, :3]
#     else:
#         if d < 2:
#             raise ValueError("X must have at least 2 features to plot.")
#         Xp = X[:, :2]

#     # probabilities -> sizes
#     if probabilities is None:
#         probabilities = np.ones(n, dtype=float)
#     else:
#         probabilities = np.asarray(probabilities, dtype=float)
#         # make sure it's in a sane range
#         pmin, pmax = probabilities.min(), probabilities.max()
#         if pmax > 1.0 or pmin < 0.0:
#             # normalize to 0..1 if needed
#             probabilities = (probabilities - pmin) / (pmax - pmin + 1e-12)

#     sizes = size_min + (size_max - size_min) * probabilities

#     # figure / axes
#     if ax is None:
#         fig = plt.figure(figsize=(9, 5))
#         if is_3d:
#             ax = fig.add_subplot(111, projection='3d')
#         else:
#             ax = fig.add_subplot(111)

#     unique_labels = np.unique(labels)
#     # color map per cluster (exclude noise for color count)
#     n_colors = len(unique_labels) - (1 if -1 in unique_labels else 0)
#     # fall back to at least 1 color to avoid linspace errors
#     n_colors = max(n_colors, 1)
#     color_list = [plt.cm.Spectral(t) for t in np.linspace(0, 1, n_colors)]

#     # build a deterministic color map for non-noise clusters
#     non_noise = [lab for lab in unique_labels if lab != -1]
#     color_map = {lab: color_list[i % len(color_list)] for i, lab in enumerate(sorted(non_noise))}

#     # plot each cluster once (vectorized scatter)
#     handles = []
#     labels_for_legend = []

#     for k in sorted(unique_labels, key=lambda x: (x == -1, x)):
#         mask = labels == k
#         if not np.any(mask):
#             continue

#         if k == -1:
#             # noise: black 'x', fixed size
#             if is_3d:
#                 h = ax.scatter(Xp[mask, 0], Xp[mask, 1], Xp[mask, 2],
#                                marker='x', c='k', s=size_min, linewidths=0.8, alpha=0.9)
#             else:
#                 h = ax.scatter(Xp[mask, 0], Xp[mask, 1],
#                                marker='x', c='k', s=size_min, linewidths=0.8, alpha=0.9)
#             handles.append(h); labels_for_legend.append("Noise")
#         else:
#             col = color_map[k]
#             if is_3d:
#                 h = ax.scatter(Xp[mask, 0], Xp[mask, 1], Xp[mask, 2],
#                                marker='o', c=[col], s=sizes[mask], edgecolors='k', linewidths=0.2, alpha=0.9)
#             else:
#                 h = ax.scatter(Xp[mask, 0], Xp[mask, 1],
#                                marker='o', c=[col], s=sizes[mask], edgecolors='k', linewidths=0.2, alpha=0.9)
#             handles.append(h); labels_for_legend.append(f"Cluster {k}")

#     # title
#     n_clusters_ = len(non_noise)
#     pre = "True" if ground_truth else "Estimated"
#     title = f"{pre} number of clusters: {n_clusters_}"
#     if parameters is not None and isinstance(parameters, dict) and len(parameters):
#         param_str = ", ".join(f"{k}={v}" for k, v in parameters.items())
#         title += f" | {param_str}"
#     ax.set_title(title)

#     # axes labels
#     if is_3d:
#         ax.set_xlabel("dim 0"); ax.set_ylabel("dim 1"); ax.set_zlabel("dim 2")
#         # a gentle view angle
#         ax.view_init(elev=18, azim=35)
#     else:
#         ax.set_xlabel("dim 0"); ax.set_ylabel("dim 1")

#     # legend (avoid too many items)
#     if len(handles) <= 20:
#         ax.legend(handles, labels_for_legend, title="Cluster Labels", fontsize='small', loc="best")

#     plt.tight_layout()
#     if save_path is not None:
#         plt.savefig(save_path, bbox_inches='tight', dpi=150)

## Plot UMAP Embeddings

## Plot HDBSCAN

In [None]:
def add_jitter(X, scale=0.02):
    """Add Gaussian noise to spread out overlapping points."""
    return X + np.random.normal(0, scale, X.shape)

In [None]:
def plot_hdbscan_highlight_kl(
    X,
    labels,
    y_kl,                 # array-like of KL-scores per point
    focus_kl,             # the KL value to highlight (e.g., 0,1,2,3,4)
    probabilities=None,
    parameters=None,
    ground_truth=False,
    ax=None,
    save_path=None,
    size_min=8,
    size_max=80,
    use_first_three_dims=True,
    gray_alpha=0.75,      # transparency for non-focused points
    gray_size_factor=1,  # size multiplier for gray points
    color_alpha = 0.3,
    global_color_map=None  # if provided, use this color map for clusters
):
    """
    Plots clusters but highlights only points with y_kl == focus_kl in color.
    All other points are rendered in light gray. Noise is still 'x' markers.

    - Keeps your 2D/3D auto logic.
    - Sizes scale with `probabilities` for focused points; gray points use reduced size.
    """
    X = np.asarray(X)
    X = add_jitter(X, scale = 0.05)
    labels = np.asarray(labels)
    y_kl = np.asarray(y_kl)
    n, d = X.shape

    # choose 2D vs 3D
    is_3d = d >= 3
    if is_3d and use_first_three_dims:
        Xp = X[:, :3]
    else:
        if d < 2:
            raise ValueError("X must have at least 2 features to plot.")
        Xp = X[:, :2]

    # probabilities -> sizes
    if probabilities is None:
        probabilities = np.ones(n, dtype=float)
    else:
        probabilities = np.asarray(probabilities, dtype=float)
        pmin, pmax = probabilities.min(), probabilities.max()
        if pmax > 1.0 or pmin < 0.0:
            probabilities = (probabilities - pmin) / (pmax - pmin + 1e-12)
    sizes = size_min + (size_max - size_min) * probabilities

    # figure / axes
    created_fig = False
    if ax is None:
        fig = plt.figure(figsize=(9, 5))
        created_fig = True
        if is_3d:
            ax = fig.add_subplot(111, projection='3d')
        else:
            ax = fig.add_subplot(111)

    # masks
    focus_mask = (y_kl == focus_kl)
    other_mask = ~focus_mask

    # --- 1) plot NON-focused points in uniform light gray (behind)
    if np.any(other_mask):
        gray_sizes = (size_min + (size_max - size_min) * 0.3) * gray_size_factor
        if is_3d:
            ax.scatter(Xp[other_mask, 0], Xp[other_mask, 1], Xp[other_mask, 2],
                       marker='o', c='lightgray', s=gray_sizes, alpha=gray_alpha)
        else:
            ax.scatter(Xp[other_mask, 0], Xp[other_mask, 1],
                       marker='o', c='lightgray', s=gray_sizes, alpha=gray_alpha)

    # --- 2) plot FOCUSED points with the original cluster coloring
    unique_labels = np.unique(labels[focus_mask]) if np.any(focus_mask) else np.array([])
    non_noise = [lab for lab in unique_labels if lab != -1]

    # build color map for focused clusters
    if global_color_map is None:
        n_colors = len(non_noise)
        n_colors = max(n_colors, 1)
        color_list = [plt.cm.Spectral(t) for t in np.linspace(0, 1, n_colors)]
        color_map = {lab: color_list[i % len(color_list)] for i, lab in enumerate(sorted(non_noise))}
    else:
        color_map = global_color_map

    handles, labels_for_legend = [], []

    # plot focused clusters (and noise)
    for k in sorted(set(unique_labels), key=lambda x: (x == -1, x)):
        mask = focus_mask & (labels == k)
        if not np.any(mask):
            continue

        if k == -1:
            # noise: black 'x'
            if is_3d:
                h = ax.scatter(Xp[mask, 0], Xp[mask, 1], Xp[mask, 2],
                               marker='x', c='k', s=size_min, linewidths=0.8, alpha=0.9)
            else:
                h = ax.scatter(Xp[mask, 0], Xp[mask, 1],
                               marker='x', c='k', s=size_min, linewidths=0.8, alpha=0.9)
            handles.append(h); labels_for_legend.append(f"Noise (KL={focus_kl})")
        else:
            col = color_map[k]
            if is_3d:
                h = ax.scatter(Xp[mask, 0], Xp[mask, 1], Xp[mask, 2],
                               marker='o', c=[col], s=sizes[mask], edgecolors='k', linewidths=0.2, alpha=color_alpha)
            else:
                h = ax.scatter(Xp[mask, 0], Xp[mask, 1],
                               marker='o', c=[col], s=sizes[mask], edgecolors='k', linewidths=0.2, alpha=color_alpha)
            handles.append(h); labels_for_legend.append(f"Cluster {k} (KL={focus_kl})")

    # title
    pre = "True" if ground_truth else "Estimated"
    n_clusters_ = len(non_noise)
    title = f"{pre} clusters in KL={focus_kl}: {n_clusters_}"
    if parameters and isinstance(parameters, dict) and len(parameters):
        param_str = ", ".join(f"{k}={v}" for k, v in parameters.items())
        title += f" | {param_str}"
    ax.set_title(title)

    # axes labels
    if is_3d:
        ax.set_xlabel("dim 0"); ax.set_ylabel("dim 1"); ax.set_zlabel("dim 2")
        ax.view_init(elev=18, azim=35)
    else:
        ax.set_xlabel("dim 0"); ax.set_ylabel("dim 1")

    if len(handles) <= 20 and len(handles) > 0:
        ax.legend(handles, labels_for_legend, title="Focused clusters", fontsize='small', loc="best")

    plt.tight_layout()
    if save_path is not None and created_fig:
        plt.savefig(save_path, bbox_inches='tight', dpi=150)
    return ax

In [None]:
def make_cluster_color_map(labels, cmap=plt.cm.tab20):
    unique_labels = sorted(set(labels) - {-1})   # exclude noise
    n_colors = len(unique_labels)
    n_colors = max(n_colors, 1)
    color_map = {}
    color_list = [plt.cm.Spectral(t) for t in np.linspace(0, 1, n_colors)]
    color_map = {lab: color_list[i % len(color_list)] for i, lab in enumerate(sorted(unique_labels))}
    color_map[-1] = (0, 0, 0, 1)  # black for noise
    return color_map

In [None]:
color_map = make_cluster_color_map(df2['cluster_label'].unique())

In [None]:
kl = list(df2['KL-Score'].unique())

kl_indexes = {}
for i in kl:
    index_ids = df2[df2['KL-Score'] == i].index
    kl_indexes[i] = index_ids
    print(len(index_ids))

print("To Double check that the indexing works, should only see on KL-Score in Value Counts:")
for i in kl:
    print(df2['KL-Score'].iloc[kl_indexes[i]].value_counts())

In [None]:
for i in kl:
    plot_hdbscan_highlight_kl(X = X_umap, labels = df2['cluster_label'], y_kl=df2['KL-Score'], focus_kl=i, global_color_map=color_map,
                              gray_alpha=0.5)

# Correlation

## Correlation with Questionnaire Scores

**!!! Careful, this actually does not make sense, since the clusters are not ranked, therefore we need to think of another way to test/measure ordinal features (questionnaire data) with categorical values (cluster labels)**

### Kruskal-Wallis (non-parametric ANOVA)

To test whether distributions of a feature differ significantly across clusters.Only tests for differences, not effect size.

H_0: The samples have the same central tendency, so samples originate from the same distribution.
H_1: at least one sample doesn't have the same central tendency, so at least one sample stochastically dominates one other sample.

In [None]:
from scipy.stats import kruskal

def kruskal_wallis(df, feature, cluster_col = 'cluster_label'):
    groups = [df.loc[df[cluster_col]==cluster, feature] for cluster in df[cluster_col].unique()]
    stat, p = kruskal(*groups)
    return stat, p

In [None]:
columns_corr = [  
       'oks_q1', 'oks_q2', 'oks_q3', 'oks_q4',
       'oks_q5', 'oks_q6', 'oks_q7', 'oks_q8', 'oks_q9', 'oks_q10', 'oks_q11',
       'oks_q12', 'koos_s1', 
       'koos_s2', 'koos_s3', 'koos_s4', 'koos_s5', 'koos_s6',
       'koos_s7', 'koos_p1', 'koos_p2', 'koos_p3', 'koos_p4', 'koos_p5',
       'koos_p6', 'koos_p7', 'koos_p8', 'koos_p9', 'koos_a1', 'koos_a2',
       'koos_a3', 'koos_a4', 'koos_a5', 'koos_a6', 'koos_a7', 'koos_a8',
       'koos_a9', 'koos_a10', 'koos_a11', 'koos_a12', 'koos_a13', 'koos_a14',
       'koos_a15', 'koos_a16', 'koos_a17',  'koos_sp1', 'koos_sp2', 'koos_sp3',
       'koos_sp4', 'koos_sp5', 
       'koos_q1', 'koos_q2', 'koos_q3', 'koos_q4'
       #, 'cluster_label'
       ] 

results = []
for feature in columns_corr:
       stat, p = kruskal_wallis(df, feature, cluster_col = 'cluster_label')
       # print(f"Kruskal-Wallis test for {feature}: H-statistic = {stat:.3f}, p-value = {p:.3e}")
       results.append({'feature': feature, 'H-statistic': stat, 'p-value': p})

results_df = pd.DataFrame(results)
# results_df = results_df.sort_values('p-value')

display(results_df.sort_values('p-value').head())
results_df.to_csv(os.path.join(filepath, f"kruskal_wallis_results_{run}.csv"), index=False)


# corr_types = ['spearman']
# for corr in corr_types:
#     print(f"Calculating {corr} correlation...")

#     df_corr = df[columns_corr].corr(method=corr)
#     plt.figure(figsize=(12, 8))
#     sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .8})
#     plt.title(f"{corr.capitalize()} Correlation Heatmap")
#     #plt.savefig(os.path.join(img_save_dir, f"{corr}corr.png"))
#     plt.show()
results_df[results_df['p-value'] >= 0.05]
plt.figure(figsize=(10, 6))
sns.barplot(data = results_df, x='feature', y='H-statistic')
plt.xticks(rotation=90)
plt.show()

H-statistic:
* small: if all clusters have similar distribution, their average rank will be similar
* large: at least one group's distribution is shifted, relative to the others.

In [None]:
# cols = ['koos_s1', 
#        'koos_s2', 'koos_s3', 'koos_s4', 'koos_s5', 'koos_s6',
#        'koos_s7'
#        , 'cluster_label']

# corr_types = ['spearman']
# for corr in corr_types:
#     print(f"Calculating {corr} correlation...")

#     df_corr = df[columns_corr].corr(method=corr)
#     plt.figure(figsize=(12, 8))
#     sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .8})
#     plt.title(f"{corr.capitalize()} Correlation Heatmap")
#     #plt.savefig(os.path.join(img_save_dir, f"{corr}corr.png"))
#     plt.show()

In [None]:
# cols = [
#        'koos_p1', 'koos_p2', 'koos_p3', 'koos_p4', 'koos_p5',
#        'koos_p6', 'koos_p7', 'koos_p8', 'koos_p9'
#        , 'cluster_label']
# corr_types = ['spearman']
# for corr in corr_types:
#     print(f"Calculating {corr} correlation...")

#     df_corr = df[columns_corr].corr(method=corr)
#     plt.figure(figsize=(12, 8))
#     sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .8})
#     plt.title(f"{corr.capitalize()} Correlation Heatmap")
#     #plt.savefig(os.path.join(img_save_dir, f"{corr}corr.png"))
#     plt.show()

In [None]:
# cols = [
#        'koos_a1', 'koos_a2',
#        'koos_a3', 'koos_a4', 'koos_a5', 'koos_a6', 'koos_a7', 'koos_a8',
#        'koos_a9', 'koos_a10', 'koos_a11', 'koos_a12', 'koos_a13', 'koos_a14',
#        'koos_a15', 'koos_a16', 'koos_a17'
#        , 'cluster_label']
# corr_types = ['spearman']
# for corr in corr_types:
#     print(f"Calculating {corr} correlation...")

#     df_corr = df[columns_corr].corr(method=corr)
#     plt.figure(figsize=(12, 8))
#     sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .8})
#     plt.title(f"{corr.capitalize()} Correlation Heatmap")
#     #plt.savefig(os.path.join(img_save_dir, f"{corr}corr.png"))
#     plt.show()

In [None]:
# cols = [
#        'koos_sp1', 'koos_sp2', 'koos_sp3',
#        'koos_sp4', 'koos_sp5', 
#        'koos_q1', 'koos_q2', 'koos_q3', 'koos_q4'
#        , 'cluster_label']
# corr_types = ['spearman']
# for corr in corr_types:
#     print(f"Calculating {corr} correlation...")

#     df_corr = df[columns_corr].corr(method=corr)
#     plt.figure(figsize=(12, 8))
#     sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .8})
#     plt.title(f"{corr.capitalize()} Correlation Heatmap")
#     #plt.savefig(os.path.join(img_save_dir, f"{corr}corr.png"))
#     plt.show()

## Correlation with KL-Score and Pain

### Kruscal Wallis

In [None]:

for feature in [ 'pain', 'age',
       'ce_bmi', 'ce_fm']:
    for c in df["cluster_label"].unique():
        print(f"NaN values: {df[feature].isna().sum()}")
        vals = df.loc[df["cluster_label"]==c, feature].dropna()
        print(f"For Feature {feature}")
        print(f"Cluster {c}: n={len(vals)}, unique={vals.nunique()}, min={vals.min()}, max={vals.max()}")
    print()


In [None]:
columns_corr = [  
 'pain', 'age',
       'ce_bmi', 'ce_fm'
       ] 

results = []
for feature in columns_corr:
       df_wonan = df.dropna(subset=[feature])
       stat, p = kruskal_wallis(df_wonan, feature, cluster_col = 'cluster_label')
       # print(f"Kruskal-Wallis test for {feature}: H-statistic = {stat:.3f}, p-value = {p:.3e}")
       results.append({'feature': feature, 'H-statistic': stat, 'p-value': p})

results_df = pd.DataFrame(results)
# results_df = results_df.sort_values('p-value')

display(results_df.sort_values('p-value').head())
results_df.to_csv(os.path.join(filepath, f"kruskal_wallis_results_{run}.csv"), index=False)

results_df[results_df['p-value'] >= 0.05]
plt.figure(figsize=(10, 6))
sns.barplot(data = results_df, x='feature', y='H-statistic')
plt.xticks(rotation=90)
plt.show()

### Plots

In [None]:
columns_corr =  ['pain', 'KL-Score'] 
barplots(df, y_list=columns_corr, x='cluster_label', hue=None, figsize = (6, 6), savepath=None)

#### Boxplot

In [None]:
def boxplot(
    df: pd.DataFrame,
    y_list: list,
    x: str,
    hue: str | None = None,
    title: str | None = None,
    xlabel: str | None = None,
    ylabel: str | None = None,
    figsize_per_panel=(5.5, 4.0),
    n_cols: int | None = None,
    order: list | None = None,
    hue_order: list | None = None,
    sharex: bool = False,
    sharey: bool = False,
    show_points: bool = True,
    points_alpha: float = 0.35,
    rotate_xticks: int = 30,
    showfliers: bool = False,
    whis: tuple | float = (5, 95),
    tight_rect=(0, 0, 0.92, 0.95),
    savepath: str | None = None,
    filename: str | None = None,
):
    # --- Category order handling (respects CategoricalDtype if present) ---
    if order is None:
        if pd.api.types.is_categorical_dtype(df[x]):
            order = list(df[x].cat.categories)
        else:
            order = list(pd.unique(df[x].dropna()))
    if hue is not None and hue_order is None:
        if pd.api.types.is_categorical_dtype(df[hue]):
            hue_order = list(df[hue].cat.categories)
        else:
            hue_order = list(pd.unique(df[hue].dropna()))

    # --- Grid geometry ---
    n = len(y_list)
    if n_cols is None:
        n_cols = 2 if n <= 4 else 3  # sensible default
    n_rows = math.ceil(n / n_cols)
    fig_w = figsize_per_panel[0] * n_cols
    fig_h = figsize_per_panel[1] * n_rows

    # --- Style (lightweight, readable) ---
    sns.set_context("talk")
    sns.set_style("whitegrid", {"axes.grid": True, "grid.linestyle": "--", "grid.alpha": 0.35})

    fig, axes = plt.subplots(
        n_rows, n_cols, figsize=(fig_w, fig_h), sharex=sharex, sharey=sharey,
        constrained_layout=False
    )
    axes = np.atleast_1d(axes).ravel()

    legend_handles, legend_labels = None, None

    for i, y in enumerate(y_list):
        ax = axes[i]

        # Boxplot
        sns.boxplot(
            data=df, x=x, y=y, hue=hue, order=order, hue_order=hue_order,
            ax=ax, dodge=True, showfliers=showfliers, whis=whis
        )

        # Optional jittered points overlay (helps see sample size & spread)
        if show_points:
            # stripplot is faster / less overplotty than swarm for big n
            sns.stripplot(
                data=df, x=x, y=y, hue=hue, order=order, hue_order=hue_order,
                ax=ax, dodge=True if hue else False, alpha=points_alpha, jitter=0.18,
                linewidth=0
            )

        # Collect legend once (we'll add a single figure legend)
        if hue and legend_handles is None:
            legend_handles, legend_labels = ax.get_legend_handles_labels()

        # Clean up duplicate legends in each subplot
        if hue:
            ax.legend_.remove()

        # Labels & ticks
        ax.set_title(f"{y} by {x}", fontsize=12)
        ax.set_xlabel(xlabel if xlabel else x, fontsize=10)
        ax.set_ylabel(ylabel if ylabel else y, fontsize=10)
        ax.tick_params(axis="x", rotation=rotate_xticks)

        # A bit of visual polish
        sns.despine(ax=ax, left=False, bottom=False)

    # Remove any unused axes
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    # Global title
    if title:
        fig.suptitle(title, fontsize=14)

    # Single shared legend (if hue)
    if hue and legend_handles:
        fig.legend(legend_handles[:len(hue_order) if hue_order else None],
                   legend_labels[:len(hue_order) if hue_order else None],
                   loc="center left", bbox_to_anchor=(0.99, 0.5), frameon=False, title=hue)

    plt.tight_layout(rect=tight_rect)

    # Saving
    if savepath is not None:
        os.makedirs(savepath, exist_ok=True)
        if filename is None:
            base = f"box_{x}_vs_{len(y_list)}y"
            if hue:
                base += f"_by_{hue}"
            filename = base + ".png"
        fig.savefig(os.path.join(savepath, filename), dpi=160, bbox_inches="tight")

    return fig, axes

In [None]:
columns_corr = [  
       'oks_q1', 'oks_q2', 'oks_q3', 'oks_q4',
       'oks_q5', 'oks_q6', 'oks_q7', 'oks_q8', 'oks_q9', 'oks_q10', 'oks_q11',
       'oks_q12', 'koos_s1', 
       'koos_s2', 'koos_s3', 'koos_s4', 'koos_s5', 'koos_s6',
       'koos_s7', 'koos_p1', 'koos_p2', 'koos_p3', 'koos_p4', 'koos_p5',
       'koos_p6', 'koos_p7', 'koos_p8', 'koos_p9', 'koos_a1', 'koos_a2',
       'koos_a3', 'koos_a4', 'koos_a5', 'koos_a6', 'koos_a7', 'koos_a8',
       'koos_a9', 'koos_a10', 'koos_a11', 'koos_a12', 'koos_a13', 'koos_a14',
       'koos_a15', 'koos_a16', 'koos_a17',  'koos_sp1', 'koos_sp2', 'koos_sp3',
       'koos_sp4', 'koos_sp5', 
       'koos_q1', 'koos_q2', 'koos_q3', 'koos_q4', 'pain', 'KL-Score', 'age', 'ce_bmi', 'ce_fm'
       #, 'cluster_label'
       ] 


boxplot(df, y_list=columns_corr, x='cluster_label', hue=None, 
         savepath=None)

In [None]:
barplots(df, y_list=columns_corr, x='cluster_label', hue='gender', figsize = (6, 6), savepath=None)

# KL-Score Visualization

In [None]:
labels = list(df['cluster_label'].unique())
labels.sort()

for i in labels:
    df_temp = df.copy()
    df_temp = df_temp[df_temp['cluster_label']==i]
    print(f"For label {i}:")
    display(df_temp['KL-Score'].value_counts().reset_index().sort_values(by="KL-Score"))
    print()

In [None]:
k=2

## Cluster 0

### KL score 0

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=0, klscore=0, img_path=img_path, k=k)

### KL-Score 3

In [None]:
img = check_img_resp_cluster_klscore(df, cluster_label=0, klscore=3, img_path=img_path)

## Cluster 1

### KL score 0

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=1, klscore=0, img_path=img_path, k=k)

### KL-Score 1

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=1, klscore=1, img_path=img_path, k=2)

## Cluster 2

### KL score 0

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=2, klscore=0, img_path=img_path, k=k)

### KL-Score 3

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=2, klscore=3, img_path=img_path, k=k)

## Cluster 3

### KL-Score 0

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=3, klscore=0, img_path=img_path, k=k)

### KL-Score 4

In [None]:
_ = check_img_resp_cluster_klscore(df, cluster_label=3, klscore=4, img_path=img_path, k=k)

# Get examples

In [None]:
# # Give me a df showing which clusters have which min max KL-Score
# kl_diffs = df_merged.groupby('cluster_label')['KL-Score'].agg(['min', 'max'])
# kl_diffs.sort_values('max', ascending=False, inplace=True)

In [None]:
# kl_diffs

In [None]:
# #for each cluster label give me 2 examples with different KL-Score

# clusters = df_merged['cluster_label'].unique()
# sorted_clusters = sorted(clusters)
# for cluster in sorted_clusters:
#     print(f"Cluster {cluster}:")
#     cluster_df = df_merged[df_merged['cluster_label'] == cluster]
    
#     if len(cluster_df) > 0:
#         for kl_score in cluster_df['KL-Score'].unique():
#             subset = cluster_df[cluster_df['KL-Score'] == kl_score]
#             if len(subset) >= 2:
#                 examples = subset.sample(n=2, random_state=42)
#                 print(f"  KL-Score {kl_score}:")
#                 display(examples[['name', 'id', 'KL-Score', 'pain', 'age', 'ce_bmi', 'ce_fm']])
#             else:
#                 examples = subset
#                 print(f"  KL-Score {kl_score}:")
#                 display(examples[['name', 'id', 'KL-Score', 'pain', 'age', 'ce_bmi', 'ce_fm']])
#     else:
#         print("  No data available for this cluster.")
