In [None]:
import config
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display
import datetime
from utils.data_exploration_utils import drop_unnamedcolumn,  investigate_data, plot_hist, scatterplot, missing_from_df

In [None]:
today = datetime.date.today()

base_dir = config.RAW_DATA_PATH
proc_dir = config.PROC_DATA_PATH

folder = "2025-08-11_data_exploration"
df_filename = "inmodi_data_questionnaire_kl_woSC.csv"

smote_types = ['SMOTE', 'DBSMOTE', 'Borderline_SMOTE1', 'Borderline_SMOTE2']
smote_type_id = 3
smote_type = smote_types[smote_type_id]

if folder is not None:
    save_dir = os.path.join(proc_dir, folder)
    save_dir2 = save_dir
else:
    save_dir = os.path.join(proc_dir, f"{today}_data_exploration")
    save_dir2 = os.path.join(proc_dir, '2025-07-14_data_exploration')

os.makedirs(save_dir, exist_ok=True)

df = pd.read_csv(os.path.join(save_dir, df_filename))
smote = pd.read_csv(os.path.join(save_dir, f'smote_oversampled_data_{smote_type}.csv'))

In [None]:
print(save_dir)

In [None]:
kl = pd.read_csv(os.path.join(base_dir, 'brul_knee_annotations.csv'))

In [None]:
kl['record_id'] = kl['name'].str.split('_').str[0]
kl['leg'] = kl['name'].str.split('_').str[2]
kl['visit'] = kl['name'].str.split('_').str[1]

In [None]:
len(kl['record_id'].unique())

In [None]:
len(kl[kl['visit']=='1']['record_id'].unique())

In [None]:
len(kl[kl['visit']=='2']['record_id'].unique())

In [None]:
len(kl)

In [None]:
kl[['record_id', 'visit']].value_counts().reset_index().sort_values(by=['count'])

In [None]:
kl[kl['record_id'].isin(['IM3009', 'IM3012', 'IM3002'])]

In [None]:
dfkl = df.merge(kl[['name', 'KL-Score']], on='name', how='left')

In [None]:
len(dfkl)

In [None]:
# some df preprocessing
df['is_male'] = df['gender'].apply(lambda x: 1 if x=='male' else 0)

cols = [
    'pain', 'age',
    'ce_bmi', 'ce_fm',  'koos_s1', 
       'koos_s2', 'koos_s3', 'koos_s4', 'koos_s5', 'koos_s6',
       'koos_s7', 
       'koos_p1', 'koos_p2', 'koos_p3', 'koos_p4', 'koos_p5',
       'koos_p6', 'koos_p7', 'koos_p8', 'koos_p9', 
       'koos_a1', 'koos_a2',
       'koos_a3', 'koos_a4', 'koos_a5', 'koos_a6', 'koos_a7', 'koos_a8',
       'koos_a9', 'koos_a10', 'koos_a11', 'koos_a12', 'koos_a13', 'koos_a14',
       'koos_a15', 'koos_a16', 'koos_a17', 
       'koos_sp1', 'koos_sp2', 'koos_sp3',
       'koos_sp4', 'koos_sp5', 
       'koos_q1', 'koos_q2', 'koos_q3', 'koos_q4',
       'oks_q1', 'oks_q2', 'oks_q3', 'oks_q4',
       'oks_q5', 'oks_q6', 'oks_q7', 'oks_q8', 'oks_q9', 'oks_q10', 'oks_q11',
       'oks_q12', 
       'is_male'
       ]

df = df.dropna(axis=0, how='any', subset=cols)

In [None]:
print(df.shape)
print(smote.shape)

# KL-Score Overall Distribution

In [None]:
display(df['KL-Score'].value_counts().reset_index())

display(smote['KL-Score'].value_counts().reset_index())

In [None]:
smote.columns

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.2)
plt.figure(figsize=(8, 5))

ax = sns.countplot(data=df, x='KL-Score', palette='Set3')
for container in ax.containers:
    ax.bar_label(container, padding=3)

plt.xlabel("KL-Score")
plt.ylabel("Count")
plt.title("Distribution of KL-Scores")

plt.show()

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.2)
plt.figure(figsize=(8, 5))

ax = sns.countplot(data=smote, x='KL-Score', palette='Set3')
for container in ax.containers:
    ax.bar_label(container, padding=3)

plt.xlabel("KL-Score")
plt.ylabel("Count")
plt.title("Distribution of KL-Scores")

plt.show()

# Histograms

## PI

In [None]:
def plot_hist(smote, column, colname, title = None, xlabel = None, y_label = "Frequency", stat = 'frequency', figsize=(10, 6), hue= None, multiple='dodge', bins = 30, kde=False):
    sns.set_theme(style="whitegrid", font_scale=1.2)
    plt.figure(figsize=figsize)
    if hue is None:
        sns.histplot(smote[column], bins=bins, stat=stat, kde=kde)
    else:
        sns.histplot(data=df, x = column, bins=bins, stat=stat, hue = hue, multiple=multiple, kde=kde)
    plt.title(title if title else f"Distribution of {colname}")
    plt.xlabel(xlabel if xlabel else colname)
    plt.ylabel(y_label)
    plt.show()

In [None]:
lcols = ['pain', 'age',
    'ce_bmi', 'ce_fm', 'is_male', 'KL-Score']

namecol = ['Pain', 'Age',
    'BMI', 'Body Fat Percentage', 'Sex', 'KL-Score']
for i, col in enumerate(lcols):
    if col != 'is_male':
        plot_hist(smote, col, colname=namecol[i], figsize=(10, 6), stat='density', y_label = 'Density', bins=10, kde=True)
    elif col == 'is_male':
        plot_hist(smote, col, colname=namecol[i], figsize=(10, 6), stat='density', y_label = 'Density', bins=2, kde=False)

In [None]:
smote['gender'] = smote['is_male'].apply(lambda x: 'male' if x > 0.5 else 'female')

In [None]:
smote['gender'].value_counts()

In [None]:
# lcols = ['pain', 'age',
#     'ce_bmi', 'ce_fm', 'is_male', 'KL-Score']

# for col in lcols:
#     plot_hist(df, col, title=f"Org. Data {col}")
#     plot_hist(smote, col, title=f"SMOTE Data {col}")

# Get original dataset

In [None]:
feature_groups = {
    "pi": ['pain', 'age', 'ce_bmi', 'ce_fm'],
    "koos": [f"koos_s{i}" for i in range(1, 8)] +
             [f"koos_p{i}" for i in range(1, 10)] +
             [f"koos_a{i}" for i in range(1, 18)] +
             [f"koos_sp{i}" for i in range(1, 6)] +
             [f"koos_q{i}" for i in range(1, 5)],
    "oks": [f"oks_q{i}" for i in range(1, 13)],
    "gender": ['gender']
}
flags = {"pi": True, "koos": True, 
             "oks": True, "gender": True}
cols = [col for key, active in flags.items() if active for col in feature_groups[key]]
cols += ['name', 'KL-Score']

In [None]:
folder = "2025-08-11_data_exploration"
df_filename = "inmodi_data_questionnaire_kl_woSC.csv"

org = pd.read_csv(os.path.join(proc_dir, folder, df_filename))

# org = org[cols].copy()


print("Dataframe before dropping NaN values: ", org.shape)
org = org.dropna(axis=0, how='any')

print()
print("Dataframe after dropping NaN values: ", org.shape)


In [None]:
org.columns

## Questionnaire Aggregation

In [None]:
koos_adl_cols = [col for col in df.columns if col.startswith("koos_a")]
koos_pain_cols = [col for col in df.columns if col.startswith("koos_p")]
koos_sport_cols = [col for col in df.columns if col.startswith("koos_sp")]
koos_symptoms_cols = [col for col in df.columns if (col.startswith("koos_s") and not col.startswith("koos_sp"))]
koos_qol_cols = [col for col in df.columns if col.startswith("koos_q")]
oks_cols = [col for col in df.columns if col.startswith("oks_")]

# Compute mean of raw items
x = smote[koos_adl_cols].mean(axis=1)
y = smote[koos_pain_cols].mean(axis=1)
z = smote[koos_sport_cols].mean(axis=1)
w = smote[koos_symptoms_cols].mean(axis=1)
v = smote[koos_qol_cols].mean(axis=1)
u = smote[oks_cols].sum(axis=1)

# Convert to KOOS 0â€“100 score
smote["KOOS_adl"] = 100 * (4 - x) / 4
smote["KOOS_pain"] = 100 * (4 - y) / 4
smote["KOOS_sport"] = 100 * (4 - z) / 4
smote["KOOS_symptoms"] = 100 * (4 - w) / 4
smote["KOOS_qol"] = 100 * (4 - v) / 4
smote["OKS_score"] = u

In [None]:
smote.columns

In [None]:
koos_col =  ['KOOS_adl', 'KOOS_pain', 'KOOS_sport', 'KOOS_symptoms', 'KOOS_qol', 'OKS_score']

display(smote[koos_col].iloc[10:25])

In [None]:
org.describe()

In [None]:
smote.describe()

In [None]:
org[koos_col].iloc[10:25]

# Boxplots

## To KL-Score

In [None]:
cols = ['pain', 'age', 'ce_bmi', 'ce_fm', 'gender']
col_names = ['Pain', 'Age', 'BMI', 'Body Fat Percentage', 'Sex']

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.2)
for i, col in enumerate(cols):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    sns.boxplot(x='KL-Score', y=col, data=df, palette='Set3')
    plt.title(f'Org. Data {col_names[i]} by KL-Score')
    plt.subplot(1, 2, 2)
    sns.boxplot(x='KL-Score', y=col, data=smote, palette='Set3')
    plt.title(f'SMOTE Data {col_names[i]} by KL-Score')
    plt.show()

In [None]:
cols = ['OKS_score', 
       'KOOS_pain', 'KOOS_symptoms', 'KOOS_sport', 'KOOS_adl',
       'KOOS_qol']


In [None]:
sns.set_theme(style="whitegrid", font_scale=1.2)
for col in cols:
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    sns.boxplot(x='KL-Score', y=col, data=df, palette='Set3')
    plt.title(f'Org. Data {col} by KL-Score')
    plt.subplot(1, 2, 2)
    sns.boxplot(x='KL-Score', y=col, data=smote, palette='Set3')
    plt.title(f'SMOTE Data {col} by KL-Score')
    plt.show()

In [None]:
df[df['OKS_score']<15][['record_id', 'KL-Score']+ cols] #['KL-Score']

In [None]:
df[df['record_id']=='IM1545']['KL-Score']

## To MRI_data

In [None]:
mri = pd.read_csv(os.path.join(base_dir, '2025-09-25_mrismall.csv'))

In [None]:
mri.head()

In [None]:
df = df.merge(mri, how='left', left_on='name', right_on = 'id', suffixes=('', '_mri'))

In [None]:
df[df['mri_cart_yn'].isna()]['KL-Score']

In [None]:
df[df['mri_cart_yn']==1]['pain'].value_counts()

In [None]:
import math
def plot_mri_grid(df, mri_colnames, cols, col_names, mri_names, ncols=2):
    """
    df: dataframe
    mri_colnames: list of MRI binary columns (x-axis)
    cols: list of numeric columns to plot (y-axis)
    col_names: pretty names for cols
    mri_names: pretty names for mri_colnames
    ncols: number of columns in the subplot grid (default 2)
    """
    sns.set_theme(style="whitegrid", font_scale=1.2)
    
    nplots = len(cols)
    nrows = math.ceil(nplots / ncols)

    for y, mcol in enumerate(mri_colnames):

        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5*nrows))
        axes = axes.flatten()  # so you can index them linearly

        for i, col in enumerate(cols):
            ax = axes[i]
            sns.boxplot(x=mcol, y=col, data=df, palette='Set3', ax=ax)
            ax.set_title(f"{col_names[i]} by {mri_names[y]}")
            ax.set_xlabel(mri_names[y])
            ax.set_ylabel(col_names[i])

        # hide unused axes in case plots < grid size
        for j in range(i+1, len(axes)):
            axes[j].set_visible(False)

        plt.tight_layout()
        plt.show()

In [None]:
cols = ['pain', 'age', 'ce_bmi', 'ce_fm']
col_names = ['Pain', 'Age', 'BMI', 'Body Fat Percentage']

mri_colnames = ['mri_cart_yn', 'mri_osteo_yn', 'mri_bml_yn']
mri_names = ['Cartilage Loss', 'Osteophytes', 'Bone Marrow Lesions']
plot_mri_grid(df, mri_colnames, cols, col_names, mri_names, ncols=2)


In [None]:
# gender => sex

In [None]:
def plot_cols_by_mri(df, cols, col_names, mri_colnames, mri_names, ncols=3
                     , savepath = None):
    """
    df: dataframe
    cols: list of y variables (numeric cols)
    col_names: pretty names for cols
    mri_colnames: list of x variables (binary MRI cols)
    mri_names: pretty names for MRI columns
    ncols: number of subplot columns in the grid (default = 3)
    """

    nplots = len(mri_colnames)
    nrows = math.ceil(nplots / ncols)

    # loop over each numerical col (y-variable)
    for idx, col in enumerate(cols):

        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 4*nrows))
        axes = axes.flatten()

        # loop over MRI columns (x-variable)
        for i, mcol in enumerate(mri_colnames):
            ax = axes[i]
            sns.boxplot(x=mcol, y=col, data=df, palette='Set3', ax=ax)
            ax.set_title(f"{col_names[idx]} by {mri_names[i]}")
            ax.set_xlabel(mri_names[i])
            ax.set_ylabel(col_names[idx])

        # hide unused axes
        for j in range(i + 1, len(axes)):
            axes[j].set_visible(False)

        fig.suptitle(f"Distribution of {col_names[idx]} across MRI Indicators", fontsize=16)
        plt.tight_layout()
        if savepath:
            plt.savefig(os.path.join(savepath, f"{col}_by_mri.png"))
        plt.show()

In [None]:
cols = ['OKS_score', 
       'KOOS_pain', 'KOOS_symptoms', 'KOOS_sport', 'KOOS_adl',
       'KOOS_qol']
mri_colnames = ['mri_cart_yn', 'mri_osteo_yn', 'mri_bml_yn']
mri_names = ['Cartilage Loss', 'Osteophytes', 'Bone Marrow Lesions']
plot_cols_by_mri(df, cols, cols, mri_colnames, mri_names, ncols=3, savepath = os.path.join(proc_dir, 'outputs', 'figures_dataexploration'))

In [None]:
cols = ['pain', 'age', 'ce_bmi', 'ce_fm']
col_names = ['Pain', 'Age', 'BMI', 'Body Fat Percentage']

mri_colnames = ['mri_cart_yn', 'mri_osteo_yn', 'mri_bml_yn']
mri_names = ['Cartilage Loss', 'Osteophytes', 'Bone Marrow Lesions']
plot_cols_by_mri(df, cols, col_names, mri_colnames, mri_names, ncols=3, savepath = os.path.join(proc_dir, 'outputs', 'figures_dataexploration'))

In [None]:
plt.figure(figsize=(20, 6))

handles, labels = None, None

for i, col in enumerate(mri_colnames):
    ax = plt.subplot(1, 3, i+1)
    sns.countplot(data=df, x=col, hue='gender', palette='Set3', stat='percent', ax=ax)
    ax.set_title(f"Sex Distribution by {mri_names[i]}")

    # only capture legend if it exists
    if ax.get_legend() is not None:
        handles, labels = ax.get_legend_handles_labels()
        ax.get_legend().remove()

# Add a global legend only if we found any labels
if handles is not None and labels is not None:
    plt.legend(handles, labels, 
               loc='center right', 
               bbox_to_anchor=(1.05, 0.5),
               title='Gender')
else:
    print("No legend labels were found. Check if 'gender' has more than 1 category.")

plt.tight_layout()
plt.show()


# Pairplots

## KL-Score

In [None]:
col_cat = ['KL-Score']
col_num = ['pain', 'age',
       'ce_bmi', 'ce_fm']
cols = col_cat + col_num


sns.pairplot(df[cols], hue = col_cat[0])
sns.pairplot(smote[cols], hue = col_cat[0])
