<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/TMS_fMRI_ANN_Simulate_Sessions_dataset_simulated_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simulate TMS-fMRI Sessions with Population ANN

Generate synthetic TMS-fMRI dataset using the population ANN model trained on task-rest data.

**Workflow:**
1. Load empirical dataset and population model
2. Generate synthetic rest + stim sessions for all subjects
3. Validate with subject-specific ΔFC analysis (empirical vs simulated)

In [None]:
# =========================
# SETUP
# =========================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os, sys, pickle, json, math
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

# Clone repo + add to path
REPO_DIR = "/content/BrainStim_ANN_fMRI_HCP"
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
else:
    print("Repo already exists ✅")

sys.path.append(REPO_DIR)
from src.NPI import build_model, device
print(f"PyTorch device: {device}")

In [None]:
# =========================
# PATHS
# =========================
BASE = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"

DATASET_EMP_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")
PREPROC_ROOT = os.path.join(BASE, "preprocessed_subjects_tms_fmri")
MODEL_DIR = os.path.join(PREPROC_ROOT, "trained_models_MLP_tms_fmri")

# Find population model
MODEL_PATH = os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pt")
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Population model not found: {MODEL_PATH}")

# Output directory
OUT_DIR = os.path.join(PREPROC_ROOT, "ANN_vs_tms_fmri")
os.makedirs(OUT_DIR, exist_ok=True)
OUT_PKL = os.path.join(OUT_DIR, "dataset_simulated_populationANN.pkl")
RESULTS_JSON = os.path.join(OUT_DIR, "deltafc_validation_results.json")

print(f"✓ Dataset: {DATASET_EMP_PKL}")
print(f"✓ Model: {MODEL_PATH}")
print(f"✓ Output: {OUT_PKL}")

In [None]:
# =========================
# CONFIG
# =========================
S = 3                          # Window length (steps)
N = 450                        # Number of ROIs
TR_MODEL = 2.0                 # Model TR (seconds)
BURN_IN = 10                   # Burn-in steps
NOISE_SIGMA = 0.3              # Input noise magnitude
STIM_AMP = 1.0                 # Stimulation amplitude
RHO_MM = 10.0                  # Gaussian spread (mm)
MAP_MODE = "round"             # Onset mapping mode

rng = np.random.default_rng(42)

# Load distance matrix + compute Gaussian kernel
DIST_PATH = os.path.join(BASE, "TMS_fMRI", "atlases", "distance_matrix_450x450_Tian50_Schaefer400.npy")
D = np.load(DIST_PATH)
W = np.exp(-(D ** 2) / (2.0 * (RHO_MM ** 2))).astype(np.float32)
W /= (W[np.arange(N), np.arange(N)][:, None] + 1e-8)  # Normalize so target = 1

print(f"Config: S={S}, N={N}, TR={TR_MODEL}s, noise={NOISE_SIGMA}, stim_amp={STIM_AMP}")
print(f"Distance matrix: {D.min():.1f}-{D.max():.1f} mm | RHO_MM={RHO_MM}")

In [None]:
# =========================
# HELPER FUNCTIONS
# =========================

def get_onset_column(df):
    """Find onset column in dataframe."""
    if df is None or len(df) == 0:
        return None
    for col in ["onset", "Onset", "stim_onset", "onset_s", "onset_sec", "time", "t", "seconds"]:
        if col in df.columns:
            return col
    for col in df.columns:
        if pd.api.types.is_numeric_dtype(df[col]):
            return col
    return None

def map_onsets_to_steps(onsets_s, tr_model=TR_MODEL, mode=MAP_MODE):
    """Map stimulus onsets (seconds) to model steps."""
    onsets_s = np.asarray(onsets_s, dtype=float)
    x = onsets_s / float(tr_model)
    if mode == "round":
        steps = np.rint(x).astype(int)
    elif mode == "floor":
        steps = np.floor(x).astype(int)
    elif mode == "ceil":
        steps = np.ceil(x).astype(int)
    else:
        raise ValueError("mode must be round|floor|ceil")
    steps = steps[steps >= 0]
    return np.unique(steps)

def safe_target_idx(target_vec):
    """Extract target region index from one-hot vector."""
    if target_vec is None:
        return None
    v = np.asarray(target_vec).astype(int).ravel()
    if v.size == 0 or v.sum() != 1:
        return None
    return int(np.argmax(v))

@torch.no_grad()
def predict_next(model, window_SxN):
    """Predict next state with input noise."""
    x_np = window_SxN.reshape(-1).astype(np.float32)
    noise = NOISE_SIGMA * rng.normal(0.0, 1.0, size=x_np.shape).astype(np.float32)
    x_np = x_np + noise
    x = torch.tensor(x_np[None, :], dtype=torch.float32, device=device)
    y = model(x)
    return y.detach().cpu().numpy().squeeze(0)

def simulate_run(model, init_window_SxN, n_steps, stim_steps=None, target_idx=None, W=None):
    """Simulate brain activity time series with optional stimulation."""
    init_window_SxN = np.asarray(init_window_SxN, dtype=np.float32)
    assert init_window_SxN.shape == (S, N)
    
    stim_steps = set(int(s) for s in (stim_steps or []))
    do_stim = (target_idx is not None) and (len(stim_steps) > 0)
    w = init_window_SxN.copy()
    
    # Burn-in
    for _ in range(BURN_IN):
        y = predict_next(model, w)
        w = np.vstack([w[1:], y[None, :]])
    
    # Simulate
    out = np.zeros((n_steps, N), dtype=np.float32)
    for t in range(n_steps):
        w_in = w.copy()
        if do_stim and (t in stim_steps):
            if W is None:
                w_in[-1, target_idx] += STIM_AMP
            else:
                w_in[-1, :] += STIM_AMP * W[target_idx, :]
        y = predict_next(model, w_in)
        out[t] = y
        w = np.vstack([w[1:], y[None, :]])
    
    meta_sim = {
        "tr_model_s": float(TR_MODEL),
        "burn_in_steps": int(BURN_IN),
        "noise_input_sigma": float(NOISE_SIGMA),
        "stim_amp": float(STIM_AMP),
        "stim_steps_modelTR": sorted(list(stim_steps)) if do_stim else [],
        "stim_mapping_mode": MAP_MODE,
    }
    return out, meta_sim

print("✓ Helper functions defined")

In [None]:
# =========================
# LOAD MODEL
# =========================
print(f"Loading model from {MODEL_PATH}...")
model = build_model("MLP", ROI_num=N, using_steps=S).to(device)

try:
    state = torch.load(MODEL_PATH, map_location=device, weights_only=True)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        raise RuntimeError("Unexpected format")
except Exception as e:
    print(f"weights_only=True failed, using weights_only=False: {e}")
    state = torch.load(MODEL_PATH, map_location=device, weights_only=False)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        model = state.to(device)

model.eval()
print("✓ Model loaded and ready")

In [None]:
# =========================
# LOAD EMPIRICAL DATASET
# =========================
print(f"Loading empirical dataset from {DATASET_EMP_PKL}...")
with open(DATASET_EMP_PKL, "rb") as f:
    dataset_emp = pickle.load(f)

print(f"✓ Loaded {len(dataset_emp)} subjects")

In [None]:
# =========================
# GENERATE SYNTHETIC DATASET
# =========================
print("Generating synthetic dataset...\n")

dataset_sim = {}
n_sim_rest = 0
n_sim_stim = 0

for sub_id, sub_data in dataset_emp.items():
    dataset_sim[sub_id] = {"task-rest": {}, "task-stim": {}}
    
    # ---- SIMULATE REST ----
    if "task-rest" in sub_data:
        for run_idx, run in sub_data["task-rest"].items():
            ts_emp = run.get("time series", None)
            md_emp = run.get("metadata", {}) or {}
            
            if ts_emp is None or not isinstance(ts_emp, np.ndarray) or ts_emp.shape[1] != N:
                continue
            
            tr_emp = float(md_emp.get("tr_s", 2.0))
            dur_s = ts_emp.shape[0] * tr_emp
            n_steps = int(math.ceil(dur_s / TR_MODEL))
            
            init_window = ts_emp[:S].copy()
            sim_ts, meta_sim = simulate_run(model, init_window, n_steps)
            
            md_out = dict(md_emp)
            md_out.update({
                "simulated": True,
                "duration_emp_s": float(dur_s),
                "n_steps_model": int(n_steps),
                **meta_sim
            })
            
            dataset_sim[sub_id]["task-rest"][int(run_idx)] = {
                "time series": sim_ts,
                "metadata": md_out
            }
            n_sim_rest += 1
    
    # ---- SIMULATE STIM ----
    if "task-stim" in sub_data:
        for run_idx, run in sub_data["task-stim"].items():
            ts_emp = run.get("time series", None)
            md_emp = run.get("metadata", {}) or {}
            target_vec = run.get("target", None)
            events_df = run.get("stim time", None)
            
            if ts_emp is None or not isinstance(ts_emp, np.ndarray) or ts_emp.shape[1] != N:
                continue
            
            target_idx = safe_target_idx(target_vec)
            if target_idx is None:
                continue
            
            onset_col = get_onset_column(events_df) if isinstance(events_df, pd.DataFrame) else None
            if onset_col is None:
                continue
            
            onsets_s = events_df[onset_col].astype(float).values
            stim_steps = list(map_onsets_to_steps(onsets_s))
            
            tr_emp = float(md_emp.get("tr_s", 2.4))
            dur_s = ts_emp.shape[0] * tr_emp
            n_steps = int(math.ceil(dur_s / TR_MODEL))
            
            init_window = ts_emp[:S].copy()
            sim_ts, meta_sim = simulate_run(model, init_window, n_steps,
                                            stim_steps=stim_steps, target_idx=target_idx, W=W)
            
            md_out = dict(md_emp)
            md_out.update({
                "simulated": True,
                "duration_emp_s": float(dur_s),
                "n_steps_model": int(n_steps),
                "target_idx": int(target_idx),
                **meta_sim
            })
            
            dataset_sim[sub_id]["task-stim"][int(run_idx)] = {
                "time series": sim_ts,
                "metadata": md_out,
                "target": target_vec,
                "stim time": events_df,
            }
            n_sim_stim += 1

print(f"✓ Generated {n_sim_rest} rest runs, {n_sim_stim} stim runs\n")

In [None]:
# =========================
# SAVE SYNTHETIC DATASET
# =========================
print(f"Saving synthetic dataset to {OUT_PKL}...")
with open(OUT_PKL, "wb") as f:
    pickle.dump(dataset_sim, f, protocol=pickle.HIGHEST_PROTOCOL)

print("✓ Saved")

# Validation: Subject-Specific ΔFC Analysis

In [None]:
# =========================
# VALIDATION FUNCTIONS
# =========================

def fc_from_timeseries(ts, cortical_only=True):
    """Compute Pearson FC, optionally using cortical ROIs only."""
    if cortical_only:
        ts = ts[:, 50:]  # Skip Tian 50, keep Schaefer 400
    return np.corrcoef(ts, rowvar=False).astype(np.float32)

def upper_tri_vec(mat, k=1):
    """Extract upper triangle as 1D vector."""
    iu = np.triu_indices(mat.shape[0], k=k)
    return mat[iu]

print("✓ Validation functions defined")

In [None]:
# =========================
# COMPUTE PER-SUBJECT ΔFC
# =========================
print("Computing per-subject ΔFC correlations...\n")

results = {
    'subject_deltafc_corr': {},
    'subject_info': {},
}

for sub_id in sorted(dataset_emp.keys()):
    # Check if we have both empirical and simulated data
    if sub_id not in dataset_sim:
        continue
    
    sub_emp = dataset_emp[sub_id]
    sub_sim = dataset_sim[sub_id]
    
    rest_runs_emp = sub_emp.get("task-rest", {})
    stim_runs_emp = sub_emp.get("task-stim", {})
    rest_runs_sim = sub_sim.get("task-rest", {})
    stim_runs_sim = sub_sim.get("task-stim", {})
    
    if len(rest_runs_emp) == 0 or len(stim_runs_emp) == 0:
        continue
    
    # Empirical rest FC
    fc_rest_emp_list = []
    for run in rest_runs_emp.values():
        ts = run.get("time series", None)
        if isinstance(ts, np.ndarray) and ts.shape[1] >= 450:
            fc_rest_emp_list.append(fc_from_timeseries(ts, cortical_only=True))
    
    # Simulated rest FC
    fc_rest_sim_list = []
    for run in rest_runs_sim.values():
        ts = run.get("time series", None)
        if isinstance(ts, np.ndarray) and ts.shape[1] >= 450:
            fc_rest_sim_list.append(fc_from_timeseries(ts, cortical_only=True))
    
    if len(fc_rest_emp_list) == 0 or len(fc_rest_sim_list) == 0:
        continue
    
    FC_rest_emp = np.mean(np.stack(fc_rest_emp_list), axis=0)
    FC_rest_sim = np.mean(np.stack(fc_rest_sim_list), axis=0)
    
    # Empirical stim FC
    fc_stim_emp_list = []
    for run in stim_runs_emp.values():
        ts = run.get("time series", None)
        if isinstance(ts, np.ndarray) and ts.shape[1] >= 450:
            fc_stim_emp_list.append(fc_from_timeseries(ts, cortical_only=True))
    
    # Simulated stim FC
    fc_stim_sim_list = []
    for run in stim_runs_sim.values():
        ts = run.get("time series", None)
        if isinstance(ts, np.ndarray) and ts.shape[1] >= 450:
            fc_stim_sim_list.append(fc_from_timeseries(ts, cortical_only=True))
    
    if len(fc_stim_emp_list) == 0 or len(fc_stim_sim_list) == 0:
        continue
    
    FC_stim_emp = np.mean(np.stack(fc_stim_emp_list), axis=0)
    FC_stim_sim = np.mean(np.stack(fc_stim_sim_list), axis=0)
    
    # Compute ΔFC
    deltaFC_emp = FC_stim_emp - FC_rest_emp
    deltaFC_sim = FC_stim_sim - FC_rest_sim
    
    # Correlate upper triangles
    vec_emp = upper_tri_vec(deltaFC_emp, k=1)
    vec_sim = upper_tri_vec(deltaFC_sim, k=1)
    
    r = pearsonr(vec_emp, vec_sim)[0]
    
    results['subject_deltafc_corr'][sub_id] = r
    results['subject_info'][sub_id] = {
        'n_rest_runs': len(rest_runs_emp),
        'n_stim_runs': len(stim_runs_emp),
        'deltafc_emp_magnitude': float(np.abs(deltaFC_emp).mean()),
        'deltafc_sim_magnitude': float(np.abs(deltaFC_sim).mean()),
    }
    
    print(f"{sub_id}: r_ΔFC = {r:.4f} | emp_mag={np.abs(deltaFC_emp).mean():.4f} | sim_mag={np.abs(deltaFC_sim).mean():.4f}")

print(f"\n✓ Computed correlations for {len(results['subject_deltafc_corr'])} subjects")

In [None]:
# =========================
# SUMMARY STATISTICS
# =========================
corrs = np.array(list(results['subject_deltafc_corr'].values()))
corrs_valid = corrs[np.isfinite(corrs)]

print("\n" + "="*70)
print("SUBJECT-SPECIFIC ΔFC VALIDATION RESULTS")
print("="*70)
print(f"\nN subjects: {len(corrs_valid)}")
print(f"Mean r(ΔFC):   {corrs_valid.mean():.4f}")
print(f"Median r(ΔFC): {np.median(corrs_valid):.4f}")
print(f"Std r(ΔFC):    {corrs_valid.std():.4f}")
print(f"Min r(ΔFC):    {corrs_valid.min():.4f}")
print(f"Max r(ΔFC):    {corrs_valid.max():.4f}")
print(f"\nCorrelations by subject:")
for sub_id, r in sorted(results['subject_deltafc_corr'].items()):
    info = results['subject_info'][sub_id]
    print(f"  {sub_id}: r={r:.4f} | emp_ΔFC_mag={info['deltafc_emp_magnitude']:.4f} | sim_ΔFC_mag={info['deltafc_sim_magnitude']:.4f}")

In [None]:
# =========================
# HISTOGRAM: Per-Subject ΔFC Correlations
# =========================
fig, ax = plt.subplots(figsize=(10, 6))

ax.hist(corrs_valid, bins=15, color='steelblue', edgecolor='black', alpha=0.7, linewidth=1.5)
ax.axvline(corrs_valid.mean(), color='red', linestyle='--', linewidth=2.5, label=f'Mean = {corrs_valid.mean():.3f}')
ax.axvline(np.median(corrs_valid), color='green', linestyle='--', linewidth=2.5, label=f'Median = {np.median(corrs_valid):.3f}')

ax.set_xlabel('Per-Subject ΔFC Correlation (r)', fontsize=13)
ax.set_ylabel('Number of Subjects', fontsize=13)
ax.set_title('Subject-Specific ΔFC Validation: Empirical vs. Simulated', fontsize=14, fontweight='bold')
ax.legend(fontsize=12, loc='upper right')
ax.grid(axis='y', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()

print(f"Histogram saved")

In [None]:
# =========================
# SAVE VALIDATION RESULTS
# =========================
summary = {
    'n_subjects': len(corrs_valid),
    'mean_r_deltafc': float(corrs_valid.mean()),
    'median_r_deltafc': float(np.median(corrs_valid)),
    'std_r_deltafc': float(corrs_valid.std()),
    'min_r_deltafc': float(corrs_valid.min()),
    'max_r_deltafc': float(corrs_valid.max()),
    'per_subject_correlations': {k: float(v) for k, v in results['subject_deltafc_corr'].items()},
    'per_subject_info': results['subject_info'],
}

with open(RESULTS_JSON, "w") as f:
    json.dump(summary, f, indent=2)

print(f"✓ Saved results to {RESULTS_JSON}")
print(f"\n✅ ANALYSIS COMPLETE")