In [None]:
import math

import torch
from datasets import load_dataset
from torch import nn
from torchinfo import summary
from tqdm import trange

In [None]:
ds = load_dataset("noanabeshima/TinyStoriesV2")


train_data = "\n\n".join(ds["train"]["text"])
val_data = "\n\n".join(ds["validation"]["text"])

In [None]:
class Tokenizer:
    def __init__(
        self, text: str, special_tokens: list[str] = ["<UNK>", "<SOS>", "<EOS>"]
    ):
        chars = set(text)
        self.char2idx = {
            char: idx for idx, char in enumerate(special_tokens + sorted(chars))
        }
        self.idx2char = {idx: char for char, idx in self.char2idx.items()}

    def encode(self, text: str) -> list[int]:
        return [self.char2idx.get(char, self.char2idx["<UNK>"]) for char in text]

    def decode(self, tokens: list[int]) -> str:
        return "".join(self.idx2char.get(token, "<UNK>") for token in tokens)


tokenizer = Tokenizer(
    train_data,
)
vocab_size = len(tokenizer.char2idx)
train_tokens = tokenizer.encode(train_data)
val_tokens = tokenizer.encode(val_data)

print(f"Vocabulary size: {vocab_size}")
print(f"Train tokens: {len(train_tokens):,}")
print(f"Validation tokens: {len(val_tokens):,}")

In [3]:
def roundTo32(x, base=32):
    return int(base * math.ceil(float(x) / base))

In [4]:
from omni.models.gpt import GPT

max_seq_length = 128
device = torch.device("mps")

model = GPT(
    vocab_size=vocab_size,
    d_model=32,
    n_heads=8,
    n_kv_heads=2,
    m=2,
    num_layers=4,
    max_seq_length=max_seq_length,
    tie_weights=True,
    num_experts=16,
    top_k=1,
).to(device)

# optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.99)
optimizer = torch.optim.AdamW(model.parameters(), lr=8e-4)
criterion = nn.CrossEntropyLoss()

summary(model)

Layer (type:depth-idx)                        Param #
GPT                                           1,024
├─Embedding: 1-1                              3,072
├─ModuleList: 1-2                             --
│    └─TransformerBlock: 2-1                  --
│    │    └─RMSNorm: 3-1                      32
│    │    └─GroupedQueryAttention: 3-2        2,640
│    │    └─RMSNorm: 3-3                      32
│    │    └─MixtureOfExperts: 3-4             67,584
│    └─TransformerBlock: 2-2                  --
│    │    └─RMSNorm: 3-5                      32
│    │    └─GroupedQueryAttention: 3-6        2,640
│    │    └─RMSNorm: 3-7                      32
│    │    └─MixtureOfExperts: 3-8             67,584
│    └─TransformerBlock: 2-3                  --
│    │    └─RMSNorm: 3-9                      32
│    │    └─GroupedQueryAttention: 3-10       2,640
│    │    └─RMSNorm: 3-11                     32
│    │    └─MixtureOfExperts: 3-12            67,584
│    └─TransformerBlock: 2-4         

In [12]:
def greedy(logits):
    return torch.argmax(logits, dim=-1, keepdim=True)


def sample(logits, temperature=1.0):
    if temperature == 0:
        return greedy(logits)
    else:
        probs = torch.softmax(logits / temperature, dim=-1)
        return torch.multinomial(probs, num_samples=1)


@torch.inference_mode()
def generate_text(model, tokenizer, prompt, max_length=100, sample=greedy):
    model.eval()
    input_ids = tokenizer.encode(prompt).unsqueeze(0).to(device)

    for _ in range(max_length):
        logits = model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token_id = sample(next_token_logits)

        input_ids = torch.cat([input_ids, next_token_id], dim=1)

    return tokenizer.decode(input_ids.squeeze())


print(generate_text(model, tokenizer, "Once upon a time, there was a"))

Once upon a time, there was and the the the the the the the the the the the the the the the the the the ther the the wa there the


In [5]:
def get_batch(tokens, batch_size, seq_length):
    start_indices = torch.randint(0, len(tokens) - seq_length - 1, (batch_size,))
    x_batch = torch.stack([tokens[i : i + seq_length] for i in start_indices])
    y_batch = torch.stack([tokens[i + 1 : i + seq_length + 1] for i in start_indices])
    return x_batch.to(device), y_batch.to(device)


num_epochs = 5
batch_size = 64

for epoch in range(num_epochs):
    model.train()

    num_batches = len(train_tokens) // (batch_size * max_seq_length)

    pbar = trange(num_batches, desc=f"Epoch {epoch + 1}/{num_epochs}")
    for _ in pbar:
        x_batch, y_batch = get_batch(train_tokens, batch_size, max_seq_length)
        optimizer.zero_grad()
        logits = model(x_batch)
        loss = criterion(logits.view(-1, vocab_size), y_batch.view(-1))
        loss.backward()
        optimizer.step()

        pbar.set_postfix({"loss": loss.item()})

    # Validation
    model.eval()
    with torch.no_grad():
        val_loss = 0
        num_val_batches = len(val_tokens) // (batch_size * max_seq_length)
        for _ in range(num_val_batches):
            x_batch, y_batch = get_batch(val_tokens, batch_size, max_seq_length)
            logits = model(x_batch)
            loss = criterion(logits.view(-1, vocab_size), y_batch.view(-1))
            val_loss += loss.item()

        avg_val_loss = val_loss / num_val_batches
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(generate_text(model, tokenizer, "First Citizen"))

Epoch 1/5: 100%|██████████| 720/720 [01:32<00:00,  7.77it/s, loss=1.95]


Validation Loss: 1.9777


NameError: name 'generate_text' is not defined