# Mini GPT with Tiny Shakespeare - Karpathy Tutorial

In this notebook we have a very well documented and commented version of mini GPT, a decoder only transformer model. The goal of which we aim to show that this model can learn already in a limited fashion be quite powerful. The goal of this is learning first and foremost though.

In [13]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

torch.manual_seed(0)

<torch._C.Generator at 0x106ee1630>

In [14]:
batch_size = 64 # How many independent sequences will we process in parallel?
block_size = 256 # What is the maximum context length for predictions?
max_iters = 5000 # The total number of training iterations.
eval_interval = 500 # How often to evaluate the model's performance.
learning_rate = 3e-4 # The step size for our optimizer.
eval_iters = 200 # Number of batches to average for loss estimation.
n_embd = 384 # The dimensionality of the token embeddings.
n_head = 6 # The number of attention heads.
n_layer = 6 # The number of transformer blocks.
dropout = 0.2 # The probability of dropping out neurons during training.

# Automatically select the best available device (CUDA, MPS, or CPU).
device = None
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"Using device: {device}")

Using device: mps


## Data Loading and Preparation

We'll use the Tiny Shakespeare dataset. We first need to load the text and create a vocabulary of all unique characters. Then, we'll create functions to encode a string into a sequence of integers (tokens) and decode a sequence of tokens back into a string.

In [15]:
# You may need to download the data first
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Find all unique characters in the text.
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Create a mapping from characters to integers (stoi) and vice-versa (itos).
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # Encoder: takes a string, outputs a list of integers.
decode = lambda l: ''.join([itos[i] for i in l]) # Decoder: takes a list of integers, outputs a string.

# Convert the entire dataset into a tensor of tokens.
data = torch.tensor(encode(text), dtype=torch.long)

# Split the data into training (90%) and validation (10%) sets.
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

## Data Batching

This function, `get_batch`, generates a small, random batch of data. For each sequence in the batch, the input `x` is a chunk of text, and the target `y` is the same chunk shifted by one character. This is how the model learns to predict the next character.

In [16]:
def get_batch(split):
    # Select the appropriate dataset (train or val).
    data = train_data if split == 'train' else val_data
    # Generate random starting indices for the batches.
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # Create the input sequences (x).
    x = torch.stack([data[i:i+block_size] for i in ix])
    # Create the target sequences (y), which are shifted by one position.
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    # Move the data to the selected device.
    x, y = x.to(device), y.to(device)
    return x, y

## Loss Estimation

To avoid noisy loss measurements, we estimate the loss by averaging it over multiple batches. This function is decorated with `@torch.no_grad()` to tell PyTorch not to calculate gradients, which saves memory and computation since we're only evaluating, not training.

In [17]:
@torch.no_grad()
def estimate_loss():
    out = {}
    # Set the model to evaluation mode.
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    # Set the model back to training mode.
    model.train()
    return out

## The Transformer Model: A Deep Dive

Now we'll build the Transformer model, piece by piece.

### Self-Attention Head

Self-attention is the core mechanism of the Transformer. It allows tokens to interact with each other and weigh their importance. Each token produces a **Query** (what I'm looking for), a **Key** (what I contain), and a **Value** (what I'll communicate). The attention score is calculated by taking the dot product of a token's Query with every other token's Key. This score is then scaled and passed through a softmax function to get the weights. Finally, the output is a weighted sum of all tokens' Values.

The mathematical formula is: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

In [18]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        # Linear projections for Key, Query, and Value.
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # A buffer for the triangular mask, not a model parameter.
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)

        # Compute attention scores ("affinities").
        # The dot product between queries and keys determines the attention weights.
        # We scale by sqrt(d_k) to prevent the softmax from becoming too saturated.
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        # Apply the causal mask to prevent tokens from attending to future tokens.
        # This is crucial for a decoder-style language model.
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        # Normalize the scores to get weights.
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # Perform the weighted aggregation of the values.
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

### Multi-Head Attention

Instead of a single attention mechanism, Transformers use multiple attention "heads" in parallel. Each head can learn to focus on different types of relationships between tokens. The outputs of all heads are concatenated and projected back to the original embedding dimension.

In [19]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        # Create a list of attention heads.
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # A linear layer to project the concatenated head outputs.
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Concatenate the outputs of each head along the last dimension.
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # Project the result back to the embedding dimension.
        out = self.dropout(self.proj(out))
        return out

### Feed-Forward Network

After the attention mechanism gathers information, a simple feed-forward network processes this information for each token independently. It consists of two linear layers with a ReLU activation in between. This allows the model to perform more complex computations on the aggregated information.

In [20]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), # The inner layer is typically 4x larger.
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # Project back to the embedding dimension.
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

### Transformer Block

A Transformer block combines multi-head attention and a feed-forward network. It also includes two important features: residual connections and layer normalization. 

- **Residual Connections**: The input to a sub-layer (like attention) is added to its output. This helps prevent the vanishing gradient problem in deep networks.
- **Layer Normalization**: This stabilizes the training by normalizing the features for each token across the embedding dimension.

In [21]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # The input is first normalized, then passed through self-attention.
        # A residual connection adds the original input to the attention output.
        x = x + self.sa(self.ln1(x))
        # The result is normalized again and passed through the feed-forward network.
        # Another residual connection is applied.
        x = x + self.ffwd(self.ln2(x))
        return x

### GPT Model

Finally, we assemble all the components into the full GPT model. This includes:

- **Token Embedding Table**: Converts input token indices into dense vectors (embeddings).
- **Positional Embedding Table**: Since self-attention is permutation-invariant, we add positional embeddings to give the model information about the order of tokens.
- **A Sequence of Transformer Blocks**: The core of the model where the processing happens.
- **A Final Layer Norm and Linear Head**: To produce the final output logits over the vocabulary.

In [22]:
class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        # Each token directly reads off the logits for the next token from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # A sequence of transformer blocks.
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # Final layer norm.
        self.lm_head = nn.Linear(n_embd, vocab_size) # The head that produces logits.

        self.apply(self._init_weights)

    def _init_weights(self, module):
        # A common practice for initializing weights in transformer models.
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Get token and position embeddings.
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        # Add them together to give the model positional information.
        x = tok_emb + pos_emb # (B,T,C)
        # Pass through the transformer blocks.
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        # Get the final logits.
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            # Reshape logits and targets for the cross-entropy loss function.
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # This method generates new tokens autoregressively.
        for _ in range(max_new_tokens):
            # Crop the context to the last block_size tokens to save computation.
            idx_cond = idx[:, -block_size:]
            # Get the predictions for the next token.
            logits, loss = self(idx_cond)
            # Focus only on the last time step's logits.
            logits = logits[:, -1, :] # becomes (B, C)
            # Apply softmax to get probabilities.
            probs = F.softmax(logits, dim=-1) # (B, C)
            # Sample from the distribution to get the next token.
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # Append the sampled index to the running sequence.
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

## Model Training

Now we can instantiate the model and the optimizer. We'll use the AdamW optimizer, which is a standard choice for training Transformers. The training loop will repeatedly sample a batch of data, calculate the loss, and update the model's parameters.

In [23]:
model = GPT()
m = model.to(device)
# Print the number of parameters in the model.
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# Create a PyTorch optimizer.
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

10.788929 M parameters


In [None]:
for iter in tqdm(range(max_iters)):

    # Every once in a while, evaluate the loss on train and val sets.
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # Sample a batch of data.
    xb, yb = get_batch('train')

    # Evaluate the loss.
    logits, loss = model(xb, yb)
    # Reset gradients from the previous iteration.
    optimizer.zero_grad(set_to_none=True)
    # Compute gradients for this batch (backpropagation).
    loss.backward()
    # Update the model's parameters.
    optimizer.step()

  0%|          | 0/5000 [00:00<?, ?it/s]

step 0: train loss 4.2134, val loss 4.2147


  2%|▏         | 122/5000 [02:49<58:15,  1.40it/s]  

## Text Generation

After training, we can use our model to generate new text. We start with a single token (a newline character in this case) and let the model predict the next token, which we then feed back into the model to predict the next one, and so on.

In [None]:
# Start generation with a single token (0 is the newline character).
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate and decode the output.
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))