In [None]:
from contextlib import nullcontext
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model import DecoderTransformer, DecoderTransformerConfig
from shakespeare_dataset import ShakespeareDataset

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

context = nullcontext() if device == "mps" else torch.autocast(device)

print(f"using {device} device")

In [None]:
OUT_DIR = "out"

MAX_ITERS = 1000
EVAL_INTERVAL = 1

BLOCK_SIZE = 128
BATCH_SIZE = 256

MIN_LR = 1e-4
MAX_LR = 1e-5
WARMUP_ITERS = MAX_ITERS // 100
LR_DECAY_ITERS = MAX_ITERS - WARMUP_ITERS

In [None]:
dataset = ShakespeareDataset("data/shakespeare.txt", block_size=BLOCK_SIZE)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
@torch.no_grad()
def evaluate_loss(model: DecoderTransformer, data_loader: DataLoader, max_iters=100) -> float:
    loss_sum = 0
    cnt = 0
    for i, (x, y) in enumerate(data_loader):
        if i >= max_iters:
            break
        
        x = x.to(device)
        y = y.to(device)
        _, loss = model(x, y)
        loss_sum += loss.cpu().item() * len(x)
        cnt += len(x)
    return loss_sum / cnt

In [None]:
def get_lr(iter_num: int) -> float:
    if iter_num < WARMUP_ITERS:
        return MAX_LR * iter_num / WARMUP_ITERS
    
    if iter_num > LR_DECAY_ITERS:
        return MIN_LR
    
    decay_ratio = (iter_num - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
    assert 0 <= decay_ratio and decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
    return MIN_LR + coeff * (MAX_LR - MIN_LR)

In [None]:
causal_config = DecoderTransformerConfig(
    block_size=BLOCK_SIZE,
    vocab_size=dataset.vocab_size,
    n_layer=2,
    n_head=4,
    n_embd=512,
    is_causal=True,
)

causal_model = DecoderTransformer(causal_config).to(device)

causal_optimizer = causal_model.configure_optimizers(weight_decay=0.1, learning_rate=MIN_LR, 
                                                     betas=(0.9, 0.99), device_type=device)

noncausal_config = DecoderTransformerConfig(
    block_size=BLOCK_SIZE,
    vocab_size=dataset.vocab_size,
    n_layer=2,
    n_head=4,
    n_embd=512,
    is_causal=False,
)

noncausal_model = DecoderTransformer(noncausal_config).to(device)

noncausal_optimizer = noncausal_model.configure_optimizers(weight_decay=0.1, learning_rate=MIN_LR, 
                                                           betas=(0.9, 0.99), device_type=device)

models_and_optimizers = [(causal_model, causal_optimizer), (noncausal_model, noncausal_optimizer)]

In [None]:
i = 0
t0 = time.time()
losses = [[], []]

In [None]:
while i < MAX_ITERS:
    for x, y in data_loader:
        if i >= MAX_ITERS:
            break
        
        x = x.to(device)
        y = y.to(device)

        lr = get_lr(i)
        for k, (model, optimizer) in enumerate(models_and_optimizers):
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
            
            with context:
                _, loss = model(x, y, backward=True)

            optimizer.step()
            optimizer.zero_grad()

            losses[k].append(loss.detach().cpu())
        
        if (i + 1) % EVAL_INTERVAL == 0:
            dt = time.time() - t0
            t0 = time.time()
            print(f"{f'[{i + 1}]':8}", end="")
            print(f"causal loss: {np.mean(losses[0][-EVAL_INTERVAL:]):.3f}", end=", ")
            print(f"noncausal loss: {np.mean(losses[0][-EVAL_INTERVAL:]):.3f}", end=", ")
            print(f"time: {dt:.1f}s")
        
        i += 1

In [None]:
plt.plot(losses[0], label="causal")
plt.plot(losses[1], label="noncausal")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.show()