In [1]:
from datasets import load_dataset
from glob import glob
from random import shuffle

train_files = glob("data/train/docs_*.jsonl")
shuffle(train_files)

test_files = glob("data/test/docs_*.jsonl")

data_files = {
    "train": train_files,
    "test": test_files
}

dataset = load_dataset("json", data_files=data_files, streaming=True)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [2]:
from utils.streaming_dataset import StreamingTokenDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# from transformers import AutoModel
# model = AutoModel.from_pretrained("allegro/herbert-base-cased")

tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")

CONTEXT_LENGTH = 128
BATCH_SIZE = 32

train_loader = DataLoader(StreamingTokenDataset(train_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)
test_loader = DataLoader(StreamingTokenDataset(test_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)


In [3]:
from architectures.lstm import SimpleLSTM

vocab_size = 50_000   # number of tokens
embed_dim = 384     # embedding dimension
hidden_dim = 384     # LSTM hidden size
num_layers = 2

lstm = SimpleLSTM(vocab_size, embed_dim, hidden_dim, num_layers)
param_count = sum(p.numel() for p in lstm.parameters() if p.requires_grad)
print(f"LSTM has {param_count} trainable params")

LSTM has 21615440 trainable params


In [4]:
import torch

def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [6]:
train_batch_count = 0
for _ in train_loader:
    train_batch_count += 1

test_batch_count = 0
for _ in test_loader:
    test_batch_count += 1

Token indices sequence length is longer than the specified maximum sequence length for this model (2746 > 512). Running this sequence through the model will result in indexing errors


In [7]:
train_batch_count, test_batch_count

(5421, 56)

In [8]:
import torch.nn as nn
from tqdm import tqdm

epochs = 25
learning_rate = 1e-3
weight_decay = 1e-2
grad_clip = 1.0
device = torch.device(choose_device())

print(f"Training on device: {device}")

lstm.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(lstm.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(epochs):

    lstm.train()
    total_loss = 0.0

    progress = tqdm(enumerate(train_loader), total=train_batch_count, desc=f"Epoch {epoch + 1}/{epochs}")

    for i, (batch_x, batch_y) in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        out, _ = lstm(batch_x)

        # Flatten for CrossEntropyLoss
        loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(lstm.parameters(), grad_clip)

        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (i + 1)

        progress.set_postfix({"loss": f"{avg_loss:.4f}", "lr": optimizer.param_groups[0]["lr"]})

    if epoch % 3 == 2:
        torch.save(lstm.state_dict(), f"lstm_epoch_{epoch + 1}.pt")


    print(f"Epoch {epoch} done | Average training loss: {avg_loss:.4f}")
    print(f"Perplexity on training data: {torch.math.exp(avg_loss)}\n")

    with torch.no_grad():
        progress = tqdm(enumerate(test_loader), total=test_batch_count, desc=f"Epoch {epoch + 1}/{epochs}")

        for i, (batch_x, batch_y) in progress:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            out, _ = lstm(batch_x)
            loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))

            total_loss += loss.item()
            avg_loss = total_loss / (i + 1)

    print(f"Average loss on held-out_dataset: {avg_loss:.4f}")
    print(f"Perplexity on held-out data: {torch.math.exp(avg_loss)}\n")

torch.save(lstm.state_dict(), "lstm_final.pt")
print("Training complete. Model saved to lstm_final.pt")


Training on device: cuda


Epoch 1/25:   0%|          | 0/5421 [00:00<?, ?it/s]

Epoch 1/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.41it/s, loss=6.5358, lr=0.001]


Epoch 0 done | Average training loss: 6.5358
Perplexity on training data: 689.35808301941



Epoch 1/25: 100%|██████████| 56/56 [00:02<00:00, 22.89it/s]


Average loss on held-out_dataset: 639.0138
Perplexity on held-out data: 3.312678996727248e+277



Epoch 2/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=5.4830, lr=0.001]


Epoch 1 done | Average training loss: 5.4830
Perplexity on training data: 240.56699836636878



Epoch 2/25: 100%|██████████| 56/56 [00:02<00:00, 22.81it/s]


Average loss on held-out_dataset: 536.6581
Perplexity on held-out data: 1.1685187795037949e+233



Epoch 3/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=5.1249, lr=0.001]


Epoch 2 done | Average training loss: 5.1249
Perplexity on training data: 168.15665420322944



Epoch 3/25: 100%|██████████| 56/56 [00:02<00:00, 22.66it/s]


Average loss on held-out_dataset: 501.7632
Perplexity on held-out data: 8.184164273942625e+217



Epoch 4/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=4.9082, lr=0.001]


Epoch 3 done | Average training loss: 4.9082
Perplexity on training data: 135.3907686075897



Epoch 4/25: 100%|██████████| 56/56 [00:02<00:00, 22.79it/s]


Average loss on held-out_dataset: 480.6393
Perplexity on held-out data: 5.482558244748784e+208



Epoch 5/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=4.7547, lr=0.001]


Epoch 4 done | Average training loss: 4.7547
Perplexity on training data: 116.12593737672567



Epoch 5/25: 100%|██████████| 56/56 [00:02<00:00, 22.57it/s]


Average loss on held-out_dataset: 465.6829
Perplexity on held-out data: 1.7519019723817946e+202



Epoch 6/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=4.6362, lr=0.001]


Epoch 5 done | Average training loss: 4.6362
Perplexity on training data: 103.15481524176217



Epoch 6/25: 100%|██████████| 56/56 [00:02<00:00, 22.48it/s]


Average loss on held-out_dataset: 454.1432
Perplexity on held-out data: 1.7057057527369424e+197



Epoch 7/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=4.5400, lr=0.001]


Epoch 6 done | Average training loss: 4.5400
Perplexity on training data: 93.68912053177239



Epoch 7/25: 100%|██████████| 56/56 [00:02<00:00, 22.74it/s]


Average loss on held-out_dataset: 444.7696
Perplexity on held-out data: 1.448776222724714e+193



Epoch 8/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=4.4589, lr=0.001]


Epoch 7 done | Average training loss: 4.4589
Perplexity on training data: 86.39520085897892



Epoch 8/25: 100%|██████████| 56/56 [00:02<00:00, 22.81it/s]


Average loss on held-out_dataset: 436.8799
Perplexity on held-out data: 5.42655218030364e+189



Epoch 9/25: 100%|██████████| 5421/5421 [09:35<00:00,  9.42it/s, loss=4.3896, lr=0.001]


Epoch 8 done | Average training loss: 4.3896
Perplexity on training data: 80.60529408604664



Epoch 9/25: 100%|██████████| 56/56 [00:02<00:00, 22.64it/s]


Average loss on held-out_dataset: 430.1303
Perplexity on held-out data: 6.356591424791664e+186



Epoch 10/25: 100%|██████████| 5421/5421 [09:37<00:00,  9.38it/s, loss=4.3292, lr=0.001]


Epoch 9 done | Average training loss: 4.3292
Perplexity on training data: 75.88063883492683



Epoch 10/25: 100%|██████████| 56/56 [00:02<00:00, 22.69it/s]


Average loss on held-out_dataset: 424.2555
Perplexity on held-out data: 1.7857591761691283e+184



Epoch 11/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.40it/s, loss=4.2760, lr=0.001]


Epoch 10 done | Average training loss: 4.2760
Perplexity on training data: 71.9534705577334



Epoch 11/25: 100%|██████████| 56/56 [00:02<00:00, 22.61it/s]


Average loss on held-out_dataset: 419.0898
Perplexity on held-out data: 1.0195382730622643e+182



Epoch 12/25: 100%|██████████| 5421/5421 [09:37<00:00,  9.39it/s, loss=4.2288, lr=0.001]


Epoch 11 done | Average training loss: 4.2288
Perplexity on training data: 68.63328274034514



Epoch 12/25: 100%|██████████| 56/56 [00:02<00:00, 22.60it/s]


Average loss on held-out_dataset: 414.4993
Perplexity on held-out data: 1.0346064854574707e+180



Epoch 13/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.40it/s, loss=4.1866, lr=0.001]


Epoch 12 done | Average training loss: 4.1866
Perplexity on training data: 65.79994427978521



Epoch 13/25: 100%|██████████| 56/56 [00:02<00:00, 22.70it/s]


Average loss on held-out_dataset: 410.4052
Perplexity on held-out data: 1.7246448842347808e+178



Epoch 14/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.40it/s, loss=4.1488, lr=0.001]


Epoch 13 done | Average training loss: 4.1488
Perplexity on training data: 63.3608088255846



Epoch 14/25: 100%|██████████| 56/56 [00:02<00:00, 22.49it/s]


Average loss on held-out_dataset: 406.7383
Perplexity on held-out data: 4.4077735254442326e+176



Epoch 15/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.40it/s, loss=4.1148, lr=0.001]


Epoch 14 done | Average training loss: 4.1148
Perplexity on training data: 61.24102798818527



Epoch 15/25: 100%|██████████| 56/56 [00:02<00:00, 22.56it/s]


Average loss on held-out_dataset: 403.4368
Perplexity on held-out data: 1.623251458137053e+175



Epoch 16/25: 100%|██████████| 5421/5421 [09:38<00:00,  9.38it/s, loss=4.0842, lr=0.001]


Epoch 15 done | Average training loss: 4.0842
Perplexity on training data: 59.396163729442925



Epoch 16/25: 100%|██████████| 56/56 [00:02<00:00, 22.71it/s]


Average loss on held-out_dataset: 400.4711
Perplexity on held-out data: 8.36321606249902e+173



Epoch 17/25: 100%|██████████| 5421/5421 [09:37<00:00,  9.38it/s, loss=4.0566, lr=0.001]


Epoch 16 done | Average training loss: 4.0566
Perplexity on training data: 57.77901624106715



Epoch 17/25: 100%|██████████| 56/56 [00:02<00:00, 22.69it/s]


Average loss on held-out_dataset: 397.7946
Perplexity on held-out data: 5.754250015408218e+172



Epoch 18/25: 100%|██████████| 5421/5421 [09:37<00:00,  9.39it/s, loss=4.0316, lr=0.001]


Epoch 17 done | Average training loss: 4.0316
Perplexity on training data: 56.352226582443514



Epoch 18/25: 100%|██████████| 56/56 [00:02<00:00, 22.67it/s]


Average loss on held-out_dataset: 395.3723
Perplexity on held-out data: 5.105306600436695e+171



Epoch 19/25: 100%|██████████| 5421/5421 [09:38<00:00,  9.37it/s, loss=4.0091, lr=0.001]


Epoch 18 done | Average training loss: 4.0091
Perplexity on training data: 55.09844279750342



Epoch 19/25: 100%|██████████| 56/56 [00:02<00:00, 22.65it/s]


Average loss on held-out_dataset: 393.1932
Perplexity on held-out data: 5.776156450325911e+170



Epoch 20/25: 100%|██████████| 5421/5421 [09:37<00:00,  9.39it/s, loss=3.9888, lr=0.001]


Epoch 19 done | Average training loss: 3.9888
Perplexity on training data: 53.98834914671257



Epoch 20/25: 100%|██████████| 56/56 [00:02<00:00, 22.69it/s]


Average loss on held-out_dataset: 391.2217
Perplexity on held-out data: 8.043530590894051e+169



Epoch 21/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.40it/s, loss=3.9703, lr=0.001]


Epoch 20 done | Average training loss: 3.9703
Perplexity on training data: 53.00219571332485



Epoch 21/25: 100%|██████████| 56/56 [00:02<00:00, 22.57it/s]


Average loss on held-out_dataset: 389.4374
Perplexity on held-out data: 1.350584374960451e+169



Epoch 22/25: 100%|██████████| 5421/5421 [09:36<00:00,  9.40it/s, loss=3.9537, lr=0.001]


Epoch 21 done | Average training loss: 3.9537
Perplexity on training data: 52.127255631397986



Epoch 22/25: 100%|██████████| 56/56 [00:02<00:00, 22.58it/s]


Average loss on held-out_dataset: 387.8247
Perplexity on held-out data: 2.6923828694951028e+168



Epoch 23/25: 100%|██████████| 5421/5421 [09:37<00:00,  9.38it/s, loss=3.9384, lr=0.001]


Epoch 22 done | Average training loss: 3.9384
Perplexity on training data: 51.3387978495124



Epoch 23/25: 100%|██████████| 56/56 [00:02<00:00, 22.45it/s]


Average loss on held-out_dataset: 386.3497
Perplexity on held-out data: 6.159404776775216e+167



Epoch 24/25: 100%|██████████| 5421/5421 [09:39<00:00,  9.35it/s, loss=3.9247, lr=0.001]


Epoch 23 done | Average training loss: 3.9247
Perplexity on training data: 50.639348504856756



Epoch 24/25: 100%|██████████| 56/56 [00:02<00:00, 22.04it/s]


Average loss on held-out_dataset: 385.0223
Perplexity on held-out data: 1.6333364711154528e+167



Epoch 25/25: 100%|██████████| 5421/5421 [09:39<00:00,  9.35it/s, loss=3.9120, lr=0.001]


Epoch 24 done | Average training loss: 3.9120
Perplexity on training data: 49.99686635682195



Epoch 25/25: 100%|██████████| 56/56 [00:02<00:00, 22.36it/s]


Average loss on held-out_dataset: 383.7878
Perplexity on held-out data: 4.7525825485872903e+166

Training complete. Model saved to lstm_final.pt
