# Part 1

### Ridge Regression
In this step, we use a pre-trained BERT model (bert-base-uncased) to extract CLS token embeddings for each TR (timepoint) in every story. Each TR contains a list of words, which are tokenized and encoded using BERT. These embeddings are then aligned with their corresponding fMRI responses. We apply z-score normalization and temporal delays to the BERT embeddings and fit a ridge regression model to predict voxel-level brain activity from the BERT representations.

In [1]:
import os
import sys
sys.path.append("..")
import pickle
import numpy as np
import torch
import tqdm
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from ridge_utils.utils import zscore, make_delayed
from ridge_utils.ridge import ridge_corr

# -------------------- Step 1: Loading -------------------- #
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = "/ocean/projects/mth240012p/shared/data"
subject = "subject2"
max_tokens = 50

print("Loading raw_text.pkl and fMRI...")
with open(os.path.join(data_dir, "raw_text.pkl"), "rb") as f:
    raw_texts = pickle.load(f)

story_names = []
Y_dict = {}
for story in raw_texts:
    fmri_path = os.path.join(data_dir, subject, f"{story}.npy")
    if os.path.exists(fmri_path):
        Y_dict[story] = np.load(fmri_path)
        story_names.append(story)

print(f"Valid stories with fMRI: {len(story_names)}")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # loading BERT Model
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)
bert_model.eval()

# -------------------- Step 2: Iterate over each story and each TR -------------------- #
X_all, Y_all = [], []

for story in tqdm(story_names, desc="Processing stories"):
    fmri = Y_dict[story]
    ds = raw_texts[story] 
    for i in range(len(ds.data)):
        word_list = ds.data[i] 
        if not word_list or i >= fmri.shape[0]:
            continue  
        sentence = " ".join(word_list)
        inputs = tokenizer(
            sentence,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = bert_model(**inputs)
        cls_embed = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        X_all.append(cls_embed)
        Y_all.append(fmri[i])

X_all = np.array(X_all)
Y_all = np.array(Y_all)

print(f"\nCollected {X_all.shape[0]} TRs | BERT dim: {X_all.shape[1]} | Voxels: {Y_all.shape[1]}")

Loading raw_text.pkl and fMRI...
Valid stories with fMRI: 101


Processing stories: 100%|██████████| 101/101 [03:18<00:00,  1.96s/it]



Collected 34700 TRs | BERT dim: 768 | Voxels: 94251


In [2]:
X_z = zscore(X_all.T).T
Y_z = zscore(Y_all)
Y_z = np.nan_to_num(Y_z)

X_delayed = make_delayed(X_z[4:], delays=[0, 1, 2, 3, 4])
Y_trimmed = Y_z[4:]

X_train, X_test, Y_train, Y_test = train_test_split(
    X_delayed, Y_trimmed, test_size=0.2, random_state=42
)

# -------------------- Step 3: Ridge Regression -------------------- #
print("Running ridge regression...")

print(f"Final shapes → X: {X_delayed.shape}, Y: {Y_trimmed.shape}")
print(f"Train: {X_train.shape}, Test: {X_test.shape}")

alphas = np.logspace(1, 3, 20)
ccs = np.array(ridge_corr(X_train, X_test, Y_train, Y_test, alphas))

print("Ridge correlation shape:", ccs.shape)
best_cc = np.max(ccs, axis=0)
print(f"Mean CC:    {np.mean(best_cc):.4f}")
print(f"Median CC:  {np.median(best_cc):.4f}")
print(f"Top 1% CC:  {np.mean(np.sort(best_cc)[-int(0.01*len(best_cc)):]):.4f}")
print(f"Top 5% CC:  {np.mean(np.sort(best_cc)[-int(0.05*len(best_cc)):]):.4f}")

Running ridge regression...
Final shapes → X: (34696, 3840), Y: (34696, 94251)
Train: (27756, 3840), Test: (6940, 3840)
Ridge correlation shape: (20, 94251)
Mean CC:    0.0067
Median CC:  0.0065
Top 1% CC:  0.0383
Top 5% CC:  0.0306


In [3]:
# Check the Result 
print("X_delayed shape:", X_delayed.shape)
print("Y_trimmed shape:", Y_trimmed.shape)
print("NaN in X:", np.isnan(X_delayed).any(), "| Inf in X:", np.isinf(X_delayed).any())
print("NaN in Y:", np.isnan(Y_trimmed).any(), "| Inf in Y:", np.isinf(Y_trimmed).any())
print("X mean/std:", np.mean(X_delayed), np.std(X_delayed))
print("Y mean/std:", np.mean(Y_trimmed), np.std(Y_trimmed))
print("All zeros in Y:", np.all(Y_trimmed == 0))
nan_voxels = np.isnan(Y_all).sum(axis=0)
print(f"Max NaN count per voxel: {nan_voxels.max()} / {Y_all.shape[0]}")
print(f"Number of voxels with any NaN: {np.sum(nan_voxels > 0)}")

X_delayed shape: (34696, 3840)
Y_trimmed shape: (34696, 94251)
NaN in X: False | Inf in X: False
NaN in Y: False | Inf in Y: False
X mean/std: 1.3324312972227674e-07 0.9999527897943146
Y mean/std: 8.215887372834887e-20 0.9887539174267306
All zeros in Y: False
Max NaN count per voxel: 776 / 34700
Number of voxels with any NaN: 32


### Fine-Tuning BERT Encoder on Stimulus Texts

In this section, we fine-tune the BERT encoder using the raw story text data via a Masked Language Modeling (MLM) objective.

We use the following approach:

- **Dataset Construction**: We convert each story's TR-aligned text (from the `DataSequence`) into input IDs for BERT, masking a subset of tokens randomly to form an MLM task. Each TR is treated as one training example.
- **Model Architecture**: We reuse the `Encoder` defined in `encoder.py`, which wraps a pre-trained BERT model (`bert-base-uncased`) with a linear decoder head to predict masked tokens.
- **Training Loop**: We leverage `train_encoder.py` to fine-tune the model over multiple epochs. The training loss is computed only on masked positions.

This fine-tuning step helps the BERT encoder adapt to the linguistic distribution of the experimental stimuli, improving the quality of the learned embeddings for downstream voxel prediction.

After training, we will save the encoder weights, which can later be loaded and used for voxel-wise ridge regression in Part 2.


In [4]:
# ======================= Fine-tuning Setup ============================
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from train_encoder import train_bert
from encoder import Encoder

# 1. Prepare Dataset for MLM pretraining
class MLMTextDataset(Dataset):
    def __init__(self, raw_texts, tokenizer, max_length=128):
        self.samples = []
        for story in raw_texts.values():
            for word_list in story.data:
                if not word_list:
                    continue
                sentence = " ".join(word_list)
                tokens = tokenizer(
                    sentence,
                    max_length=max_length,
                    truncation=True,
                    padding='max_length',
                    return_tensors='pt'
                )
                self.samples.append(tokens)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        return {k: v.squeeze(0) for k, v in item.items()}

# 2. Initialize model, tokenizer, and dataset
print("Preparing dataset for fine-tuning...")
dataset = MLMTextDataset(raw_texts, tokenizer, max_length=128)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

print("Initializing encoder model...")
encoder = Encoder(vocab_size=tokenizer.vocab_size).to(device)

# 3. Train the encoder using MLM objective
print("Starting MLM training...")
train_bert(
    model=encoder,
    dataloader=dataloader,
    tokenizer=tokenizer,
    device=device,
    epochs=3, # Change this if you want
    lr=1e-4 # Can change into smaller
)

# 4. Save encoder for later use
torch.save(encoder.state_dict(), "finetuned_encoder.pt") # use this model for new BERT Model
print("Fine-tuned encoder saved as 'finetuned_encoder.pt'")


Preparing dataset for fine-tuning...
Initializing encoder model...
Starting MLM training...
Epoch [1/3], Train Loss: 1.0625, Val Loss: 0.7735
Epoch [2/3], Train Loss: 0.7266, Val Loss: 0.6953
Epoch [3/3], Train Loss: 0.6588, Val Loss: 0.6577
Fine-tuned encoder saved as 'finetuned_encoder.pt'


## Step 3 Fine-tuning with LoRa


In [5]:
import torch, numpy as np, os
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoConfig,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, TaskType

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True   # TF32 for faster matmul
print("Running on:", device)

Running on: cuda


In [6]:
voxel_idx = 0                       # <— change voxel here

X_train_text = X_train              # list[str]
X_test_text  = X_test

y_train_v = Y_train[:, voxel_idx]
y_test_v  = Y_test[:,  voxel_idx]

mu, sd    = y_train_v.mean(), y_train_v.std()
y_train_z = (y_train_v - mu) / sd
y_test_z  = (y_test_v  - mu) / sd

print("Label mean/std:", y_train_z.mean(), y_train_z.std())   # ≈0 / 1

Label mean/std: 2.559960858049071e-18 1.0


In [7]:
# --- Pick a voxel & build TR-aligned sentences -------------
voxel_idx = 0          # choose voxel to predict
train_split = 0.9

sentences, labels = [], []
for story in story_names:                # story_names from your earlier code
    fmri = Y_dict[story]                 # (n_TR, n_voxels)
    ds   = raw_texts[story]              # word lists per TR
    for tr, word_list in enumerate(ds.data):
        if not word_list or tr >= fmri.shape[0]:
            continue
        sentences.append(" ".join(word_list))
        labels.append(float(fmri[tr, voxel_idx]))

# ---- shuffle & train/valid split --------------------------
idx = np.random.RandomState(42).permutation(len(sentences))
sentences = [sentences[i] for i in idx]
labels    = np.array(labels)[idx]

split = int(train_split * len(sentences))
train_text, test_text  = sentences[:split],  sentences[split:]
train_label, test_label = labels[:split],    labels[split:]

print(f"{len(train_text)} train  /  {len(test_text)} valid sentences")

31230 train  /  3470 valid sentences


In [8]:
mu, sd = train_label.mean(), train_label.std()
train_z = (train_label - mu) / sd
test_z  = (test_label  - mu) / sd
print("Label mean/std after z-score:", train_z.mean(), train_z.std())

Label mean/std after z-score: 2.047673590086744e-17 0.9999999999999999


In [9]:
train_enc = tokenizer(train_text, truncation=True,
                      padding="max_length", max_length=128)
test_enc  = tokenizer(test_text,  truncation=True,
                      padding="max_length", max_length=128)

train_ds = Dataset.from_dict({
    "input_ids"     : train_enc["input_ids"],
    "attention_mask": train_enc["attention_mask"],
    "labels"        : train_z.astype(np.float32),
}).with_format("torch", columns=["input_ids","attention_mask","labels"])

test_ds = Dataset.from_dict({
    "input_ids"     : test_enc["input_ids"],
    "attention_mask": test_enc["attention_mask"],
    "labels"        : test_z.astype(np.float32),
}).with_format("torch", columns=["input_ids","attention_mask","labels"])

In [10]:
base = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=1, problem_type="regression"
)

lora_cfg = LoraConfig(
    r=32, lora_alpha=32, target_modules=["query", "value"],
    lora_dropout=0.05, bias="none", task_type=TaskType.SEQ_CLS
)
model = get_peft_model(base, lora_cfg).to(device)

# freeze everything except LoRA & regressor
for n, p in model.named_parameters():
    if ("lora_" not in n) and ("classifier" not in n):
        p.requires_grad = False

model.print_trainable_parameters()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 1,180,417 || all params: 110,663,426 || trainable%: 1.0667


In [11]:
from transformers import TrainingArguments
import inspect, torch, numpy as np, warnings

# ---------- hyper-params ----------
BATCH = 32          
EPOCHS = 5
LR     = 1e-5
# -----------------------------------------------

total_steps  = int(np.ceil(len(train_ds) / BATCH) * EPOCHS)
warmup_steps = int(0.05 * total_steps)

base_kwargs = dict(
    output_dir                  = "lora-regression",
    per_device_train_batch_size = BATCH,
    per_device_eval_batch_size  = BATCH,
    num_train_epochs            = EPOCHS,
    learning_rate               = LR,
    warmup_steps                = warmup_steps,
    weight_decay                = 0.01,
    logging_steps               = 100,
    seed                        = 42,
    fp16                        = torch.cuda.is_available(),
    optim                       = "adamw_torch_fused",
)

extra = {
    "save_strategy"      : "no",
    "evaluation_strategy": "no",
    "dataloader_num_workers": 4,
}

sig = inspect.signature(TrainingArguments.__init__)
allowed = {k: v for k, v in {**base_kwargs, **extra}.items()
           if k in sig.parameters}

dropped = set(base_kwargs) | set(extra) - set(allowed)
if dropped:
    warnings.warn(f"Transformer version doesn’t support {sorted(dropped)}; dropped them.")

train_args = TrainingArguments(**allowed)
print("TrainingArguments compiled with", len(allowed), "keywords")



TrainingArguments compiled with 13 keywords


In [12]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
def r2_metric(eval_pred):
    preds, labels = eval_pred
    preds = preds.squeeze()
    r2 = 1.0 - np.mean((labels - preds) ** 2)   # on z-scored labels
    return {"r2": r2}
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=r2_metric,
)
trainer.train()

  trainer = Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
100,1.1631
200,1.0947
300,1.0201
400,1.0075
500,1.0501
600,0.99
700,1.0222
800,1.0342
900,1.0158
1000,0.9977


TrainOutput(global_step=4880, training_loss=1.0161433251177678, metrics={'train_runtime': 102.1694, 'train_samples_per_second': 1528.344, 'train_steps_per_second': 47.764, 'total_flos': 1.04126649866496e+16, 'train_loss': 1.0161433251177678, 'epoch': 5.0})

In [13]:
import numpy as np
from sklearn.metrics import r2_score

preds = trainer.predict(test_ds).predictions.squeeze()
r2    = r2_score(test_z, preds)          # test_z from Mini-Cell B
print(f"Voxel {voxel_idx}  –  R² = {r2:.4f}")

Voxel 0  –  R² = 0.0001


In [14]:
# --- load the trained model (adapter + base) ---
model.eval()

all_cls = []
with torch.no_grad():
    for story in story_names:
        ds   = raw_texts[story]
        for word_list in ds.data:
            if not word_list:
                continue
            text = " ".join(word_list)
            tok  = tokenizer(text, return_tensors="pt",
                             padding="max_length", truncation=True,
                             max_length=128).to(device)
            hidden = model.bert(**tok).last_hidden_state[:,0,:]   # CLS
            all_cls.append(hidden.squeeze().cpu().numpy())

all_cls = np.array(all_cls)          # (n_TR, 768)
np.save("lora_cls_subject2.npy", all_cls)
print("Saved embeddings:", all_cls.shape)

Saved embeddings: (190401, 768)


In [15]:
embeds   = []
fmri_rows = []

for story in story_names:
    ds   = raw_texts[story]
    fmri = Y_dict[story]

    for tr, word_list in enumerate(ds.data):
        if not word_list or tr >= fmri.shape[0]:
            continue           

        sent = " ".join(word_list)
        tok  = tokenizer(sent, return_tensors="pt",
                         padding="max_length", truncation=True,
                         max_length=128).to(device)

        with torch.no_grad():
            cls = model.bert(**tok).last_hidden_state[:, 0, :].squeeze()

        embeds.append(cls.cpu().numpy())
        fmri_rows.append(fmri[tr])
X_all = np.asarray(embeds)       # shape (N_keep, 768)
Y_all = np.asarray(fmri_rows)    # shape (N_keep, n_voxels)
print(X_all.shape, Y_all.shape)  # should have identical first dim

(34700, 768) (34700, 94251)


In [16]:
save_dir = "results/lora_model"      
os.makedirs(save_dir, exist_ok=True)

model.save_pretrained(save_dir)         # save adapter config + weights
tokenizer.save_pretrained(save_dir)     # same vocab / padding

('results/lora_model/tokenizer_config.json',
 'results/lora_model/special_tokens_map.json',
 'results/lora_model/vocab.txt',
 'results/lora_model/added_tokens.json')

In [17]:
# === Fast z-score, delay, and chunked ridge for LoRA embeddings =========
import numpy as np, logging, time
from ridge_utils.ridge import ridge_corr
from ridge_utils.utils import make_delayed  

t0 = time.time()

# ----- 1. vectorised z-score (float32) ----------------------------------
X32 = X_all.astype(np.float32)
X_z = (X32 - X32.mean(0, keepdims=True)) / (X32.std(0, keepdims=True) + 1e-10)

Y32 = Y_all.astype(np.float32)
Y_z = (Y32 - Y32.mean(0, keepdims=True)) / (Y32.std(0, keepdims=True) + 1e-10)

# ----- 2. temporal delays ------------------------------------------------
X_del  = make_delayed(X_z[4:], delays=[0,1,2,3,4]).astype(np.float32)
Y_trim = Y_z[4:].astype(np.float32)
print("Design:", X_del.shape, "  fMRI:", Y_trim.shape)

# ----- 3. chunked ridge --------------------------------------------------
alphas = np.logspace(1, 3, 20).astype(np.float32)
chunk  = 10_000          

# dummy logger to satisfy ridge_corr
log = logging.getLogger("ridge-dummy")
log.addHandler(logging.NullHandler())

best_ccs = []
for start in range(0, Y_trim.shape[1], chunk):
    stop = min(start + chunk, Y_trim.shape[1])

    cc_list = ridge_corr(                
        X_del, X_del,
        Y_trim[:, start:stop], Y_trim[:, start:stop],
        alphas,
        use_corr=True,
        logger=log,
    )
    cc      = np.asarray(cc_list)       
    best_ccs.append(cc.max(axis=0))      

    print(f"processed voxels {start:>6}-{stop-1:<6}  "
          f"chunk mean CC = {cc.max(axis=0).mean():.4f}", flush=True)

best_cc = np.concatenate(best_ccs)

# --- top-k summaries ----------------------------------------------
top1  = best_cc[np.argpartition(best_cc, -int(0.01*len(best_cc)))[-int(0.01*len(best_cc)):]].mean()
top5  = best_cc[np.argpartition(best_cc, -int(0.05*len(best_cc)))[-int(0.05*len(best_cc)):]].mean()

print(f"\nDone in {time.time() - t0:.1f} s")
print(f"Mean   CC (LoRA) : {best_cc.mean():.4f}")
print(f"Median CC (LoRA) : {np.median(best_cc):.4f}")
print(f"Top 1% CC (LoRA) : {top1:.4f}")
print(f"Top 5% CC (LoRA) : {top5:.4f}")

Design: (34696, 3840)   fMRI: (34696, 94251)
processed voxels      0-9999    chunk mean CC = 0.3143
processed voxels  10000-19999   chunk mean CC = 0.3154
processed voxels  20000-29999   chunk mean CC = 0.3154
processed voxels  30000-39999   chunk mean CC = 0.3153
processed voxels  40000-49999   chunk mean CC = 0.3152
processed voxels  50000-59999   chunk mean CC = 0.3150
processed voxels  60000-69999   chunk mean CC = 0.3150
processed voxels  70000-79999   chunk mean CC = 0.3151
processed voxels  80000-89999   chunk mean CC = 0.3152
processed voxels  90000-94250   chunk mean CC = 0.3151

Done in 784.8 s
Mean   CC (LoRA) : 0.3151
Median CC (LoRA) : 0.3152
Top 1% CC (LoRA) : 0.3253
Top 5% CC (LoRA) : 0.3230


# Part 2

# 1) Identifying voxels where the model performs well

In [None]:
import pickle
import os
import numpy as np

data_dir = "/ocean/projects/mth240012p/shared/data"
subject = "subject2"
story = "myfathershands"

# Load the raw_text.pkl
with open(os.path.join(data_dir, "raw_text.pkl"), "rb") as f:
    raw_text = pickle.load(f)

# Load fMRI data only for selected story
fmri_path = os.path.join(data_dir, subject, f"{story}.npy")
assert os.path.exists(fmri_path), f"{story}.npy not found"
Y_story = np.load(fmri_path)

print("Loaded story:", story, "| fMRI shape:", Y_story.shape)


In [None]:
import os
import pickle
import numpy as np
import torch
from transformers import BertTokenizer
from encoder import Encoder
from ridge_utils.utils import make_delayed, zscore
from ridge_utils.ridge import ridge_corr
from sklearn.model_selection import train_test_split

# ---------------------
# Setup
# ---------------------
data_dir = "/ocean/projects/mth240012p/shared/data"
subject = "subject2"
# story = "stagefright"
story = "myfathershands"
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------
# Load raw_text and fMRI for one story
# ---------------------
with open(os.path.join(data_dir, "raw_text.pkl"), "rb") as f:
    raw_text = pickle.load(f)

Y_story = np.load(os.path.join(data_dir, subject, f"{story}.npy"))

# ---------------------
# Load encoder and tokenizer
# ---------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
encoder = Encoder(vocab_size=tokenizer.vocab_size)
encoder.load_state_dict(torch.load("finetuned_encoder.pt", map_location=device))
encoder.to(device).eval()

# ---------------------
# Get sentence embeddings
# ---------------------
X_story, Y_rows = [], []
with torch.no_grad():
    for tr, word_list in enumerate(raw_text[story].data):
        if not word_list or tr >= Y_story.shape[0]:
            continue
        sentence = " ".join(word_list)
        inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
        hidden = encoder(**inputs)  # shape: (1, seq_len, hidden_dim)
        cls = hidden[:, 0, :]       # CLS token
        X_story.append(cls.squeeze().cpu().numpy())
        Y_rows.append(Y_story[tr])

X_story = np.array(X_story)
Y_story = np.array(Y_rows)

# ---------------------
# Normalize + delay
# ---------------------
X_z = zscore(X_story.T).T
Y_z = zscore(Y_story)
Y_z = np.nan_to_num(Y_z)

X_del = make_delayed(X_z[4:], delays=[0,1,2,3,4])
Y_trim = Y_z[4:]

# ---------------------
# Split
# ---------------------
X_train, X_test, Y_train, Y_test = train_test_split(X_del, Y_trim, test_size=0.2, random_state=42)

# ---------------------
# Ridge regression
# ---------------------
print("Running ridge regression...")
alphas = np.logspace(1, 3, 20)
ccs = np.array(ridge_corr(X_train, X_test, Y_train, Y_test, alphas))

# ---------------------
# Summary
# ---------------------
best_cc = np.max(ccs, axis=0)
print("Ridge correlation shape:", ccs.shape)
print(f"Mean CC:    {np.mean(best_cc):.4f}")
print(f"Median CC:  {np.median(best_cc):.4f}")
print(f"Top 1% CC:  {np.mean(np.sort(best_cc)[-int(0.01 * len(best_cc)):]):.4f}")
print(f"Top 5% CC:  {np.mean(np.sort(best_cc)[-int(0.05 * len(best_cc)):]):.4f}")

# ---------------------
# Save result (optional)
# ---------------------
np.save("/tmp/ridge_corr_myfathershands.npy", ccs)
print("Saved ridge_corr_myfathershands.npy")


In [None]:
import numpy as np
import torch
import os

# Load ridge correlation results for myfathershands
ccs = np.load("/tmp/ridge_corr_myfathershands.npy")  # shape: (num_alphas, num_voxels)
best_cc = np.max(ccs, axis=0)  # best correlation per voxel

# Select top 10 voxels
top_voxel_indices = np.argsort(best_cc)[-10:][::-1]

print("Top 10 voxel indices selected for interpretation:")
print(top_voxel_indices)
print("Top 10 voxel correlation scores:")
print(best_cc[top_voxel_indices])


# Running SHAP and LIME to Identigy Influential Words

In [None]:
import numpy as np
import torch
import os
import pickle
import shap
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.linear_model import Ridge
from transformers import BertTokenizer
from encoder import Encoder
from ridge_utils.utils import zscore, make_delayed
from ridge_utils.ridge import ridge_corr

# --- Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = "/ocean/projects/mth240012p/shared/data"
subject = "subject2"
story = "myfathershands"

# --- Load raw text ---
with open(os.path.join(data_dir, "raw_text.pkl"), "rb") as f:
    raw_text = pickle.load(f)

# --- Load fMRI ---
Y_dict = {}
Y_dict[story] = np.load(os.path.join(data_dir, subject, f"{story}.npy"))

# --- Load fine-tuned encoder ---
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
encoder = Encoder(vocab_size=tokenizer.vocab_size)
encoder.load_state_dict(torch.load("finetuned_encoder.pt", map_location=device))
encoder.to(device).eval()

# --- Get embeddings + fMRI for this story ---
X_story = []
Y_story_rows = []

with torch.no_grad():
    for tr, word_list in enumerate(raw_text[story].data):
        if not word_list or tr >= Y_dict[story].shape[0]:
            continue
        sentence = " ".join(word_list)
        tok = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
        hidden = encoder(**tok)
        cls = hidden[:, 0, :]
        X_story.append(cls.squeeze().cpu().numpy())
        Y_story_rows.append(Y_dict[story][tr])

X_story = np.array(X_story)
Y_story = np.array(Y_story_rows)

# --- Z-score + delay ---
X_z = zscore(X_story.T).T
Y_z = zscore(Y_story)
Y_z = np.nan_to_num(Y_z)
X_del = make_delayed(X_z[4:], delays=[0, 1, 2, 3, 4])
Y_trim = Y_z[-X_del.shape[0]:]

# --- Ridge regression on all voxels ---
alphas = np.logspace(1, 3, 20)
ccs = np.array(ridge_corr(X_del, X_del, Y_trim, Y_trim, alphas))
best_cc = np.max(ccs, axis=0)

# --- Select top voxel ---
top_voxel_indices = np.argsort(best_cc)[-10:][::-1]
top_voxel = top_voxel_indices[0]
print("Top voxel:", top_voxel)

# --- Fit Ridge model for that voxel ---
alpha_best = alphas[np.argmax(ccs[:, top_voxel])]
model = Ridge(alpha=alpha_best).fit(X_del, Y_trim[:, top_voxel])

# --- SHAP Explanation ---
explainer = shap.Explainer(model.predict, X_del[:100])
shap_values = explainer(X_del[:100], max_evals=2 * X_del.shape[1] + 1)

# --- Sentence list aligned to delayed input ---
sentence_lists = [w for w in raw_text[story].data if w][4:]

# --- Compute word-level importance ---
def aggregate_shap_importance(shap_values, sentence_lists, delay=5):
    word_contribs = defaultdict(float)
    num_examples, total_dim = shap_values.values.shape
    d = total_dim // delay
    for i in range(num_examples):
        shap_vec = shap_values.values[i]
        for d_idx in range(delay):
            vec = shap_vec[d_idx * d: (d_idx + 1) * d]
            words = sentence_lists[i + d_idx].split()  # ✅ split sentence into words
            if not words:
                continue
            contrib = np.sum(np.abs(vec)) / len(words)
            for word in words:
                word_contribs[word] += contrib
    return word_contribs


word_contribs = aggregate_shap_importance(shap_values, sentence_lists)

# --- Plot top 20 words ---
top_words = sorted(word_contribs.items(), key=lambda x: -x[1])[:20]
words, scores = zip(*top_words)

plt.figure(figsize=(10, 5))
plt.bar(words, scores, color="green")
plt.xticks(rotation=45)
plt.ylabel("Importance")
plt.title(f"Top 20 Most Influential Words for Voxel {top_voxel}")
plt.tight_layout()
plt.show()


In [None]:
from lime.lime_text import LimeTextExplainer
from sklearn.linear_model import Ridge
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

# -----------------------
# 1. Fit Ridge Model for a Top Voxel
# -----------------------
print("Step 1: Fitting Ridge model for top voxel...")
voxel_idx = top_voxel_indices[0]
alpha_best = np.logspace(1, 3, 20)[np.argmax(ccs, axis=0)]
model = Ridge(alpha=alpha_best[voxel_idx])
model.fit(X_del, Y_trim[:, voxel_idx])

# -----------------------
# 2. Define LIME Explainer
# -----------------------

print("Step 2: Setting up LIME explainer and tokenizer...")

from lime.lime_text import LimeTextExplainer
import re

explainer = LimeTextExplainer(
    class_names=["voxel activation"],
    char_level=False,
    split_expression=r"\s+"  # splits only on spaces
)



# TR-aligned sentences
# LIME needs full sentences for tokenization — align with delayed X
sentence_texts = []
for i in range(4, len(raw_text[story].data) - 4):
    sentence_words = []
    for j in range(i-4, i+1):  # match the 5-TR delay
        word_chars = raw_text[story].data[j]
        if word_chars:
            word = "".join(word_chars)  # ✅ join chars into word
            sentence_words.append(word)
    sentence_texts.append(" ".join(sentence_words))


# -----------------------
# 3. Define predict function
# -----------------------
def predict_fn(texts):
    embeddings = []
    for t in texts:
        inputs = tokenizer(t, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
        with torch.no_grad():
            cls_emb = encoder(**inputs)[:, 0, :].cpu().numpy()  # shape (1, 256)
        delayed_emb = np.tile(cls_emb, (1, 5))  # shape (1, 1280)
        embeddings.append(delayed_emb)
    return model.predict(np.vstack(embeddings)).reshape(-1, 1)  # reshape for LIME

# -----------------------
# 4. Run LIME on Sampled Sentences
# -----------------------

print("Step 3: Running LIME explanations...")
word_scores = defaultdict(float)
num_samples = 100  # for speed

for i in range(min(num_samples, len(sentence_texts))):
    exp = explainer.explain_instance(sentence_texts[i], predict_fn, labels=[0], num_features=10)
    for word, weight in exp.as_list(label=0):
        word_scores[word] += abs(weight)

# -----------------------
# 5. Plot LIME Results
# -----------------------

print("Step 4: Plotting results...")
sorted_words = sorted(word_scores.items(), key=lambda x: x[1], reverse=True)[:20]
words, scores = zip(*sorted_words)

plt.figure(figsize=(12, 6))
plt.bar(words, scores, color="orange")
plt.title(f"Top 20 Most Influential Words by LIME for Voxel {voxel_idx}")
plt.ylabel("Importance")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()



# Comparing SHAP and LIME

Required work:

We will add side-by-side comparisons and try across a couple more voxels. could have something like the following:

EXAMPLE CODE
<!-- 
def compare_word_importance(shap_values, lime_exp, voxel_idx):
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # SHAP bar plot
    shap.plots.bar(shap_values[:, :, voxel_idx].mean(0), show=False, ax=axes[0])
    axes[0].set_title(f"SHAP - Voxel {voxel_idx}")
    
    # LIME bar plot
    lime_exp.as_pyplot_figure(axes[1])
    axes[1].set_title(f"LIME - Voxel {voxel_idx}")
    
    plt.tight_layout()
    plt.show()


compare_word_importance(shap_values, lime_exp, voxel_idx=0)
 -->


# Repeat all analysis for another test-story

In [None]:
# Just change the story name and re-run Section 2 onward??? not sure
story_name = "myfirstdaywiththeyankees"
sentences = [" ".join(w) for w in raw_text[story_name].data if len(w) > 0][:50]

# Rerun: predict_voxel_activation(), SHAP, and LIME
shap_values = shap_explainer(sentences)
shap.plots.text(shap_values[0])
lime_exp = lime_explainer.explain_instance(sentences[0], predict_voxel_activation, num_features=10)
lime_fig = lime_exp.as_pyplot_figure()
plt.title("LIME Word Importance for Another Sentence (Voxel-Level)")
plt.show()