<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/TMS_fMRI_ANN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TMS-fMRI ANN stimulation prediction (population model)

This notebook:
- loads the **population ANN** trained on task-rest
- loads the original TMS-fMRI `dataset_tian50_schaefer400_allruns.pkl`
- loops over **task-stim** runs and **stimulation events**
- extracts the **pre-stim window** (S=3 volumes)
- predicts the next state with the ANN:
  - **baseline** (no perturbation)
  - **perturbed** (+0.1 on the stimulated parcel at the last time of the window)
- also computes a 2-step rollout (saved for TR-mismatch flexibility)
- saves results to:

`/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data/preprocessed_subjects_tms_fmri/ANN_vs_tms_fmri/`


In [None]:

# =========================
# 0) Mount Google Drive
# =========================
from google.colab import drive
drive.mount('/content/drive')


In [None]:

# =========================
# 1) Clone repo + imports
# =========================
import os, sys, pickle, json
import numpy as np
import pandas as pd
import torch

REPO_DIR = "/content/BrainStim_ANN_fMRI_HCP"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
else:
    print("Repo already exists âœ…")

sys.path.append(REPO_DIR)

from src.NPI import build_model, device  # uses GPU if available
print("Torch device:", device)


In [None]:

# =========================
# 2) Paths (EDIT IF NEEDED)
# =========================
BASE = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"

# Original dictionary with both rest and stim runs
DATASET_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")

# Folder where you saved preprocessed rest data + trained population model
PREPROC_ROOT = os.path.join(BASE, "preprocessed_subjects_tms_fmri")

# Where to save ANN-vs-empirical stimulation dictionaries
OUT_DIR = os.path.join(PREPROC_ROOT, "ANN_vs_tms_fmri")
os.makedirs(OUT_DIR, exist_ok=True)

# Population model path (adjust if your filename differs)
MODEL_DIR = os.path.join(PREPROC_ROOT, "trained_models_MLP_tms_fmri")
MODEL_PATH_CANDIDATES = [
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pt"),
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.pth"),
    os.path.join(MODEL_DIR, "population_MLP_tms_fmri.ptc"),
]
MODEL_PATH = next((p for p in MODEL_PATH_CANDIDATES if os.path.exists(p)), None)

print("Dataset PKL:", DATASET_PKL, "| exists:", os.path.exists(DATASET_PKL))
print("Model dir   :", MODEL_DIR, "| exists:", os.path.exists(MODEL_DIR))
print("Model path  :", MODEL_PATH)


In [None]:

# =========================
# 3) Config
# =========================
S = 3                   # input window length (must match training)
N = 450                 # parcels (Tian50 + Schaefer400)
PERTURB_AMP = 0.1       # z-scored units added to targeted parcel at last step of the window
MAX_EMP_HORIZON = 3     # save empirical t+1..t+3 (volumes)
SAVE_2STEP = True       # compute and save ANN 2-step rollout (recommended)

# Note on TR mismatch:
# - model trained on rest TR ~= 2.0s
# - stim runs TR may be ~= 2.4s
# We save 1-step and 2-step predictions for later comparison choices.


In [None]:

# =========================
# 4) Load dataset dict + model
# =========================
with open(DATASET_PKL, "rb") as f:
    dataset = pickle.load(f)

print("Loaded subjects:", len(dataset))

if MODEL_PATH is None:
    raise FileNotFoundError("Could not find population model file in MODEL_DIR. Please update MODEL_PATH.")

# Build model (must match training architecture)
METHOD = "MLP"
model = build_model(METHOD, ROI_num=N, using_steps=S).to(device)

# Load weights
state = torch.load(MODEL_PATH, map_location=device)

# Most common: state_dict
if isinstance(state, dict) and "state_dict" in state:
    model.load_state_dict(state["state_dict"])
elif isinstance(state, dict):
    try:
        model.load_state_dict(state)
    except Exception:
        # maybe the whole model object was saved
        try:
            model = state.to(device)
        except Exception as e:
            raise RuntimeError(f"Unknown model save format at {MODEL_PATH}. Error: {e}")
else:
    # whole model object
    model = state.to(device)

model.eval()
print("Model loaded and set to eval().")


In [None]:

# =========================
# 5) Helpers
# =========================
def get_onset_column(df: pd.DataFrame):
    """Infer onset column name in seconds."""
    if df is None or len(df) == 0:
        return None
    candidates = ["onset", "Onset", "stim_onset", "t", "time", "onset_s", "onset_sec", "seconds"]
    for c in candidates:
        if c in df.columns:
            return c
    # fallback: first numeric column
    for c in df.columns:
        if pd.api.types.is_numeric_dtype(df[c]):
            return c
    return None

@torch.no_grad()
def predict_from_window(window_SxN: np.ndarray):
    """window_SxN: (S,N) numpy -> returns (N,) numpy"""
    x = torch.tensor(window_SxN.reshape(1, -1), dtype=torch.float32, device=device)  # (1, S*N)
    y = model(x)  # (1, N)
    return y.detach().cpu().numpy().squeeze(0)

def rollout_2step(window_SxN: np.ndarray):
    """returns (pred_t1, pred_t2) each (N,)"""
    pred1 = predict_from_window(window_SxN)
    w2 = np.vstack([window_SxN[1:], pred1[None, :]])  # shift window, append pred1
    pred2 = predict_from_window(w2)
    return pred1, pred2

def safe_target_idx(target_vec):
    """Return 0-based target index if exactly one 1 else None."""
    if target_vec is None:
        return None
    v = np.asarray(target_vec).astype(int).ravel()
    if v.size == 0 or v.sum() != 1:
        return None
    return int(np.argmax(v))


In [None]:

# =========================
# 6) Main loop: build stimulation dictionaries
# =========================
stim_effects = {}  # stim_effects[sub_id][target_idx] = list of event dicts

n_runs = 0
n_events_total = 0
n_events_used = 0
n_events_skipped = 0

for sub_id, sub_data in dataset.items():
    if "task-stim" not in sub_data:
        continue

    stim_effects[sub_id] = {}
    stim_runs = sub_data["task-stim"]  # int-keyed dict of runs

    for run_idx, run in stim_runs.items():
        n_runs += 1

        ts = run.get("time series", None)
        md = run.get("metadata", {}) or {}
        target_vec = run.get("target", None)
        events_df = run.get("stim time", None)

        if ts is None or not isinstance(ts, np.ndarray) or ts.ndim != 2:
            n_events_skipped += 1
            continue
        if ts.shape[1] != N:
            raise ValueError(f"{sub_id} run {run_idx}: expected N={N}, got {ts.shape[1]}")

        target_idx = safe_target_idx(target_vec)
        if target_idx is None:
            n_events_skipped += 1
            continue

        if not isinstance(events_df, pd.DataFrame) or len(events_df) == 0:
            n_events_skipped += 1
            continue

        onset_col = get_onset_column(events_df)
        if onset_col is None:
            n_events_skipped += 1
            continue

        tr_s = float(md.get("tr_s", 2.4))
        onsets = events_df[onset_col].astype(float).values
        n_events_total += len(onsets)

        stim_effects[sub_id].setdefault(target_idx, [])

        for onset_s in onsets:
            # k_pre: last volume BEFORE stimulation onset
            k_pre = int(np.floor(onset_s / tr_s) - 1)

            # need window [k_pre-S+1 .. k_pre] and empirical up to k_pre+3
            if (k_pre - (S - 1)) < 0:
                n_events_skipped += 1
                continue
            if (k_pre + MAX_EMP_HORIZON) >= ts.shape[0]:
                n_events_skipped += 1
                continue

            w_pre = ts[k_pre - (S - 1): k_pre + 1, :]  # (S,N)

            # baseline preds
            if SAVE_2STEP:
                pred_base_t1, pred_base_t2 = rollout_2step(w_pre)
            else:
                pred_base_t1 = predict_from_window(w_pre)
                pred_base_t2 = None

            # perturbed: add at last step only
            w_stim = w_pre.copy()
            w_stim[-1, target_idx] += PERTURB_AMP

            if SAVE_2STEP:
                pred_stim_t1, pred_stim_t2 = rollout_2step(w_stim)
            else:
                pred_stim_t1 = predict_from_window(w_stim)
                pred_stim_t2 = None

            # empirical post states (volumes)
            emp_t1 = ts[k_pre + 1, :].copy()
            emp_t2 = ts[k_pre + 2, :].copy()
            emp_t3 = ts[k_pre + 3, :].copy()

            ev = {
                "sub_id": sub_id,
                "run_idx": int(run_idx),
                "session": md.get("session", None),
                "stim_mni_xyz": md.get("stim_mni_xyz", None),
                "tr_s": tr_s,
                "onset_s": float(onset_s),
                "k_pre": int(k_pre),
                "S": int(S),
                "target_idx": int(target_idx),
                "perturb_amp": float(PERTURB_AMP),

                "pre_window": w_pre.astype(np.float32),

                "emp_t1": emp_t1.astype(np.float32),
                "emp_t2": emp_t2.astype(np.float32),
                "emp_t3": emp_t3.astype(np.float32),

                "pred_base_t1": pred_base_t1.astype(np.float32),
                "pred_stim_t1": pred_stim_t1.astype(np.float32),
            }
            if SAVE_2STEP:
                ev["pred_base_t2"] = pred_base_t2.astype(np.float32)
                ev["pred_stim_t2"] = pred_stim_t2.astype(np.float32)

            stim_effects[sub_id][target_idx].append(ev)
            n_events_used += 1

print("Runs processed:", n_runs)
print("Events total in TSVs:", n_events_total)
print("Events saved:", n_events_used)
print("Events skipped:", n_events_skipped)


In [None]:

# =========================
# 7) Save outputs
# =========================
import datetime

out_pkl = os.path.join(OUT_DIR, "stim_effects_ann_vs_emp.pkl")
with open(out_pkl, "wb") as f:
    pickle.dump(stim_effects, f, protocol=pickle.HIGHEST_PROTOCOL)

# Lightweight CSV index (one row per event)
rows = []
for sub_id, targets in stim_effects.items():
    for target_idx, events in targets.items():
        for e in events:
            rows.append({
                "sub_id": sub_id,
                "target_idx": target_idx,
                "run_idx": e["run_idx"],
                "session": e.get("session"),
                "onset_s": e["onset_s"],
                "k_pre": e["k_pre"],
                "tr_s": e["tr_s"],
                "stim_mni_xyz": str(e.get("stim_mni_xyz")),
            })

idx_csv = os.path.join(OUT_DIR, "stim_effects_index.csv")
pd.DataFrame(rows).to_csv(idx_csv, index=False)

cfg = {
    "created_at": datetime.datetime.now().isoformat(),
    "dataset_pkl": DATASET_PKL,
    "model_path": MODEL_PATH,
    "method": METHOD,
    "S": S,
    "N": N,
    "perturb_amp": PERTURB_AMP,
    "max_emp_horizon": MAX_EMP_HORIZON,
    "save_2step": SAVE_2STEP,
    "out_dir": OUT_DIR,
}
with open(os.path.join(OUT_DIR, "config.json"), "w") as f:
    json.dump(cfg, f, indent=2)

print("Saved:", out_pkl)
print("Saved:", idx_csv)
print("Saved:", os.path.join(OUT_DIR, "config.json"))


In [None]:

# =========================
# 8) Quick sanity peek
# =========================
sub0 = next(iter(stim_effects.keys()), None)
if sub0:
    print("Example subject:", sub0)
    targets = stim_effects[sub0]
    print("Targets:", list(targets.keys())[:10], "...")
    t0 = next(iter(targets.keys()), None)
    if t0 is not None and len(targets[t0]) > 0:
        e0 = targets[t0][0]
        print("Example event keys:", list(e0.keys()))
        print("pre_window shape:", e0["pre_window"].shape)
        print("pred_base_t1 shape:", e0["pred_base_t1"].shape)
        if "pred_base_t2" in e0:
            print("pred_base_t2 shape:", e0["pred_base_t2"].shape)
