## Introduction & Overview

This notebook builds nanoGPT piece-by-piece, trains on the entire works of Shakespeare, and generates some
text-completion output.

In [1]:
"""Set notebook settings and imports."""

%load_ext autoreload
%autoreload 2

import json
import time
from pathlib import Path
from warnings import warn

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm

In [None]:
""" Set environment variables for CUDA debugging."""

# Only uncomment to force synchronous CUDA operations for debugging
# %env CUDA_LAUNCH_BLOCKING=1

In [2]:
"""Set torch device."""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}\n")
if torch.cuda.is_available():
    print(f"Current GPU device: {torch.cuda.get_device_name(device)}\n")
    n_devices = torch.cuda.device_count()
    print("All GPU devices:")
    for i in range(n_devices):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")

device=device(type='cuda')

Current GPU device: NVIDIA GeForce RTX 3090

All GPU devices:
Device 0: NVIDIA GeForce RTX 3090
Device 1: NVIDIA GeForce RTX 2060


In [3]:
"""Read in data."""

filepath = Path.cwd() / "data/tiny_shakespeare.txt"
with open(filepath) as f:
    text = f.read()

In [4]:
"""View some info about text."""

n_chars = len(text)
tokens = sorted(set(text))
vocab_sz = len(tokens)

print(f"{n_chars=}")
print(f"\n{vocab_sz=}")
print(f"\nTokens: {''.join(tokens)}")
print(f"\nFirst 100 chars of text:\n---\n\n{text[:100]}\n\n---")

n_chars=1115394

vocab_sz=65

Tokens: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz

First 100 chars of text:
---

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You

---


In [5]:
"""Tokenize and save tokens to a file."""

token_to_int = {t: i for i, t in enumerate(tokens)}
int_to_token = {i: t for t, i in token_to_int.items()}
encode = lambda tokens: [token_to_int[t] for t in tokens]
decode = lambda ints: "".join([int_to_token[i] for i in ints])

# Example
print(encode("Hello, world!"))
print(decode(encode("Hello, world!")))

# Encode entire text dataset
data = torch.tensor(encode(text), dtype=torch.long)

# Save tokens to file
filepath = Path.cwd() / "data/tiny_shakespeare_tokens.txt"
with open(filepath, "w") as f:
    for tok in tokens:
        f.write("%s" % tok)

[20, 43, 50, 50, 53, 6, 1, 61, 53, 56, 50, 42, 2]
Hello, world!


In [6]:
"""Motivating self-attention."""

# We can get weighted aggregations of past elements by matmul of `x` by a lower triangular matrix (weights)

ctx_len = 4  # context length

x = torch.randn(ctx_len, ctx_len)  # input sequence
w = torch.zeros(ctx_len, ctx_len)  # attention weights
tril = torch.tril(torch.ones(ctx_len, ctx_len))  # triangular mask for our weights
w = w.masked_fill(tril == 0, float("-inf"))  # mask out upper triangle (we can't access future info)
print(f"{w=}")
w = F.softmax(w, dim=1)  # convert weight values to probs
print(f"{w=}")
print(f"{x=}")
attn_out = w @ x  # weighted aggregation (sum) of past elements (can think of this as self-attn output)
print(f"{attn_out=}")

w=tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])
w=tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])
x=tensor([[ 0.1234,  0.0689, -1.8504, -0.5551],
        [ 1.0161,  0.4646,  0.3244, -0.2380],
        [-0.0536,  1.3074,  0.8214, -1.3427],
        [ 0.9463,  1.3927,  0.7936, -0.6684]])
attn_out=tensor([[ 0.1234,  0.0689, -1.8504, -0.5551],
        [ 0.5698,  0.2668, -0.7630, -0.3965],
        [ 0.3620,  0.6137, -0.2349, -0.7119],
        [ 0.5081,  0.8084,  0.0222, -0.7010]])


In [7]:
"""Self-attention for a Decoder."""

# (Note: here we use a mask to prevent the Decoder from looking ahead (into future) in the sequence.
# This is not necessary for an Encoder.)

# <s Create layers for computing Query, Key, Value tensors
batch_sz, ctx_len, emb_dim = 4, 4, 8
# The head size is the dimensionality of the Query, Key, and Value tensors (previously we assumed
# this was just equal to `ctx_len`). Typically, we can think of `emb_dim` as getting fanned into
# a smaller `head_sz` dimensionality. The greater the `head_sz`, the more granular the information
# a head can capture.
head_sz = 2
# We can think of input `x` into query, key, and value layers with a rough analogy:
# For a given token, we can think of `x` as its private information, it's query as what it wants to know 
# about other tokens, the keys (of itself and all other tokens) as the summary of the tokens, and the 
# values as the detailed information / content about the tokens.
query = nn.Linear(emb_dim, head_sz, bias=False)
key = nn.Linear(emb_dim, head_sz, bias=False)
value = nn.Linear(emb_dim, head_sz, bias=False)
# (Note: we don't use biases for these layers because we want to ensure that the dot product of the
# query and key tensors is the only thing that determines the similarity between the query and key,
# to determine the weighting of values (i.e. the attention out))
# /s>

# <s Get Query, Key, Value tensors
x = torch.randn(batch_sz, ctx_len, emb_dim)
q = query(x)  # -> [batch_sz, ctx_len, head_sz]
k = key(x)  
v = value(x)
# /s>

# <s Get Key-Query attention weights
# First we compute the similarity between each token's query and all other tokens' keys as the dot 
# across `emb_dim` for each batch example over `ctx_len`, scaled by `head_sz` to preserve k, q variance.
k_q_sim = q @ k.transpose(2, 1) / np.sqrt(head_sz)
print(f"{k_q_sim.shape=}")
print(f"{k_q_sim=}")
tril = torch.tril(torch.ones(ctx_len, ctx_len))  # mask out upper triangle (we can't access future info)
k_q_sim = k_q_sim.masked_fill(tril == 0, float("-inf"))
attn_weights = F.softmax(k_q_sim, dim=2)
print(f"{attn_weights.shape=}")
print(f"{attn_weights=}")
# /s>
attn_out = attn_weights @ v  # weighted sum of values
print(f"{attn_out.shape=}")
print(f"{attn_out=}")

k_q_sim.shape=torch.Size([4, 4, 4])
k_q_sim=tensor([[[ 0.3487, -0.1135, -0.0213,  0.0302],
         [ 0.0026, -0.0246,  0.1513,  0.0830],
         [ 0.0861, -0.1113,  0.5255,  0.2975],
         [-0.3584,  0.1128,  0.0470, -0.0173]],

        [[ 0.4830,  0.3114,  0.3637, -0.4686],
         [ 0.0928,  0.0720,  0.0882, -0.1159],
         [ 0.0366,  0.1030,  0.1473, -0.2049],
         [-0.0879, -0.1774, -0.2481,  0.3428]],

        [[-0.1646,  0.0421,  0.0096,  0.0463],
         [-0.3919,  0.4353,  0.1338,  0.7327],
         [-0.1373, -0.4617, -0.1563, -0.8837],
         [-0.8719,  1.2711,  0.3978,  2.1920]],

        [[ 0.1248,  0.3058, -0.1433,  0.3543],
         [ 0.0079, -0.0271,  0.0883,  0.0300],
         [ 0.1354,  0.3058, -0.1007,  0.3888],
         [-0.2005, -0.4268,  0.0951, -0.5797]]], grad_fn=<DivBackward0>)
attn_weights.shape=torch.Size([4, 4, 4])
attn_weights=tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5068, 0.4932, 0.0000, 0.0000],
         [0.2965, 0.2434, 0.4601

In [8]:
"""Create a class for a head of a self-attention unit for a Decoder."""

class Head(nn.Module):
    """Self-attention head."""

    def __init__(self, head_sz, emb_dim):
        """Initialize key, query, value."""
        super().__init__()
        self.head_sz, self.emb_dim = head_sz, emb_dim
        self.key = nn.Linear(emb_dim, head_sz, bias=False)
        self.query = nn.Linear(emb_dim, head_sz, bias=False)
        self.value = nn.Linear(emb_dim, head_sz, bias=False)

    def forward(self, x):
        """Compute self-attention output."""
        _batch_sz, ctx_len, _emb_dim = x.shape
        q = self.query(x)
        k = self.key(x)  # -> [batch_sz, ctx_len, head_sz]
        v = self.value(x)
        k_q_sim = q @ k.transpose(2, 1) / np.sqrt(self.head_sz)  # scaled attention to preserve k, q variance
        tril = torch.tril(torch.ones(ctx_len, ctx_len)).to(device)  # mask out upper triangle (we can't access future info)
        k_q_sim = k_q_sim.masked_fill(tril == 0, float("-inf"))
        attn_weights = F.softmax(k_q_sim, dim=2)
        attn_out = attn_weights @ v  # weighted sum of values
        # Note, if *not* using this in a MultiHead setting, we should project back to emb_dim
        #proj = nn.Linear(head_sz, emb_dim)
        #attn_out = proj(attn_out)
        return attn_out

In [9]:
"""Multi-head self-attention for a Decoder."""

# Multi-head attention is applying multiple self-attention heads in parallel, then concatenating their
# outputs, and projecting them back to `emb_dim`. This allows the model to attend to information from
# different subspaces (representations) of the input simultaneously.

n_heads = 2
heads = nn.ModuleList([Head(head_sz, emb_dim) for _ in range(n_heads)]).to(device)
attn_outs = [head(x.to(device)) for head in heads]  # -> n_heads x [batch_sz, ctx_len, head_sz]
attn_out = torch.cat(attn_outs, dim=2)  # -> [batch_sz, ctx_len, n_heads * head_sz]
proj = nn.Linear(n_heads * head_sz, emb_dim).to(device)  # project back to emb_dim
multi_attn_out = proj(attn_out)
print(f"{multi_attn_out.shape=}")

multi_attn_out.shape=torch.Size([4, 4, 8])


In [10]:
"""Create a class for multi-head self-attention for a Decoder."""

class MultiHead(nn.Module):
    """Multi-head self-attention."""

    def __init__(self, n_heads, head_sz, emb_dim):
        """Initialize heads."""
        super().__init__()
        self.n_heads, self.head_sz, self.emb_dim = n_heads, head_sz, emb_dim
        self.heads = nn.ModuleList([Head(head_sz, emb_dim) for _ in range(n_heads)])
        self.proj = nn.Linear(self.n_heads * self.head_sz, self.emb_dim)  # project back to `emb_dim`

    def forward(self, x):
        attn_outs = [head(x) for head in self.heads]
        attn_out = torch.cat(attn_outs, dim=2)  # concatenate across head dimension
        attn_out = self.proj(attn_out)
        return attn_out

In [11]:
"""Create a class for a Feedforward network that operates on attention outputs."""

# Self-attention operates on the entire input sequence at once, but feedforward layer(s) can be applied
# independently to each token's representation from the self-attention layer to allow the model to
# process and adjust features of each token individually, that might otherwise get diluted from the global
# attention mechanism.

class Feedforward(nn.Module):
    """Feedforward layer."""

    def __init__(self, emb_dim, ff_dim):
        """Initialize weights."""
        super().__init__()
        # Linear layer ReLU sandwich: dim fans out by factor of `ff_dim` and then back to `emb_dim`.
        # ("Position-wise Feed-Forward Networks" in "Attention is All You Need")
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, emb_dim * ff_dim), nn.ReLU(), nn.Linear(emb_dim * ff_dim, emb_dim)
        )

    def forward(self, x):
        return self.layers(x)

In [12]:
"""Create a Transformer block: communication -> computation (self-attention + feedforward + res + norm)."""

# Parts:
#  - Multi-head self-attention
#  - Position-wise feedforward network
#  - Residual connections
#  - Layer normalization (pre-norm formulation)
#  - ~ Weight normalization ~ (not for now)
#  - Dropout

class Block(nn.Module):
    """Transformer block: communication followed by computation."""

    def __init__(self, n_heads, head_sz, emb_dim, ff_dim, dropout):
        super().__init__()
        self.n_heads, self.head_sz, self.emb_dim, self.ff_dim = n_heads, head_sz, emb_dim, ff_dim
        self.self_attn_ln = nn.LayerNorm(emb_dim)  # layer norm pre self-attention
        self.self_attn = MultiHead(n_heads, head_sz, emb_dim)  # multi-head self-attention
        self.self_attn_dropout = nn.Dropout(dropout)  # dropout after self-attention
        self.ff_ln = nn.LayerNorm(emb_dim)  # layer norm pre feedforward
        self.ff = Feedforward(emb_dim, ff_dim)  # position-wise feedforward
        self.ff_dropout = nn.Dropout(dropout)  # dropout after feedforward

    def forward(self, x):
        # layer-norm -> self-attention -> dropout + residual
        x = x + self.self_attn_dropout(self.self_attn(self.self_attn_ln(x)))
        # layer-norm -> feedforward -> dropout + residual
        x = x + self.ff_dropout(self.ff(self.ff_ln(x)))
        return x

In [13]:
"""Create NanoGPT: Decoder-only Transformer."""

# In addition to our Transformer blocks, we need token embedding and positional embedding layers, to compute
# the positional encodings that get passed to the attention units in the transformer blocks.

# We'll also apply weight init.

# We want our output to be [batch_sz, ctx_len, n_tokens], because we want to predict the next token for each 
# token in the context.


class NanoGPT(nn.Module):
    """NanoGPT: Decoder-only Transformer."""

    def __init__(
        self,
        n_tokens=vocab_sz,
        ctx_len=512,
        n_blocks=8,
        n_heads=10,
        head_sz=64,
        emb_dim=512,
        ff_dim=4,
        dropout=0.1,
    ):
        super().__init__()
        (
            self.n_tokens,
            self.ctx_len,
            self.n_blocks,
            self.n_heads,
            self.head_sz,
            self.emb_dim,
            self.ff_dim,
        ) = (n_tokens, ctx_len, n_blocks, n_heads, head_sz, emb_dim, ff_dim)
        if (emb_dim / n_heads / head_sz) != 1:
            warn(f"Ratio of n_heads and head_sz to emb_dim ({emb_dim / n_heads / head_sz}) is not 1")
        self.tok_emb = nn.Embedding(n_tokens, emb_dim)  # to learn token embeddings
        self.pos_emb = nn.Embedding(ctx_len, emb_dim)  # to learn positional embeddings
        self.blocks = nn.Sequential(  # Transformer blocks
            *[Block(n_heads, head_sz, emb_dim, ff_dim, dropout) for _ in range(n_blocks)]
        )
        self.f_ln = nn.LayerNorm(emb_dim)  # final layer norm
        self.f_dropout = nn.Dropout(dropout)  # final dropout
        self.out = nn.Linear(emb_dim, n_tokens)
        self.apply(self.xavier_init)

    @staticmethod
    def xavier_init(module, gain=1):
        """Applies Xavier initialization to all linear and embedding layer weights."""
        if isinstance(module, nn.Linear) or isinstance(module, nn.Embedding):
            nn.init.xavier_normal_(module.weight, gain=gain)

    def forward(self, x):
        batch_sz, ctx_len = x.shape
        # Compute positional encodings
        tok_emb = self.tok_emb(x)  # -> [batch_sz, ctx_len, emb_dim]
        pos_emb = self.pos_emb.weight[0:ctx_len]  # -> [ctx_len, emb_dim]
        pos_enc = tok_emb + pos_emb  # -> [batch_sz, ctx_len, emb_dim]
        # Go through transformer blocks and final linear layer
        logits = self.out(self.f_dropout(self.f_ln(self.blocks(pos_enc))))
        return logits

In [14]:
"""Create a function for generating output from the model."""

def generate(model, tokens, in_txt=None, n_tokens=100, temp=1.0, top_k=None, seed=42, print_gen=True):
    """Generate text from a nanoGPT model."""
    # Set a random seed for generation
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Create token_to_int, int_to_token dicts.
    token_to_int = {t: i for i, t in enumerate(tokens)}
    int_to_token = {i: t for t, i in token_to_int.items()}

    # Process input_text if provided, else start with "\n".
    if in_txt is not None:
        # Convert input text to tokens and encode.
        encode = lambda tokens: [token_to_int[t] for t in tokens]
        in_tkns = encode(in_txt)
        input_len = len(in_tkns)
        # Initialize output starting with input text.
        x = torch.zeros((input_len + n_tokens,), dtype=torch.long).to(device)
        x[:input_len] = torch.tensor(in_tkns, dtype=torch.long).to(device)
    else:
        # Initialize output starting with "\n".
        x = torch.zeros((1 + n_tokens,), dtype=torch.long).to(device)
        x[0] = token_to_int["\n"]
        input_len = 1
    assert len(x) <= model.ctx_len, (
        f"Generated length {len(x) + n_tokens} would exceed model context length {model.ctx_len}."
    )

    # Run inference (generation) in eval mode
    model.eval()
    with torch.no_grad():
        first_gen_idx, last_gen_idx = input_len - 1, input_len + n_tokens - 1
        for i in range(first_gen_idx, last_gen_idx):  # start gen after `input_len`
            model_first_ctx = 0 if i < model.ctx_len else i - model.ctx_len + 1
            logits = model(x[model_first_ctx:(i + 1)].unsqueeze(0))  # feed in `x` with a batch_sz of 1
            # Get logits for just `len(tokens)` (squeeze out ctx_len), and scale by temp
            logits = logits[:, -1, :] / temp
            if top_k is not None:  # limit to top_k most likely tokens
                top_vals, top_idxs = logits.topk(top_k, dim=1)
                probs = F.softmax(top_vals, dim=1)  # compute top_k probs
                next_tkn_int = top_idxs.gather(1, torch.multinomial(probs, 1))  # sample top_k probs
            else:
                probs = F.softmax(logits, dim=1)  # compute probs for all tokens
                next_tkn_int = torch.multinomial(probs, 1)  # sample from probs
            x[i + 1] = next_tkn_int
            if print_gen:
                print(int_to_token[next_tkn_int.item()], end="")

    # Decode `x` and return it.
    decode = lambda ints: "".join([int_to_token[i] for i in ints])
    return decode(x.tolist())

In [15]:
"""Build model, view its layers and parameters, and sample generation."""

nanogpt = NanoGPT().to(device)

print(nanogpt)
n_params_tot = 0
for name, parameter in nanogpt.named_parameters():
    if not parameter.requires_grad:
        continue
    n_params = parameter.numel()
    print(f"{name=}: {n_params=}")
    n_params_tot += n_params
print(f"\n{n_params_tot / 1e6} M parameters total\n")

print("Generating sample...\n")
in_txt = (
    "Wherefore art thou, Romeo? "
    "We are such stuff as dreams are made on. " 
    "The course of true love never did run smooth."
)
gen = generate(nanogpt, tokens, in_txt=in_txt, n_tokens=200)

  warn(f"Ratio of n_heads and head_sz to emb_dim {emb_dim / n_heads / head_sz} is not 1")


NanoGPT(
  (tok_emb): Embedding(65, 512)
  (pos_emb): Embedding(512, 512)
  (blocks): Sequential(
    (0): Block(
      (self_attn_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (self_attn): MultiHead(
        (heads): ModuleList(
          (0-9): 10 x Head(
            (key): Linear(in_features=512, out_features=64, bias=False)
            (query): Linear(in_features=512, out_features=64, bias=False)
            (value): Linear(in_features=512, out_features=64, bias=False)
          )
        )
        (proj): Linear(in_features=640, out_features=512, bias=True)
      )
      (self_attn_dropout): Dropout(p=0.1, inplace=False)
      (ff_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): Feedforward(
        (layers): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (ff_dropout): Dropout(p=0.1, inplace=Fa

In [17]:
"""Create a DataLoader."""

def build_dataset(txtfile, ctx_len):
    """Build dataset from text file."""
    with open(txtfile) as f:
        text = f.read()
    tokens = sorted(list(set(text)))
    token_to_int = {t: i for i, t in enumerate(tokens)}
    encode = lambda tokens: [token_to_int[t] for t in tokens]
    data = torch.tensor(encode(text), dtype=torch.long)
    n_chars = len(text)
    n_examples = n_chars - ctx_len
    idxs = torch.arange(ctx_len + 1).unsqueeze(0) + torch.arange(n_examples).unsqueeze(1)
    X, Y = data[idxs[:, :-1]], data[idxs[:, 1:]]
    return X, Y


X, Y = build_dataset(Path.cwd() / "data/tiny_shakespeare.txt", nanogpt.ctx_len)
dataset = TensorDataset(X, Y)
train_data, test_data, val_data = random_split(dataset, [0.9, 0.05, 0.05])
batch_sz = 32
train_loader = DataLoader(train_data, batch_size=batch_sz, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_sz, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_sz, shuffle=True)

In [18]:
"""Create a train function."""

def apply_gradient_centralization(optimizer):
    """Applies gradient centralization to the optimizer.

    This function should be called before optimizer.step() in the training loop.
    """
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is not None:
                # Compute the mean of the gradient
                grad_mean = param.grad.data.mean(dim=tuple(range(1, len(param.grad.shape))), keepdim=True)
                # Centralize the gradient
                param.grad.data -= grad_mean


def train(
    model: nn.Module,  # model
    train_loader: DataLoader,  # batched dataset for training
    val_loader: DataLoader,  # batched dataset for validation
    optimizer: optim,  # optimizer
    loss_fn: nn.modules.loss,  # loss function
    max_epochs: int = 2,  # max n training epochs
    max_batches: int = 1e9,  # max n batches to train
    val_chk_interval: int = 200,  # check val loss every `val_chk_interval` batches and print losses
    val_iter: int = 5,  # number of batches on val_loader to run and avg when computing val loss
    patience_thresh: int = 1e9,  # consecutive batches without val loss decrease for early stopping
    save_chkpt_dir: str = "",  # dir to save model checkpoint
    save_chkpt_thresh: float = 0.5,  # save model checkpoint every `save_chkpt_interval` loss decrease
) -> tuple[torch.Tensor, np.ndarray, np.ndarray]:  # -> loss, train_losses, val_losses
    """Trains a model, returns loss."""
    # <s Nested helper functions to make `train` more readable.
    def print_losses(epoch, batch_i, train_losses_avg, val_losses_avg):
        """Print current average losses."""
        print(
            f"Epoch {epoch + 1}: Batch {batch_i + 1}:  "
            f"Loss = {train_losses_avg[-1]:.3f}, Val Loss = {val_losses_avg[-1]:.3f}"
        )

    @torch.no_grad()
    def estimate_losses(model, val_loader, val_losses, val_losses_avg, train_losses, train_losses_avg):
        """Estimate losses on val_loader, and return val loss and train loss avg."""
        model.eval()
        for val_i, (x_val, y_val) in enumerate(val_loader):
            logits = model(x_val.to(device))
            val_loss = loss_fn(logits.view(-1, n_tokens), y_val.to(device).view(-1))
            val_losses.append(val_loss.item())
            if val_i >= (val_iter - 1):
                break
        val_losses_avg.append(np.mean(val_losses[-val_iter:]))
        train_losses_avg.append(np.mean(train_losses[-val_chk_interval:]))
        model.train()
    # /s>

    # <s Trackers
    ctx_len, n_tokens  = model.ctx_len, model.n_tokens
    batch_sz, n_batches = train_loader.batch_size, len(train_loader)
    batch_lim = min(max_batches, n_batches * max_epochs)
    patience_thresh *= val_chk_interval  # convert to batches within model validation block
    train_losses, val_losses, train_losses_avg, val_losses_avg = [], [], [], []
    init_loss, best_val_loss = float("inf"), float("inf")
    patience_ct = 0
    # /s>

    # <s Training loop
    for epoch in range(max_epochs):
        pbar = tqdm(enumerate(train_loader), total=batch_lim, desc="Batch progression")  # tqdm progress bar
        for batch_i, (x_train, y_train) in pbar:
            # <ss Model training.
            optimizer.zero_grad()
            logits = model(x_train.to(device))  # -> [batch_sz, ctx_len, n_tokens], but...
            # must reshape to compare against batch_sz vector of targets for cross-entropy loss.
            loss = loss_fn(logits.view(-1, n_tokens), y_train.to(device).view(-1))
            loss.backward()
            apply_gradient_centralization(optimizer)
            optimizer.step()
            train_losses.append(loss.item())
            # /ss>
            # <ss Model validation.
            if val_chk_interval and batch_i % val_chk_interval == 0:
                # Estimate and print losses.
                estimate_losses(
                    model, val_loader, val_losses, val_losses_avg, train_losses, train_losses_avg
                )
                print_losses(epoch, batch_i, train_losses_avg, val_losses_avg)
                pbar.set_postfix_str(f"Total Batch {(batch_i + 1) * (epoch + 1)} / {batch_lim}")
                # Patience check for early stopping.
                patience_ct = (
                    0 if val_losses_avg[-1] < best_val_loss else patience_ct + val_chk_interval
                )
                best_val_loss = min(best_val_loss, val_losses_avg[-1])
                if patience_ct >= patience_thresh:
                    print("Early stopping.")
                    print_losses(epoch, batch_i, train_losses_avg, val_losses_avg)
                    return loss, train_losses_avg, val_losses_avg
            # Max batch check.
            if (batch_i + 1) * (epoch + 1) >= max_batches:
                print("Finished training:")
                print_losses(epoch, batch_i, train_losses_avg, val_losses_avg)
                return loss, train_losses_avg, val_losses_avg
            # Save checkpoint check.
            if (Path(save_chkpt_dir).exists()) and (init_loss - loss.item() > save_chkpt_thresh):
                torch.save(model.state_dict(), Path(save_chkpt_dir) / f"model_chkpt_loss{loss.item():.3f}.pth")
                init_loss = loss.item()
            # /ss> /s>

    print("Finished training:")
    print_losses(epoch, batch_i, train_losses_avg, val_losses_avg)
    return loss, train_losses_avg, val_losses_avg

In [19]:
"""Train and eval for just a few batches."""

save_chkpt_dir = Path.cwd() / "models/shakespeare_chkpts"
adam = torch.optim.AdamW(nanogpt.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
loss, train_losses, val_losses = train(
    nanogpt, train_loader, val_loader, adam, loss_fn, max_batches=1000, save_chkpt_dir=save_chkpt_dir
)

Batch progression:   0%|          | 1/1000 [00:01<29:12,  1.75s/it, Total Batch 1 / 1000]

Epoch 1: Batch 1:  Loss = 5.083, Val Loss = 5.909


Batch progression:  20%|██        | 201/1000 [01:30<09:18,  1.43it/s, Total Batch 201 / 1000]

Epoch 1: Batch 201:  Loss = 2.982, Val Loss = 2.504


Batch progression:  40%|████      | 401/1000 [02:59<06:46,  1.47it/s, Total Batch 401 / 1000]

Epoch 1: Batch 401:  Loss = 2.500, Val Loss = 2.436


Batch progression:  60%|██████    | 601/1000 [04:27<04:35,  1.45it/s, Total Batch 601 / 1000]

Epoch 1: Batch 601:  Loss = 2.413, Val Loss = 2.318


Batch progression:  80%|████████  | 801/1000 [05:56<02:19,  1.43it/s, Total Batch 801 / 1000]

Epoch 1: Batch 801:  Loss = 2.236, Val Loss = 2.036


Batch progression: 100%|█████████▉| 999/1000 [07:22<00:00,  2.26it/s, Total Batch 801 / 1000]

Finished training:
Epoch 1: Batch 1000:  Loss = 2.236, Val Loss = 2.036





In [20]:
"""Post-training generation."""

in_txt = (
    "Wherefore art thou, Romeo? "
    "We are such stuff as dreams are made on. "
    "The course of true love never did run smooth."
)

gen = generate(nanogpt, tokens, in_txt=in_txt, n_tokens=200, top_k=50, temp=0.5, seed=42)



KING RICHARD IIIII:
Why, hould come to him word the worldow,
The man so for the crion of your buther.

KING HENRY VI:
Not a more so the will the dough all the see.

GLOUCESTER:
Nor what more the con

In [21]:
"""Save / load model."""

# Save
torch.save(nanogpt.state_dict(), Path.cwd() / "models/nanogpt_shakespeare.pth")
with open(Path.cwd() / "models/nanogpt_shakespeare_config.json", "w") as f:
    json.dump(model_config, f)

# Wait a sec, then try and load
time.sleep(1)
with open(Path.cwd() / "models/nanogpt_shakespeare_config.json", "w") as f:
    model_config = json.load(f)
nanogpt = NanoGPT(
    n_tokens=model_config["n_tokens"],
    ctx_len=model_config["ctx_len"],
    n_blocks=model_config["n_blocks"],
    n_heads=model_config["n_heads"],
    head_sz=model_config["head_sz"],
    emb_dim=model_config["emb_dim"],
    ff_dim=model_config["ff_dim"],
    dropout=model_config["dropout"],
).to(device)
nanogpt.load_state_dict(torch.load((Path.cwd() / "models/nanogpt_shakespeare.pth")))

  warn(f"Ratio of n_heads and head_sz to emb_dim {emb_dim / n_heads / head_sz} is not 1")


<All keys matched successfully>