# Fine-tuning BERT to generate word embeddings

In [None]:
import os

# Create the directory if it doesn't exist:
scratch_cache = "/ocean/projects/mth250011p/smazioud/huggingface_cache"
os.makedirs(scratch_cache, exist_ok=True)

# Force HuggingFace + PyTorch to store EVERYTHING there
os.environ['HF_HOME'] = scratch_cache
os.environ['TRANSFORMERS_CACHE'] = scratch_cache
os.environ['HF_DATASETS_CACHE'] = scratch_cache
os.environ['HF_MODULES_CACHE'] = scratch_cache
os.environ['HF_METRICS_CACHE'] = scratch_cache
os.environ['TORCH_HOME'] = scratch_cache
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

from huggingface_hub import login
login()

print("HF cache directory:", scratch_cache)

In [None]:
import os
import sys
import pickle
import matplotlib.pyplot as plt
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from datasets import Dataset
from fine_tune_bert import mask_tokens, train_bert, apply_lora_to_bert
from data_cleaning import clean_data

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


In [None]:
repo_root = "/jet/home/smazioud/stat215a_final"
sys.path.append(repo_root)

raw_path = "/ocean/projects/mth250011p/shared/215a/final_project/data/raw_text.pkl"

with open(raw_path, "rb") as f:
    raw = pickle.load(f)

story_texts, story_ids = clean_data(raw)
print(f"\nNumber of stories: {len(story_texts)}")
print(f"Example story ID: {story_ids[0]}")
print(f"Example text (first 200 chars):\n{story_texts[0][:200]}\n")


### Load tokenizer

In [None]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
print(f"Tokenizer loaded: {model_name}")

### Tokenize dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch

max_seq_len = 128  # fixed window length for BERT

all_input_ids = []
all_attention_masks = []

for text in story_texts:
    # 1) Tokenize the full story into token IDs (no truncation, no padding)
    token_ids = tokenizer(
        text,
        add_special_tokens=False,      # we don't want [CLS]/[SEP] repeated across chunks
        return_attention_mask=False,
        return_tensors=None,
    )["input_ids"]

    # 2) Chunk into windows of length max_seq_len
    for i in range(0, len(token_ids), max_seq_len):
        chunk = token_ids[i:i + max_seq_len]

        # 3) Pad each chunk up to max_seq_len
        padded = tokenizer.pad(
            {"input_ids": [chunk]},
            padding="max_length",
            max_length=max_seq_len,
            return_attention_mask=True,
            return_tensors="pt",
        )

        all_input_ids.append(padded["input_ids"][0])          # (seq_len,)
        all_attention_masks.append(padded["attention_mask"][0])

# 4) Stack into tensors
input_ids = torch.stack(all_input_ids)         # (num_chunks, max_seq_len)
attention_mask = torch.stack(all_attention_masks)

print("Number of chunks:", input_ids.shape[0])
print("Sequence length:", input_ids.shape[1])


In [None]:
class TextDataset(Dataset):
    def __init__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        assert input_ids.shape == attention_mask.shape
        self.input_ids = input_ids
        self.attention_mask = attention_mask

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
        }

dataset = TextDataset(input_ids, attention_mask)

# Simple train/val split
val_frac = 0.2
n_total = len(dataset)
n_val = int(val_frac * n_total)
n_train = n_total - n_val

train_ds, val_ds = torch.utils.data.random_split(dataset, [n_train, n_val])

batch_size = 16

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")


## Fine-tune

In [None]:
# Load BERT, apply LORA

from transformers import BertForMaskedLM

base_model = BertForMaskedLM.from_pretrained(model_name)
base_model = base_model.to(DEVICE)

print(f"Base model parameters: {sum(p.numel() for p in base_model.parameters()):,}")

# Wrap with LoRA adapters
lora_model = apply_lora_to_bert(base_model, r=8, alpha=16, dropout=0.1)
lora_model = lora_model.to(DEVICE)

print(f"LoRA-wrapped model is on device: {DEVICE}")


In [None]:
epochs = 5
lr = 5e-5

trained_model = train_bert(
    model=lora_model,
    train_loader=train_loader,
    tokenizer=tokenizer,
    val_loader=val_loader,
    epochs=epochs,
    lr=lr,
    device=DEVICE,
)

output_dir = Path(repo_root) / "saved_models" / "bert_lora_finetuned"
output_dir.mkdir(parents=True, exist_ok=True)

trained_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Fine-tuned LoRA model saved to: {output_dir}")

## Embeddings

In [None]:
from transformers import BertModel

output_dir = Path(repo_root) / "saved_models" / "bert_lora_finetuned"
finetuned_bert = BertModel.from_pretrained(output_dir, output_hidden_states=True)
finetuned_bert = finetuned_bert.to(DEVICE)
finetuned_bert.eval()

print("Loaded finetuned encoder for embeddings.")

In [None]:
def get_bert_word_embeddings_for_words(
    words,
    tokenizer,
    model: BertModel,
    device=DEVICE,
    chunk_size: int = 256,
):
    """
    Compute word-level embeddings for a list of words using BERT.

    - Tokenize with is_split_into_words=True
    - Average subword token embeddings per word
    """
    model.eval()
    hidden_dim = model.config.hidden_size
    T = len(words)
    embs = np.zeros((T, hidden_dim), dtype=np.float32)

    start = 0
    with torch.no_grad():
        while start < T:
            end = min(start + chunk_size, T)
            chunk_words = words[start:end]

            enc = tokenizer(
                chunk_words,
                is_split_into_words=True,
                return_tensors="pt",
                padding=True,
                truncation=True,
            ).to(device)

            outputs = model(**enc)
            last_hidden = outputs.last_hidden_state.cpu()  # (1, L, H)

            # Map tokens back to original word indices
            word_ids = enc.word_ids(batch_index=0)
            word_to_vecs = {}

            for tok_idx, w_id in enumerate(word_ids):
                if w_id is None:
                    continue
                word_to_vecs.setdefault(w_id, []).append(last_hidden[0, tok_idx].numpy())

            # Average subword embeddings per word
            for local_w_id, vecs in word_to_vecs.items():
                global_w_id = start + local_w_id
                if global_w_id < T:
                    embs[global_w_id] = np.mean(vecs, axis=0)

            start = end

    return embs


In [None]:
bert_embeddings = {}  # story_id -> dict

for story_id, ds in raw.items():
    words = ds.data               # list[str]
    word_times = ds.data_times    # np.array, shape (T,)
    tr_times = ds.tr_times        # np.array, shape (n_TR,)

    print(f"Processing story: {story_id}, #words = {len(words)}")

    embs = get_bert_word_embeddings_for_words(
        words=words,
        tokenizer=tokenizer,
        model=finetuned_bert,
        device=DEVICE,
        chunk_size=256
    )  # shape (T, 768)

    print("  embeddings shape:", embs.shape)

    bert_embeddings[story_id] = {
        "words": words,
        "word_times": word_times,
        "tr_times": tr_times,
        "embeddings": embs,
    }

out_path = output_dir / "bert_lora_finetuned_word_embeddings.pkl"

with open(out_path, "wb") as f:
    pickle.dump(bert_embeddings, f)

## Visualizations

### PCA

In [None]:
emb_path = Path(output_dir) / "bert_lora_finetuned_word_embeddings.pkl"

with open(emb_path, "rb") as f:
    bert_embeddings = pickle.load(f)

first_sid = sorted(bert_embeddings.keys())[1]
story_data = bert_embeddings[first_sid]

embs = story_data["embeddings"]   # shape (T, 768)
words = story_data["words"]       # list of T words
print(first_sid, embs.shape)

T = embs.shape[0]   # total number of word-level embeddings

if T <= 300:
    idx = np.arange(T)
else:
    idx = np.linspace(0, T-1, 300, dtype=int)

embs_ds = embs[idx]          # (300, 768)
words_ds = [words[i] for i in idx]


from sklearn.decomposition import PCA

pca = PCA(n_components=2)
embs_pca = pca.fit_transform(embs_ds)

plt.figure(figsize=(12, 8))
plt.scatter(embs_pca[:, 0], embs_pca[:, 1], s=12, alpha=0.7)

# Label first 80 words
num_labels = min(80, len(words_ds))
for i in range(num_labels):
    plt.text(
        embs_pca[i, 0],
        embs_pca[i, 1],
        words_ds[i],
        fontsize=8,
        color="black",
        alpha=0.9
    )

plt.title("PCA (downsampled to 300 words) â€“ First 80 words labeled")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()



In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# Downsample words if the story is long
max_words_for_viz = 300
idxs = np.linspace(0, len(example_words) - 1,
                   num=min(len(example_words), max_words_for_viz)).astype(int)

viz_words = [example_words[i] for i in idxs]
viz_vecs = finetuned_embeddings[idxs]

pca = PCA(n_components=2)
vecs_2d = pca.fit_transform(viz_vecs)


plt.figure(figsize=(10, 10))
for i, w in enumerate(viz_words[:80]):  # label first 80 to avoid clutter
    x, y = vecs_2d[i]
    plt.scatter(x, y, s=5)
    plt.text(x + 0.01, y + 0.01, w, fontsize=8)
plt.title(f"PCA with word labels (subset)\nStory {example_story_id}")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.grid(True)
plt.show()


### t-SNE

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(
    n_components=2,
    perplexity=30,        # good for 200 < n < 2000
    learning_rate=200, 
    init='pca',           # MUCH better results
    random_state=42,
    n_jobs=-1             # parallel threads (sklearn >=1.4)
)

emb2d_tsne = tsne.fit_transform(viz_vecs)

plt.figure(figsize=(10,10))
for i, w in enumerate(viz_words[:80]):  # label first 80 to avoid clutter
    x, y = emb2d_tsne[i]
    plt.scatter(x, y, s=5)
    plt.text(x + 0.01, y + 0.01, w, fontsize=8)
plt.title("t-SNE of BERT Embeddings")
plt.xlabel("Dim 1")
plt.ylabel("Dim 2")
plt.show()


## Pre-processing

In [None]:
from preprocessing import downsample_word_vectors, make_delayed

BASE_DIR = Path("/ocean/projects/mth250011p/shared/215a/final_project")
TEXT_PATH = BASE_DIR / "data" / "raw_text.pkl"

BOLD_BASE = BASE_DIR / "data"
SUBJECT_DIRS = {
    2: BOLD_BASE / "subject2",
    3: BOLD_BASE / "subject3",
}
BERT_EMB_PATH = Path(repo_root) / "embeddings" / "bert_lora_finetuned" / "bert_lora_finetuned_word_embeddings.pkl"

# DataSequence (wordseqs)
with open(TEXT_PATH, "rb") as f:
    wordseqs = pickle.load(f)   # dict: story_id -> DataSequence

print("wordseqs stories:", list(wordseqs.keys())[:5])

# BERT embedding
with open(BERT_EMB_PATH, "rb") as f:
    bert_emb = pickle.load(f)   # dict: story_id -> {..., "embeddings": (T,768)}

print("bert_emb stories:", list(bert_emb.keys())[:5])

# check story id
stories = sorted(set(wordseqs.keys()) & set(bert_emb.keys()))
print("num stories:", len(stories))
stories[:5]

### Downsample

In [None]:
# downsample_word_vectors
word_vectors = {}
for sid in stories:
    embs = bert_emb[sid]["embeddings"]   # (num_words, 768)
    word_vectors[sid] = embs.astype("float32")

downsampled_semanticseqs = downsample_word_vectors(
    stories=stories,
    word_vectors=word_vectors,
    wordseqs=wordseqs
)

### Trim and delay

In [None]:
OUT_DIR = "/ocean/projects/mth250011p/smazioud/" / "preprocessing" / "bert_lora_finetuned"
OUT_DIR.mkdir(parents=True, exist_ok=True)

def preprocess_subject_streaming(subject_id, delays=None):
    subj_dir = SUBJECT_DIRS[subject_id]
    assert subj_dir.is_dir(), f"{subj_dir} does not exist"

    missing_stories = []

    for sid in stories:
        ds = wordseqs[sid]
        tr_times = ds.tr_times
        stim_tr  = downsampled_semanticseqs[sid]

        assert stim_tr.shape[0] == len(tr_times)

        bold_path = subj_dir / f"{sid}.npy"
        if not bold_path.is_file():
            print(f"[WARN] Subject {subject_id}: missing BOLD for story '{sid}', skipping.")
            missing_stories.append(sid)
            continue

        bold = np.load(bold_path)

        n_stim = stim_tr.shape[0]
        n_bold = bold.shape[0]

        if n_stim < n_bold:
            print(f"[WARN] {sid}: stim shorter than bold, skipping.")
            missing_stories.append(sid)
            continue

        # TR trimming to match BOLD length
        diff = n_stim - n_bold
        drop_start = diff // 3 if diff > 0 else 0
        drop_end = diff - drop_start if diff > 0 else 0

        stim_trim = stim_tr[drop_start : n_stim - drop_end]

        if stim_trim.shape[0] != n_bold:
            print(f"[WARN] {sid}: mismatch after trim, skipping.")
            missing_stories.append(sid)
            continue

        # delay
        if delays is None:
            raise ValueError("delays must be provided if only X_delayed is saved.")

        X_delayed = make_delayed(stim_trim, delays=delays)
        X_delayed = X_delayed.astype("float32")

        bold = bold.astype("float32")

        result = {
            "X_delayed": X_delayed,   # (N, 768 * len(delays))
            "bold": bold,             # (N, n_vox)
        }

        out_file = OUT_DIR / f"subject{subject_id}_{sid}_Xdelayed.pkl"
        with open(out_file, "wb") as f:
            pickle.dump(result, f)

        print(
            f"[SAVE] Subject {subject_id}, story {sid}: "
            f"X_delayed {X_delayed.shape}, bold {bold.shape}, saved"
        )

        del bold, stim_trim, X_delayed, result

    if missing_stories:
        print(f"\n[INFO] Subject {subject_id} skipped stories:")
        for s in missing_stories:
            print("  -", s)
    else:
        print(f"\n[INFO] Subject {subject_id}: all stories processed.")


In [None]:
delays = [1,2,3,4]
preprocess_subject_streaming(2, delays=delays)
preprocess_subject_streaming(3, delays=delays)