<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/TMS_fMRI_ANN_Simulate_Sessions_dataset_simulated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simulate TMS-fMRI sessions with population ANN (dataset_simulated)

Creates a `dataset_simulated` dictionary mirroring the empirical `dataset` structure and saves it to Google Drive.


In [None]:

# =========================
# 0) Mount Google Drive
# =========================
from google.colab import drive
drive.mount('/content/drive')


In [None]:

# =========================
# 1) Clone repo + imports
# =========================
import os, sys, pickle, json, math
import numpy as np
import pandas as pd
import torch

REPO_DIR = "/content/BrainStim_ANN_fMRI_HCP"
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
else:
    print("Repo already exists âœ…")

sys.path.append(REPO_DIR)

from src.NPI import build_model, device
print("Torch device:", device)


In [None]:

# =========================
# 2) Paths (EDIT IF NEEDED)
# =========================
BASE = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"

DATASET_EMP_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")

PREPROC_ROOT = os.path.join(BASE, "preprocessed_subjects_tms_fmri")
MODEL_DIR = os.path.join(PREPROC_ROOT, "trained_models_MLP_tms_fmri")

MODEL_PATH = None
for cand in [
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pt"),
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pth"),
]:
    if os.path.exists(cand):
        MODEL_PATH = cand
        break

OUT_DIR = os.path.join(PREPROC_ROOT, "ANN_vs_tms_fmri")
os.makedirs(OUT_DIR, exist_ok=True)
OUT_PKL = os.path.join(OUT_DIR, "dataset_simulated_populationANN.pkl")

print("Empirical dataset:", DATASET_EMP_PKL, "| exists:", os.path.exists(DATASET_EMP_PKL))
print("Model:", MODEL_PATH)
print("Will save to:", OUT_PKL)


In [None]:

# =========================
# 3) Load empirical dataset
# =========================
with open(DATASET_EMP_PKL, "rb") as f:
    dataset_emp = pickle.load(f)

print("Loaded subjects:", len(dataset_emp))


In [None]:

# =========================
# 4) Config
# =========================
S = 3
N = 450

TR_MODEL = 2.0          # model step in seconds (trained on rest TR=2s)
BURN_IN = 100           # steps (not saved)
NOISE_SIGMA = 0.01      # z-scored units
STIM_AMP = 0.1          # z-scored units

MAP_MODE = "round"      # "round" | "floor" | "ceil" for mapping onset_s -> model steps

rng = np.random.default_rng(0)


In [None]:

# =========================
# 5) Helpers
# =========================
def get_onset_column(df: pd.DataFrame):
    if df is None or len(df) == 0:
        return None
    for c in ["onset", "Onset", "stim_onset", "onset_s", "onset_sec", "time", "t", "seconds"]:
        if c in df.columns:
            return c
    for c in df.columns:
        if pd.api.types.is_numeric_dtype(df[c]):
            return c
    return None

def map_onsets_to_steps(onsets_s, tr_model=TR_MODEL, mode=MAP_MODE):
    onsets_s = np.asarray(onsets_s, dtype=float)
    x = onsets_s / float(tr_model)
    if mode == "round":
        steps = np.rint(x).astype(int)
    elif mode == "floor":
        steps = np.floor(x).astype(int)
    elif mode == "ceil":
        steps = np.ceil(x).astype(int)
    else:
        raise ValueError("mode must be round|floor|ceil")
    steps = steps[steps >= 0]
    return np.unique(steps)

def safe_target_idx(target_vec):
    if target_vec is None:
        return None
    v = np.asarray(target_vec).astype(int).ravel()
    if v.size == 0 or v.sum() != 1:
        return None
    return int(np.argmax(v))

@torch.no_grad()
def predict_next(model, window_SxN: np.ndarray):
    x = torch.tensor(window_SxN.reshape(1, -1), dtype=torch.float32, device=device)
    y = model(x)
    return y.detach().cpu().numpy().squeeze(0)

def simulate_run(model, init_window_SxN, n_steps, stim_steps=None, target_idx=None,
                 stim_amp=STIM_AMP, noise_sigma=NOISE_SIGMA, burn_in=BURN_IN):
    init_window_SxN = np.asarray(init_window_SxN, dtype=np.float32)
    assert init_window_SxN.shape == (S, N)

    stim_steps = set(int(s) for s in (stim_steps or []))
    do_stim = (target_idx is not None) and (len(stim_steps) > 0)

    w = init_window_SxN.copy()

    # burn-in
    for _ in range(burn_in):
        y = predict_next(model, w)
        if noise_sigma > 0:
            y = y + rng.normal(0, noise_sigma, size=y.shape).astype(np.float32)
        w = np.vstack([w[1:], y[None, :]])

    # session
    out = np.zeros((n_steps, N), dtype=np.float32)
    for t in range(n_steps):
        w_in = w.copy()
        if do_stim and (t in stim_steps):
            w_in[-1, target_idx] += stim_amp
        y = predict_next(model, w_in)
        if noise_sigma > 0:
            y = y + rng.normal(0, noise_sigma, size=y.shape).astype(np.float32)
        out[t] = y
        w = np.vstack([w[1:], y[None, :]])

    meta_sim = {
        "tr_model_s": float(TR_MODEL),
        "burn_in_steps": int(burn_in),
        "noise_sigma": float(noise_sigma),
        "stim_amp": float(stim_amp),
        "stim_steps_modelTR": sorted(list(stim_steps)) if do_stim else [],
        "stim_mapping_mode": MAP_MODE,
    }
    return out, meta_sim


In [None]:

# =========================
# 6) Load model (robust to PyTorch 2.6+)
# =========================
if MODEL_PATH is None:
    raise FileNotFoundError("Could not find population model checkpoint in MODEL_DIR.")

METHOD = "MLP"
model = build_model(METHOD, ROI_num=N, using_steps=S).to(device)

try:
    state = torch.load(MODEL_PATH, map_location=device, weights_only=True)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        raise RuntimeError("weights_only=True returned non-dict; falling back.")
    print("Loaded weights with weights_only=True")
except Exception as e:
    print("weights_only=True failed; using weights_only=False")
    print("Reason:", repr(e))
    state = torch.load(MODEL_PATH, map_location=device, weights_only=False)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        model = state.to(device)

model.eval()
print("Model ready.")


In [None]:

# =========================
# 7) Build dataset_simulated
# =========================
dataset_sim = {}
n_sim_runs = 0

for sub_id, sub_data in dataset_emp.items():
    dataset_sim[sub_id] = {"task-rest": {}, "task-stim": {}}

    # ---- REST ----
    if "task-rest" in sub_data:
        for run_idx, run in sub_data["task-rest"].items():
            ts_emp = run.get("time series", None)
            md_emp = run.get("metadata", {}) or {}
            if ts_emp is None or not isinstance(ts_emp, np.ndarray) or ts_emp.shape[1] != N:
                continue

            tr_emp = float(md_emp.get("tr_s", 2.0))
            dur_s = ts_emp.shape[0] * tr_emp
            n_steps = int(math.ceil(dur_s / TR_MODEL))

            init_window = ts_emp[:S].copy()
            sim_ts, meta_sim = simulate_run(model, init_window, n_steps)

            md_out = dict(md_emp)
            md_out.update({
                "simulated": True,
                "source_empirical_task": "task-rest",
                "source_run_idx": int(run_idx),
                "duration_emp_s": float(dur_s),
                "n_steps_model": int(n_steps),
                **meta_sim
            })

            dataset_sim[sub_id]["task-rest"][int(run_idx)] = {
                "time series": sim_ts,
                "metadata": md_out
            }
            n_sim_runs += 1

    # ---- STIM ----
    if "task-stim" in sub_data:
        for run_idx, run in sub_data["task-stim"].items():
            ts_emp = run.get("time series", None)
            md_emp = run.get("metadata", {}) or {}
            target_vec = run.get("target", None)
            events_df = run.get("stim time", None)

            if ts_emp is None or not isinstance(ts_emp, np.ndarray) or ts_emp.shape[1] != N:
                continue

            target_idx = safe_target_idx(target_vec)
            if target_idx is None:
                continue

            onset_col = get_onset_column(events_df) if isinstance(events_df, pd.DataFrame) else None
            if onset_col is None:
                continue

            onsets_s = events_df[onset_col].astype(float).values
            stim_steps = map_onsets_to_steps(onsets_s, tr_model=TR_MODEL, mode=MAP_MODE)

            tr_emp = float(md_emp.get("tr_s", 2.4))
            dur_s = ts_emp.shape[0] * tr_emp
            n_steps = int(math.ceil(dur_s / TR_MODEL))

            init_window = ts_emp[:S].copy()
            sim_ts, meta_sim = simulate_run(model, init_window, n_steps,
                                            stim_steps=stim_steps, target_idx=target_idx)

            md_out = dict(md_emp)
            md_out.update({
                "simulated": True,
                "source_empirical_task": "task-stim",
                "source_run_idx": int(run_idx),
                "duration_emp_s": float(dur_s),
                "n_steps_model": int(n_steps),
                "target_idx": int(target_idx),
                **meta_sim
            })

            dataset_sim[sub_id]["task-stim"][int(run_idx)] = {
                "time series": sim_ts,
                "metadata": md_out,
                "target": target_vec,
                "stim time": events_df,
            }
            n_sim_runs += 1

print("Simulated runs:", n_sim_runs)


In [None]:

# =========================
# 8) Save to Drive
# =========================
with open(OUT_PKL, "wb") as f:
    pickle.dump(dataset_sim, f, protocol=pickle.HIGHEST_PROTOCOL)

print("Saved dataset_simulated to:", OUT_PKL)

# Quick peek
some_sub = next(iter(dataset_sim.keys()), None)
if some_sub:
    print("Example subject:", some_sub)
    print("task-rest runs:", list(dataset_sim[some_sub]["task-rest"].keys())[:5])
    print("task-stim runs:", list(dataset_sim[some_sub]["task-stim"].keys())[:5])
    if dataset_sim[some_sub]["task-stim"]:
        r0 = next(iter(dataset_sim[some_sub]["task-stim"].keys()))
        md0 = dataset_sim[some_sub]["task-stim"][r0]["metadata"]
        print("Example stim target_idx:", md0.get("target_idx"))
        print("Example stim mapped steps (first 10):", md0.get("stim_steps_modelTR", [])[:10])
