In [26]:
import os
import pickle
import numpy as np
import torch
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from ridge_utils.utils import zscore, make_delayed
from ridge_utils.ridge import ridge_corr

# -------------------- Step 1: 配置 -------------------- #
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = "/ocean/projects/mth240012p/shared/data"
subject = "subject2"
max_tokens = 50

# -------------------- Step 2: 加载原始文本和 fMRI -------------------- #
print("Loading raw_text.pkl and fMRI...")
with open(os.path.join(data_dir, "raw_text.pkl"), "rb") as f:
    raw_texts = pickle.load(f)

story_names = []
Y_dict = {}
for story in raw_texts:
    fmri_path = os.path.join(data_dir, subject, f"{story}.npy")
    if os.path.exists(fmri_path):
        Y_dict[story] = np.load(fmri_path)
        story_names.append(story)

print(f"Valid stories with fMRI: {len(story_names)}")

# -------------------- Step 3: 加载 BERT 模型 -------------------- #
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)
bert_model.eval()

# -------------------- Step 4: 遍历每个 story → 每个 TR -------------------- #
X_all, Y_all = [], []

for story in tqdm(story_names, desc="Processing stories"):
    fmri = Y_dict[story]
    ds = raw_texts[story]  # ✅ DataSequence 对象
    for i in range(len(ds.data)):
        word_list = ds.data[i]  # ✅ 兼容 __getitem__
        if not word_list or i >= fmri.shape[0]:
            continue  # 跳过空 TR 或 fMRI 数据缺失
        sentence = " ".join(word_list)
        inputs = tokenizer(
            sentence,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = bert_model(**inputs)
        cls_embed = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        X_all.append(cls_embed)
        Y_all.append(fmri[i])

X_all = np.array(X_all)
Y_all = np.array(Y_all)

print(f"\nCollected {X_all.shape[0]} TRs | BERT dim: {X_all.shape[1]} | Voxels: {Y_all.shape[1]}")

# -------------------- Step 5: Z-score 和划分 Train/Test -------------------- #
X_z = zscore(X_all.T).T
Y_z = zscore(Y_all)

X_train, X_test, Y_train, Y_test = train_test_split(X_z, Y_z, test_size=0.2, random_state=42)

# -------------------- Step 6: Ridge Regression -------------------- #
print("Running ridge regression...")
alphas = np.logspace(1, 3, 20)
ccs = np.array(ridge_corr(make_delayed(X_train, [0,1,2,3,4]),
                          make_delayed(X_test, [0,1,2,3,4]),
                          Y_train, Y_test, alphas))

# -------------------- Step 7: 输出结果 -------------------- #
print("Ridge correlation shape:", ccs.shape)
best_cc = np.max(ccs, axis=0)
print(f"Mean CC:    {np.mean(best_cc):.4f}")
print(f"Median CC:  {np.median(best_cc):.4f}")
print(f"Top 1% CC:  {np.mean(np.sort(best_cc)[-int(0.01*len(best_cc)):]):.4f}")
print(f"Top 5% CC:  {np.mean(np.sort(best_cc)[-int(0.05*len(best_cc)):]):.4f}")


Loading raw_text.pkl and fMRI...
Valid stories with fMRI: 101


Processing stories: 100%|██████████| 101/101 [40:25<00:00, 24.01s/it]



Collected 34700 TRs | BERT dim: 768 | Voxels: 94251
Running ridge regression...
Ridge correlation shape: (20, 94251)
Mean CC:    0.0000
Median CC:  0.0000
Top 1% CC:  0.0000
Top 5% CC:  0.0000
