## Compute and plot cerebellar volumes


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
import json

### Load data (paths are specified in a local config file)

In [None]:
local_config_f = Path('../../local_config.json')
if local_config_f.exists():
    with open(local_config_f) as f:
        local_config = json.load(f)
else:
    print(f'Specify a local_config.json with path to nipoppy DATASET_DIR')

print('local_config:', local_config)


dx_color_palette = local_config['plot_styles']['DX_COLOR_PALETTE']
palette = [dx_color_palette["PD"], dx_color_palette["control"]]

sns.palplot(palette)

In [None]:
dataset_dir = local_config['DATASET_DIR']
current_release = local_config['DATASET_RELEASE']

pipeline = "maget_brain"
pipeline_version = "1.0.0"
session = "ses-01"

# Current nipoppy manifest
manifest_csv = f"{dataset_dir}/manifest.csv"

# tabular data
tabular_dir = f"{dataset_dir}/tabular/"

# demographics
demographics_csv = f"{tabular_dir}/demographics.csv"

# Dx
dx_csv = f"{tabular_dir}/assessments/diagnosis.csv"

mri_sessions_csv = f"{tabular_dir}/mri_info/mri_sessions.csv"

# derivative data
derivatives_dir = f"{dataset_dir}/derivatives/"

# IDPs
idp_dir = f"{derivatives_dir}/{pipeline}/{pipeline_version}/idp/"

CB_vol_csv = f"{idp_dir}/{session}/cerebellar_volumes.csv"

# save dirs
results_dir = f"{dataset_dir}/results/"
figs_dir = f"{results_dir}ses-01/cerebellum/figs/"

print('results_dir:', results_dir)
print('figs_dir:', figs_dir,)

### Dx data

In [None]:
dx_df = pd.read_csv(dx_csv)
dx_df = dx_df[dx_df["redcap_event_name"] == "Baseline (Arm 1: C-OPN)"]

control_participants = dx_df[dx_df["diagnosis_group_for_analysis"] == "control"]["participant_id"].unique()
PD_participants = dx_df[dx_df["diagnosis_group_for_analysis"] == "PD"]["participant_id"].unique()

all_participants = list(control_participants) + list(PD_participants)

print(f"PD + control: {len(all_participants)}")
print(f"Control: {len(control_participants)}")
print(f"PD: {len(PD_participants)}")

dx_df.head()

In [None]:
def merge_CB_dx_df(CB_df, dx_df, cerebellar_lobules):
    """Merge cerebellar volumes with demographics and/or diagnosis. Then split into left and right hemispheres"""
    left_lobules = list("L_" + pd.Series(cerebellar_lobules))
    right_lobules = list("R_" + pd.Series(cerebellar_lobules))
    lh_lobule_dict = dict(zip(left_lobules,cerebellar_lobules))
    rh_lobule_dict = dict(zip(right_lobules,cerebellar_lobules))

    if "Subject" in CB_df.columns:
        CB_df["bids_id"] = CB_df["Subject"].str.rsplit("/", n=1, expand=True)[1].str.split("_", expand=True)[0]
        CB_df["participant_id"] = CB_df["bids_id"].str.split("-", expand=True)[1]
        CB_demo_df = CB_demo_df.drop(columns=["Subject"])

    # merge with demo
    CB_demo_df = pd.merge(CB_df, dx_df, on="participant_id", how="inner")
    
    # split into left and right hemispheres
    demo_cols = list(set(CB_demo_df.columns) - set(left_lobules) - set(right_lobules))

    left_CB_df = CB_demo_df[demo_cols + left_lobules].copy()
    left_CB_df["hemi"] = "left"
    right_CB_df = CB_demo_df[demo_cols + right_lobules].copy()
    right_CB_df["hemi"] = "right"

    # add total CB vol column
    left_CB_df["total_CB_vol"] = left_CB_df[left_lobules].sum(axis=1)
    right_CB_df["total_CB_vol"] = right_CB_df[right_lobules].sum(axis=1)

    # rename columns to stack vertically
    left_CB_df = left_CB_df.rename(columns=lh_lobule_dict)
    right_CB_df = right_CB_df.rename(columns=rh_lobule_dict)
    
    CB_demo_df = pd.concat([left_CB_df, right_CB_df], axis=0)

    return CB_demo_df

### Get merged cerebellar volumes + demo df

In [None]:
cerebellar_lobules = ['I_II', 'III', 'IV', 'V', 'VI', 'Crus_I', 'Crus_II','VIIB', 'VIIIA', 'VIIIB', 'IX', 'X', 'CM']

CB_vols_df = pd.read_csv(f"{CB_vol_csv}")
if "Subject" in CB_vols_df.columns:
    n_CB_participants = CB_vols_df["Subject"].nunique()
else:
    n_CB_participants = CB_vols_df["participant_id"].nunique()

print(f"n_CB_participants: {n_CB_participants}")

startification_col = "diagnosis_group_for_analysis"
demo_cols = ["participant_id", startification_col]

# filter dx group for analysis
dx_df = dx_df[dx_df[startification_col].isin(["control", "PD"])][demo_cols]
participants_per_group = dx_df.groupby([startification_col])["participant_id"].nunique()
print(f"participants per group: {participants_per_group}")

CB_demo_vols_df = merge_CB_dx_df(CB_vols_df, dx_df[demo_cols], cerebellar_lobules)
CB_demo_vols_df.head()

### Remove outliers
 This is structure specific (need to be QCed visually)

In [None]:
remove_outliers = True

min_vol_thresh_list = [25, 400, 900, 2000, 4500, 6000, 4000, 2000, 2000, 1500, 1500, 200, 4000, 32000]
max_vol_thresh_list = [200, 1300, 3500, 6000, 11000, 18000, 13000, 7000, 8000, 5000, 6000, 800, 13000, 80000]
outlier_min_thesh_dict = dict(zip(cerebellar_lobules, min_vol_thresh_list))
outlier_max_thesh_dict = dict(zip(cerebellar_lobules, max_vol_thresh_list))

if remove_outliers:

    print("Removing outliers")
    for roi, thresh in outlier_min_thesh_dict.items():
        n_participants = CB_demo_vols_df["participant_id"].nunique()
        print(f"roi: {roi}, n_participants: {n_participants}")
        CB_demo_vols_df = CB_demo_vols_df[CB_demo_vols_df[roi] > thresh].copy()
        n_participants = CB_demo_vols_df["participant_id"].nunique()
        print(f"n_participants after outlier removal: {n_participants}")

    for roi, thresh in outlier_max_thesh_dict.items():
        n_participants = CB_demo_vols_df["participant_id"].nunique()
        print(f"roi: {roi}, n_participants: {n_participants}")
        CB_demo_vols_df = CB_demo_vols_df[CB_demo_vols_df[roi] < thresh].copy()
        n_participants = CB_demo_vols_df["participant_id"].nunique()
        print(f"n_participants after outlier removal: {n_participants}")


### Plots

In [None]:
save_fig = True

CB_vol_df_melt = CB_demo_vols_df.melt(
    id_vars=demo_cols + ["bids_id", "hemi"],
    var_name="ROI", 
    value_name="volume",
)

plot_df = CB_vol_df_melt.copy()
plot_df = plot_df.rename(columns={"diagnosis_group_for_analysis": "group"})

n_participants = plot_df["participant_id"].nunique()
print(f"n_participants: {n_participants}")

sns.set_theme(font_scale=4)
with sns.axes_style("whitegrid"):
    g = sns.catplot(y="volume",x="hemi", hue="group", col="ROI",kind="box", col_wrap=7, #col_order=hemi_roi_list,
    palette=palette, data=plot_df, aspect=1, height=10, sharey=False)
    # g.tick_params(axis='x', rotation=90, labelsize=14)
    
if save_fig:
    g.savefig(f"{figs_dir}/CB_vol.png")