In [1]:
from datasets import load_dataset
from glob import glob
from random import shuffle

train_files = glob("data/train/docs_*.jsonl")
shuffle(train_files)

test_files = glob("data/test/docs_*.jsonl")

data_files = {
    "train": train_files,
    "test": test_files
}

dataset = load_dataset("json", data_files=data_files, streaming=True)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [2]:
from utils.streaming_dataset import StreamingTokenDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

# from transformers import AutoModel
# model = AutoModel.from_pretrained("allegro/herbert-base-cased")

tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")

CONTEXT_LENGTH = 128
BATCH_SIZE = 32
train_loader = DataLoader(StreamingTokenDataset(train_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)
test_loader = DataLoader(StreamingTokenDataset(test_dataset, tokenizer, context_size=CONTEXT_LENGTH), batch_size=BATCH_SIZE)

In [3]:
from architectures.gpt import GPTDecoder

vocab_size = 50_000
embed_dim = 256
num_heads = 8
ff_hidden_dim = 2048
num_layers = 6
context_length = 128
dropout = 0.1

gpt = GPTDecoder(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
    context_length=context_length,
    dropout=dropout
)

In [4]:
sum(p.numel() for p in gpt.parameters())

20690944

In [10]:
train_batch_count = 0
for _ in train_loader:
    train_batch_count += 1

test_batch_count = 0
for _ in test_loader:
    test_batch_count += 1

Token indices sequence length is longer than the specified maximum sequence length for this model (965 > 512). Running this sequence through the model will result in indexing errors


In [6]:
import torch

def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [None]:
import torch.nn as nn
from tqdm import tqdm

epochs = 25
learning_rate = 1e-3
weight_decay = 1e-2
grad_clip = 1.0
device = torch.device(choose_device())

print(f"Training on device: {device}")

gpt.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=learning_rate, weight_decay=weight_decay)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)

for epoch in range(1, epochs + 1):
    gpt.train()
    total_loss = 0.0

    progress = tqdm(enumerate(train_loader), total=train_batch_count, desc=f"Epoch {epoch}/{epochs}")

    for i, (batch_x, batch_y) in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        out = gpt(batch_x)

        # Flatten for CrossEntropyLoss
        loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (i + 1)

        progress.set_postfix({"loss": f"{avg_loss:.4f}", "lr": optimizer.param_groups[0]["lr"]})

    # scheduler.step()

    if epoch % 3 == 2:
        torch.save(gpt.state_dict(), f"gpt_epoch_{epoch}.pt")

    print(f"Epoch {epoch} done | Average training loss: {avg_loss:.4f}")
    print(f"Perplexity on training data: {torch.math.exp(avg_loss)}\n")

    with torch.no_grad():
        progress = tqdm(enumerate(test_loader), total=test_batch_count, desc=f"Epoch {epoch + 1}/{epochs}")

        for i, (batch_x, batch_y) in progress:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            out = gpt(batch_x)
            loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))

            total_loss += loss.item()
            avg_loss = total_loss / (i + 1)

    try:
        print(f"Average loss on held-out_dataset: {avg_loss:.4f}")
        print(f"Perplexity on held-out data: {torch.math.exp(avg_loss)}\n")
    except OverflowError:
        pass

torch.save(gpt.state_dict(), "gpt_final.pt")
print("Training complete. Model saved to gpt_final.pt")


Training on device: cuda


Epoch 1/25: 100%|██████████| 5421/5421 [09:56<00:00,  9.09it/s, loss=6.1191, lr=0.001]


Epoch 1 done | Average training loss: 6.1191
Perplexity on training data: 454.47427380295625



Epoch 2/25: 100%|██████████| 56/56 [00:02<00:00, 22.61it/s]


Average loss on held-out_dataset: 598.7742
Perplexity on held-out data: 1.107500732691838e+260



Epoch 2/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=5.6236, lr=0.001]


Epoch 2 done | Average training loss: 5.6236
Perplexity on training data: 276.8910556676884



Epoch 3/25: 100%|██████████| 56/56 [00:02<00:00, 22.51it/s]


Average loss on held-out_dataset: 550.4529
Perplexity on held-out data: 1.144648415516723e+239



Epoch 3/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=5.3035, lr=0.001]


Epoch 3 done | Average training loss: 5.3035
Perplexity on training data: 201.0327705610222



Epoch 4/25: 100%|██████████| 56/56 [00:02<00:00, 22.70it/s]


Average loss on held-out_dataset: 519.2071
Perplexity on held-out data: 3.0816780409419756e+225



Epoch 4/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.08it/s, loss=5.0703, lr=0.001]


Epoch 4 done | Average training loss: 5.0703
Perplexity on training data: 159.2276981164478



Epoch 5/25: 100%|██████████| 56/56 [00:02<00:00, 22.14it/s]


Average loss on held-out_dataset: 496.4597
Perplexity on held-out data: 4.071223061931786e+215



Epoch 5/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=4.8906, lr=0.001]


Epoch 5 done | Average training loss: 4.8906
Perplexity on training data: 133.03109184512587



Epoch 6/25: 100%|██████████| 56/56 [00:02<00:00, 22.32it/s]


Average loss on held-out_dataset: 478.9145
Perplexity on held-out data: 9.770871489726995e+207



Epoch 6/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.7480, lr=0.001]


Epoch 6 done | Average training loss: 4.7480
Perplexity on training data: 115.35831623115871



Epoch 7/25: 100%|██████████| 56/56 [00:02<00:00, 22.56it/s]


Average loss on held-out_dataset: 465.0011
Perplexity on held-out data: 8.859624384451608e+201



Epoch 7/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=4.6313, lr=0.001]


Epoch 7 done | Average training loss: 4.6313
Perplexity on training data: 102.64470364723263



Epoch 8/25: 100%|██████████| 56/56 [00:02<00:00, 22.35it/s]


Average loss on held-out_dataset: 453.6085
Perplexity on held-out data: 9.992181855629549e+196



Epoch 8/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=4.5335, lr=0.001]


Epoch 8 done | Average training loss: 4.5335
Perplexity on training data: 93.08228392679311



Epoch 9/25: 100%|██████████| 56/56 [00:02<00:00, 22.61it/s]


Average loss on held-out_dataset: 444.0702
Perplexity on held-out data: 7.198460021912333e+192



Epoch 9/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=4.4494, lr=0.001]


Epoch 9 done | Average training loss: 4.4494
Perplexity on training data: 85.5744573344069



Epoch 10/25: 100%|██████████| 56/56 [00:02<00:00, 21.97it/s]


Average loss on held-out_dataset: 435.8631
Perplexity on held-out data: 1.963035620363912e+189



Epoch 10/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=4.3763, lr=0.001]


Epoch 10 done | Average training loss: 4.3763
Perplexity on training data: 79.54286352732507



Epoch 11/25: 100%|██████████| 56/56 [00:02<00:00, 21.71it/s]


Average loss on held-out_dataset: 428.7360
Perplexity on held-out data: 1.5763930404349336e+186



Epoch 11/25: 100%|██████████| 5421/5421 [09:59<00:00,  9.05it/s, loss=4.3121, lr=0.001]


Epoch 11 done | Average training loss: 4.3121
Perplexity on training data: 74.59869739683089



Epoch 12/25: 100%|██████████| 56/56 [00:02<00:00, 22.28it/s]


Average loss on held-out_dataset: 422.4798
Perplexity on held-out data: 3.024426542762975e+183



Epoch 12/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.2548, lr=0.001]


Epoch 12 done | Average training loss: 4.2548
Perplexity on training data: 70.44577676975874



Epoch 13/25: 100%|██████████| 56/56 [00:02<00:00, 21.79it/s]


Average loss on held-out_dataset: 416.8899
Perplexity on held-out data: 1.1297161136343163e+181



Epoch 13/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.2037, lr=0.001]


Epoch 13 done | Average training loss: 4.2037
Perplexity on training data: 66.93264563194188



Epoch 14/25: 100%|██████████| 56/56 [00:02<00:00, 22.66it/s]


Average loss on held-out_dataset: 411.9077
Perplexity on held-out data: 7.748986891055825e+178



Epoch 14/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.05it/s, loss=4.1574, lr=0.001]


Epoch 14 done | Average training loss: 4.1574
Perplexity on training data: 63.905693324218724



Epoch 15/25: 100%|██████████| 56/56 [00:02<00:00, 22.40it/s]


Average loss on held-out_dataset: 407.3924
Perplexity on held-out data: 8.477476915590407e+176



Epoch 15/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.1166, lr=0.001]


Epoch 15 done | Average training loss: 4.1166
Perplexity on training data: 61.34744427823111



Epoch 16/25: 100%|██████████| 56/56 [00:02<00:00, 22.46it/s]


Average loss on held-out_dataset: 403.4143
Perplexity on held-out data: 1.5870892849560245e+175



Epoch 16/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.0793, lr=0.001]


Epoch 16 done | Average training loss: 4.0793
Perplexity on training data: 59.10620133900921



Epoch 17/25: 100%|██████████| 56/56 [00:02<00:00, 22.48it/s]


Average loss on held-out_dataset: 399.7832
Perplexity on held-out data: 4.2035538333060633e+173



Epoch 17/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.0455, lr=0.001]


Epoch 17 done | Average training loss: 4.0455
Perplexity on training data: 57.139036255694826



Epoch 18/25: 100%|██████████| 56/56 [00:02<00:00, 22.53it/s]


Average loss on held-out_dataset: 396.4869
Perplexity on held-out data: 1.5562334411483108e+172



Epoch 18/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=4.0151, lr=0.001]


Epoch 18 done | Average training loss: 4.0151
Perplexity on training data: 55.42796906548439



Epoch 19/25: 100%|██████████| 56/56 [00:02<00:00, 22.24it/s]


Average loss on held-out_dataset: 393.5277
Perplexity on held-out data: 8.070315018724316e+170



Epoch 19/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=3.9874, lr=0.001]


Epoch 19 done | Average training loss: 3.9874
Perplexity on training data: 53.91548640432788



Epoch 20/25: 100%|██████████| 56/56 [00:02<00:00, 22.38it/s]


Average loss on held-out_dataset: 390.8337
Perplexity on held-out data: 5.4566723150578266e+169



Epoch 20/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=3.9625, lr=0.001]


Epoch 20 done | Average training loss: 3.9625
Perplexity on training data: 52.588846260420766



Epoch 21/25: 100%|██████████| 56/56 [00:02<00:00, 22.01it/s]


Average loss on held-out_dataset: 388.4104
Perplexity on held-out data: 4.836049510814389e+168



Epoch 21/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=3.9395, lr=0.001]


Epoch 21 done | Average training loss: 3.9395
Perplexity on training data: 51.39504228411632



Epoch 22/25: 100%|██████████| 56/56 [00:02<00:00, 22.74it/s]


Average loss on held-out_dataset: 386.1744
Perplexity on held-out data: 5.169064651936253e+167



Epoch 22/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=3.9191, lr=0.001]


Epoch 22 done | Average training loss: 3.9191
Perplexity on training data: 50.35297990356954



Epoch 23/25: 100%|██████████| 56/56 [00:02<00:00, 22.70it/s]


Average loss on held-out_dataset: 384.1814
Perplexity on held-out data: 7.044856970295205e+166



Epoch 23/25: 100%|██████████| 5421/5421 [09:57<00:00,  9.07it/s, loss=3.9000, lr=0.001]


Epoch 23 done | Average training loss: 3.9000
Perplexity on training data: 49.400601164196964



Epoch 24/25: 100%|██████████| 56/56 [00:02<00:00, 22.29it/s]


Average loss on held-out_dataset: 382.3197
Perplexity on held-out data: 1.0948331129072706e+166



Epoch 24/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=3.8828, lr=0.001]


Epoch 24 done | Average training loss: 3.8828
Perplexity on training data: 48.5581408800383



Epoch 25/25: 100%|██████████| 56/56 [00:02<00:00, 21.02it/s]


Average loss on held-out_dataset: 380.6472
Perplexity on held-out data: 2.0557874263186437e+165



Epoch 25/25: 100%|██████████| 5421/5421 [09:58<00:00,  9.06it/s, loss=3.8668, lr=0.001]


Epoch 25 done | Average training loss: 3.8668
Perplexity on training data: 47.78930563885979



Epoch 26/25: 100%|██████████| 56/56 [00:02<00:00, 22.35it/s]


Average loss on held-out_dataset: 379.0924
Perplexity on held-out data: 4.3424431318853286e+164

Training complete. Model saved to gpt_final.pt
