In [5]:
TOKENIZER = "wp" # sp, gpt, wp

In [6]:
from datasets import load_from_disk
import sys

sys.path.append("../")
ds = load_from_disk("data/tinystories")

train = ds["train"]
val = ds["validation"]

In [7]:
from tokenizer import GPTTokenizer, Tokenizer, SPTokenizer, WPTokenizer

match TOKENIZER:
    case "sp":
        tokenizer = SPTokenizer('data/tokenizer_bpe.model')
    case "gpt":
        tokenizer = GPTTokenizer()
    case "wp":
        tokenizer = WPTokenizer.from_json("data/custom_vocab.json")
    case _:
        raise ValueError(f"Unknown tokenizer: {TOKENIZER}")

tokenizer.decode(tokenizer.encode("Test sentence"))

'test sentence'

In [8]:
tokenizer

<tokenizer.tokenizer.WPTokenizer at 0x152c8b560>

Tokenów: 471_872_517
Dokumentów: 2_119_719

In [9]:
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets.arrow_dataset import Dataset
from typing import Generator

class StreamingTokenDataset(IterableDataset):
    def __init__(
            self, 
            dataset: Dataset,
            tokenizer: Tokenizer,
            context_size=128, 
            buffer_size=10_000
        ) -> None:

        self.dataset = dataset
        self.tokenizer = tokenizer

        self.context_size = context_size
        self.buffer_size = buffer_size

    def _token_stream(self) -> Generator[int, None, None]:
        for example in self.dataset:
            tokens = self.tokenizer.encode(example["text"])
            yield from tokens
            yield 0

    def _chunk_stream(self):
        buf = []
        for token in self._token_stream():
            buf.append(token)
            if len(buf) > self.context_size:

                context_batch = buf[:self.context_size + 1]

                input_tokens = torch.tensor(context_batch[:self.context_size], dtype=torch.long)
                pred_tokens = torch.tensor(context_batch[1:], dtype=torch.long)
                yield input_tokens, pred_tokens
                buf = buf[self.context_size:]

    def __iter__(self):
        yield from self._chunk_stream()

In [10]:
train_dataset = StreamingTokenDataset(train, tokenizer)
val_dataset = StreamingTokenDataset(val, tokenizer)


train_loader = DataLoader(train_dataset, batch_size=4)
test_loader = DataLoader(train_dataset, batch_size=4)

In [11]:
from lab1.architectures.gpt import GPTDecoder

vocab_size = tokenizer.vocab_size()
embed_dim = 256
num_heads = 8
ff_hidden_dim = 2048
num_layers = 6
context_length = 128
dropout = 0.1

gpt = GPTDecoder(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
    context_length=context_length,
    dropout=dropout
)

In [12]:
model_path = f"trained_models/{TOKENIZER}_tokenizer.pt"
gpt.load_state_dict(torch.load(model_path, weights_only=True, map_location="cpu"))

<All keys matched successfully>

In [14]:
def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [None]:
import torch.nn as nn
from tqdm import tqdm

epochs = 1
grad_clip = 10.0
device = torch.device(choose_device())

print(f"Training on device: {device}")

gpt.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.AdamW(gpt.parameters())

for epoch in range(1, epochs + 1):
    gpt.train()
    total_loss = 0.0

    progress = tqdm(enumerate(train_loader), total=900_000, desc=f"Epoch {epoch}/{epochs}")

    for i, (batch_x, batch_y) in progress:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        out = gpt(batch_x)
        loss = criterion(out.view(-1, out.size(-1)), batch_y.view(-1))
        loss.backward()

        torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (i + 1)

        progress.set_postfix({"loss": f"{avg_loss:.4f}", "lr": optimizer.param_groups[0]["lr"]})

    torch.save(gpt.state_dict(), f"data/{TOKENIZER}_epoch_{epoch}.pt")
    print(f"Epoch {epoch} done | Average training loss: {avg_loss:.4f}")
    print(f"Perplexity on training data: {torch.math.exp(avg_loss)}\n")


torch.save(gpt.state_dict(), f"data/{TOKENIZER}_final.pt")
print("Training complete. Model saved to gpt_final.pt")