In [1]:
from components.gptmodel import GPTModel

In [2]:
import lightning as L

In [3]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 256,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

In [4]:
import torch.nn as nn
import torch

In [5]:
class LitGPTModel(L.LightningModule):
    def __init__(self, GPTModel):
        super().__init__()
        self.model = GPTModel

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        return self.loss(logits, y)

    def validation_step(self, batch, batch_idx):
        x, y = batch

        logits = self.model(x)
        loss = self.loss(logits, y)
        return loss

    def loss(self, output, expected):
        loss = nn.functional.cross_entropy(
            output.flatten(0, 1), expected.flatten()
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=1e-4, weight_decay=0.1
        )
        return optimizer

In [6]:
from components.data import create_dataloader_v1

In [7]:
train_file = "../data/TinyStories/TinyStoriesV2-GPT4-train.txt"
with open(train_file, "r", encoding="utf-8") as f:
    train_text = f.read()

train_len = len(train_text)
train_text = train_text[:train_len // 1000]
train_loader = create_dataloader_v1(
    train_text,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=11
)

In [8]:
val_file = "../data/TinyStories/TinyStoriesV2-GPT4-valid.txt"
with open(val_file, "r", encoding="utf-8") as f:
    val_text = f.read()

val_len = len(val_text)
val_text = val_text[:val_len // 1000]
val_loader = create_dataloader_v1(
    val_text,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=11
)

In [9]:
model = GPTModel(GPT_CONFIG_124M)
litmodel = LitGPTModel(model)

In [10]:
trainer = L.Trainer(max_epochs=1)
trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/gebrial/miniforge3/envs/fromscratch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params | Mode 
-------------------------------------------
0 | model | GPTModel | 162 M  | train
---------------------------------

Epoch 0: 100%|█████████████████████████████████████████████████████████████| 1072/1072 [03:09<00:00,  5.64it/s, v_num=5]
Validation: |                                                                                     | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                   | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|█████▎                                                     | 1/11 [00:00<00:00, 27.02it/s][A
Validation DataLoader 0:  18%|██████████▋                                                | 2/11 [00:00<00:00, 30.59it/s][A
Validation DataLoader 0:  27%|████████████████                                           | 3/11 [00:00<00:00, 30.81it/s][A
Validation DataLoader 0:  36%|█████████████████████▍                                     | 4/11 [00:00<00:00, 30.33it/s][A
Validation 

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|█████████████████████████████████████████████████████████████| 1072/1072 [03:12<00:00,  5.56it/s, v_num=5]
