# Load in data and configure environment

In [1]:
import pandas as pd
import os
from os.path import join as ospj
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import glob
import json
from tqdm import tqdm

# Make Georgia the font for all plots
plt.rcParams['font.family'] = 'Georgia'

# Specify WM and GM atlases
wm_atlas = "HCP1065"
gm_atlas = "4S156"

# Define directory to GAM outputs
gam_outputs_wm_dir = f"/mnt/sauce/littlab/users/mjaskir/structural_tractometry/derivatives/gam/outputs/tracts/{wm_atlas}"
gam_outputs_gm_dir = f"/mnt/sauce/littlab/users/mjaskir/structural_tractometry/derivatives/gam/outputs/regions/{gm_atlas}"

# Define directory to save z-score summaries
gam_outputs_group_summaries_dir = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/derivatives/gam/outputs/"
if not os.path.exists(gam_outputs_group_summaries_dir):
    os.makedirs(gam_outputs_group_summaries_dir)

# Measures files
measures_json_path = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/data/metadata/scalar_labels_to_filenames.json"
colors_json_path = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/data/metadata/scalar_labels_to_colors.json"
human_labels_json_path = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/data/metadata/scalar_labels_to_human.json"

# Load clinical data
clinical_file = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/derivatives/metadata/clinical_penn_epilepsy_qsirecon.csv"
clinical_df = pd.read_csv(clinical_file)

## WM ROIs and metadata

In [2]:
# Load wm metadata (Fields: label, name, hemi, type)
wm_metadata_file = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/data/atlases/HCP1065/HCP1065_tract_metadata.csv"
wm_metadata_df = pd.read_csv(wm_metadata_file)

# Remove cranial nerve tracts
wm_metadata_df = wm_metadata_df[~wm_metadata_df["label"].str.startswith("CN")]

# Remove any bilateral tracts
wm_metadata_df = wm_metadata_df[wm_metadata_df["hemi"] != "bilateral"]

# Remove any cerebllar tracts
# wm_metadata_df = wm_metadata_df[wm_metadata_df["type"] != "cerebellar"]

wm_rois = wm_metadata_df["label"].tolist()
global wm_metadata_df

## GM ROIs and metadata

In [3]:
# Load gm metadata (Fields: index, label, network_label, network_label_17network)
# Reminder: subcortical, thalamic, and cerebellar regions do not have network assignments
gm_metadata_file = "/mnt/sauce/littlab/users/mjaskir/structural_tractometry/data/atlases/4S/atlas-4S156Parcels_dseg.tsv"
gm_metadata_df = pd.read_csv(gm_metadata_file, sep="\t")

# Add a hemi column that is inferred from label: if starts with LH, left, if starts with RH, right, else bilateral
gm_metadata_df["hemi"] = gm_metadata_df["label"].apply(lambda x: "left" if x.startswith("LH") else "right" if x.startswith("RH") else "bilateral")

# Remove cerebellar ROIs from gm_rois
# gm_metadata_df = gm_metadata_df[~gm_metadata_df["label"].str.startswith("Cerebellar")]

gm_rois = gm_metadata_df["label"].tolist()
global gm_metadata_df

# Define epilepsy subgroups

In [4]:
# Options: left, right, bilateral, left > right, right > left, generalized, inconclusive
subs_seizure_lateralization = clinical_df[clinical_df["seizure_lateralization"].notna()]
print("Seizure lateralization:")
subs_seizure_lateralization_L = clinical_df[
    (clinical_df["seizure_lateralization"] == "left")
]["sub"].tolist()
subs_seizure_lateralization_R = clinical_df[
    (clinical_df["seizure_lateralization"] == "right")
]["sub"].tolist()
subs_seizure_lateralization_B = clinical_df[
    (clinical_df["seizure_lateralization"] == "bilateral")
]["sub"].tolist()
subs_seizure_lateralization_L_R = clinical_df[
    (clinical_df["seizure_lateralization"] == "left > right")
]["sub"].tolist()
subs_seizure_lateralization_R_L = clinical_df[
    (clinical_df["seizure_lateralization"] == "right > left")
]["sub"].tolist()
subs_seizure_lateralization_G = clinical_df[
    (clinical_df["seizure_lateralization"] == "generalized")
]["sub"].tolist()
subs_seizure_lateralization_I = clinical_df[
    (clinical_df["seizure_lateralization"] == "inconclusive")
]["sub"].tolist()
print(f"    Left (n={len(subs_seizure_lateralization_L)})")
print(f"    Right (n={len(subs_seizure_lateralization_R)})")
print(f"    Bilateral (n={len(subs_seizure_lateralization_B)})")
print(f"    Left > Right (n={len(subs_seizure_lateralization_L_R)})")
print(f"    Right > Left (n={len(subs_seizure_lateralization_R_L)})")
print(f"    Generalized (n={len(subs_seizure_lateralization_G)})")
print(f"    Inconclusive (n={len(subs_seizure_lateralization_I)})")
print("")

# Options: frontal, temporal, parietal, occipital, generalized, central, multifocal, nonlocalizable, insular
subs_seizure_localization = clinical_df[clinical_df["seizure_localization"].notna()]
print("Seizure localization:")
subs_seizure_localization_F = clinical_df[
    (clinical_df["seizure_localization"] == "frontal")
]["sub"].tolist()
subs_seizure_localization_T = clinical_df[
    (clinical_df["seizure_localization"] == "temporal")
]["sub"].tolist()
subs_seizure_localization_P = clinical_df[
    (clinical_df["seizure_localization"] == "parietal")
]["sub"].tolist()
subs_seizure_localization_O = clinical_df[
    (clinical_df["seizure_localization"] == "occipital")
]["sub"].tolist()
subs_seizure_localization_G = clinical_df[
    (clinical_df["seizure_localization"] == "generalized")
]["sub"].tolist()
subs_seizure_localization_C = clinical_df[
    (clinical_df["seizure_localization"] == "central")
]["sub"].tolist()
subs_seizure_localization_M = clinical_df[
    (clinical_df["seizure_localization"] == "multifocal")
]["sub"].tolist()
subs_seizure_localization_N = clinical_df[
    (clinical_df["seizure_localization"] == "nonlocalizable")
]["sub"].tolist()
subs_seizure_localization_I = clinical_df[
    (clinical_df["seizure_localization"] == "insular")
]["sub"].tolist()
print(f"    Frontal (n={len(subs_seizure_localization_F)})")
print(f"    Temporal (n={len(subs_seizure_localization_T)})")
print(f"    Parietal (n={len(subs_seizure_localization_P)})")
print(f"    Occipital (n={len(subs_seizure_localization_O)})")
print(f"    Generalized (n={len(subs_seizure_localization_G)})")
print(f"    Central (n={len(subs_seizure_localization_C)})")
print(f"    Multifocal (n={len(subs_seizure_localization_M)})")
print(f"    Nonlocalizable (n={len(subs_seizure_localization_N)})")
print(f"    Insular (n={len(subs_seizure_localization_I)})")
print("")

print("Seizure lateralization and localization")
subs_seizure_temporal_L = clinical_df[
    (clinical_df["seizure_localization"] == "temporal") &
    (clinical_df["seizure_lateralization"] == "left")
]["sub"].tolist()
subs_seizure_temporal_R = clinical_df[
    (clinical_df["seizure_localization"] == "temporal") &
    (clinical_df["seizure_lateralization"] == "right")
]["sub"].tolist()
subs_seizure_frontal_L = clinical_df[
    (clinical_df["seizure_localization"] == "frontal") &
    (clinical_df["seizure_lateralization"] == "left")
]["sub"].tolist()
subs_seizure_frontal_R = clinical_df[
    (clinical_df["seizure_localization"] == "frontal") &
    (clinical_df["seizure_lateralization"] == "right")
]["sub"].tolist()
print(f"    Left temporal (n={len(subs_seizure_temporal_L)})")
print(f"    Right temporal (n={len(subs_seizure_temporal_R)})")
print(f"    Left frontal (n={len(subs_seizure_frontal_L)})")
print(f"    Right frontal (n={len(subs_seizure_frontal_R)})")
print("")

# Print the number of patients with an intervention_laterality value, and summarize
subs_resection_or_ablation = clinical_df[clinical_df["intervention_laterality"].notna()]
print("Patients receiving resection or ablation:")
subs_resection_or_ablation_L = clinical_df[
    (clinical_df["intervention_laterality"] == "left")
]["sub"].tolist()
subs_resection_or_ablation_R = clinical_df[
    (clinical_df["intervention_laterality"] == "right")
]["sub"].tolist()
print(f"    Left (n={len(subs_resection_or_ablation_L)})")
print(f"    Right (n={len(subs_resection_or_ablation_R)})")
print("")

# Print the number of patients with an intervention_laterality value for ILAE >= 2 at 1 year (ilae_category_pecclinical in [1a, 1b, 2])
subs_resection_or_ablation_good_outcomes_1yr = clinical_df[clinical_df["ilae_category_pecclinical"].isin(["1a", "1b", "2"])]
print("Good outcomes at 1 year:")
subs_resection_or_ablation_good_outcomes_1yr_L = clinical_df[
    (clinical_df["ilae_category_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "left")
]["sub"].tolist()
subs_resection_or_ablation_good_outcomes_1yr_R = clinical_df[
    (clinical_df["ilae_category_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "right")
]["sub"].tolist()
print(f"    Left (n={len(subs_resection_or_ablation_good_outcomes_1yr_L)})")
print(f"    Right (n={len(subs_resection_or_ablation_good_outcomes_1yr_R)})")
print("")

# Print the number of patients with an intervention_laterality value for ILAE >= 2 at 2 years (ilae_category_2_pecclinical in [1a, 1b, 2])
subs_resection_or_ablation_good_outcomes_2yr = clinical_df[clinical_df["ilae_category_2_pecclinical"].isin(["1a", "1b", "2"])]
print("Good outcomes at 2 years:")
subs_resection_or_ablation_good_outcomes_2yr_L = clinical_df[
    (clinical_df["ilae_category_2_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "left")
]["sub"].tolist()
subs_resection_or_ablation_good_outcomes_2yr_R = clinical_df[
    (clinical_df["ilae_category_2_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "right")
]["sub"].tolist()
print(f"    Left (n={len(subs_resection_or_ablation_good_outcomes_2yr_L)})")
print(f"    Right (n={len(subs_resection_or_ablation_good_outcomes_2yr_R)})")
print("")

# Temporal intervention with good outcomes at 1 year
subs_temporal_intervention_good_outcomes_1yr = clinical_df[(clinical_df["intervention_lobe"] == "temporal") & (clinical_df["ilae_category_pecclinical"].isin(["1a", "1b", "2"]))]
print("Temporal intervention with good outcomes at 1 year:")
subs_temporal_intervention_good_outcomes_1yr_L = clinical_df[
    (clinical_df["intervention_lobe"] == "temporal") &
    (clinical_df["ilae_category_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "left")
]["sub"].tolist()
subs_temporal_intervention_good_outcomes_1yr_R = clinical_df[
    (clinical_df["intervention_lobe"] == "temporal") &
    (clinical_df["ilae_category_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "right")
]["sub"].tolist()
print(f"    Left (n={len(subs_temporal_intervention_good_outcomes_1yr_L)})")
print(f"    Right (n={len(subs_temporal_intervention_good_outcomes_1yr_R)})")
print("")

# Temporal intervention with good outcomes at 2 years
subs_temporal_intervention_good_outcomes_2yr = clinical_df[(clinical_df["intervention_lobe"] == "temporal") & (clinical_df["ilae_category_2_pecclinical"].isin(["1a", "1b", "2"]))]
print("Temporal intervention with good outcomes at 2 years:")
subs_temporal_intervention_good_outcomes_2yr_L = clinical_df[
    (clinical_df["intervention_lobe"] == "temporal") &
    (clinical_df["ilae_category_2_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "left")
]["sub"].tolist()
subs_temporal_intervention_good_outcomes_2yr_R = clinical_df[
    (clinical_df["intervention_lobe"] == "temporal") &
    (clinical_df["ilae_category_2_pecclinical"].isin(["1a", "1b", "2"])) &
    (clinical_df["intervention_laterality"] == "right")
]["sub"].tolist()
print(f"    Left (n={len(subs_temporal_intervention_good_outcomes_2yr_L)})")
print(f"    Right (n={len(subs_temporal_intervention_good_outcomes_2yr_R)})")
print("")

# Neuromodulation
subs_neuromodulation = clinical_df[clinical_df["intervention_type"].isin(["RNS", "VNS", "DBS"])]
print("Neuromodulation:")
subs_neuromodulation = clinical_df[
    (clinical_df["intervention_type"].isin(["RNS", "VNS", "DBS"]))
]["sub"].tolist()
print(f"    (n={len(subs_neuromodulation)})")
print("")

Seizure lateralization:
    Left (n=47)
    Right (n=30)
    Bilateral (n=8)
    Left > Right (n=8)
    Right > Left (n=4)
    Generalized (n=0)
    Inconclusive (n=10)

Seizure localization:
    Frontal (n=7)
    Temporal (n=77)
    Parietal (n=0)
    Occipital (n=0)
    Generalized (n=0)
    Central (n=2)
    Multifocal (n=7)
    Nonlocalizable (n=7)
    Insular (n=0)

Seizure lateralization and localization
    Left temporal (n=36)
    Right temporal (n=24)
    Left frontal (n=2)
    Right frontal (n=3)

Patients receiving resection or ablation:
    Left (n=18)
    Right (n=15)

Good outcomes at 1 year:
    Left (n=9)
    Right (n=5)

Good outcomes at 2 years:
    Left (n=8)
    Right (n=3)

Temporal intervention with good outcomes at 1 year:
    Left (n=7)
    Right (n=5)

Temporal intervention with good outcomes at 2 years:
    Left (n=7)
    Right (n=3)

Neuromodulation:
    (n=13)



# Specify ROIs, stats, and measures to plot

In [5]:
def get_input_specs(input_data_type="all"):
    """
    Returns roi_dict, stats, and measures based on the input_data_type.
    """
    if input_data_type == "all":
        # Load in GAM outputs dataframes for both WM and GM ROIs
        roi_dict = {
            "wm": {
                "dir": gam_outputs_wm_dir,
                "rois": wm_rois
            },
            "gm": {
                "dir": gam_outputs_gm_dir,
                "rois": gm_rois
            }
        }
        stats = ["mean"]
        measures = list(json.load(open(measures_json_path)).keys())
    elif input_data_type == "gm":
        roi_dict = {
            "gm": {
                "dir": gam_outputs_gm_dir,
                "rois": gm_rois
            }
        }
        stats = ["mean"]
        measures = list(json.load(open(measures_json_path)).keys())
    elif input_data_type == "wm":
        roi_dict = {
            "wm": {
                "dir": gam_outputs_wm_dir,
                "rois": wm_rois
            }
        }
        stats = ["mean"]
        measures = list(json.load(open(measures_json_path)).keys())
    elif input_data_type == "test":
        roi_dict = {
            "wm": {
                "dir": gam_outputs_wm_dir,
                "rois": ["F_L", "F_R", "UF_L", "UF_R", "C_PH_L", "C_PH_R", "ILF_L", "ILF_R"]
            },
            "gm": {
                "dir": gam_outputs_gm_dir,
                "rois": ["LH_Hippocampus", "RH_Hippocampus", "LH_Amygdala", "RH_Amygdala", "LH-NAC", "RH_NAC"]
            }
        }
        stats = ["mean"]
        measures = ["dti_md"]
    elif input_data_type == "test_gm":
        roi_dict = {
            "gm": {
                "dir": gam_outputs_gm_dir,
                "rois": ["LH_Hippocampus", "RH_Hippocampus", "LH_Amygdala", "RH_Amygdala", "LH_NAC", "RH_NAC"]
            }
        }
        stats = ["mean"]
        measures = ["dti_md"]

    elif input_data_type == "test_wm":
        roi_dict = {
            "wm": {
                "dir": gam_outputs_wm_dir,
                "rois": ["F_L", "F_R", "UF_L", "UF_R", "C_PH_L", "C_PH_R", "ILF_L", "ILF_R"]
            }
        }
        stats = ["mean"]
        measures = ["dti_md"]
    else:
        raise ValueError(f"Unknown input_data_type: {input_data_type}")
    return roi_dict, stats, measures

# Plot age distributions by group

In [6]:
def plot_age_and_sex_histograms_by_group(df, groups=None, bins=30, alpha=0.8):
    """
    Compact plot with column titles for 'Age distributions' and 'Sex distributions'.
    Ensures enough space between rows of plots.
    Group titles are placed clearly above each plot, not overlapping with the plot area.
    """
    import matplotlib.ticker as mticker

    if groups is None:
        default_order = ["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"]
        groups = [g for g in default_order if g in df['group'].unique()]
        groups += [g for g in df['group'].unique() if g not in groups]
    group_colors = {
        "hcpya": "#1f77b4",         # blue
        "hcpaging": "#2ca02c",      # green 
        "penn_controls": "#ff7f0e", # orange
        "penn_epilepsy": "#d62728"  # red
    }
    group_display = {
        "hcpya": "HCP-Young-Adults",
        "hcpaging": "HCP-Aging",
        "penn_controls": "Penn-Controls",
        "penn_epilepsy": "Penn-Epilepsy"
    }
    sex_colors = {
        "M": "#377eb8",  # blue
        "F": "#e41a1c"   # red
    }
    sex_display = {
        "M": "Male",
        "F": "Female"
    }
    min_age = df['age'].min()
    max_age = df['age'].max()
    bin_edges = np.linspace(min_age, max_age, bins + 1)

    n_groups = len(groups)
    # Make the plot smaller, but allow enough vertical space between rows
    fig_height = max(2.0 * n_groups + 0.7, 3.5)
    fig_width = 7.5
    fig, axes = plt.subplots(
        n_groups, 2, 
        figsize=(fig_width, fig_height), 
        gridspec_kw={'width_ratios': [2.0, 1], 'wspace': 0.22, 'hspace': 0.55},
        sharex='col'
    )
    if n_groups == 1:
        axes = np.array([axes])
    plt.subplots_adjust(left=0.10, right=0.98, top=0.88, bottom=0.13)

    # Add column titles above the axes, not inside the axes
    col_titles = ["Age distributions", "Sex distributions"]
    for j, col_title in enumerate(col_titles):
        ax = axes[0, j]
        bbox = ax.get_position()
        fig.text(
            bbox.x0 + bbox.width/2, 
            bbox.y1 + 0.06, 
            col_title, 
            ha='center', va='bottom', 
            fontsize=15, color="#222222"
        )

    # Add group titles above each row, spanning both columns
    for i, group in enumerate(groups):
        group_df = df[df['group'] == group]
        n = len(group_df)
        # Place the group title above the row, spanning both columns
        # Get the leftmost and rightmost axes in the row
        left_ax = axes[i, 0]
        right_ax = axes[i, 1]
        left_bbox = left_ax.get_position()
        right_bbox = right_ax.get_position()
        # Place the group title above the axes, outside the plot area
        fig.text(
            (left_bbox.x0 + right_bbox.x1) / 2,
            left_bbox.y1 + 0.0125,
            r"$\bf{" + group_display.get(group, group) + "}$" + f"   n = {n}",
            ha='center', va='bottom',
            fontsize=15, color="#222222"
        )

        # --- Age histogram (left column) ---
        counts, _, _ = axes[i, 0].hist(
            group_df['age'],
            bins=bin_edges,
            alpha=alpha,
            color=group_colors.get(group, "#888888"),
            edgecolor='black'
        )
        axes[i, 0].set_ylabel("N", fontsize=12)
        # Remove the title from inside the plot
        max_count = int(np.max(counts)) if len(counts) > 0 else 1
        axes[i, 0].yaxis.set_major_locator(mticker.MaxNLocator(integer=True))
        axes[i, 0].set_ylim(0, max(max_count + 1, axes[i, 0].get_ylim()[1]))
        axes[i, 0].tick_params(axis='both', which='major', labelsize=10)
        axes[i, 0].grid(axis='y', linestyle=':', alpha=0.4)

        # --- Sex bar plot (right column) ---
        sex_col = group_df['sex'].astype(str).str.upper().replace({"MALE": "M", "FEMALE": "F"})
        sex_counts = sex_col.value_counts()
        male_count = sex_counts.get("M", 0)
        female_count = sex_counts.get("F", 0)
        bar_labels = [sex_display["M"], sex_display["F"]]
        bar_heights = [male_count, female_count]
        bar_colors = [sex_colors["M"], sex_colors["F"]]
        bars = axes[i, 1].bar(
            bar_labels, bar_heights, color=bar_colors, edgecolor='black', alpha=alpha, width=0.6
        )
        axes[i, 1].set_ylabel("N", fontsize=12)
        axes[i, 1].yaxis.set_major_locator(mticker.MaxNLocator(integer=True))
        max_bar_height = max(bar_heights + [1, axes[i, 1].get_ylim()[1] if hasattr(axes[i, 1], 'get_ylim') else 1])
        axes[i, 1].set_ylim(0, max_bar_height * 1.18)
        axes[i, 1].tick_params(axis='both', which='major', labelsize=10)
        axes[i, 1].set_xticks([0, 1])
        axes[i, 1].set_xticklabels(bar_labels, fontsize=11)
        axes[i, 1].grid(axis='y', linestyle=':', alpha=0.4)
        # Annotate counts on top of bars
        for bar, count in zip(bars, bar_heights):
            axes[i, 1].annotate(
                str(count), 
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                xytext=(0, 4), textcoords="offset points", 
                ha='center', va='bottom', fontsize=10, fontweight='bold', color='black'
            )
        # Remove spines for a cleaner look
        for ax in axes[i]:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

    axes[-1, 0].set_xlabel("Age", fontsize=13)
    for ax in axes[:, 1]:
        ax.set_xlabel("")
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()

# Example usage (uncomment and adjust as needed):
example_df = pd.read_csv("/mnt/sauce/littlab/users/mjaskir/structural_tractometry/derivatives/gam/outputs/regions/4S156/Cerebellar_Region1/Cerebellar_Region1_mean_dki_ad_gam.csv")
# plot_age_and_sex_histograms_by_group(example_df)


# Plot observations by group

In [7]:
def plot_gam_by_group(roi_dict, stats, measures, groups=None, highlight_epilepsy=False, z_abs_thr=None):
    """
    Plot normative model predictions and patient observations for the specified ROIs, stats, and measures.

    Args:
        roi_dict (dict): Dictionary specifying ROI types, directories, and ROI lists.
        stats (list): List of statistics to consider (e.g., ["mean"]).
        measures (list): List of measures to consider (e.g., ["dti_md"]).
        groups (list): List of group names to plot (e.g., ["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"]).
        highlight_epilepsy (bool): If True, plot penn_epilepsy observations in red and all others in grey.
        z_abs_thr (float): Absolute z-score threshold for labeling outliers. Default is 3.
    """
    # Default group list if not provided
    if groups is None:
        groups = ["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"]

    # Group display names and colors
    group_display = {
        "hcpya": "HCP-YA",
        "hcpaging": "HCP-Aging",
        "penn_controls": "Penn Controls",
        "penn_epilepsy": "Penn Epilepsy"
    }
    group_colors = {
        "hcpya": "#1f77b4",         # blue
        "hcpaging": "#2ca02c",      # green 
        "penn_controls": "#ff7f0e", # orange
        "penn_epilepsy": "#d62728"  # red
    }
    

    # Load in colors json
    with open(colors_json_path, "r") as f:
        scalar_labels_to_colors = json.load(f)

    def get_human_readable_label(roi, roi_type):
        if roi_type == "wm":
            # Remove trailing _L or _R for lookup
            roi_base = roi
            if roi.endswith("_L") or roi.endswith("_R"):
                roi_base = roi[:-2]
            # Find row in metadata
            row = wm_metadata_df[wm_metadata_df["label"] == roi]
            if row.empty:
                # Try with base (in case metadata omits _L/_R)
                row = wm_metadata_df[wm_metadata_df["label"] == roi_base]
            if not row.empty:
                name = row.iloc[0]["name"].replace("_", " ")
                hemi = row.iloc[0]["hemi"]
                # Remove trailing _L/_R from name if present
                if name.endswith(" L"):
                    name = name[:-2]
                elif name.endswith(" R"):
                    name = name[:-2]
                if hemi.lower() in ["left", "right"]:
                    hemi_str = f" ({hemi.lower()})"
                else:
                    hemi_str = ""
                return f"{name}{hemi_str}"
            else:
                # Fallback: try to prettify the ROI string
                pretty = roi
                if pretty.endswith("_L"):
                    pretty = pretty[:-2]
                    hemi_str = " (left)"
                elif pretty.endswith("_R"):
                    pretty = pretty[:-2]
                    hemi_str = " (right)"
                else:
                    hemi_str = ""
                pretty = pretty.replace("_", " ")
                return f"{pretty}{hemi_str}"
        elif roi_type == "gm":
            # Remove RH- or RH_ or LH- or LH_ prefix, get hemisphere
            if roi.startswith("RH-") or roi.startswith("RH_"):
                base = roi[3:]
                hemi = "right"
            elif roi.startswith("LH-") or roi.startswith("LH_"):
                base = roi[3:]
                hemi = "left"
            else:
                base = roi
                hemi = ""
            base = base.replace("_", " ").replace("-", " ")
            hemi_str = f" ({hemi})" if hemi else ""
            return f"{base}{hemi_str}"
        else:
            return roi

    # Set unified marker, alpha, and size for all groups
    marker = "o"
    alpha = 0.5
    size = 15

    for roi_type, roi_info in roi_dict.items():
        roi_dir = roi_info["dir"]
        rois = roi_info["rois"]
        for roi in rois:
            for stat in stats:
                for measure in measures:
                    gam_outputs_file = ospj(roi_dir, roi, f"{roi}_{stat}_{measure}_gam.csv")
                    if os.path.exists(gam_outputs_file):
                        gam_outputs_df = pd.read_csv(gam_outputs_file)
                    else:
                        print(f"GAM outputs file not found for {roi_type} {roi} {stat} {measure}")
                        continue

                    fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
                    sexes = [("F", f"Females n={len(gam_outputs_df[gam_outputs_df['sex'] == 'F'])}"), ("M", f"Males n={len(gam_outputs_df[gam_outputs_df['sex'] == 'M'])}")]

                    # Get human-readable label for this ROI
                    roi_label = get_human_readable_label(roi, roi_type)

                    for ax, (sex_code, sex_label) in zip(axes, sexes):
                        # For legend: keep track of which group colors have been added
                        legend_handles = []
                        legend_labels = []

                        if highlight_epilepsy:
                            # Plot all non-epilepsy groups in grey, then epilepsy in red
                            # Plot predicted curves for HCP groups as usual
                            for group in groups:
                                group_df = gam_outputs_df[(gam_outputs_df["group"] == group) & (gam_outputs_df["sex"] == sex_code)]
                                display_name = group_display.get(group, group)
                                # Only plot predicted values for HCP groups
                                if group not in ["penn_epilepsy", "penn_controls"] and f"{measure}_pred" in group_df.columns and not group_df.empty:
                                    sorted_df = group_df.sort_values("age")
                                    ax.plot(
                                        sorted_df["age"],
                                        sorted_df[f"{measure}_pred"],
                                        color="black",
                                        linewidth=2
                                    )
                            # Plot all non-epilepsy observations in grey
                            non_epi_df = gam_outputs_df[
                                (gam_outputs_df["group"].isin([g for g in groups if g != "penn_epilepsy"])) &
                                (gam_outputs_df["sex"] == sex_code)
                            ]
                            if not non_epi_df.empty:
                                scatter_non_epi = ax.scatter(
                                    non_epi_df["age"],
                                    non_epi_df[measure],
                                    color="#bbbbbb",
                                    alpha=alpha,
                                    s=size,
                                    marker=marker,
                                    label="Control Groups"
                                )
                                legend_handles.append(scatter_non_epi)
                                legend_labels.append("Control Groups")
                            # Plot epilepsy observations in red
                            epi_df = gam_outputs_df[
                                (gam_outputs_df["group"] == "penn_epilepsy") &
                                (gam_outputs_df["sex"] == sex_code)
                            ]
                            if not epi_df.empty:
                                scatter_epi = ax.scatter(
                                    epi_df["age"],
                                    epi_df[measure],
                                    color="#d62728",
                                    alpha=alpha,
                                    s=size,
                                    marker=marker,
                                    label="Penn Epilepsy"
                                )
                                legend_handles.append(scatter_epi)
                                legend_labels.append("Penn Epilepsy")
                                # For Penn Epilepsy, add z-score labels for outliers
                                if "z" in epi_df.columns:
                                    for _, row in epi_df.iterrows():
                                        age = row["age"]
                                        y_val = row[measure]
                                        if pd.notnull(row.get("z", None)):
                                            z_rounded = round(row["z"], 1)
                                            if z_abs_thr is not None and np.abs(z_rounded) > z_abs_thr:
                                                ax.text(
                                                    age, y_val, f"z={z_rounded}",
                                                    fontsize=10, color="black", va="bottom", ha="left", fontweight="bold"
                                                )
                        else:
                            for group in groups:
                                group_df = gam_outputs_df[(gam_outputs_df["group"] == group) & (gam_outputs_df["sex"] == sex_code)]
                                color = group_colors.get(group, None)
                                display_name = group_display.get(group, group)

                                # Only plot predicted values for HCP groups
                                if group not in ["penn_epilepsy", "penn_controls"] and f"{measure}_pred" in group_df.columns and not group_df.empty:
                                    sorted_df = group_df.sort_values("age")
                                    ax.plot(
                                        sorted_df["age"],
                                        sorted_df[f"{measure}_pred"],
                                        color="black",
                                        linewidth=2
                                    )
                                # Plot observed data points (same marker, alpha, and size for all groups)
                                if not group_df.empty:
                                    scatter = ax.scatter(
                                        group_df["age"],
                                        group_df[measure],
                                        color=color,
                                        alpha=alpha,
                                        s=size,
                                        marker=marker,
                                    )
                                    # Add to legend only once per group per subplot
                                    if display_name not in legend_labels:
                                        legend_handles.append(scatter)
                                        legend_labels.append(display_name)
                                    # For Penn Epilepsy, add z-score labels for outliers
                                    if group == "penn_epilepsy" and "z" in group_df.columns:
                                        for _, row in group_df.iterrows():
                                            age = row["age"]
                                            y_val = row[measure]
                                            if pd.notnull(row.get("z", None)):
                                                z_rounded = round(row["z"], 1)
                                                if z_abs_thr is not None and np.abs(z_rounded) > z_abs_thr:
                                                    ax.text(
                                                        age, y_val, f"z={z_rounded}",
                                                        fontsize=10, color="black", va="bottom", ha="left", fontweight="bold"
                                                    )

                        ax.set_title(f"{roi_label}\n{sex_label}", fontsize=16)
                        ax.set_xlabel("Age", fontsize=16)
                        # Set y-axis label color based on measure
                        y_color = scalar_labels_to_colors.get(measure, "#000000")

                        # Load in human-readable labels
                        with open(human_labels_json_path, "r") as f:
                            scalar_labels_to_human = json.load(f)
                        measure_human = scalar_labels_to_human.get(measure, measure)

                        ax.set_ylabel(f"{measure_human}", color=y_color, fontsize=16)
                        # Also set the tick params for y-axis to match label color for clarity
                        ax.tick_params(axis='y', labelsize=16)
                        ax.tick_params(axis='x', labelsize=16)
                        # Plot y-axis tick labels explicitly
                        y_ticks = ax.get_yticks()
                        y_ticklabels = [f"{ytick:.2f}" for ytick in y_ticks]
                        ax.set_yticks(y_ticks)
                        ax.set_yticklabels(y_ticklabels, fontsize=14)
                        # Legend: just explain the point colors with group labels
                        ax.legend(legend_handles, legend_labels, loc="lower right", fontsize=10, title="Group", title_fontsize=12)

                    plt.tight_layout()
                    plt.show()

roi_dict, stats, measures = get_input_specs(input_data_type="test")
# plot_gam_by_group(roi_dict, stats, measures, groups=["hcpya", "hcpaging", "penn_controls"])
# plot_gam_by_group(roi_dict, stats, measures, groups=["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"], highlight_epilepsy=True)
# plot_gam_by_group(roi_dict, stats, measures, groups=["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"], highlight_epilepsy=True, z_abs_thr=5)

# Plot harmonization effects by group

In [8]:
def plot_harmonization(roi_dict, stats, measures, groups=None, scanner_ids=False):
    """
    Plot raw values before and after harmonization for the specified ROIs, stats, and measures.

    Args:
        roi_dict (dict): Dictionary specifying ROI types, directories, and ROI lists.
        stats (list): List of statistics to consider (e.g., ["mean"]).
        measures (list): List of measures to consider (e.g., ["dti_md"]).
        groups (list): List of group names to plot (e.g., ["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"]).
        scanner_ids (bool): If True, color by scanner IDs ("bat" variable) instead of group.
    """
    # Default group list if not provided
    if groups is None:
        groups = ["hcpya", "hcpaging", "penn_controls", "penn_epilepsy"]

    # Group display names and colors
    group_display = {
        "hcpya": "HCP-YA",
        "hcpaging": "HCP-Aging",
        "penn_controls": "Penn Controls",
        "penn_epilepsy": "Penn Epilepsy"
    }
    group_colors = {
        "hcpya": "#1f77b4",         # blue
        "hcpaging": "#2ca02c",      # green 
        "penn_controls": "#ff7f0e", # orange
        "penn_epilepsy": "#d62728"  # red
    }

    # Load in colors json
    with open(colors_json_path, "r") as f:
        scalar_labels_to_colors = json.load(f)

    def get_human_readable_label(roi, roi_type):
        if roi_type == "wm":
            # Remove trailing _L or _R for lookup
            roi_base = roi
            if roi.endswith("_L") or roi.endswith("_R"):
                roi_base = roi[:-2]
            # Find row in metadata
            row = wm_metadata_df[wm_metadata_df["label"] == roi]
            if row.empty:
                # Try with base (in case metadata omits _L/_R)
                row = wm_metadata_df[wm_metadata_df["label"] == roi_base]
            if not row.empty:
                name = row.iloc[0]["name"].replace("_", " ")
                hemi = row.iloc[0]["hemi"]
                # Remove trailing _L/_R from name if present
                if name.endswith(" L"):
                    name = name[:-2]
                elif name.endswith(" R"):
                    name = name[:-2]
                if hemi.lower() in ["left", "right"]:
                    hemi_str = f" ({hemi.lower()})"
                else:
                    hemi_str = ""
                return f"{name}{hemi_str}"
            else:
                # Fallback: try to prettify the ROI string
                pretty = roi
                if pretty.endswith("_L"):
                    pretty = pretty[:-2]
                    hemi_str = " (left)"
                elif pretty.endswith("_R"):
                    pretty = pretty[:-2]
                    hemi_str = " (right)"
                else:
                    hemi_str = ""
                pretty = pretty.replace("_", " ")
                return f"{pretty}{hemi_str}"
        elif roi_type == "gm":
            # Remove RH- or RH_ or LH- or LH_ prefix, get hemisphere
            if roi.startswith("RH-") or roi.startswith("RH_"):
                base = roi[3:]
                hemi = "right"
            elif roi.startswith("LH-") or roi.startswith("LH_"):
                base = roi[3:]
                hemi = "left"
            else:
                base = roi
                hemi = ""
            base = base.replace("_", " ").replace("-", " ")
            hemi_str = f" ({hemi})" if hemi else ""
            return f"{base}{hemi_str}"
        else:
            return roi

    # Set unified marker, alpha, and size for all groups
    marker = "o"
    alpha = 0.5
    size = 15

    # Predefine a list of discrete color names for scanner IDs
    # Use a long list to support up to ~20 scanners, then cycle if needed
    scanner_color_names = [
        "tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown", "tab:pink", "tab:gray", "tab:olive",
        "gold", "lime", "crimson", "indigo", "saddlebrown", "hotpink", "slategray", "yellowgreen", "teal"
    ]

    for roi_type, roi_info in roi_dict.items():
        roi_dir = roi_info["dir"]
        rois = roi_info["rois"]
        for roi in rois:
            for stat in stats:
                for measure in measures:
                    gam_outputs_file = ospj(roi_dir, roi, f"{roi}_{stat}_{measure}_gam.csv")
                    print(f"[{roi_type.upper()}] {gam_outputs_file}")
                    if os.path.exists(gam_outputs_file):
                        gam_outputs_df = pd.read_csv(gam_outputs_file)
                    else:
                        print(f"GAM outputs file not found for {roi_type} {roi} {stat} {measure}")
                        continue

                    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharey=True)
                    sexes = [("F", "Females"), ("M", "Males")]
                    harmonization_states = [
                        (f"{measure}_orig", ""),  # Remove harmonization label
                        (measure, "")            # Remove harmonization label
                    ]

                    # Get human-readable label for this ROI
                    roi_label = get_human_readable_label(roi, roi_type)

                    # Precompute sample sizes for each sex
                    sample_sizes = {}
                    for sex_code, sex_label in sexes:
                        n = gam_outputs_df[gam_outputs_df["sex"] == sex_code].shape[0]
                        sample_sizes[sex_code] = n

                    for row_idx, (y_col, harm_label) in enumerate(harmonization_states):
                        for col_idx, (sex_code, sex_label) in enumerate(sexes):
                            ax = axes[row_idx, col_idx]
                            legend_handles = []
                            legend_labels = []

                            if scanner_ids:
                                # Color by scanner ID ("bat" variable) using discrete color names
                                # We want to assign each scanner ID a unique color across all groups, and once a color is used, it is not available for the next group
                                available_colors = scanner_color_names.copy()
                                used_scanner_ids = set()
                                scanner_id_to_color = {}
                                for group in groups:
                                    group_df = gam_outputs_df[(gam_outputs_df["group"] == group) & (gam_outputs_df["sex"] == sex_code)]
                                    if y_col in group_df.columns and not group_df.empty:
                                        scanner_id_values = group_df["bat"].dropna().unique()
                                        scanner_id_values_sorted = sorted(scanner_id_values)
                                        for sid in scanner_id_values_sorted:
                                            if sid not in scanner_id_to_color:
                                                if available_colors:
                                                    color = available_colors.pop(0)
                                                else:
                                                    # If we run out of colors, cycle (fallback, but should be rare)
                                                    color = scanner_color_names[len(scanner_id_to_color) % len(scanner_color_names)]
                                                scanner_id_to_color[sid] = color
                                        # Now plot for this group
                                        for sid in scanner_id_values_sorted:
                                            sid_df = group_df[group_df["bat"] == sid]
                                            color = scanner_id_to_color[sid]
                                            display_name = f"{sid}"
                                            scatter = ax.scatter(
                                                sid_df["age"],
                                                sid_df[y_col],
                                                color=color,
                                                alpha=alpha,
                                                s=size,
                                                marker=marker,
                                                label=display_name
                                            )
                                            if display_name not in legend_labels:
                                                legend_handles.append(scatter)
                                                legend_labels.append(display_name)
                            else:
                                # Color by group (default)
                                for group in groups:
                                    group_df = gam_outputs_df[(gam_outputs_df["group"] == group) & (gam_outputs_df["sex"] == sex_code)]
                                    color = group_colors.get(group, None)
                                    display_name = group_display.get(group, group)

                                    # Only plot if the column exists
                                    if y_col in group_df.columns and not group_df.empty:
                                        scatter = ax.scatter(
                                            group_df["age"],
                                            group_df[y_col],
                                            color=color,
                                            alpha=alpha,
                                            s=size,
                                            marker=marker,
                                        )
                                        if display_name not in legend_labels:
                                            legend_handles.append(scatter)
                                            legend_labels.append(display_name)

                            # Titles and labels
                            # For this subplot, get the sample size for the current sex
                            n_this_sex = sample_sizes.get(sex_code, 0)
                            # Title: ROI label, sex, and sample size for the plotted sex only
                            title = (
                                f"{roi_label}\n"
                                f"{sex_label} (n={n_this_sex})"
                            )
                            ax.set_title(title, fontsize=14)

                            if col_idx == 0:
                                # Set y-axis label color based on measure
                                y_color = scalar_labels_to_colors.get(measure, "#000000")
                                with open(human_labels_json_path, "r") as f:
                                    scalar_labels_to_human = json.load(f)
                                measure_human = scalar_labels_to_human.get(measure, measure)
                                ax.set_ylabel(f"{measure_human}", color=y_color, fontsize=16)
                            else:
                                ax.set_ylabel("")
                            ax.set_xlabel("Age", fontsize=16)
                            ax.tick_params(axis='y', labelsize=14)
                            ax.tick_params(axis='x', labelsize=14)

                            # Set x-axis ticks to display every 20 years
                            # Get min and max age for this subplot
                            all_ages = gam_outputs_df["age"].dropna()
                            if not all_ages.empty:
                                min_age = 20
                                max_age = int(np.ceil(all_ages.max() / 20.0) * 20)
                                xticks = np.arange(min_age, max_age + 1, 20)
                                ax.set_xticks(xticks)
                                ax.set_xticklabels([str(int(x)) for x in xticks], fontsize=14)

                            # Always show legend in every subplot, but make it tighter
                            if scanner_ids:
                                ax.legend(
                                    legend_handles, legend_labels,
                                    loc="lower right",
                                    fontsize=8, title="Scanner ID", title_fontsize=10, ncol=4,
                                    framealpha=0.7,
                                    handletextpad=0.5, borderpad=0.3, labelspacing=0.3, columnspacing=0.7, borderaxespad=0.3
                                )
                            else:
                                ax.legend(
                                    legend_handles, legend_labels,
                                    loc="lower right",
                                    fontsize=8, title="Group", title_fontsize=10,
                                    framealpha=0.7,
                                    handletextpad=0.5, borderpad=0.3, labelspacing=0.3, columnspacing=0.7, borderaxespad=0.3
                                )

                    plt.tight_layout()
                    plt.show()

roi_dict, stats, measures = get_input_specs(input_data_type="test")
# plot_harmonization(roi_dict, stats, measures, groups=["hcpya", "hcpaging", "penn_controls"], scanner_ids=False)
# plot_harmonization(roi_dict, stats, measures, groups=["hcpya", "hcpaging", "penn_controls"], scanner_ids=True)

# Ranking ROIs by model deviance

In [9]:
def get_hemi(row):
    """Return hemisphere for a given ROI row using metadata."""
    if row["tissue"] == "gm" and gm_metadata_df is not None:
        match = gm_metadata_df[gm_metadata_df["label"] == row["roi"]]
        if not match.empty and "hemi" in match.columns:
            return match["hemi"].values[0]
    elif row["tissue"] == "wm" and wm_metadata_df is not None:
        match = wm_metadata_df[wm_metadata_df["label"] == row["roi"]]
        if not match.empty and "hemi" in match.columns:
            return match["hemi"].values[0]
    return None

def get_display_label(row, ipsi=None):
    """
    Return display label for ROI, removing hemispheric prefixes/suffixes only if ipsi is specified.
    For WM, use 'name' from metadata if available.
    """
    label = row["roi"]
    if row["tissue"] == "gm":
        # Remove "LH-" or "RH-" prefix if ipsi is specified
        if ipsi is not None and (label.startswith("LH-") or label.startswith("RH-") or label.startswith("LH_") or label.startswith("RH_")):
            label = label[3:]
        return label
    elif row["tissue"] == "wm":
        # Use 'name' from metadata if available
        if wm_metadata_df is not None and "label" in wm_metadata_df.columns and "name" in wm_metadata_df.columns:
            match = wm_metadata_df[wm_metadata_df["label"] == label]
            if not match.empty:
                name = match["name"].values[0]
                # Remove "_L" or "_R" suffix if ipsi is specified
                if ipsi is not None and (name.endswith("_L") or name.endswith("_R")):
                    name = name[:-2]
                return name
        # Fallback: remove "_L" or "_R" from roi label if ipsi is specified
        if ipsi is not None and (label.endswith("_L") or label.endswith("_R")):
            label = label[:-2]
        return label
    return label

def get_network_label(row):
    """Return network label for a given ROI row using metadata."""
    if row["tissue"] == "gm" and gm_metadata_df is not None:
        match = gm_metadata_df[gm_metadata_df["label"] == row["roi"]]
        if not match.empty:
            return match["network_label"].values[0]
    elif row["tissue"] == "wm" and wm_metadata_df is not None:
        match = wm_metadata_df[wm_metadata_df["label"] == row["roi"]]
        if not match.empty:
            return match["type"].values[0]
    return "Unknown"

def plot_roi_zscores(
    roi_dict, stats, measures, method="mean", z_type="abs", n_plot=10, subs=None, networks=False, ipsi=None, title=None
):
    """
    Plot the top n_plot ROIs (or networks if networks=True) ranked by summarized z-scores across measures and stats.
    If 'title' is specified, it will be used as the first line of the plot title.

    This version creates two subplots side by side: GM ROIs on the left, WM ROIs on the right, with equally scaled y-axes.
    """
    import warnings

    if method not in ["sum", "mean"]:
        raise ValueError(f"Unknown method: {method}. Only 'sum' and 'mean' are allowed.")
    if z_type not in ["raw", "abs"]:
        raise ValueError(f"Unknown z_type: {z_type}. Only 'raw' and 'abs' are allowed.")
    if ipsi is not None and ipsi not in ["left", "right"]:
        raise ValueError(f"Unknown ipsi: {ipsi}. Only 'left', 'right', or None are allowed.")

    # Nested dict: subj -> tissue_type -> roi -> stat -> measure -> z
    subj_zscores = {}

    # Determine which subjects to use
    all_subs = set()
    if subs is not None:
        all_subs = set(subs)
    else:
        # If subs is None, collect all subjects from all files
        for tissue_type, info in roi_dict.items():
            gam_outputs_dir = info["dir"]
            rois = info["rois"]
            for roi in rois:
                for stat in stats:
                    for measure in measures:
                        gam_outputs_path = ospj(gam_outputs_dir, roi, f"{roi}_{stat}_{measure}_gam.csv")
                        if not os.path.exists(gam_outputs_path):
                            continue
                        try:
                            gam_outputs_df = pd.read_csv(gam_outputs_path)
                            gam_outputs_df = gam_outputs_df[gam_outputs_df["group"] == "penn_epilepsy"]
                            all_subs.update(gam_outputs_df["sub"].unique())
                        except Exception:
                            continue

    all_subs = sorted(list(all_subs))
    for sub in all_subs:
        subj_zscores[sub] = {}

    # For each subject, for each tissue_type, roi, stat, measure, store z-score
    for tissue_type, info in roi_dict.items():
        gam_outputs_dir = info["dir"]
        rois = info["rois"]
        print(f"Summarizing z-scores for {tissue_type}")
        for roi in tqdm(rois):
            for stat in stats:
                for measure in measures:
                    gam_outputs_path = ospj(gam_outputs_dir, roi, f"{roi}_{stat}_{measure}_gam.csv")
                    if not os.path.exists(gam_outputs_path):
                        continue
                    try:
                        gam_outputs_df = pd.read_csv(gam_outputs_path)
                        gam_outputs_df = gam_outputs_df[gam_outputs_df["group"] == "penn_epilepsy"]
                        gam_outputs_df = gam_outputs_df[gam_outputs_df["sub"].isin(all_subs)]
                        for _, row in gam_outputs_df.iterrows():
                            sub = row["sub"]
                            if "z" not in gam_outputs_df.columns:
                                continue
                            z_val = row["z"]
                            # Initialize nested dicts as needed
                            if tissue_type not in subj_zscores[sub]:
                                subj_zscores[sub][tissue_type] = {}
                            if roi not in subj_zscores[sub][tissue_type]:
                                subj_zscores[sub][tissue_type][roi] = {}
                            if stat not in subj_zscores[sub][tissue_type][roi]:
                                subj_zscores[sub][tissue_type][roi][stat] = {}
                            subj_zscores[sub][tissue_type][roi][stat][measure] = z_val
                    except Exception:
                        continue

    # Now, for each ROI, summarize z-scores across all subjects, stats, and measures
    roi_scores = []
    for tissue_type, info in roi_dict.items():
        rois = info["rois"]
        for roi in rois:
            all_zs = []
            for sub in all_subs:
                z_dict = subj_zscores.get(sub, {}).get(tissue_type, {}).get(roi, {})
                for stat in z_dict:
                    for measure in z_dict[stat]:
                        all_zs.append(z_dict[stat][measure])
            if all_zs:
                all_zs = np.array(all_zs)
                if z_type == "abs":
                    all_zs = np.abs(all_zs)
                summary_z = np.sum(all_zs) if method == "sum" else np.mean(all_zs)
                roi_scores.append({
                    "tissue": tissue_type,
                    "roi": roi,
                    "summary_z": summary_z,
                    "n_zscores": len(all_zs)
                })

    roi_scores_df = pd.DataFrame(roi_scores)
    if roi_scores_df.empty:
        print("No z-scores found for the given ROIs and measures.")
        return roi_scores_df

    # --- Make text much larger relative to the bars ---
    # Define font sizes
    title_fontsize = 24
    label_fontsize = 24
    tick_fontsize = 20
    legend_fontsize = 20

    # --- Consistent y-axis for abs z-scores ---
    FIXED_ABS_Z_YLIM = 2.5

    # --- Helper for fixed axes area ---
    def set_fixed_axes_area(fig, ax, axes_width_in=8, axes_height_in=8):
        """
        Adjust the figure size so that the axes area (the plotting area, not including labels/ticks) is fixed.
        This expands the figure as needed to accommodate long labels.
        """
        import matplotlib.transforms as mtransforms

        # Draw the canvas to get the correct positions
        fig.canvas.draw()
        renderer = fig.canvas.get_renderer()
        bbox = ax.get_window_extent(renderer=renderer)
        dpi = fig.dpi
        axes_width = bbox.width / dpi
        axes_height = bbox.height / dpi

        # Calculate current figure size
        fig_width, fig_height = fig.get_size_inches()

        # Calculate the difference between current axes size and desired axes size
        width_ratio = axes_width_in / axes_width
        height_ratio = axes_height_in / axes_height

        # Expand the figure size accordingly
        new_fig_width = fig_width * width_ratio
        new_fig_height = fig_height * height_ratio

        fig.set_size_inches(new_fig_width, new_fig_height, forward=True)
        fig.tight_layout(rect=[0, 0, 0.88, 1])  # Re-tighten after resizing

    if networks:
        roi_scores_df["network"] = roi_scores_df.apply(get_network_label, axis=1)
        net_scores = (
            roi_scores_df.groupby(["tissue", "network"])
            .agg({"summary_z": "mean", "n_zscores": "sum"})
            .reset_index()
            .sort_values("summary_z", ascending=False)
        )
        top_df = net_scores.head(n_plot)
        color_map = {"wm": "#1f77b4", "gm": "#ff7f0e"}
        colors = [color_map.get(t, "gray") for t in top_df["tissue"]]
        fig, ax = plt.subplots(figsize=(12, 10))
        bars = ax.bar(top_df["network"], top_df["summary_z"], color=colors)
        ax.set_xlabel("Network", fontsize=label_fontsize)
        ax.set_ylabel(f"Mean summarized z-score ({method}, {z_type})", fontsize=label_fontsize)
        # Compose title
        plot_title = f"Top {n_plot} Networks by mean summarized z-score ({method}, {z_type})"
        if title is not None:
            plot_title = f"{title}\n{plot_title}"
        ax.set_title(plot_title, fontsize=title_fontsize, pad=18)
        ax.set_xticks(range(len(top_df)))
        ax.set_xticklabels(top_df["network"], rotation=45, ha="right", fontsize=tick_fontsize)
        ax.tick_params(axis='y', labelsize=tick_fontsize)
        if z_type == "abs":
            ax.set_ylim(0, FIXED_ABS_Z_YLIM)
        plt.tight_layout(rect=[0, 0, 0.88, 1])
        set_fixed_axes_area(fig, ax, axes_width_in=8, axes_height_in=8)
        plt.show()
        return top_df

    # Not networks: create two subplots side by side for GM and WM
    roi_scores_df = roi_scores_df.sort_values("summary_z", ascending=False)

    # Separate GM and WM data
    gm_scores = roi_scores_df[roi_scores_df["tissue"] == "gm"].head(n_plot).copy()
    wm_scores = roi_scores_df[roi_scores_df["tissue"] == "wm"].head(n_plot).copy()

    # Add hemisphere information
    gm_scores["hemi"] = gm_scores.apply(get_hemi, axis=1)
    wm_scores["hemi"] = wm_scores.apply(get_hemi, axis=1)

    # Create figure with two subplots side by side, but stretch vertically
    fig, (ax_gm, ax_wm) = plt.subplots(1, 2, figsize=(22, 12))  # Increased width from 20 to 22 for more space

    # Add more space between the two subplots
    fig.subplots_adjust(wspace=0.35)  # Increase wspace for more space between subplots

    # Function to create a single subplot
    def create_subplot(ax, scores_df, tissue_type):
        if scores_df.empty:
            ax.text(0.5, 0.5, f'No {tissue_type.upper()} data available',
                   ha='center', va='center', transform=ax.transAxes, fontsize=16)
            ax.set_title(f"{tissue_type.upper()} ROIs", fontsize=title_fontsize, pad=30)
            return

        # Bar color: green for ipsi, grey for contra/unknown
        bar_colors = [
            "green" if ipsi is not None and str(row["hemi"]).lower() == ipsi else "grey"
            for _, row in scores_df.iterrows()
        ]

        # Prepare display labels
        display_labels = [get_display_label(row, ipsi=ipsi).replace("_", " ") for _, row in scores_df.iterrows()]

        # Create bars
        bars = ax.bar(
            range(len(scores_df)),
            scores_df["summary_z"],
            color=bar_colors,
            edgecolor="none"
        )

        # Add outline to GM bars (only for GM subplot)
        if tissue_type == "gm":
            for bar in bars:
                bar.set_edgecolor("black")
                bar.set_linewidth(2)

        # Set labels and formatting
        # xrotation = 45 if tissue_type == "gm" else 60
        xrotation = 40
        xha = "right"
        xpad = 10 if tissue_type == "gm" else 20

        if z_type == "abs":
            ax.set_ylabel(f"{method.capitalize()} |z|", fontsize=label_fontsize)
        else:
            ax.set_ylabel(f"{method.capitalize()} z-score", fontsize=label_fontsize)

        ax.set_title(f"{tissue_type.upper()} ROIs", fontsize=title_fontsize, pad=30)
        ax.set_xticks(range(len(scores_df)))
        ax.set_xticklabels(display_labels, rotation=xrotation, ha=xha, fontsize=tick_fontsize)
        ax.tick_params(axis='y', labelsize=tick_fontsize)

        # Set y-axis limits for abs z-scores
        if z_type == "abs":
            ax.set_ylim(0, FIXED_ABS_Z_YLIM)

    # Create both subplots
    create_subplot(ax_gm, gm_scores, "gm")
    create_subplot(ax_wm, wm_scores, "wm")

    # Set equal y-axis limits for both subplots
    if z_type == "abs":
        ax_gm.set_ylim(0, FIXED_ABS_Z_YLIM)
        ax_wm.set_ylim(0, FIXED_ABS_Z_YLIM)

    # Custom legend for color coding (only if ipsi is specified)
    if ipsi is not None:
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color="green", lw=6, label=f"Ipsilateral ({ipsi.title()})"),
            Line2D([0], [0], color="grey", lw=6, label="Contralateral")
        ]
        # Place legend closer to the right edge, not too far
        fig.legend(
            handles=legend_elements,
            loc="center left",
            bbox_to_anchor=(0.92, 0.5),  # Move legend closer to the plots
            fontsize=legend_fontsize,
            borderaxespad=0.0
        )

    # Set main title if provided
    if title is not None:
        fig.suptitle(title, fontsize=title_fontsize + 2, y=0.98)

    plt.tight_layout(rect=[0, 0, 0.86, 0.95])  # Leave space for legend and title, but not too much

    # Apply fixed axes area to both subplots, with increased height
    set_fixed_axes_area(fig, ax_gm, axes_width_in=8, axes_height_in=8)
    set_fixed_axes_area(fig, ax_wm, axes_width_in=8, axes_height_in=8)

    plt.show()
    return roi_scores_df

## Seizure lateralization and localization

### Left temporal seizures (n=36)

In [10]:
# roi_dict, stats, measures = get_input_specs(input_data_type="all")
# left_temporal_seizures_abs_z_means = plot_roi_zscores(roi_dict, stats, measures, subs=subs_seizure_temporal_L, ipsi="left", title="Left temporal seizures only during EMU stay (n=36)")

# # Print top 10 where tissue_type is gm
# print(left_temporal_seizures_abs_z_means[left_temporal_seizures_abs_z_means["tissue"] == "gm"].sort_values(by="summary_z", ascending=False).head(10))

# # Print top 10 where tissue_type is wm
# print(left_temporal_seizures_abs_z_means[left_temporal_seizures_abs_z_means["tissue"] == "wm"].sort_values(by="summary_z", ascending=False).head(10))

### Right temporal seizures (n=24)

In [11]:
# roi_dict, stats, measures = get_input_specs(input_data_type="all")
# right_temporal_seizures_abs_z_means_all = plot_roi_zscores(roi_dict, stats, measures, subs=subs_seizure_temporal_R,ipsi="right",title="Right temporal seizures only during EMU stay (n=24)")

# # Print top 10 where tissue_type is gm
# print(right_temporal_seizures_abs_z_means_all[right_temporal_seizures_abs_z_means_all["tissue"] == "gm"].sort_values(by="summary_z", ascending=False).head(10))

# # Print top 10 where tissue_type is wm
# print(right_temporal_seizures_abs_z_means_all[right_temporal_seizures_abs_z_means_all["tissue"] == "wm"].sort_values(by="summary_z", ascending=False).head(10))

### Neuromodulation (n=13)

In [12]:
# roi_dict, stats, measures = get_input_specs(input_data_type="all")
# plot_roi_zscores(roi_dict, stats, measures, subs=subs_neuromodulation, title="Neuromodulation (n=13)")

# roi_dict, stats, measures = get_input_specs(input_data_type="gm")
# neuromodulation_abs_z_means_gm = plot_roi_zscores(roi_dict, stats, measures, subs=subs_neuromodulation, title="Neuromodulation (n=13)")
# neuromodulation_abs_z_means_gm.to_csv(ospj(gam_outputs_group_summaries_dir, "neuromodulation_abs_z_means_gm.csv"), index=False)

# roi_dict, stats, measures = get_input_specs(input_data_type="wm")
# neuromodulation_abs_z_means_wm = plot_roi_zscores(roi_dict, stats, measures, subs=subs_neuromodulation, title="Neuromodulation (n=13)")
# neuromodulation_abs_z_means_wm.to_csv(ospj(gam_outputs_group_summaries_dir, "neuromodulation_abs_z_means_wm.csv"), index=False)

# Asymmetries

In [None]:
def plot_roi_zscores_asym(
    roi_dict, stats, measures, method="mean", z_type="abs", n_plot=5, subs=None, networks=False, title=None, ipsi=None
):
    """
    Plot the top n_plot base ROIs (or networks if networks=True) ranked by summarized z-score asymmetries across measures and stats.
    For each subject, computes ipsilateral - contralateral z-scores within each base ROI, then summarizes asymmetries at the base ROI level.
    If 'title' is specified, it will be used as the first line of the plot title.

    The 'ipsi' argument should be either "left" or "right" and indicates which hemisphere is ipsilateral for the inputted subs.
    If ipsi is None, sidedness is inferred from clinical_df['seizure_lateralization'] for each subject.
    """
    if method not in ["sum", "mean"]:
        raise ValueError(f"Unknown method: {method}. Only 'sum' and 'mean' are allowed.")
    if z_type not in ["raw", "abs"]:
        raise ValueError(f"Unknown z_type: {z_type}. Only 'raw' and 'abs' are allowed.")
    if ipsi not in [None, "left", "right"]:
        raise ValueError(f"Unknown ipsi: {ipsi}. Only 'left', 'right', or None are allowed.")

    # Nested dict: subj -> tissue_type -> roi -> stat -> measure -> z
    subj_zscores = {}

    # Determine which subjects to use
    all_subs = set()
    if subs is not None:
        all_subs = set(subs)
    else:
        # If subs is None, collect all subjects from all files
        for tissue_type, info in roi_dict.items():
            gam_outputs_dir = info["dir"]
            rois = info["rois"]
            for roi in rois:
                for stat in stats:
                    for measure in measures:
                        gam_outputs_path = ospj(gam_outputs_dir, roi, f"{roi}_{stat}_{measure}_gam.csv")
                        if not os.path.exists(gam_outputs_path):
                            continue
                        try:
                            gam_outputs_df = pd.read_csv(gam_outputs_path)
                            gam_outputs_df = gam_outputs_df[gam_outputs_df["group"] == "penn_epilepsy"]
                            all_subs.update(gam_outputs_df["sub"].unique())
                        except Exception:
                            continue

    all_subs = sorted(list(all_subs))
    for sub in all_subs:
        subj_zscores[sub] = {}

    # For each subject, for each tissue_type, roi, stat, measure, store z-score
    for tissue_type, info in roi_dict.items():
        gam_outputs_dir = info["dir"]
        rois = info["rois"]
        print(f"Collecting z-scores for {tissue_type}")
        for roi in tqdm(rois):
            for stat in stats:
                for measure in measures:
                    gam_outputs_path = ospj(gam_outputs_dir, roi, f"{roi}_{stat}_{measure}_gam.csv")
                    if not os.path.exists(gam_outputs_path):
                        continue
                    try:
                        gam_outputs_df = pd.read_csv(gam_outputs_path)
                        gam_outputs_df = gam_outputs_df[gam_outputs_df["group"] == "penn_epilepsy"]
                        gam_outputs_df = gam_outputs_df[gam_outputs_df["sub"].isin(all_subs)]
                        for _, row in gam_outputs_df.iterrows():
                            sub = row["sub"]
                            if "z" not in gam_outputs_df.columns:
                                continue
                            z_val = row["z"]
                            # Initialize nested dicts as needed
                            if tissue_type not in subj_zscores[sub]:
                                subj_zscores[sub][tissue_type] = {}
                            if roi not in subj_zscores[sub][tissue_type]:
                                subj_zscores[sub][tissue_type][roi] = {}
                            if stat not in subj_zscores[sub][tissue_type][roi]:
                                subj_zscores[sub][tissue_type][roi][stat] = {}
                            subj_zscores[sub][tissue_type][roi][stat][measure] = z_val
                    except Exception:
                        continue

    # If ipsi is None, infer sidedness from clinical_df for each subject
    # clinical_df must be available in the global scope
    if ipsi is None:
        if "clinical_df" not in globals():
            raise RuntimeError("clinical_df must be available in the global scope to infer sidedness when ipsi=None.")
        # Build a mapping from sub to seizure_lateralization
        clinical_lateralization = {}
        for _, row in clinical_df.iterrows():
            clinical_lateralization[row["sub"]] = row["seizure_lateralization"]
        # Only keep subjects for which we have lateralization info
        all_subs_with_lat = [sub for sub in all_subs if sub in clinical_lateralization]
        if len(all_subs_with_lat) < len(all_subs):
            print(f"Warning: {len(all_subs) - len(all_subs_with_lat)} subjects missing seizure_lateralization in clinical_df and will be skipped.")
        all_subs = all_subs_with_lat

    # Now compute asymmetries for each base ROI (aggregate left/right)
    base_roi_asym_scores = {}
    for tissue_type, info in roi_dict.items():
        rois = info["rois"]
        print(f"Computing asymmetries for {tissue_type}")
        # Build mapping from base_roi to left/right roi names
        base_roi_to_lr = {}
        for roi in rois:
            base_roi = get_base_roi_name(roi, tissue_type)
            if base_roi is None:
                continue
            if tissue_type == "gm":
                if roi.startswith("LH-") or roi.startswith("LH_"):
                    base_roi_to_lr.setdefault(base_roi, {})["left"] = roi
                elif roi.startswith("RH-") or roi.startswith("RH_"):
                    base_roi_to_lr.setdefault(base_roi, {})["right"] = roi
            elif tissue_type == "wm":
                if roi.endswith("_L"):
                    base_roi_to_lr.setdefault(base_roi, {})["left"] = roi
                elif roi.endswith("_R"):
                    base_roi_to_lr.setdefault(base_roi, {})["right"] = roi

        # Only keep base_rois with both left and right
        base_roi_to_lr = {k: v for k, v in base_roi_to_lr.items() if "left" in v and "right" in v}

        for base_roi, lr_dict in base_roi_to_lr.items():
            left_roi = lr_dict["left"]
            right_roi = lr_dict["right"]

            all_asyms = []
            if ipsi is None:
                # For each subject, infer sidedness and compute asymmetry accordingly
                for sub in all_subs:
                    lat = clinical_lateralization.get(sub, None)
                    if lat not in ["left", "right"]:
                        continue  # skip if not left/right
                    for stat in stats:
                        for measure in measures:
                            if lat == "left":
                                ipsi_roi, contra_roi = left_roi, right_roi
                            elif lat == "right":
                                ipsi_roi, contra_roi = right_roi, left_roi
                            else:
                                continue
                            # Get ipsilateral z-score
                            ipsi_z = None
                            if (sub in subj_zscores and
                                tissue_type in subj_zscores[sub] and
                                ipsi_roi in subj_zscores[sub][tissue_type] and
                                stat in subj_zscores[sub][tissue_type][ipsi_roi] and
                                measure in subj_zscores[sub][tissue_type][ipsi_roi][stat]):
                                ipsi_z = subj_zscores[sub][tissue_type][ipsi_roi][stat][measure]

                            # Get contralateral z-score
                            contra_z = None
                            if (sub in subj_zscores and
                                tissue_type in subj_zscores[sub] and
                                contra_roi in subj_zscores[sub][tissue_type] and
                                stat in subj_zscores[sub][tissue_type][contra_roi] and
                                measure in subj_zscores[sub][tissue_type][contra_roi][stat]):
                                contra_z = subj_zscores[sub][tissue_type][contra_roi][stat][measure]

                            # Compute asymmetry if both values exist
                            if ipsi_z is not None and contra_z is not None:
                                asym = ipsi_z - contra_z
                                if z_type == "abs":
                                    asym = np.abs(asym)
                                all_asyms.append(asym)
            else:
                # Use the original logic for fixed ipsi
                def get_ipsi_contra_rois(ipsi, tissue_type, left_roi, right_roi):
                    if ipsi is None:
                        return left_roi, right_roi  # default: left is ipsi, right is contra
                    if tissue_type == "gm":
                        if ipsi == "left":
                            return left_roi, right_roi
                        elif ipsi == "right":
                            return right_roi, left_roi
                    elif tissue_type == "wm":
                        if ipsi == "left":
                            return left_roi, right_roi
                        elif ipsi == "right":
                            return right_roi, left_roi
                    return left_roi, right_roi

                for sub in all_subs:
                    for stat in stats:
                        for measure in measures:
                            ipsi_roi, contra_roi = get_ipsi_contra_rois(ipsi, tissue_type, left_roi, right_roi)
                            # Get ipsilateral z-score
                            ipsi_z = None
                            if (sub in subj_zscores and
                                tissue_type in subj_zscores[sub] and
                                ipsi_roi in subj_zscores[sub][tissue_type] and
                                stat in subj_zscores[sub][tissue_type][ipsi_roi] and
                                measure in subj_zscores[sub][tissue_type][ipsi_roi][stat]):
                                ipsi_z = subj_zscores[sub][tissue_type][ipsi_roi][stat][measure]

                            # Get contralateral z-score
                            contra_z = None
                            if (sub in subj_zscores and
                                tissue_type in subj_zscores[sub] and
                                contra_roi in subj_zscores[sub][tissue_type] and
                                stat in subj_zscores[sub][tissue_type][contra_roi] and
                                measure in subj_zscores[sub][tissue_type][contra_roi][stat]):
                                contra_z = subj_zscores[sub][tissue_type][contra_roi][stat][measure]

                            # Compute asymmetry if both values exist
                            if ipsi_z is not None and contra_z is not None:
                                asym = ipsi_z - contra_z
                                if z_type == "abs":
                                    asym = np.abs(asym)
                                all_asyms.append(asym)

            if all_asyms:
                all_asyms = np.array(all_asyms)
                summary_asym = np.sum(all_asyms) if method == "sum" else np.mean(all_asyms)
                # Only store the bilateral ROI label (base_roi) and the summary_asym
                base_roi_asym_scores[(tissue_type, base_roi)] = {
                    "tissue": tissue_type,
                    "base_roi": base_roi,
                    "summary_asym": summary_asym,
                    "n_asyms": len(all_asyms)
                }

    roi_asym_scores_df = pd.DataFrame(list(base_roi_asym_scores.values()))
    if roi_asym_scores_df.empty:
        print("No asymmetry scores found for the given ROIs and measures.")
        return roi_asym_scores_df

    if networks:
        roi_asym_scores_df["network"] = roi_asym_scores_df.apply(get_network_label, axis=1)
        net_scores = (
            roi_asym_scores_df.groupby(["tissue", "network"])
            .agg({"summary_asym": "mean", "n_asyms": "sum"})
            .reset_index()
            .sort_values("summary_asym", ascending=False)
        )
        top_df = net_scores.head(n_plot)
        color_map = {"wm": "#1f77b4", "gm": "#ff7f0e"}
        colors = [color_map.get(t, "gray") for t in top_df["tissue"]]
        plt.figure(figsize=(12, 6))
        plt.bar(top_df["network"], top_df["summary_asym"], color=colors)
        plt.xlabel("Network")
        plt.ylabel(f"Mean summarized asymmetry ({method}, {z_type})")
        # Compose title
        plot_title = f"Top {n_plot} Networks by mean summarized asymmetry ({method}, {z_type})"
        if title is not None:
            plot_title = f"{title}\n{plot_title}"
        plt.title(plot_title)
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.show()
        return top_df

    # Not networks: plot individual base ROIs (bilateral only)
    roi_asym_scores_df = roi_asym_scores_df.sort_values("summary_asym", ascending=False)
    top_rois = roi_asym_scores_df.head(n_plot).copy()

    plt.figure(figsize=(12, 6))
    bars = plt.bar(
        range(len(top_rois)),
        top_rois["summary_asym"],
        color="blue",
        edgecolor="none"
    )

    # Add outline to GM bars
    for bar, tissue in zip(bars, top_rois["tissue"]):
        if tissue == "gm":
            bar.set_edgecolor("black")
            bar.set_linewidth(3)
        else:
            bar.set_edgecolor("none")

    # Set x-tick labels using base ROI names (bilateral only)
    display_labels = top_rois["base_roi"].tolist()
    plt.xlabel("ROI")
    plt.ylabel(f"{method.capitalize()} {z_type} asymmetry (ipsi - contra)")
    if title is not None:
        plot_title = f"{title}"
    plt.title(plot_title)
    plt.xticks(range(len(top_rois)), display_labels, rotation=45, ha="right")

    # Custom legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color="blue", lw=6, label="Base ROI"),
        Line2D([0], [0], color="black", lw=2, label="GM region")
    ]
    plt.legend(handles=legend_elements, loc="best")

    plt.tight_layout()
    plt.show()
    return roi_asym_scores_df


def get_base_roi_name(roi, tissue_type):
    """
    Get the base ROI name without hemisphere information.
    """
    if tissue_type == "gm":
        # Remove "LH-" or "RH-" prefix
        if roi.startswith("LH-") or roi.startswith("RH-") or roi.startswith("LH_") or roi.startswith("RH_"):
            return roi[3:]
        return roi
    elif tissue_type == "wm":
        # Remove "_L" or "_R" suffix
        if roi.endswith("_L") or roi.endswith("_R"):
            return roi[:-2]
        return roi
    return roi


def get_contralateral_roi(roi, tissue_type):
    """
    Get the contralateral ROI name.
    """
    if tissue_type == "gm":
        # Swap LH- and RH- prefixes
        if roi.startswith("LH-") or roi.startswith("LH_"):
            return roi.replace("LH-", "RH-").replace("LH_", "RH_")
        elif roi.startswith("RH-") or roi.startswith("RH_"):
            return roi.replace("RH-", "LH-").replace("RH_", "LH_")
        return None  # Can't determine contralateral for bilateral ROIs
    elif tissue_type == "wm":
        # Swap _L and _R suffixes
        if roi.endswith("_L"):
            return roi[:-2] + "_R"
        elif roi.endswith("_R"):
            return roi[:-2] + "_L"
        return None  # Can't determine contralateral for bilateral ROIs
    return None

### Temporal seizures (n=60)

In [None]:
# subs_seizure_temporal = subs_seizure_temporal_L + subs_seizure_temporal_R
# roi_dict, stats, measures = get_input_specs(input_data_type="all")
# temporal_seizures_abs_z_means = plot_roi_zscores(roi_dict, stats, measures, subs=subs_seizure_temporal, title="Temporal seizures during EMU stay (n=60)")
# print(temporal_seizures_abs_z_means.head(10))


### Left temporal seizures (n=36)

In [None]:
# roi_dict, stats, measures = get_input_specs(input_data_type="gm")
# left_temporal_seizures_abs_z_means_gm_asym = plot_roi_zscores_asym(roi_dict, stats, measures, subs=subs_seizure_temporal_L, ipsi="left", title="Left temporal seizures only during EMU stay (n=36)")
# print(left_temporal_seizures_abs_z_means_gm_asym.head(10))
# # left_temporal_seizures_abs_z_means_gm_asym.to_csv(ospj(gam_outputs_group_summaries_dir, "left_temporal_seizures_abs_z_means_gm_asym.csv"), index=False)
# # print(left_temporal_seizures_abs_z_means_gm_asym)

# roi_dict, stats, measures = get_input_specs(input_data_type="wm")
# left_temporal_seizures_abs_z_means_wm_asym = plot_roi_zscores_asym(roi_dict, stats, measures, subs=subs_seizure_temporal_L, ipsi="left", title="Left temporal seizures only during EMU stay (n=36)")
# print(left_temporal_seizures_abs_z_means_wm_asym.head(10))
# # left_temporal_seizures_abs_z_means_wm_asym.to_csv(ospj(gam_outputs_group_summaries_dir, "left_temporal_seizures_abs_z_means_wm_asym.csv"), index=False)
# # print(left_temporal_seizures_abs_z_means_wm_asym)

### Right temporal seizures (n=24)

In [None]:
# roi_dict, stats, measures = get_input_specs(input_data_type="gm")
# right_temporal_seizures_abs_z_means_gm_asym = plot_roi_zscores_asym(roi_dict, stats, measures, subs=subs_seizure_temporal_R, ipsi="right", title="Right temporal seizures only during EMU stay (n=24)")
# print(right_temporal_seizures_abs_z_means_gm_asym.head(10))
# # right_temporal_seizures_abs_z_means_gm_asym.to_csv(ospj(gam_outputs_group_summaries_dir, "right_temporal_seizures_abs_z_means_gm_asym.csv"), index=False)

# roi_dict, stats, measures = get_input_specs(input_data_type="wm")
# right_temporal_seizures_abs_z_means_wm_asym = plot_roi_zscores_asym(roi_dict, stats, measures, subs=subs_seizure_temporal_R, ipsi="right", title="Right temporal seizures only during EMU stay (n=24)")
# print(right_temporal_seizures_abs_z_means_wm_asym.head(10))
# right_temporal_seizures_abs_z_means_wm_asym.to_csv(ospj(gam_outputs_group_summaries_dir, "right_temporal_seizures_abs_z_means_wm_asym.csv"), index=False)

### Neuromodulation (n=13)

In [None]:
# roi_dict, stats, measures = get_input_specs(input_data_type="gm")
# neuromodulation_abs_z_means_gm_asym = plot_roi_zscores_asym(roi_dict, stats, measures, subs=subs_neuromodulation, title="Neuromodulation (n=13)")
# neuromodulation_abs_z_means_gm_asym.to_csv(ospj(gam_outputs_group_summaries_dir, "neuromodulation_abs_z_means_gm_asym.csv"), index=False)

# roi_dict, stats, measures = get_input_specs(input_data_type="wm")
# neuromodulation_abs_z_means_wm_asym = plot_roi_zscores_asym(roi_dict, stats, measures, subs=subs_neuromodulation, title="Neuromodulation (n=13)")
# neuromodulation_abs_z_means_wm_asym.to_csv(ospj(gam_outputs_group_summaries_dir, "neuromodulation_abs_z_means_wm_asym.csv"), index=False)


# Ranking parameters by model deviance