In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset
class FinancialDatasetBERT(Dataset):
    def __init__(self, text_path, tokenizer_name, max_len=128):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_len = max_len
        self.data = self.load_data(text_path)

    def load_data(self, path):
        with open(path, 'r') as f:
            text = f.read().split('\n')
        return [line.strip() for line in text if line.strip() and not line.startswith('---')]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence = self.data[idx]
        encoding = self.tokenizer(sentence, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        # Create target ids shifted by one (language modeling)
        target_ids = input_ids.clone()
        target_ids[:-1] = input_ids[1:]
        target_ids[-1] = self.tokenizer.pad_token_id

        return input_ids, attention_mask, target_ids


In [2]:
# Collate function
def collate_fn_bert(batch):
    input_ids, attention_masks, target_ids = zip(*batch)
    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.stack(target_ids)
    )

In [4]:
# Model with unidirectional LSTM
class LSTMWithBERT(nn.Module):
    def __init__(self, bert_model_name, hidden_dim, vocab_size):
        super(LSTMWithBERT, self).__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        # Set bidirectional=False to use a unidirectional LSTM
        self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():  # Freeze BERT for efficiency (optional)
            bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        lstm_out, _ = self.lstm(bert_output.last_hidden_state)
        output = self.fc(lstm_out)
        return output


bert_model_name = "bert-base-uncased"
max_len = 128
batch_size = 8
hidden_dim = 512
num_epochs = 3
checkpoint_path = "bert_lstm_checkpoint.pth"
text_file_path = "/content/tokenized_output.txt"

# Dataset and DataLoader
dataset = FinancialDatasetBERT(text_file_path, bert_model_name, max_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_bert)

# Model setup
vocab_size = dataset.tokenizer.vocab_size
model = LSTMWithBERT(bert_model_name, hidden_dim, vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.tokenizer.pad_token_id)

# Load checkpoint if exists
start_epoch = 0
if os.path.exists(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")

# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0

    for input_ids, attention_mask, target_ids in dataloader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        target_ids = target_ids.to(device)

        optimizer.zero_grad()
        output = model(input_ids, attention_mask)
        loss = criterion(output.view(-1, vocab_size), target_ids.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"📘 Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1}")


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

📘 Epoch 1/3, Loss: 9.548226
Checkpoint saved at epoch 1
📘 Epoch 2/3, Loss: 6.260621
Checkpoint saved at epoch 2
📘 Epoch 3/3, Loss: 3.707253
Checkpoint saved at epoch 3


In [5]:
def calculate_perplexity(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for input_ids, attention_mask, target_ids in dataloader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            target_ids = target_ids.to(device)

            output = model(input_ids, attention_mask)
            loss = criterion(output.view(-1, vocab_size), target_ids.view(-1))
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    return perplexity


In [7]:
def generate_text(model, tokenizer, prompt, max_new_tokens=20):
    model.eval()
    input_encoding = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    input_ids = input_encoding["input_ids"].to(device)
    attention_mask = input_encoding["attention_mask"].to(device)

    generated_ids = input_ids.clone()

    with torch.no_grad():
        for _ in range(max_new_tokens):
            output = model(generated_ids, attention_mask)
            next_token_logits = output[:, -1, :]  # Get the last time step
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
            generated_ids = torch.cat((generated_ids, next_token_id), dim=1)

            # Update attention mask
            attention_mask = torch.ones_like(generated_ids).to(device)

            # Stop if [SEP] or [PAD] generated
            if next_token_id.item() in [tokenizer.sep_token_id, tokenizer.pad_token_id]:
                break

    generated_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)
    return generated_text

# 🔍 Run generation on the prompt
prompt = "what are the global trends in finance"
output_text = generate_text(model, dataset.tokenizer, prompt)
print(f"Prompt: {prompt}\n Model Output: {output_text}")

Prompt: what are the global trends in finance
 Model Output: what are the global trends in finance
