## Demographic tables

### Imports

In [None]:
import json
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

### local dataset configs

In [None]:
ds_config_path = Path("../../PD-challenge-ds-store/ds_config.json")
with open(ds_config_path, "r") as f:
    ds_config = json.load(f)
    ds_store_path = Path(ds_config["ds_store_path"])
    data_sources = ds_config["data_sources"]

PD_challenge_datasets = data_sources.keys()
print("PD_challenge_datasets:", PD_challenge_datasets)

In [None]:
data_sources["qpn"]

### Demographic tables

In [None]:
def load_tabular_data(data_sources, ds_name):
    tabular_data = data_sources[ds_name]["tabular_data"]
    # print("Tabular files:", tabular_data)

    index_columns = data_sources[ds_name]["index_columns"]
    # print("Index columns:", index_columns)

    # merge all tabular files
    demographics_df = pd.DataFrame()
    for table_name in tabular_data.keys():
        table_info = tabular_data[table_name]
        table_file_path = ds_store_path / ds_name / table_info["file_name"]
        table_column_map = table_info.get("column_map", {})
        table_value_map = table_info.get("value_map", {})
        # invert the column map to get original column names
        inverted_column_map = {v: k for k, v in table_column_map.items()}
        # read the table
        df = pd.read_csv(table_file_path)[index_columns + list(table_column_map.values())]
        # rename columns to original names
        df = df.rename(columns=inverted_column_map)

        # map values if value_map is provided
        for col, val_map in table_value_map.items():
            if col in df.columns:
                df[col] = df[col].map(val_map).fillna(df[col])  
                
        # merge with demographics_df
        if demographics_df.empty:
            demographics_df = df
        else:
            demographics_df = pd.merge(demographics_df, df, on=index_columns, how='inner')
            
    return demographics_df

In [None]:
ds_list = ["qpn", "calgary", "nimhans-serb", "nimhans-metal","nimhans-ylo"]

demographics_df = pd.DataFrame()
for ds in ds_list:
    print(f"\nLoading demographics for dataset: {ds}")
    _df = load_tabular_data(data_sources, ds)
    _df["dataset"] = ds  # add a column to indicate the dataset
    demographics_df = pd.concat([demographics_df, _df], ignore_index=True)


group_counts = demographics_df["group"].value_counts()
print(f"\nGroup counts across all datasets:\n{group_counts}\n")


demographics_df.head()

In [None]:
## QC value distributions
possible_age_range = (40, 100)
invalid_age_entries = demographics_df[(demographics_df["age"] < possible_age_range[0]) | 
                                       (demographics_df["age"] > possible_age_range[1])]
print("Invalid age entries:")
print(invalid_age_entries)  

# drop invalid age entries
demographics_df = demographics_df[(demographics_df["age"] >= possible_age_range[0]) & 
                                  (demographics_df["age"] <= possible_age_range[1])]


### Plots

In [None]:
from enum import Enum

# Poster colors
class ds_colors(Enum):
    NIM_SERB_CONTROL =  "#B5E48C"
    NIM_SERB_PD =       "#76C893"
    NIM_METAL_CONTROL = "#34A0A4"
    NIM_METAL_PD =      "#1E6091"
    NIM_YLO_CONTROL =   "#003566"
    NIM_YLO_PD =        "#001d3d"

    QPN_CONTROL =       "#ffb627"
    QPN_PD =            "#e2711d"

    CALGARY_CONTROL =   "#FF758F"
    CALGARY_PD=         "#C9184A"

    

ds_color_list = [  ds_colors.NIM_SERB_CONTROL.value, ds_colors.NIM_SERB_PD.value,
                ds_colors.NIM_METAL_CONTROL.value, ds_colors.NIM_METAL_PD.value,
                ds_colors.NIM_YLO_CONTROL.value, ds_colors.NIM_YLO_PD.value,
                ds_colors.QPN_CONTROL.value, ds_colors.QPN_PD.value,             
                ds_colors.CALGARY_CONTROL.value, ds_colors.CALGARY_PD.value,          
              ]

ds_palette = sns.color_palette(palette=[ds_colors.QPN_PD.value, ds_colors.QPN_CONTROL.value])


group_color_list = ["#ffb627", "#e2711d"]  
group_palette = sns.color_palette(palette=group_color_list)
sns.palplot(group_palette)

In [None]:
# Prepare and plot sex counts and age distributions

try:
    df = demographics_df.copy()
    df["ds_group"] = df["dataset"] + df["group"]
except NameError:
    raise NameError("demographics_df not found. Run the earlier cells that build demographics_df before running this cell.")

# Ensure age column exists and is numeric
if 'age' not in df.columns:
    print("No 'age' column found. Age plots will be empty.")
    df['age'] = np.nan
else:
    df['age'] = pd.to_numeric(df['age'], errors='coerce')

# Order datasets by sample size for plotting readability
dataset_order = ds_list 

sns.set_style(style='whitegrid')

fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True, )

# Left: sex counts per dataset (stacked/grouped)
ax = axes[0]
try:
    sns.countplot(data=df, y='dataset', hue='group', order=dataset_order, ax=ax, palette=group_palette)
    ax.set_title('Group counts per dataset', fontsize=24)
    ax.set_ylabel('Dataset', fontsize=18)
    ax.set_xlabel('Count', fontsize=18)
    ax.legend(title='Group')
    for label in ax.get_yticklabels():
        label.set_rotation(45)
except Exception as e:
    ax.text(0.5, 0.5, f'Could not plot age counts:\n{e}', ha='center')
    ax.set_axis_off()

# despine
sns.despine(top=True, right=True, left=False, bottom=False)
# Right: age distribution by dataset and sex â€” use boxplot with swarm overlay for detail
ax = axes[1]
try:
    # Use df_age which has filtered ages
    # sns.boxplot(data=df, y='dataset', x='age', hue='group', order=dataset_order, ax=ax, showfliers=False, palette=palette)
    sns.violinplot(data=df, y='dataset', x='age', hue='group', order=dataset_order, ax=ax, palette=group_palette, legend=False, split=True )
    # overlay a stripplot (jitter) for distribution detail; limit points to a sample for performance
    # sample_df = df.sample(n=min(len(df), 1000), random_state=0)
    # sns.stripplot(data=sample_df, y='dataset', x='age', hue='group', order=dataset_order, ax=ax, dodge=True, linewidth=0.5, alpha=0.5, palette=palette, legend=False)
    ax.set_title('Age distribution by dataset and group', fontsize=24)
    ax.set_ylabel('Dataset', fontsize=18)
    ax.set_xlabel('Age', fontsize=18)
    # Remove duplicate legends (we used hue twice)
    handles, labels = ax.get_legend_handles_labels()
    # Only keep the first set of legend entries
    if len(handles) > 0:
        ax.legend(handles[:len(set(df['group']))], labels[:len(set(df['group']))], title='group')
    for label in ax.get_xticklabels():
        label.set_rotation(45)
except Exception as e:
    ax.text(0.5, 0.5, f'Could not plot age distribution:\n{e}', ha='center')
    ax.set_axis_off()

sns.despine(top=True, right=True, left=True, bottom=False)
plt.tight_layout()
plt.show()