<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/Simulate_TMS_fMRI_ANN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simulate TMS-fMRI Sessions with Population ANN

Generate synthetic TMS-fMRI dataset using the population ANN model trained on task-rest data.

## Workflow
1. Load empirical dataset and trained population model
2. For each subject, simulate REST and STIM runs using autoregressive prediction
3. Inject TMS stimulation with spatial spread via Gaussian kernel
4. Save synthetic dataset with same structure as empirical data

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os, sys, pickle, json, math
import numpy as np
import pandas as pd
import torch
from scipy.stats import pearsonr

# Clone repo + add to path
REPO_DIR = "/content/BrainStim_ANN_fMRI_HCP"
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
else:
    print("Repo already exists ✅")

sys.path.append(REPO_DIR)
from src.NPI import build_model, device
print(f"PyTorch device: {device}")

## Define Paths

In [None]:
BASE = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"

DATASET_EMP_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")
PREPROC_ROOT = os.path.join(BASE, "preprocessed_subjects_tms_fmri")
MODEL_DIR = os.path.join(PREPROC_ROOT, "trained_models_MLP_tms_fmri")

# Find population model
MODEL_PATH = os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pt")
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Population model not found: {MODEL_PATH}")

# Output directory
OUT_DIR = os.path.join(PREPROC_ROOT, "ANN_vs_tms_fmri")
os.makedirs(OUT_DIR, exist_ok=True)
OUT_PKL = os.path.join(OUT_DIR, "dataset_simulated_populationANN.pkl")

print(f"✓ Dataset: {DATASET_EMP_PKL}")
print(f"✓ Model: {MODEL_PATH}")
print(f"✓ Output: {OUT_PKL}")

## Configuration

In [None]:
# Simulation parameters
S = 3                          # Window length (steps for input)
N = 450                        # Number of ROIs (Tian 50 + Schaefer 400)
TR_MODEL = 2.0                 # Model TR (seconds)
BURN_IN = 10                   # Burn-in steps to stabilize
NOISE_SIGMA = 0.28             # Input noise magnitude at each step
STIM_AMP = 10.0                # Stimulation amplitude
STIM_DURATION_S = TR_MODEL     # TMS pulse duration (seconds)
RHO_MM = 10.0                  # Gaussian spread of TMS (mm)
MAP_MODE = "round"             # Onset mapping mode (round|floor|ceil)

rng = np.random.default_rng(42)

# Load distance matrix + compute Gaussian kernel for spatial TMS spread
DIST_PATH = os.path.join(BASE, "TMS_fMRI", "atlases", "distance_matrix_450x450_Tian50_Schaefer400.npy")
D = np.load(DIST_PATH)
W = np.exp(-(D ** 2) / (2.0 * (RHO_MM ** 2))).astype(np.float32)
W /= (W[np.arange(N), np.arange(N)][:, None] + 1e-8)  # Normalize so target = 1

print(f"Config: S={S}, N={N}, TR={TR_MODEL}s, noise_sigma={NOISE_SIGMA}, stim_amp={STIM_AMP}")
print(f"Distance matrix: {D.min():.1f}-{D.max():.1f} mm | RHO_MM={RHO_MM}")
print(f"Gaussian kernel W: shape={W.shape}, range=[{W.min():.4f}, {W.max():.4f}]")

## Helper Functions

In [None]:
def get_onset_column(df):
    """Find onset column in dataframe."""
    if df is None or len(df) == 0:
        return None
    for col in ["onset", "Onset", "stim_onset", "onset_s", "onset_sec", "time", "t", "seconds"]:
        if col in df.columns:
            return col
    for col in df.columns:
        if pd.api.types.is_numeric_dtype(df[col]):
            return col
    return None

def map_onsets_to_steps(onsets_s, tr_model=TR_MODEL, mode=MAP_MODE):
    """Map stimulus onsets (seconds) to model steps."""
    onsets_s = np.asarray(onsets_s, dtype=float)
    x = onsets_s / float(tr_model)
    if mode == "round":
        steps = np.rint(x).astype(int)
    elif mode == "floor":
        steps = np.floor(x).astype(int)
    elif mode == "ceil":
        steps = np.ceil(x).astype(int)
    else:
        raise ValueError("mode must be round|floor|ceil")
    steps = steps[steps >= 0]
    return np.unique(steps)

def map_onsets_to_steps_with_duration(onsets_s, duration_s=STIM_DURATION_S, tr_model=TR_MODEL, mode=MAP_MODE):
    """Map stimulus onsets and duration (seconds) to model steps.
    
    Creates steps from onset to onset+duration for each stimulus.
    """
    onsets_s = np.asarray(onsets_s, dtype=float)
    stim_steps = set()

    for onset in onsets_s:
        onset_step = onset / float(tr_model)
        offset_step = (onset + duration_s) / float(tr_model)

        if mode == "round":
            steps = np.arange(np.rint(onset_step), np.rint(offset_step)).astype(int)
        elif mode == "floor":
            steps = np.arange(np.floor(onset_step), np.floor(offset_step)).astype(int)
        elif mode == "ceil":
            steps = np.arange(np.ceil(onset_step), np.ceil(offset_step)).astype(int)
        else:
            raise ValueError("mode must be round|floor|ceil")

        stim_steps.update(steps[steps >= 0])

    return np.unique(sorted(stim_steps))

def safe_target_idx(target_vec):
    """Extract target region index from one-hot vector."""
    if target_vec is None:
        return None
    v = np.asarray(target_vec).astype(int).ravel()
    if v.size == 0 or v.sum() != 1:
        return None
    return int(np.argmax(v))

@torch.no_grad()
def predict_next(model, window_SxN):
    """Predict next state with input noise."""
    x_np = window_SxN.reshape(-1).astype(np.float32)
    noise = NOISE_SIGMA * rng.normal(0.0, 1.0, size=x_np.shape).astype(np.float32)
    x_np = x_np + noise
    x = torch.tensor(x_np[None, :], dtype=torch.float32, device=device)
    y = model(x)
    return y.detach().cpu().numpy().squeeze(0)

def simulate_run(model, init_window_SxN, n_steps, stim_steps=None, target_idx=None, W=None):
    """Simulate brain activity time series with optional TMS stimulation.
    
    Args:
        model: trained ANN model
        init_window_SxN: (S, N) initial state
        n_steps: number of simulation steps
        stim_steps: set of steps to apply stimulation
        target_idx: target region for stimulation
        W: (N, N) spatial Gaussian kernel for TMS spread
    
    Returns:
        sim_ts: (n_steps, N) simulated time series
        meta_sim: dict of simulation metadata
    """
    init_window_SxN = np.asarray(init_window_SxN, dtype=np.float32)
    assert init_window_SxN.shape == (S, N)

    stim_steps = set(int(s) for s in (stim_steps or []))
    do_stim = (target_idx is not None) and (len(stim_steps) > 0)
    w = init_window_SxN.copy()

    # Burn-in: stabilize initial state
    for _ in range(BURN_IN):
        y = predict_next(model, w)
        w = np.vstack([w[1:], y[None, :]])

    # Simulation loop with optional stimulation
    out = np.zeros((n_steps, N), dtype=np.float32)
    for t in range(n_steps):
        w_in = w.copy()
        if do_stim and (t in stim_steps):
            if W is None:
                # Direct stimulation: single target
                w_in[-1, target_idx] += STIM_AMP
            else:
                # Spatial spread: apply Gaussian kernel
                w_in[-1, :] += STIM_AMP * W[target_idx, :]
        y = predict_next(model, w_in)
        out[t] = y
        w = np.vstack([w[1:], y[None, :]])

    meta_sim = {
        "tr_model_s": float(TR_MODEL),
        "burn_in_steps": int(BURN_IN),
        "noise_input_sigma": float(NOISE_SIGMA),
        "stim_amp": float(STIM_AMP),
        "stim_steps_modelTR": sorted(list(stim_steps)) if do_stim else [],
        "stim_mapping_mode": MAP_MODE,
    }
    return out, meta_sim

print("✓ Helper functions defined")

## Load Model

In [None]:
print(f"Loading model from {MODEL_PATH}...")
model = build_model("MLP", ROI_num=N, using_steps=S).to(device)

try:
    state = torch.load(MODEL_PATH, map_location=device, weights_only=True)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        raise RuntimeError("Unexpected format")
except Exception as e:
    print(f"weights_only=True failed, using weights_only=False: {e}")
    state = torch.load(MODEL_PATH, map_location=device, weights_only=False)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        model = state.to(device)

model.eval()
print("✓ Model loaded and ready")

## Load Empirical Dataset

In [None]:
print(f"Loading empirical dataset from {DATASET_EMP_PKL}...")
with open(DATASET_EMP_PKL, "rb") as f:
    dataset_emp = pickle.load(f)

print(f"✓ Loaded {len(dataset_emp)} subjects")
print(f"  First subject: {list(dataset_emp.keys())[0]}")

## Generate Synthetic Dataset

For each subject, simulate REST and STIM runs using the trained model.

In [None]:
print("Generating synthetic dataset...\n")

dataset_sim = {}
n_sim_rest = 0
n_sim_stim = 0

for sub_id, sub_data in dataset_emp.items():
    dataset_sim[sub_id] = {"task-rest": {}, "task-stim": {}}

    # ---- SIMULATE REST ----
    if "task-rest" in sub_data:
        for run_idx, run in sub_data["task-rest"].items():
            ts_emp = run.get("time series", None)
            md_emp = run.get("metadata", {}) or {}

            if ts_emp is None or not isinstance(ts_emp, np.ndarray) or ts_emp.shape[1] != N:
                continue

            tr_emp = float(md_emp.get("tr_s", 2.0))
            dur_s = ts_emp.shape[0] * tr_emp
            n_steps = int(math.ceil(dur_s / TR_MODEL))

            init_window = ts_emp[:S].copy()
            sim_ts, meta_sim = simulate_run(model, init_window, n_steps)

            md_out = dict(md_emp)
            md_out.update({
                "simulated": True,
                "duration_emp_s": float(dur_s),
                "n_steps_model": int(n_steps),
                **meta_sim
            })

            dataset_sim[sub_id]["task-rest"][int(run_idx)] = {
                "time series": sim_ts,
                "metadata": md_out
            }
            n_sim_rest += 1

    # ---- SIMULATE STIM ----
    if "task-stim" in sub_data:
        for run_idx, run in sub_data["task-stim"].items():
            ts_emp = run.get("time series", None)
            md_emp = run.get("metadata", {}) or {}
            target_vec = run.get("target", None)
            events_df = run.get("stim time", None)

            if ts_emp is None or not isinstance(ts_emp, np.ndarray) or ts_emp.shape[1] != N:
                continue

            target_idx = safe_target_idx(target_vec)
            if target_idx is None:
                continue

            onset_col = get_onset_column(events_df) if isinstance(events_df, pd.DataFrame) else None
            if onset_col is None:
                continue

            onsets_s = events_df[onset_col].astype(float).values
            stim_steps = list(map_onsets_to_steps_with_duration(onsets_s, duration_s=STIM_DURATION_S))

            tr_emp = float(md_emp.get("tr_s", 2.4))
            dur_s = ts_emp.shape[0] * tr_emp
            n_steps = int(math.ceil(dur_s / TR_MODEL))

            init_window = ts_emp[:S].copy()
            sim_ts, meta_sim = simulate_run(model, init_window, n_steps,
                                            stim_steps=stim_steps, target_idx=target_idx, W=W)

            md_out = dict(md_emp)
            md_out.update({
                "simulated": True,
                "duration_emp_s": float(dur_s),
                "n_steps_model": int(n_steps),
                "target_idx": int(target_idx),
                **meta_sim
            })

            dataset_sim[sub_id]["task-stim"][int(run_idx)] = {
                "time series": sim_ts,
                "metadata": md_out,
                "target": target_vec,
                "stim time": events_df,
            }
            n_sim_stim += 1

print(f"✓ Generated {n_sim_rest} rest runs, {n_sim_stim} stim runs\n")

## Save Synthetic Dataset

In [None]:
print(f"Saving synthetic dataset to {OUT_PKL}...")
with open(OUT_PKL, "wb") as f:
    pickle.dump(dataset_sim, f, protocol=pickle.HIGHEST_PROTOCOL)

print("✓ Saved")
print(f"\nOutput file size: {os.path.getsize(OUT_PKL) / 1e9:.2f} GB")

## Summary

The synthetic dataset has been successfully created with the same hierarchical structure as the empirical data:

```
dataset_sim[sub_id]['task-rest'][run_idx] = {
    'time series': (n_steps, 450),
    'metadata': {...}
}

dataset_sim[sub_id]['task-stim'][run_idx] = {
    'time series': (n_steps, 450),
    'metadata': {...},
    'target': one-hot(450,),
    'stim time': DataFrame
}
```

Ready for validation analysis!