In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
import sys, os

# Get absolute path to the project root
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(PROJECT_ROOT)

sys.path.append("/cluster/home/herminea/mental_health_project/workspace/utils")
sys.path.append("/cluster/home/herminea/mental_health_project/workspace/scripts")

from collections import defaultdict
from utils.io_results import load_results, group_results_by_subject
from utils.plot.plot_decomp import plot_example_subjects

# --- Load combined FCs ---
RES_DIR = "/cluster/home/herminea/mental_health_project/workspace/results/fmri_prep/vlmd/fc"

In [None]:
from utils.plot.plot_base import set_mpl_style
set_mpl_style()


In [4]:
vlmd_results = load_results(RES_DIR)
vlmd_subjects_dict = group_results_by_subject(vlmd_results)

[Load] Loaded 165 FC result files from /cluster/home/herminea/mental_health_project/test/results/fmri_prep/vlmd/fc


In [5]:
freq_bands = {
    #"Slow-6": (0.005, 0.01),
    "Slow-5": (0.01, 0.027),
    "Slow-4": (0.027, 0.073),
    "Slow-3": (0.073, 0.198),
    "Slow-2": (0.198, 0.25),
}

In [8]:
vlmd_groups = { subj: entries[0]["group"] 
               for subj, entries in vlmd_subjects_dict.items() 
            }

In [None]:
import pandas as pd
import numpy as np

# --- Load mapping file ---
atlas_df = pd.read_csv("/cluster/home/herminea/mental_health_project/workspace/atlas/roi_to_net_434.csv")

# --- Extract columns cleanly ---
atlas_df["NetworkID"] = pd.to_numeric(atlas_df["NetworkID"], errors="coerce").astype(int)
atlas_df["NetworkName"] = atlas_df["NetworkName"].astype(str)


# ROI → network mapping (numeric, 0–20)
roi_to_net = atlas_df["NetworkID"].values
n_networks = len(np.unique(roi_to_net))

net_names_21 = (
    atlas_df.groupby("NetworkID")["NetworkName"]
    .first()
    .to_list()
)


In [None]:
def compute_network_marginal_spectrum(inst_amp, inst_freq, roi_to_net,
                                      net_idx, fs, fmax=0.25, nbins=200):
    """
    inst_amp, inst_freq: arrays (K, T, R)
    roi_to_net: array (R,), mapping ROIs to network IDs (0..n_networks-1)
    net_idx: which network to compute the spectrum for
    Returns:
        freqs: (nbins,)
        marginal: (nbins,) marginal Hilbert spectrum for that network
    """
    K, T, R = inst_amp.shape
    freq_bins = np.linspace(0, fmax, nbins + 1)

    # ROIs belonging to this network
    mask_roi = (roi_to_net == net_idx)
    if not np.any(mask_roi):
        return None, None

    # Average amplitude and frequency across ROIs of this network
    amp_net = inst_amp[:, :, mask_roi].mean(axis=2)   # (K, T)
    frq_net = inst_freq[:, :, mask_roi].mean(axis=2)  # (K, T)

    # Bin by frequency
    inds = np.digitize(frq_net, freq_bins) - 1
    inds = inds.clip(0, len(freq_bins) - 2)

    H = np.zeros(len(freq_bins) - 1)

    # Accumulate amplitude into frequency bins for all IMFs and time points
    for k in range(K):
        np.add.at(H, inds[k, :], amp_net[k, :])

    # Average over time (discrete approximation of integral over t)
    H /= T

    # Frequency bin centers
    freqs = 0.5 * (freq_bins[:-1] + freq_bins[1:])
    return freqs, H


In [11]:
from utils.decomp import compute_hht

fs = 1 / 0.8  # TR = 0.8 s
fmax = 0.25
nbins = 200

# For each network: group -> list of spectra (one per subject)
net_marginals = {
    net_idx: {"MDD": [], "HC": []}
    for net_idx in range(n_networks)
}

# We’ll also store freqs once (same for all subjects if params are fixed)
freqs_ref = None

for subj, entries in vlmd_subjects_dict.items():
    e = entries[0]
    imfs = e["imfs"]  # (K, T, R)
    group = vlmd_groups[subj]  # "MDD" or "HC"

    inst_amp, inst_freq = compute_hht(imfs, fs=fs, smooth_sigma=1)

    for net_idx in range(n_networks):
        freqs, H = compute_network_marginal_spectrum(
            inst_amp, inst_freq, roi_to_net, net_idx, fs=fs, fmax=fmax, nbins=nbins
        )
        if freqs is None:
            continue

        if freqs_ref is None:
            freqs_ref = freqs  # store first non-empty

        net_marginals[net_idx][group].append(H)


In [22]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(context="paper", style="white", font_scale=1.0)
set_mpl_style()

def add_freq_bands(ax, freq_bands, alpha=0.15):
    colors = ["#a6cee3", "#b2df8a", "#fb9a99", "#fdbf6f"]  # Slow-5..Slow-2
    for (band, (fmin, fmax)), c in zip(freq_bands.items(), colors):
        ax.axvspan(fmin, fmax, color=c, alpha=alpha, zorder=0)

def add_band_labels(ax, freq_bands, fontsize=8):
    # place labels at the top of the axes, independent of data y-limits
    for band, (fmin, fmax) in freq_bands.items():
        x = 0.5 * (fmin + fmax)
        ax.text(
            x, 0.98, band,
            transform=ax.get_xaxis_transform(),  # x in data, y in axes fraction
            ha="center", va="top",
            fontsize=fontsize, rotation=90
        )

net_list = ["LimbicA", "LimbicB", "MTL", "Thalamus", "Cerebellum", "Striatum", "SomMotB", "DefaultA"]

# Overleaf-friendly: use double-column-ish width
FIGSIZE = (6.8, 3.2)


for net_name in net_list:
    net_idx = net_names_21.index(net_name)

    mdd_arr = np.vstack(net_marginals[net_idx]["MDD"])
    hc_arr  = np.vstack(net_marginals[net_idx]["HC"])

    mdd_mean = mdd_arr.mean(axis=0)
    hc_mean  = hc_arr.mean(axis=0)
    mdd_se = mdd_arr.std(axis=0, ddof=1) / np.sqrt(mdd_arr.shape[0])
    hc_se  = hc_arr.std(axis=0, ddof=1) / np.sqrt(hc_arr.shape[0])

    fig, ax = plt.subplots(figsize=FIGSIZE)

    # frequency band shading FIRST (so lines are on top)
    add_freq_bands(ax, freq_bands, alpha=0.20)
    ax.axvspan(0.01, 0.25, color="lightgrey", alpha=0.03, zorder=0)

    # curves
    ax.plot(freqs_ref, mdd_mean, label="MDD", color="red")
    ax.fill_between(freqs_ref, mdd_mean - mdd_se, mdd_mean + mdd_se, color="red", alpha=0.2)

    ax.plot(freqs_ref, hc_mean, label="HC", color="blue")
    ax.fill_between(freqs_ref, hc_mean - hc_se, hc_mean + hc_se, color="blue", alpha=0.2)

    # band labels (robust placement)
    add_band_labels(ax, freq_bands, fontsize=8)

    ax.set_xlim(0.0, 0.25)
    ax.set_ylim(0.0, 4)
    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel("Marginal amplitude")
    ax.set_title(f"Marginal Hilbert Spectrum – {net_name}")
    ax.legend(frameon=False)

    sns.despine(ax=ax)
    fig.tight_layout()

    fig.savefig(f"marginal_hilbert_{net_name}.pdf", bbox_inches="tight")
    plt.close(fig)
