In [14]:
import sys
sys.path.append('..') 
from ridge_utils.DataSequence import DataSequence
import torch
import pickle
import os
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
from encoder import Encoder
from data import TextDataset
from train_encoder import train_bert

In [15]:
# Load the podcast text data
data_path = "/ocean/projects/mth240012p/shared/data"
raw_text_path = os.path.join(data_path, "raw_text.pkl")

with open(raw_text_path, "rb") as f:
    raw_text = pickle.load(f)

print(f"Loaded {len(raw_text)} podcast stories")

# Flatten all the words from all stories into a single list
all_sentences = []
for story in raw_text.values():
    all_sentences.extend(story.data)

# Tokenizer & Dataset
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
dataset = TextDataset(all_sentences, tokenizer, max_len=32)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Create encoder model
device = "cuda" if torch.cuda.is_available() else "cpu"
encoder = Encoder(vocab_size=tokenizer.vocab_size)
encoder = encoder.to(device)

# Train the model using masked language modeling
train_bert(
    model=encoder,
    dataloader=dataloader,
    tokenizer=tokenizer,
    epochs=20, # we can change this 
    lr=5e-4, # same as well 
    device=device
)

# Save pretrained encoder weights
torch.save(encoder.state_dict(), "pretrained_encoder.pt")
print("Pretrained encoder saved as pretrained_encoder.pt")

Loaded 109 podcast stories
Pretrained encoder saved as pretrained_encoder.pt
