In [1]:
from datasets import load_dataset
from glob import glob
from random import shuffle

train_files = glob("data/train/docs_*.jsonl")
shuffle(train_files)

test_files = glob("data/test/docs_*.jsonl")

data_files = {
    "train": train_files,
    "test": test_files
}

dataset = load_dataset("json", data_files=data_files, streaming=True)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [2]:
from utils.streaming_dataset import StreamingTokenDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# from transformers import AutoModel
# model = AutoModel.from_pretrained("allegro/herbert-base-cased")

tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")

CONTEXT_LENGTH = 128
BATCH_SIZE = 64

train_loader = DataLoader(StreamingTokenDataset(train_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)
test_loader = DataLoader(StreamingTokenDataset(test_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)


In [3]:
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


vocab_size = 50_000   # number of tokens
embed_dim = 768     # embedding dimension
hidden_dim = 768     # LSTM hidden size
num_layers = 2


class SimpleLSTM(nn.Module):
    def __init__(
            self, 
            vocab_size: int, 
            embed_dim: int, 
            hidden_dim: int, 
            num_layers: int,
            embedding: Optional[nn.Embedding] = None
        ):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.fc.weight = self.embed.weight

    def forward(self, x, hidden=None):
        x = self.embed(x)            # [batch, seq_len, embed_dim]
        out, hidden = self.lstm(x, hidden)  # [batch, seq_len, hidden_dim]
        logits = self.fc(out)        # [batch, seq_len, vocab_size]
        return logits, hidden

In [4]:
lstm = SimpleLSTM(vocab_size, embed_dim, hidden_dim, num_layers)
param_count = sum(p.numel() for p in lstm.parameters() if p.requires_grad)
print(f"LSTM has {param_count} trainable params")

LSTM has 47899472 trainable params


In [5]:
import torch

def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [6]:
train_batch_count = 0
for _ in train_loader:
    train_batch_count += 1

test_batch_count = 0
for _ in test_loader:
    test_batch_count += 1

Token indices sequence length is longer than the specified maximum sequence length for this model (8868 > 512). Running this sequence through the model will result in indexing errors


In [7]:
train_batch_count, test_batch_count

(1419, 28)

In [11]:
import torch.nn as nn
from tqdm import tqdm

epochs = 10
learning_rate = 1e-3
weight_decay = 1e-2
grad_clip = 1.0
device = torch.device(choose_device())

print(f"Training on device: {device}")

lstm.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(lstm.parameters(), lr=learning_rate, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)

for epoch in range(epochs):

    lstm.train()
    total_loss = 0.0

    progress = tqdm(enumerate(train_loader), total=train_batch_count, desc=f"Epoch {epoch + 1}/{epochs}")

    for i, (batch_x, batch_y) in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        out, _ = lstm(batch_x)

        # Flatten for CrossEntropyLoss
        loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(lstm.parameters(), grad_clip)

        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (i + 1)

        progress.set_postfix({"loss": f"{avg_loss:.4f}", "lr": optimizer.param_groups[0]["lr"]})

        break

    scheduler.step()

    if epoch < epochs - 1:
        torch.save(lstm.state_dict(), f"lstm_epoch_{epoch}.pt")


    print(f"Epoch {epoch} done | Average training loss: {avg_loss:.4f}")
    print(f"Perplexity on training data: {torch.math.exp(avg_loss)}\n")

    with torch.no_grad():
        progress = tqdm(enumerate(test_loader), total=test_batch_count, desc=f"Epoch {epoch + 1}/{epochs}")

        for i, (batch_x, batch_y) in progress:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            out, _ = lstm(batch_x)
            loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))

            total_loss += loss.item()
            avg_loss = total_loss / (i + 1)

    print(f"Average loss on held-out_dataset: {avg_loss:.4f}")
    print(f"Perplexity on held-out data: {torch.math.exp(avg_loss)}\n")

torch.save(lstm.state_dict(), "lstm_final.pt")
print("Training complete. Model saved to lstm_final.pt")


Training on device: mps


Epoch 1/10:   0%|          | 0/1419 [00:00<?, ?it/s, loss=5.9067, lr=0.001]


Epoch 0 done | Average training loss: 5.9067
Perplexity on training data: 367.5097045863052



Epoch 1/10: 100%|██████████| 28/28 [00:04<00:00,  5.82it/s]


Average loss on held-out_dataset: 11.3972
Perplexity on held-out data: 89074.39988224805



Epoch 2/10:   0%|          | 0/1419 [00:00<?, ?it/s, loss=5.4826, lr=0.000976]


Epoch 1 done | Average training loss: 5.4826
Perplexity on training data: 240.47002454843087



Epoch 2/10: 100%|██████████| 28/28 [00:04<00:00,  5.86it/s]


Average loss on held-out_dataset: 11.1095
Perplexity on held-out data: 66800.54730864591



Epoch 3/10:   0%|          | 0/1419 [00:00<?, ?it/s, loss=4.8194, lr=0.000905]


KeyboardInterrupt: 

In [12]:
def generate_text(model, tokenizer, prompt, max_new_tokens=20, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    model.eval()

    # Encode prompt
    tokens = tokenizer.encode(prompt, add_special_tokens=False)
    input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    generated_tokens = tokens.copy()

    with torch.no_grad():
        hidden = None
        for _ in range(max_new_tokens):
            out, hidden = model(input_ids, hidden)
            last_logits = out[0, -1, :]  # last token
            probs = torch.softmax(last_logits, dim=-1)
            predicted_id = torch.argmax(probs).item()

            # Append predicted token
            generated_tokens.append(predicted_id)
            
            # Prepare next input
            input_ids = torch.tensor([[predicted_id]], dtype=torch.long).to(device)
    
    # Decode full sequence
    text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return text


In [14]:
# generate_text(lstm, tokenizer)