# Part 1

### Ridge Regression
In this step, we use a pre-trained BERT model (bert-base-uncased) to extract CLS token embeddings for each TR (timepoint) in every story. Each TR contains a list of words, which are tokenized and encoded using BERT. These embeddings are then aligned with their corresponding fMRI responses. We apply z-score normalization and temporal delays to the BERT embeddings and fit a ridge regression model to predict voxel-level brain activity from the BERT representations.

In [1]:
import os
import sys
sys.path.append("..")
import pickle
import numpy as np
import torch
import tqdm
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from ridge_utils.utils import zscore, make_delayed
from ridge_utils.ridge import ridge_corr

# -------------------- Step 1: Loading -------------------- #
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = "/ocean/projects/mth240012p/shared/data"
subject = "subject2"
max_tokens = 50

print("Loading raw_text.pkl and fMRI...")
with open(os.path.join(data_dir, "raw_text.pkl"), "rb") as f:
    raw_texts = pickle.load(f)

story_names = []
Y_dict = {}
for story in raw_texts:
    fmri_path = os.path.join(data_dir, subject, f"{story}.npy")
    if os.path.exists(fmri_path):
        Y_dict[story] = np.load(fmri_path)
        story_names.append(story)

print(f"Valid stories with fMRI: {len(story_names)}")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # loading BERT Model
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)
bert_model.eval()

# -------------------- Step 2: Iterate over each story and each TR -------------------- #
X_all, Y_all = [], []

for story in tqdm(story_names, desc="Processing stories"):
    fmri = Y_dict[story]
    ds = raw_texts[story] 
    for i in range(len(ds.data)):
        word_list = ds.data[i] 
        if not word_list or i >= fmri.shape[0]:
            continue  
        sentence = " ".join(word_list)
        inputs = tokenizer(
            sentence,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = bert_model(**inputs)
        cls_embed = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        X_all.append(cls_embed)
        Y_all.append(fmri[i])

X_all = np.array(X_all)
Y_all = np.array(Y_all)

print(f"\nCollected {X_all.shape[0]} TRs | BERT dim: {X_all.shape[1]} | Voxels: {Y_all.shape[1]}")

Loading raw_text.pkl and fMRI...
Valid stories with fMRI: 101


Processing stories: 100%|██████████| 101/101 [03:18<00:00,  1.97s/it]



Collected 34700 TRs | BERT dim: 768 | Voxels: 94251


In [6]:
X_z = zscore(X_all.T).T
Y_z = zscore(Y_all)
Y_z = np.nan_to_num(Y_z)

X_delayed = make_delayed(X_z[4:], delays=[0, 1, 2, 3, 4])
Y_trimmed = Y_z[4:]

X_train, X_test, Y_train, Y_test = train_test_split(
    X_delayed, Y_trimmed, test_size=0.2, random_state=42
)

# -------------------- Step 3: Ridge Regression -------------------- #
print("Running ridge regression...")

print(f"Final shapes → X: {X_delayed.shape}, Y: {Y_trimmed.shape}")
print(f"Train: {X_train.shape}, Test: {X_test.shape}")

alphas = np.logspace(1, 3, 20)
ccs = np.array(ridge_corr(X_train, X_test, Y_train, Y_test, alphas))

print("Ridge correlation shape:", ccs.shape)
best_cc = np.max(ccs, axis=0)
print(f"Mean CC:    {np.mean(best_cc):.4f}")
print(f"Median CC:  {np.median(best_cc):.4f}")
print(f"Top 1% CC:  {np.mean(np.sort(best_cc)[-int(0.01*len(best_cc)):]):.4f}")
print(f"Top 5% CC:  {np.mean(np.sort(best_cc)[-int(0.05*len(best_cc)):]):.4f}")

Running ridge regression...
Final shapes → X: (34696, 3840), Y: (34696, 94251)
Train: (27756, 3840), Test: (6940, 3840)


KeyboardInterrupt: 

In [None]:
# Check the Result 
print("X_delayed shape:", X_delayed.shape)
print("Y_trimmed shape:", Y_trimmed.shape)
print("NaN in X:", np.isnan(X_delayed).any(), "| Inf in X:", np.isinf(X_delayed).any())
print("NaN in Y:", np.isnan(Y_trimmed).any(), "| Inf in Y:", np.isinf(Y_trimmed).any())
print("X mean/std:", np.mean(X_delayed), np.std(X_delayed))
print("Y mean/std:", np.mean(Y_trimmed), np.std(Y_trimmed))
print("All zeros in Y:", np.all(Y_trimmed == 0))
nan_voxels = np.isnan(Y_all).sum(axis=0)
print(f"Max NaN count per voxel: {nan_voxels.max()} / {Y_all.shape[0]}")
print(f"Number of voxels with any NaN: {np.sum(nan_voxels > 0)}")

### Fine-Tuning BERT Encoder on Stimulus Texts

In this section, we fine-tune the BERT encoder using the raw story text data via a Masked Language Modeling (MLM) objective.

We use the following approach:

- **Dataset Construction**: We convert each story's TR-aligned text (from the `DataSequence`) into input IDs for BERT, masking a subset of tokens randomly to form an MLM task. Each TR is treated as one training example.
- **Model Architecture**: We reuse the `Encoder` defined in `encoder.py`, which wraps a pre-trained BERT model (`bert-base-uncased`) with a linear decoder head to predict masked tokens.
- **Training Loop**: We leverage `train_encoder.py` to fine-tune the model over multiple epochs. The training loss is computed only on masked positions.

This fine-tuning step helps the BERT encoder adapt to the linguistic distribution of the experimental stimuli, improving the quality of the learned embeddings for downstream voxel prediction.

After training, we will save the encoder weights, which can later be loaded and used for voxel-wise ridge regression in Part 2.


In [None]:
# ======================= Fine-tuning Setup ============================
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from train_encoder import train_bert
from encoder import Encoder

# 1. Prepare Dataset for MLM pretraining
class MLMTextDataset(Dataset):
    def __init__(self, raw_texts, tokenizer, max_length=128):
        self.samples = []
        for story in raw_texts.values():
            for word_list in story.data:
                if not word_list:
                    continue
                sentence = " ".join(word_list)
                tokens = tokenizer(
                    sentence,
                    max_length=max_length,
                    truncation=True,
                    padding='max_length',
                    return_tensors='pt'
                )
                self.samples.append(tokens)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        return {k: v.squeeze(0) for k, v in item.items()}

# 2. Initialize model, tokenizer, and dataset
print("Preparing dataset for fine-tuning...")
dataset = MLMTextDataset(raw_texts, tokenizer, max_length=128)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

print("Initializing encoder model...")
encoder = Encoder(vocab_size=tokenizer.vocab_size).to(device)

# 3. Train the encoder using MLM objective
print("Starting MLM training...")
train_bert(
    model=encoder,
    dataloader=dataloader,
    tokenizer=tokenizer,
    device=device,
    epochs=3, # Change this if you want
    lr=1e-4 # Can change into smaller
)

# 4. Save encoder for later use
torch.save(encoder.state_dict(), "finetuned_encoder.pt") # use this model for new BERT Model
print("Fine-tuned encoder saved as 'finetuned_encoder.pt'")


## Step 3 Fine-tuning with LoRa


In [7]:
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import Dataset
import torch.nn as nn
import torch
import numpy as np

In [8]:
# Select a voxel to predict
voxel_index = 0  # change this to try other voxels

# Input text and corresponding voxel response
X_train_text = X_train  # list of story segments (strings)
X_test_text = X_test
y_train_voxel = Y_train[:, voxel_index]
y_test_voxel = Y_test[:, voxel_index]

In [9]:
import torch
print(torch.cuda.is_available())

True


In [20]:
voxel_index = 0
X_train_text = X_train
X_test_text = X_test
y_train_voxel = Y_train[:, voxel_index]
y_test_voxel = Y_test[:, voxel_index]

# Select subset first
# train_dataset = Dataset.from_dict({"text": X_train_text, "label": y_train_voxel.tolist()}).select(range(200))
# test_dataset = Dataset.from_dict({"text": X_test_text, "label": y_test_voxel.tolist()}).select(range(50)) 


def preprocess(example):
    tokens = tokenizer(str(example["text"]), truncation=True, padding="max_length", max_length=128)
    tokens["label"] = example["label"]
    return tokens

train_dataset = train_dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [19]:
# --- Imports ---
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, AutoConfig, BertModel, BertPreTrainedModel,
    TrainingArguments, Trainer
)
from transformers.modeling_outputs import SequenceClassifierOutput
from peft import get_peft_model, LoraConfig, TaskType

# --- Model Definition with Correct Output ---
class BertRegression(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.1)
        self.regressor = nn.Linear(config.hidden_size, 1)
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs.pooler_output)
        logits = self.regressor(pooled_output).squeeze(-1)

        loss = None
        if labels is not None:
            loss = nn.MSELoss()(logits, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

# --- Tokenizer & Config ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
config = AutoConfig.from_pretrained("bert-base-uncased")

# --- Model + LoRA ---
model = BertRegression.from_pretrained("bert-base-uncased", config=config)

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION
)

model = get_peft_model(model, lora_config).to("cuda")

# --- TrainingArguments (minimal) ---
training_args = TrainingArguments(
    output_dir="./lora-regression",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    logging_dir="./logs",
    logging_steps=5,
)

Some weights of BertRegression were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['regressor.bias', 'regressor.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

trainer.train()

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
No label_names provided for model class `PeftModelForFeatureExtraction`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
5,0.6314
10,1.97
15,0.5798
20,1.1758
25,0.845
30,0.7648
35,0.6042
40,1.6414
45,1.1182
50,0.7957


TrainOutput(global_step=75, training_loss=1.0274275652567546, metrics={'train_runtime': 1.6011, 'train_samples_per_second': 374.734, 'train_steps_per_second': 46.842, 'total_flos': 39602199398400.0, 'train_loss': 1.0274275652567546, 'epoch': 3.0})

In [22]:
torch.save(model.state_dict(), "lora_finetuned_model.pt")

In [23]:
# Proper unpacking
raw_preds = trainer.predict(test_dataset).predictions
preds = raw_preds[0] if isinstance(raw_preds, tuple) else raw_preds

# Flatten if needed
preds = preds.flatten()

# Compute R²
r2 = np.corrcoef(preds, y_test_voxel[:len(preds)])[0, 1] ** 2
print(f"LoRA R² for voxel {voxel_index}: {r2:.4f}")

LoRA R² for voxel 0: 0.0848


In [None]:
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
import torch

def get_lora_bert_embeddings(sentences, model, tokenizer, batch_size=32):
    model.eval()
    embeddings = []

    with torch.no_grad():
        for i in range(0, len(sentences), batch_size):
            batch = sentences[i:i+batch_size]
            inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=128).to("cuda")
            outputs = model.bert(**inputs)
            pooled = outputs.pooler_output.cpu().numpy()
            embeddings.append(pooled)

    return np.vstack(embeddings)  # shape: (n_samples, hidden_size)

# 1. Extract LoRA fine-tuned embeddings
X_train_emb = get_lora_bert_embeddings(X_train, model, tokenizer)
X_test_emb = get_lora_bert_embeddings(X_test, model, tokenizer)

# 2. Fit ridge regression for each voxel
ridge = Ridge(alpha=1.0)
ridge.fit(X_train_emb, Y_train[:, voxel_index])
y_pred = ridge.predict(X_test_emb)

# 3. Evaluate R²
r2_ridge_lora = r2_score(Y_test[:, voxel_index], y_pred)
print(f"Ridge R² using LoRA-embedding for voxel {voxel_index}: {r2_ridge_lora:.4f}")

# Part 2

# 1) Identifying voxels where the model performs well

In [1]:
import numpy as np
import torch
import pickle
import os

# Load the fine-tuned encoder
ccs = np.load("ridge_corr_subject2.npy")  # shape: (num_alphas, num_voxels)
best_cc = np.max(ccs, axis=0)  # shape: (num_voxels,)

# Pick top-performing voxels (e.g., top 5%)
top_n = int(0.05 * len(best_cc))
top_voxel_indices = np.argsort(best_cc)[-top_n:]
print(f"Top {top_n} voxel indices selected for interpretation.")

FileNotFoundError: [Errno 2] No such file or directory: 'ridge_corr_subject2.npy'

# Running SHAP and LIME to Identigy Influential Words

In [None]:
import shap
from lime.lime_text import LimeTextExplainer
from transformers import BertTokenizer, BertModel
from encoder import Encoder 
from train_encoder import train_bert
from torch.utils.data import Dataset, DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer and fine-tuned encoder
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
encoder = Encoder(vocab_size=tokenizer.vocab_size)
encoder.load_state_dict(torch.load("finetuned_encoder.pt"))
encoder.to(device).eval()

# Load raw text + extract one test story
with open("/ocean/projects/mth240012p/shared/data/raw_text.pkl", "rb") as f:
    raw_text = pickle.load(f)

story_name = "losing_my_legs"
sentences = [" ".join(w) for w in raw_text[story_name].data if len(w) > 0][:50]  # Trimmed for speed!!

# Return voxel predictions from BERT CLS token
def predict_voxel_activation(texts):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
    input_ids = inputs['input_ids'].to(device)
    token_type_ids = inputs['token_type_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        hidden = encoder(input_ids, token_type_ids, attention_mask)
        cls_embeds = hidden[:, 0, :]  # CLS token

    # Use pretrained ridge weights per voxel (top_voxel_indices)
    ridge_weights = np.load("ridge_weights_top_voxels.npy")  # (voxel_count, cls_dim)
    preds = cls_embeds.cpu().numpy() @ ridge_weights.T        # (num_sentences, top_voxels)
    return preds

# SHAP 
shap_explainer = shap.Explainer(predict_voxel_activation, tokenizer)
shap_values = shap_explainer(sentences)

# LIME
lime_explainer = LimeTextExplainer(class_names=[f"Voxel {i}" for i in top_voxel_indices])
lime_exp = lime_explainer.explain_instance(sentences[0], predict_voxel_activation, num_features=10)


# Comparing SHAP and LIME

Required work:

We will add side-by-side comparisons and try across a couple more voxels. could have something like the following:

EXAMPLE CODE
<!-- 
def compare_word_importance(shap_values, lime_exp, voxel_idx):
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # SHAP bar plot
    shap.plots.bar(shap_values[:, :, voxel_idx].mean(0), show=False, ax=axes[0])
    axes[0].set_title(f"SHAP - Voxel {voxel_idx}")
    
    # LIME bar plot
    lime_exp.as_pyplot_figure(axes[1])
    axes[1].set_title(f"LIME - Voxel {voxel_idx}")
    
    plt.tight_layout()
    plt.show()


compare_word_importance(shap_values, lime_exp, voxel_idx=0)
 -->


In [None]:
import matplotlib.pyplot as plt

# SHAP plot
shap.plots.text(shap_values[0]) 
shap.plots.bar(shap_values[:, :, 0].mean(0))  # average over sentences for Voxel 0

# LIME plot
lime_fig = lime_exp.as_pyplot_figure()
plt.title("LIME Word Importance for One Sentence (Voxel-Level)")
plt.tight_layout()
plt.show()

# Repeat all analysis for another test-story

In [None]:
# Just change the story name and re-run Section 2 onward??? not sure
story_name = "myfirstdaywiththeyankees"
sentences = [" ".join(w) for w in raw_text[story_name].data if len(w) > 0][:50]

# Rerun: predict_voxel_activation(), SHAP, and LIME
shap_values = shap_explainer(sentences)
shap.plots.text(shap_values[0])
lime_exp = lime_explainer.explain_instance(sentences[0], predict_voxel_activation, num_features=10)
lime_fig = lime_exp.as_pyplot_figure()
plt.title("LIME Word Importance for Another Sentence (Voxel-Level)")
plt.show()