# Part1: Fine-tuning

In [2]:
import sys
import numpy as np
sys.path.append('..') 
from ridge_utils.DataSequence import DataSequence
import torch
import pickle
import os
from transformers import BertTokenizerFast, BertModel
from torch.utils.data import DataLoader
from encoder import Encoder
from data import TextDataset
from train_encoder import train_bert
from data import *
from encoder import *
from preprocessing import *
from train_encoder import *

In [3]:
# === Step 1: Load raw tokenized text ===

data_path = "/ocean/projects/mth240012p/shared/data"
raw_text_path = os.path.join(data_path, "raw_text.pkl")
with open(raw_text_path, "rb") as f:
    raw_text = pickle.load(f)

# 提取故事文本
story_names = []
story_texts = []
for name, ds in raw_text.items():
    if isinstance(ds, DataSequence) and hasattr(ds, "data"):
        story_names.append(name)
        story_texts.append(" ".join(ds.data))  # Concatenate token list

print(f"Loaded {len(story_names)} story texts.")
print("Example:", story_names[0])
print(story_texts[0][:200], "...")

Loaded 109 story texts.
Example: sweetaspie
 i embarked on a journey toward the sea of matrimony at the perilous age of forty one yeah you'd think forty one a trip to marriage would be pretty smooth but nobody had told my family my sister calle ...


In [4]:
# === Step 2: Extract BERT CLS embeddings ===
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased')
bert.eval()

max_length = 512
all_token_embeddings = {}
valid_names = []

with torch.no_grad():
    for name, text in zip(story_names, story_texts):
        try:
            words = text.split()  
            chunks = [words[i:i+max_length] for i in range(0, len(words), max_length)]

            story_embeds = []
            for chunk in chunks:
                inputs = tokenizer(
                    chunk,
                    is_split_into_words=True,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=max_length
                )
                outputs = bert(**inputs)
                emb = outputs.last_hidden_state.squeeze(0).cpu().numpy()  # (512, 768)

                actual_len = len(chunk)
                story_embeds.append(emb[:actual_len])

            full_emb = np.vstack(story_embeds)
            if full_emb.ndim == 2:
                tr_time_length = len(wordseqs[name].tr_times)
                if full_emb.shape[0] > tr_time_length:
                    full_emb = full_emb[:tr_time_length, :]
                elif full_emb.shape[0] < tr_time_length:
                    padding = np.zeros((tr_time_length - full_emb.shape[0], full_emb.shape[1]))
                    full_emb = np.vstack([full_emb, padding])

                all_token_embeddings[name] = full_emb
                valid_names.append(name)
        except Exception as e:
            print(f"Skipping {name} due to error: {e}")

Skipping sweetaspie due to error: name 'wordseqs' is not defined
Skipping thatthingonmyarm due to error: name 'wordseqs' is not defined
Skipping tildeath due to error: name 'wordseqs' is not defined
Skipping indianapolis due to error: name 'wordseqs' is not defined
Skipping lawsthatchokecreativity due to error: name 'wordseqs' is not defined
Skipping golfclubbing due to error: name 'wordseqs' is not defined
Skipping jugglingandjesus due to error: name 'wordseqs' is not defined
Skipping shoppinginchina due to error: name 'wordseqs' is not defined
Skipping cocoonoflove due to error: name 'wordseqs' is not defined
Skipping hangtime due to error: name 'wordseqs' is not defined
Skipping beneaththemushroomcloud due to error: name 'wordseqs' is not defined
Skipping dialogue4 due to error: name 'wordseqs' is not defined
Skipping thepostmanalwayscalls due to error: name 'wordseqs' is not defined
Skipping stumblinginthedark due to error: name 'wordseqs' is not defined
Skipping kiksuya due to err

In [5]:
# === Step 3: Get wordseqs from raw_text.pkl ===
wordseqs = {
    name: ds for name, ds in raw_text.items()
    if isinstance(ds, DataSequence) and name in all_token_embeddings
}
print(f"Retrieved {len(wordseqs)} DataSequence objects matching BERT stories.")

Retrieved 0 DataSequence objects matching BERT stories.


In [6]:
# === Step 4: Downsample embeddings to TR-aligned signals ===
from preprocessing import downsample_word_vectors, make_delayed
    
X_ds = downsample_word_vectors(valid_names, all_token_embeddings, wordseqs)

# Create lagged features for time delay
X_lagged = {
    s: make_delayed(X_ds[s], delays=[1, 2, 3, 4])
    for s in valid_names if s in X_ds
}

print(f"Downsampled and delayed embeddings for {len(X_lagged)} stories.")

Downsampled and delayed embeddings for 0 stories.
