In [None]:
from typing import Generator
from contextlib import nullcontext

import torch
from torch import nn, Tensor
import torch.nn.functional as F

import time
import numpy as np
import matplotlib.pyplot as plt

from model import DecoderTransformer, DecoderTransformerConfig

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

context = nullcontext() if device == "mps" else torch.autocast(device)

print(f"using {device} device")

In [None]:
OUT_DIR = "out"

MAX_ITERS = 1000
EVAL_INTERVAL = 1

BLOCK_SIZE = 128
BATCH_SIZE = 256

MIN_LR = 1e-4
MAX_LR = 1e-5
WARMUP_ITERS = MAX_ITERS // 100
LR_DECAY_ITERS = MAX_ITERS - WARMUP_ITERS

In [None]:
with open("data/shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

char2int = {c: i for i, c in enumerate(chars)}
int2char = {i: c for i, c in enumerate(chars)}


def encode(s: str) -> list[int]:
    return [char2int[c] for c in s if c in char2int]


def decode(y: list[int] | np.ndarray | Tensor) -> str:
    return "".join([int2char[int(i)] for i in y if int(i) in int2char])

In [None]:
full_data = torch.tensor(encode(text), dtype=torch.int64)

val_size = len(full_data) // 10

train_data = full_data[val_size:]
val_data = full_data[:val_size]

In [None]:
def block_data(data: Tensor, block_size=BLOCK_SIZE) -> tuple[Tensor, Tensor]:
    n_blocks = len(data) - block_size - 1
    x = torch.stack([data[i:i + block_size] for i in range(n_blocks)])
    y = torch.stack([data[i:i + block_size] for i in range(1, n_blocks + 1)])
    return x, y

In [None]:
x_train, y_train = block_data(train_data)
x_val, y_val = block_data(val_data)

In [None]:
def batch_iterate(
    x: Tensor, 
    y: Tensor, 
    batch_size=BATCH_SIZE,
    device=device,
) -> Generator[tuple[Tensor, Tensor], None, None]:
    permutation = torch.from_numpy(np.random.permutation(len(y)))
    for i in range(0, len(y), batch_size):
        idxs = permutation[i:i + batch_size]
        bx = x[idxs].to(device)
        by = y[idxs].to(device)
        yield bx, by

In [None]:
@torch.no_grad()
def evaluate_loss(model: nn.Module, x: Tensor, y: Tensor, max_iters=100) -> float:
    loss_sum = 0
    cnt = 0
    for i, (bx, by) in enumerate(batch_iterate(x, y, BATCH_SIZE)):
        if i >= max_iters:
            break

        logits = model(bx)
        loss = F.cross_entropy(logits, by).cpu()
        loss_sum += loss.cpu().item() * len(x)
        cnt += len(x)
    return loss_sum / cnt

In [None]:
def get_lr(iter_num: int) -> float:
    if iter_num < WARMUP_ITERS:
        return MAX_LR * iter_num / WARMUP_ITERS

    if iter_num > LR_DECAY_ITERS:
        return MIN_LR

    decay_ratio = (iter_num - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
    assert 0 <= decay_ratio and decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
    return MIN_LR + coeff * (MAX_LR - MIN_LR)

In [None]:
causal_config = DecoderTransformerConfig(
    block_size=BLOCK_SIZE,
    vocab_size=vocab_size,
    n_layer=2,
    n_head=4,
    n_embd=512,
    is_causal=True,
)

causal_model = DecoderTransformer(causal_config).to(device)

causal_optimizer = causal_model.configure_optimizers(weight_decay=0.1, learning_rate=MIN_LR, 
                                                     betas=(0.9, 0.99), device_type=device)

noncausal_config = DecoderTransformerConfig(
    block_size=BLOCK_SIZE,
    vocab_size=vocab_size,
    n_layer=2,
    n_head=4,
    n_embd=512,
    is_causal=False,
)

noncausal_model = DecoderTransformer(noncausal_config).to(device)

noncausal_optimizer = noncausal_model.configure_optimizers(weight_decay=0.1, learning_rate=MIN_LR, 
                                                           betas=(0.9, 0.99), device_type=device)

models_and_optimizers = [(causal_model, causal_optimizer), (noncausal_model, noncausal_optimizer)]

In [None]:
i = 0
t0 = time.time()
losses = [[], []]

In [None]:
while i < MAX_ITERS:
    for x, y in batch_iterate(x_train, y_train):
        if i >= MAX_ITERS:
            break

        lr = get_lr(i)
        for k, (model, optimizer) in enumerate(models_and_optimizers):
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
            
            with context:
                _, loss = model(x, y, backward=True)

            optimizer.step()
            optimizer.zero_grad()

            losses[k].append(loss.detach().cpu())
        
        if (i + 1) % EVAL_INTERVAL == 0:
            dt = time.time() - t0
            t0 = time.time()
            print(f"{f'[{i + 1}]':8}", end="")
            print(f"causal loss: {np.mean(losses[0][-EVAL_INTERVAL:]):.3f}", end=", ")
            print(f"noncausal loss: {np.mean(losses[0][-EVAL_INTERVAL:]):.3f}", end=", ")
            print(f"time: {dt:.1f}s")
        
        i += 1

In [None]:
plt.plot(losses[0], label="causal")
plt.plot(losses[1], label="noncausal")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.show()