<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/TMS_fMRI_ANN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/TMS_fMRI_ANN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TMS-fMRI ANN stimulation prediction (population model)

This notebook:
- loads the **population ANN** trained on task-rest
- loads the original TMS-fMRI `dataset_tian50_schaefer400_allruns.pkl`
- loops over **task-stim** runs and **stimulation events**
- extracts the **pre-stim window** (S=3 volumes)
- predicts the next state with the ANN:
  - **baseline** (no perturbation)
  - **perturbed** (+0.1 on the stimulated parcel at the last time of the window)
- also computes a 2-step rollout (saved for TR-mismatch flexibility)
- saves results to:

`/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/ANN_vs_tms_fmri/`


In [1]:

# =========================
# 0) Mount Google Drive
# =========================
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:

# =========================
# 1) Clone repo + imports
# =========================
import os, sys, pickle, json
import numpy as np
import pandas as pd
import torch

REPO_DIR = "/content/BrainStim_ANN_fMRI_HCP"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
else:
    print("Repo already exists ✅")

sys.path.append(REPO_DIR)

from src.NPI import build_model, device  # uses GPU if available
print("Torch device:", device)


Cloning into 'BrainStim_ANN_fMRI_HCP'...
remote: Enumerating objects: 494, done.[K
remote: Counting objects: 100% (144/144), done.[K
remote: Compressing objects: 100% (134/134), done.[K
remote: Total 494 (delta 64), reused 10 (delta 10), pack-reused 350 (from 2)[K
Receiving objects: 100% (494/494), 63.21 MiB | 13.60 MiB/s, done.
Resolving deltas: 100% (172/172), done.
Torch device: cuda:0


In [3]:

# =========================
# 2) Paths (EDIT IF NEEDED)
# =========================
BASE = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"

# Original dictionary with both rest and stim runs
DATASET_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")

# Folder where you saved preprocessed rest data + trained population model
PREPROC_ROOT = os.path.join(BASE, "preprocessed_subjects_tms_fmri")

# Where to save ANN-vs-empirical stimulation dictionaries
OUT_DIR = os.path.join(PREPROC_ROOT, "ANN_vs_tms_fmri")
os.makedirs(OUT_DIR, exist_ok=True)

# Population model path (adjust if your filename differs)
MODEL_DIR = os.path.join(PREPROC_ROOT, "trained_models_MLP_tms_fmri")
MODEL_PATH_CANDIDATES = [
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pt"),
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pth"),
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.ptc"),
]
MODEL_PATH = next((p for p in MODEL_PATH_CANDIDATES if os.path.exists(p)), None)

print("Dataset PKL:", DATASET_PKL, "| exists:", os.path.exists(DATASET_PKL))
print("Model dir   :", MODEL_DIR, "| exists:", os.path.exists(MODEL_DIR))
print("Model path  :", MODEL_PATH)


Dataset PKL: /content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/TMS_fMRI/dataset_tian50_schaefer400_allruns.pkl | exists: True
Model dir   : /content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/trained_models_MLP_tms_fmri | exists: True
Model path  : /content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/trained_models_MLP_tms_fmri/population_MLP_tms_fmri.pt


In [4]:

# =========================
# 3) Config
# =========================
S = 3                   # input window length (must match training)
N = 450                 # parcels (Tian50 + Schaefer400)
PERTURB_AMP = 0.5       # z-scored units added to targeted parcel at last step of the window
MAX_EMP_HORIZON = 3     # save empirical t+1..t+3 (volumes)
SAVE_2STEP = True       # compute and save ANN 2-step rollout (recommended)

# Note on TR mismatch:
# - model trained on rest TR ~= 2.0s
# - stim runs TR may be ~= 2.4s
# We save 1-step and 2-step predictions for later comparison choices.


In [5]:

# =========================
# 4) Load dataset dict + model
# =========================
with open(DATASET_PKL, "rb") as f:
    dataset = pickle.load(f)

print("Loaded subjects:", len(dataset))

if MODEL_PATH is None:
    raise FileNotFoundError("Could not find population model file in MODEL_DIR. Please update MODEL_PATH.")


import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch

# Build model (must match training)
METHOD = "MLP"
model = build_model(METHOD, ROI_num=N, using_steps=S).to(device)

# ---- Robust load (PyTorch 2.6+ safe default change) ----
try:
    # First try: load weights-only state_dict (newer safe behavior)
    state = torch.load(MODEL_PATH, map_location=device, weights_only=True)
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
    elif isinstance(state, dict):
        model.load_state_dict(state)
    else:
        raise RuntimeError("weights_only=True returned non-dict object; falling back.")
    print("Loaded checkpoint as weights/state_dict (weights_only=True).")

except Exception as e:
    print("weights_only=True failed, falling back to weights_only=False.")
    print("Reason:", repr(e))

    # Fallback: trust your own checkpoint and load full pickle
    state = torch.load(MODEL_PATH, map_location=device, weights_only=False)

    # If it’s a dict, treat as state_dict; otherwise it’s likely a full model object
    if isinstance(state, dict) and "state_dict" in state:
        model.load_state_dict(state["state_dict"])
        print("Loaded state['state_dict'] (weights_only=False).")
    elif isinstance(state, dict):
        model.load_state_dict(state)
        print("Loaded state_dict (weights_only=False).")
    else:
        model = state.to(device)
        print("Loaded full model object (weights_only=False).")

model.eval()
print("Model ready.")


  dataset = pickle.load(f)


Loaded subjects: 46
weights_only=True failed, falling back to weights_only=False.
Reason: UnpicklingError('Weights only load failed. This file can still be loaded, to do so you have two options, \x1b[1mdo those steps only if you trust the source of the checkpoint\x1b[0m. \n\t(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.\n\t(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.\n\tWeightsUnpickler error: Unsupported global: GLOBAL src.NPI.ANN_MLP was not an allowed global by default. Please use `torch.serialization.add_safe_globals([src.NPI.ANN_MLP])` or the `torch.serialization.safe_globals([src.NPI.ANN_MLP])` context manager to allowlist this global if you trust this clas

In [6]:
# =========================
# 5) Helpers
# =========================
def get_onset_column(df: pd.DataFrame):
    """Infer onset column name in seconds."""
    if df is None or len(df) == 0:
        return None
    candidates = ["onset", "Onset", "stim_onset", "t", "time", "onset_s", "onset_sec", "seconds"]
    for c in candidates:
        if c in df.columns:
            return c
    # fallback: first numeric column
    for c in df.columns:
        if pd.api.types.is_numeric_dtype(df[c]):
            return c
    return None

@torch.no_grad()
def predict_from_window(window_SxN: np.ndarray):
    """window_SxN: (S,N) numpy -> returns (N,) numpy"""
    x = torch.tensor(window_SxN.reshape(1, -1), dtype=torch.float32, device=device)  # (1, S*N)
    y = model(x)  # (1, N)
    return y.detach().cpu().numpy().squeeze(0)

def rollout_2step(window_SxN: np.ndarray):
    """returns (pred_t1, pred_t2) each (N,)"""
    pred1 = predict_from_window(window_SxN)
    w2 = np.vstack([window_SxN[1:], pred1[None, :]])  # shift window, append pred1
    pred2 = predict_from_window(w2)
    return pred1, pred2

def safe_target_idx(target_vec):
    """Return 0-based target index if exactly one 1 else None."""
    if target_vec is None:
        return None
    v = np.asarray(target_vec).astype(int).ravel()
    if v.size == 0 or v.sum() != 1:
        return None
    return int(np.argmax(v))


In [7]:
# =========================
# 6) Main loop: build stimulation dictionaries
# =========================
stim_effects = {}  # stim_effects[sub_id][target_idx] = list of event dicts

n_runs = 0
n_events_total = 0
n_events_used = 0
n_events_skipped = 0

for sub_id, sub_data in dataset.items():
    if "task-stim" not in sub_data:
        continue

    stim_effects[sub_id] = {}
    stim_runs = sub_data["task-stim"]  # int-keyed dict of runs

    for run_idx, run in stim_runs.items():
        n_runs += 1

        ts = run.get("time series", None)
        md = run.get("metadata", {}) or {}
        target_vec = run.get("target", None)
        events_df = run.get("stim time", None)

        if ts is None or not isinstance(ts, np.ndarray) or ts.ndim != 2:
            n_events_skipped += 1
            continue
        if ts.shape[1] != N:
            raise ValueError(f"{sub_id} run {run_idx}: expected N={N}, got {ts.shape[1]}")

        target_idx = safe_target_idx(target_vec)
        if target_idx is None:
            n_events_skipped += 1
            continue

        if not isinstance(events_df, pd.DataFrame) or len(events_df) == 0:
            n_events_skipped += 1
            continue

        onset_col = get_onset_column(events_df)
        if onset_col is None:
            n_events_skipped += 1
            continue

        tr_s = float(md.get("tr_s", 2.4))
        onsets = events_df[onset_col].astype(float).values
        n_events_total += len(onsets)

        stim_effects[sub_id].setdefault(target_idx, [])

        for onset_s in onsets:
            # k_pre: last volume BEFORE stimulation onset
            k_pre = int(np.floor(onset_s / tr_s) - 1)

            # need window [k_pre-S+1 .. k_pre] and empirical up to k_pre+3
            if (k_pre - (S - 1)) < 0:
                n_events_skipped += 1
                continue
            if (k_pre + MAX_EMP_HORIZON) >= ts.shape[0]:
                n_events_skipped += 1
                continue

            w_pre = ts[k_pre - (S - 1): k_pre + 1, :]  # (S,N)

            # baseline preds
            if SAVE_2STEP:
                pred_base_t1, pred_base_t2 = rollout_2step(w_pre)
            else:
                pred_base_t1 = predict_from_window(w_pre)
                pred_base_t2 = None

            # perturbed: add at last step only
            w_stim = w_pre.copy()
            w_stim[-1, target_idx] += PERTURB_AMP

            if SAVE_2STEP:
                pred_stim_t1, pred_stim_t2 = rollout_2step(w_stim)
            else:
                pred_stim_t1 = predict_from_window(w_stim)
                pred_stim_t2 = None

            # empirical post states (volumes)
            emp_t1 = ts[k_pre + 1, :].copy()
            emp_t2 = ts[k_pre + 2, :].copy()
            emp_t3 = ts[k_pre + 3, :].copy()

            ev = {
                "sub_id": sub_id,
                "run_idx": int(run_idx),
                "session": md.get("session", None),
                "stim_mni_xyz": md.get("stim_mni_xyz", None),
                "tr_s": tr_s,
                "onset_s": float(onset_s),
                "k_pre": int(k_pre),
                "S": int(S),
                "target_idx": int(target_idx),
                "perturb_amp": float(PERTURB_AMP),

                "pre_window": w_pre.astype(np.float32),

                "emp_t1": emp_t1.astype(np.float32),
                "emp_t2": emp_t2.astype(np.float32),
                "emp_t3": emp_t3.astype(np.float32),

                "pred_base_t1": pred_base_t1.astype(np.float32),
                "pred_stim_t1": pred_stim_t1.astype(np.float32),
            }
            if SAVE_2STEP:
                ev["pred_base_t2"] = pred_base_t2.astype(np.float32)
                ev["pred_stim_t2"] = pred_stim_t2.astype(np.float32)

            stim_effects[sub_id][target_idx].append(ev)
            n_events_used += 1

print("Runs processed:", n_runs)
print("Events total in TSVs:", n_events_total)
print("Events saved:", n_events_used)
print("Events skipped:", n_events_skipped)


Runs processed: 432
Events total in TSVs: 29376
Events saved: 29376
Events skipped: 0


In [8]:
# =========================
# 7) Save outputs
# =========================
import datetime

out_pkl = os.path.join(OUT_DIR, "stim_effects_ann_vs_emp.pkl")
with open(out_pkl, "wb") as f:
    pickle.dump(stim_effects, f, protocol=pickle.HIGHEST_PROTOCOL)

# Lightweight CSV index (one row per event)
rows = []
for sub_id, targets in stim_effects.items():
    for target_idx, events in targets.items():
        for e in events:
            rows.append({
                "sub_id": sub_id,
                "target_idx": target_idx,
                "run_idx": e["run_idx"],
                "session": e.get("session"),
                "onset_s": e["onset_s"],
                "k_pre": e["k_pre"],
                "tr_s": e["tr_s"],
                "stim_mni_xyz": str(e.get("stim_mni_xyz")),
            })

idx_csv = os.path.join(OUT_DIR, "stim_effects_index.csv")
pd.DataFrame(rows).to_csv(idx_csv, index=False)

cfg = {
    "created_at": datetime.datetime.now().isoformat(),
    "dataset_pkl": DATASET_PKL,
    "model_path": MODEL_PATH,
    "method": METHOD,
    "S": S,
    "N": N,
    "perturb_amp": PERTURB_AMP,
    "max_emp_horizon": MAX_EMP_HORIZON,
    "save_2step": SAVE_2STEP,
    "out_dir": OUT_DIR,
}
with open(os.path.join(OUT_DIR, "config.json"), "w") as f:
    json.dump(cfg, f, indent=2)

print("Saved:", out_pkl)
print("Saved:", idx_csv)
print("Saved:", os.path.join(OUT_DIR, "config.json"))


Saved: /content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/ANN_vs_tms_fmri/stim_effects_ann_vs_emp.pkl
Saved: /content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/ANN_vs_tms_fmri/stim_effects_index.csv
Saved: /content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/ANN_vs_tms_fmri/config.json


In [9]:
# =========================
# 8) Quick sanity peek
# =========================
sub0 = next(iter(stim_effects.keys()), None)
if sub0:
    print("Example subject:", sub0)
    targets = stim_effects[sub0]
    print("Targets:", list(targets.keys())[:10], "...")
    t0 = next(iter(targets.keys()), None)
    if t0 is not None and len(targets[t0]) > 0:
        e0 = targets[t0][0]
        print("Example event keys:", list(e0.keys()))
        print("pre_window shape:", e0["pre_window"].shape)
        print("pred_base_t1 shape:", e0["pred_base_t1"].shape)
        if "pred_base_t2" in e0:
            print("pred_base_t2 shape:", e0["pred_base_t2"].shape)


Example subject: sub-NTHC1001
Targets: [401] ...
Example event keys: ['sub_id', 'run_idx', 'session', 'stim_mni_xyz', 'tr_s', 'onset_s', 'k_pre', 'S', 'target_idx', 'perturb_amp', 'pre_window', 'emp_t1', 'emp_t2', 'emp_t3', 'pred_base_t1', 'pred_stim_t1', 'pred_base_t2', 'pred_stim_t2']
pre_window shape: (3, 450)
pred_base_t1 shape: (450,)
pred_base_t2 shape: (450,)


In [10]:
stim_effects.keys()

dict_keys(['sub-NTHC1001', 'sub-NTHC1003', 'sub-NTHC1009', 'sub-NTHC1015', 'sub-NTHC1016', 'sub-NTHC1019', 'sub-NTHC1021', 'sub-NTHC1022', 'sub-NTHC1023', 'sub-NTHC1024', 'sub-NTHC1026', 'sub-NTHC1027', 'sub-NTHC1028', 'sub-NTHC1029', 'sub-NTHC1032', 'sub-NTHC1035', 'sub-NTHC1036', 'sub-NTHC1037', 'sub-NTHC1038', 'sub-NTHC1039', 'sub-NTHC1040', 'sub-NTHC1043', 'sub-NTHC1047', 'sub-NTHC1049', 'sub-NTHC1050', 'sub-NTHC1052', 'sub-NTHC1053', 'sub-NTHC1055', 'sub-NTHC1056', 'sub-NTHC1057', 'sub-NTHC1061', 'sub-NTHC1062', 'sub-NTHC1064', 'sub-NTHC1065', 'sub-NTHC1066', 'sub-NTHC1068', 'sub-NTHC1073', 'sub-NTHC1075', 'sub-NTHC1097', 'sub-NTHC1098', 'sub-NTHC1099', 'sub-NTHC1101', 'sub-NTHC1102', 'sub-NTHC1105', 'sub-NTHC1107', 'sub-NTHC1108'])

In [11]:
stim_effects['sub-NTHC1108'][392][0].keys()

dict_keys(['sub_id', 'run_idx', 'session', 'stim_mni_xyz', 'tr_s', 'onset_s', 'k_pre', 'S', 'target_idx', 'perturb_amp', 'pre_window', 'emp_t1', 'emp_t2', 'emp_t3', 'pred_base_t1', 'pred_stim_t1', 'pred_base_t2', 'pred_stim_t2'])