In [1]:
import json
import math
from typing import Optional

import torch
from torch import nn
from torchinfo import summary
from tqdm import trange

In [2]:
data = "\n\n\n".join(
    [
        json.loads(line)["story"]
        for line in open("../data/stories.jsonl", "r").readlines()
    ]
)


class Tokenizer:
    def __init__(
        self,
        data: str,
        initial_vocab: Optional[list[str]] = None,
        swaps: Optional[dict[str, str]] = None,
    ):
        if swaps is None:
            swaps = {}
        for k, v in swaps.items():
            data = data.replace(k, v)

        if initial_vocab is None:
            unique_chars = sorted(set(data))
        else:
            unique_chars = sorted(set(initial_vocab).union(set(data)))

        self.swaps = swaps

        self.char2idx = {ch: idx for idx, ch in enumerate(unique_chars)}
        self.idx2char = {idx: ch for ch, idx in self.char2idx.items()}

    def encode(self, text: str) -> torch.Tensor:
        for k, v in self.swaps.items():
            text = text.replace(k, v)
        return torch.tensor([self.char2idx[ch] for ch in text], dtype=torch.long)

    def decode(self, indices: torch.Tensor) -> str:
        return "".join([self.idx2char[int(idx.item())] for idx in indices])


train_data = data[: int(0.9 * len(data))]
val_data = data[int(0.9 * len(data)) :]

tokenizer = Tokenizer(
    train_data,
    initial_vocab=[chr(i) for i in range(32, 127)],
    swaps={"é": "e", "–": "-", "—": "-", "’": "'"},
)
vocab_size = len(tokenizer.char2idx)
train_tokens = tokenizer.encode(train_data)
val_tokens = tokenizer.encode(val_data)

print(f"Vocabulary size: {vocab_size}")
print(f"Train tokens: {len(train_tokens):,}")
print(f"Validation tokens: {len(val_tokens):,}")

Vocabulary size: 96
Train tokens: 5,875,117
Validation tokens: 652,791


In [3]:
def roundTo32(x, base=32):
    return int(base * math.ceil(float(x) / base))

In [5]:
from omni.models.gpt import GPT

max_seq_length = 128
device = torch.device("cpu")

model = GPT(
    vocab_size=vocab_size,
    d_model=16,
    n_heads=4,
    n_kv_heads=1,
    m=2,
    num_layers=6,
    max_seq_length=max_seq_length,
    num_experts=8,
    top_k=2,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=8e-4)
criterion = nn.CrossEntropyLoss()

summary(model)

Layer (type:depth-idx)                        Param #
GPT                                           2,048
├─Embedding: 1-1                              1,536
├─ModuleList: 1-2                             --
│    └─TransformerBlock: 2-1                  --
│    │    └─RMSNorm: 3-1                      16
│    │    └─GroupedQueryAttention: 3-2        680
│    │    └─RMSNorm: 3-3                      16
│    │    └─MixtureOfExperts: 3-4             8,704
│    └─TransformerBlock: 2-2                  --
│    │    └─RMSNorm: 3-5                      16
│    │    └─GroupedQueryAttention: 3-6        680
│    │    └─RMSNorm: 3-7                      16
│    │    └─MixtureOfExperts: 3-8             8,704
│    └─TransformerBlock: 2-3                  --
│    │    └─RMSNorm: 3-9                      16
│    │    └─GroupedQueryAttention: 3-10       680
│    │    └─RMSNorm: 3-11                     16
│    │    └─MixtureOfExperts: 3-12            8,704
│    └─TransformerBlock: 2-4                  

In [6]:
def get_batch(tokens, batch_size, seq_length):
    start_indices = torch.randint(0, len(tokens) - seq_length - 1, (batch_size,))
    x_batch = torch.stack([tokens[i : i + seq_length] for i in start_indices])
    y_batch = torch.stack([tokens[i + 1 : i + seq_length + 1] for i in start_indices])
    return x_batch.to(device), y_batch.to(device)


num_epochs = 5
batch_size = 128
step = 0

for epoch in range(num_epochs):
    model.train()

    num_batches = len(train_tokens) // (batch_size * max_seq_length)

    pbar = trange(num_batches, desc=f"Epoch {epoch + 1}/{num_epochs}")
    for _ in pbar:
        x_batch, y_batch = get_batch(train_tokens, batch_size, max_seq_length)
        optimizer.zero_grad()
        logits = model(x_batch)
        loss = criterion(logits.view(-1, vocab_size), y_batch.view(-1))
        loss.backward()
        optimizer.step()

        pbar.set_postfix({"loss": loss.item()})

        step += 1

    # Validation
    model.eval()
    with torch.no_grad():
        val_loss = 0
        num_val_batches = len(val_tokens) // (batch_size * max_seq_length)
        for _ in range(num_val_batches):
            x_batch, y_batch = get_batch(val_tokens, batch_size, max_seq_length)
            logits = model(x_batch)
            loss = criterion(logits.view(-1, vocab_size), y_batch.view(-1))
            val_loss += loss.item()

        avg_val_loss = val_loss / num_val_batches
        print(f"Validation Loss: {avg_val_loss:.4f}")

Epoch 1/5:   3%|▎         | 11/358 [00:12<06:24,  1.11s/it, loss=4.31]


KeyboardInterrupt: 

In [None]:
@torch.inference_mode()
def generate_text(model, tokenizer, prompt, max_length=24):
    model.eval()
    input_ids = tokenizer.encode(prompt).unsqueeze(0).to(device)

    for _ in range(max_length):
        logits = model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
        input_ids = torch.cat([input_ids, next_token_id], dim=1)

    return tokenizer.decode(input_ids.squeeze())


print(generate_text(model, tokenizer, "First"))