## Clean GPT2

# What is GPT-2?

## Introduction to GPT-2

[GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), which stands for Generative Pre-trained Transformer 2, was released by OpenAI in February 2019 and marked a pivotal moment in natural language processing (NLP) and the entire broader AI community. This transformer-based language model demonstrated unprecedented capabilities in generating coherent, human-like text, setting new standards for what AI could achieve in language understanding and generation.

## How GPT-2 Works
At its core, GPT-2 operates like a highly sophisticated text prediction system. Given a sequence of words, it predicts the most likely next word. The model achieves this through:

1. A transformer architecture that processes text using [attention](https://arxiv.org/pdf/1706.03762) mechanisms, allowing it to consider relationships between words regardless of their distance in the text
2. **Unsupervised learning** on a massive dataset of internet text, enabling it to capture patterns in human language without explicit training labels
3. A large-scale architecture (up to 1.5 billion parameters) that can store and utilize complex language patterns.

- Transformer-based: GPT-2 is built upon the [Transformer architecture](https://arxiv.org/pdf/1706.03762), a neural network design that excels at understanding context and relationships within sequences of data, like text. This architecture allows it to process and generate text much more effectively than previous models.
- Generative: The "G" in GPT-2 stands for "generative." This means it's designed to generate new text that is coherent and often remarkably similar to human-written text. It can continue a story, answer questions, write poems, translate languages, and perform various other text-based tasks.
- Pre-trained: The "P" stands for "pre-trained." GPT-2 was trained on a massive dataset of text from the internet (estimated to be around 40GB of text from 8 million web pages). This pre-training allows it to learn the nuances of language, including grammar, facts about the world, and even different writing styles.
- Unsupervised Learning: It learns patterns and structures in the data without explicit human labeling, simply by trying to predict the next word in a sequence, repeatedly extending the text one token at a time ("autoregressive").


Significance of GPT-2:

- **Unprecedented Text Generation Quality:** GPT-2 was a significant leap forward in the quality of text generated by AI. Its output was often surprisingly coherent, fluent, and contextually relevant, sometimes even difficult to distinguish from text written by a human. This sparked excitement and concern about the potential implications of such technology.
- **Phased Release and Ethical Considerations:** OpenAI initially withheld the full version of GPT-2 due to concerns about potential misuse. They were worried about its ability to generate convincing fake news, spam, and other forms of malicious content. This phased release, where progressively larger versions were released over time, was a novel approach in the AI community and highlighted the growing ethical considerations surrounding powerful AI models. It triggered a wider discussion about responsible AI development and deployment.
- **Impact on NLP Research:** GPT-2 served as a catalyst for further research in natural language processing. It demonstrated the power of large-scale, pre-trained Transformer models and inspired the development of even larger and more sophisticated models like GPT-3, BERT, LaMDA, and others. It shifted the focus of NLP research towards scaling up models and datasets.
- **Zero-Shot and Few-Shot Learning:** GPT-2 showed promising results in zero-shot and few-shot learning.
    - Zero-shot learning: The model could perform tasks it wasn't explicitly trained for, simply by being given a description or a few examples in the prompt.
    - Few-shot learning: With just a few examples, the model could quickly adapt to new tasks. This ability to generalize to new tasks with minimal training data was a significant advancement.
- **Applications and Potential**: GPT-2, and subsequent models, have paved the way for numerous applications, including:
    - Chatbots and conversational AI: More engaging and human-like conversations.
    - Content creation: Assisting with writing articles, stories, scripts, and marketing copy.
    - Code generation: Helping programmers write and debug code.
    - Language translation: Improving the accuracy and fluency of machine translation.
    - Text summarization: Generating concise summaries of large amounts of text.


## Historical Significance

GPT-2's release was notable for several reasons:

- It demonstrated that scaling up model size and training data could lead to qualitatively predictably better performance ("[Scaling Laws](https://arxiv.org/pdf/2001.08361#page=3&org=openai)")
- It challenged the prevailing wisdom in AI research by showing that simple architectures at scale could outperform complex architectural innovations
- Its capabilities were significant enough that OpenAI initially delayed the full release due to concerns about potential misuse
- It helped establish the foundation for modern language models and showed the potential of unsupervised learning for NLP tasks

## Impact on AI Development
GPT-2's success triggered a series of transformative changes in the field:
- **The Scaling Race** - GPT-2 sparked an industry-wide competition to build increasingly larger models. This "scaling race" led to rapid advancements, with organizations like Google, OpenAI, and Anthropic pushing boundaries in model size and capability. The focus shifted from architecture innovation to scaling existing architectures effectively.

- **Emergence of Meta Learning** - One of GPT-2's most surprising discoveries was its ability to perform "[few-shot learning](https://arxiv.org/pdf/2005.14165)" – adapting to new tasks with minimal explicit instruction. This phenomenon, later explored more deeply with GPT-3 and GPT-4, suggested that large language models could develop *meta-learning* capabilities, learning how to learn during pre-training.

- **Emergent Capabilities** - GPT-2 began revealing what we now call "emergent abilities" – capabilities that appear suddenly above certain scale thresholds. This observation, formally documented in the [GPT-4 technical report](https://arxiv.org/pdf/2303.08774), suggested that scaling language models could lead to qualitatively new behaviors that are difficult to predict in advance. For instance, the ability to perform basic arithmetic or follow implicit reasoning steps emerged without explicit training for these tasks.


## Broader Implications

The success of GPT-2 influenced the development of multimodal models, showing how scaling could benefit other domains beyond text
It sparked important discussions about AI safety and ethics, leading to more thoughtful release strategies for powerful AI systems
The model demonstrated that unsupervised pre-training could capture significant world knowledge, laying groundwork for future work in knowledge representation and reasoning.

These developments fundamentally changed how researchers and organizations approach AI development, shifting focus from small, specialized models to large, general-purpose systems capable of emergent behaviors and meta-learning.

# Setup (don't read)

In [76]:
import math
import os
import sys
import webbrowser
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import datasets
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate

# Run on "mps" (Mac M series) or "cuda" (NVIDIA GPUs) if available, else run on "cpu"
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing: The Tokenizer 
GPT2's input is natural language (i.e. a sequence of characters, strings, etc), but ML models usually take in vectors as input. To convert natural language into vectors, the **tokenizer** splits up the lanuguage into units called **tokens**, and then converts the list of tokens into vectors. 

### Splitting language to tokens
A token is a substring that is a member of the **vocabulary**  set. But what is a good implementation for how to create a **vocabulary**?

Can we take a set of all every word in every dictionary ever made, and have each word be a token? No, this wouldn't allow us to be able to handle arbitary text (i.e. typos, punctuations, URLs, etc). 

Could we just use every characters available in the keyboard? No, this loses relational meaning within words (i.e. "language" is more meaningful than "gangeula")

The most common practice is called **Byte-Pair encodings**. This solves the above two questions by providing us with a general way of splitting langague that is also efficient. However, it far from a perfect system as it is the source of many bugs (i.e. being bad at counting). 

High-Level algorithm:
1. Start with a inital vocabulary of all individual characters as tokens
2. Find the most common pair of tokens in the text, merge this pair into a new token, and re-tokenize the text with the new token
3. Repeat step 2 until you reach a desired vocabulary size or no more pairs can be merged

<details>
<summary>Note: Space (" ") counts as a character and therefore merges with space are very common</summary>

```python
import tiktoken 

tokenizer = tiktoken.get_encoding('gpt2')

print(tokenizer.encode(" a")) # [257]
print(tokenizer.encode("a")) # [64]
print(tokenizer.encode("a ")) # [64, 220]
print(tokenizer.encode(" i")) # [1312]
print(tokenizer.encode("i")) # [72]
print(tokenizer.encode("i ")) # [72]
```

</details>

### Converting tokens into vectors
This process is pretty straight-forward. We can convert each token to a **one-hot encoding** of the vocabulary. **One-hot encoding** vectors are filled with zeros at ever position, except in the position corresponding to the token's index in the vocabulary. 

A key inuition about **one-hot encodings** is they allow you to think of each integer independently. 

$$
\begin{aligned}
t_i &= (0, \dots, 0, 1, 0, \dots, 0) \quad \text{is the one-hot encoding for the }i\text{th token (length }d_{vocab}\text{)} \\
\\
\end{aligned}
$$

<details>
<summary>Not ideal things about tokenization</summary>

**Capitalization and Leading spaces matter** 

```python
import tiktoken 

tokenizer = tiktoken.get_encoding('gpt2')

print(tokenizer.encode("Michael")) # [13256]
print(tokenizer.encode(" Michael")) # [3899]
print(tokenizer.encode(" michael")) # [285, 40302]
print(tokenizer.encode("michael")) # [76, 40302]
```

**Arithmetic does not sense**
Common numbers are bundle together.

```python
import tiktoken 

tokenizer = tiktoken.get_encoding('gpt2')

print(tokenizer.encode("56873+3184623=123456789-1000000000")) # [49211, 4790, 10, 36042, 3510, 1954, 28, 10163, 2231, 3134, 4531, 12, 16, 10535, 830]
```

</details>

In this notebook, we will not be implementing the tokenizer from scatch. Instead we will be importing OpenAI's [tiktoken](https://github.com/openai/tiktoken) library to use their offical tokenizer. 

To see a full walkthrough of implementing a tokenizer check out Karthapthy's video: [Let's build the GPT Tokenizer](https://www.youtube.com/watch?v=zduSFxRajkE).

In [77]:
import tiktoken

tokenizer = tiktoken.get_encoding('gpt2')

reference_text = "A day without laughter is a day" # "A day without laughter is a day wasted" - Charlie Chapin
print("Reference text: " + reference_text)

tokens = tokenizer.encode(reference_text)
print("Tokenized sequence: " + str(tokens))

reconstructed_reference_text = tokenizer.decode(tokens)
print("Reconstructed reference text: " + reconstructed_reference_text)

Reference text: A day without laughter is a day
Tokenized sequence: [32, 1110, 1231, 20263, 318, 257, 1110]
Reconstructed reference text: A day without laughter is a day


# Config

In [78]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12


cfg = Config()

In [79]:
reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,  # you'll learn about these arguments later!
)

for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(f"{name:18} {tuple(param.shape)}")

Loaded pretrained model gpt2-small into HookedTransformer
embed.W_E          (50257, 768)
pos_embed.W_pos    (1024, 768)
blocks.0.ln1.w     (768,)
blocks.0.ln1.b     (768,)
blocks.0.ln2.w     (768,)
blocks.0.ln2.b     (768,)
blocks.0.attn.W_Q  (12, 768, 64)
blocks.0.attn.W_O  (12, 64, 768)
blocks.0.attn.b_Q  (12, 64)
blocks.0.attn.b_O  (768,)
blocks.0.attn.W_K  (12, 768, 64)
blocks.0.attn.W_V  (12, 768, 64)
blocks.0.attn.b_K  (12, 64)
blocks.0.attn.b_V  (12, 64)
blocks.0.mlp.W_in  (768, 3072)
blocks.0.mlp.b_in  (3072,)
blocks.0.mlp.W_out (3072, 768)
blocks.0.mlp.b_out (768,)
ln_final.w         (768,)
ln_final.b         (768,)
unembed.W_U        (768, 50257)
unembed.b_U        (50257,)


# Embedding
## Token Embedding

Simply lookup table from token to embedding vector. 

In [80]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]


## Positional Embedding

Similar to token embeddings, positional embeddings are a lookup table where the indices are simply the positions (0, 1, 2, ...) of tokens in the sequence rather than token IDs.

In [81]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)

## Layer Norm

Layer normalization is a technique used to stabilize the learning process in deep neural networks. It normalizes the input to each layer by subtracting the mean and dividing by the standard deviation, which helps in stabilizing the gradients and preventing overfitting.

In [82]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b

# Attention

## Causal Mask

The causal mask ensures that each position in the sequence can only attend to previous positions and itself.
This is crucial for maintaining the autoregressive property during training and inference.

For example, when predicting the 3rd token, the model should only look at tokens 1 and 2,
not tokens 4 and beyond which would leak information from the future.

The mask is implemented as a triangular matrix where:
- The diagonal and lower triangle contain 1's (allowing attention)
- The upper triangle contains 0's (blocking attention)

During attention score calculation, the 0's are converted to negative infinity,
which become 0 after the softmax operation, effectively preventing attention to future tokens.


In [83]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def apply_causal_mask(
        self,
        attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"],
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

## Attention Heads

Each attention head performs the following steps:

1. Create attention pattern:
   - Map input to queries and keys (shape: [batch, seq_pos, head_idx, d_head])
   - Compute attention scores by taking dot product of queries and keys
   - Scale scores by dividing by sqrt(d_head) to prevent vanishing gradients
   - Apply causal mask to ensure tokens only attend to past/present
   - Apply softmax to get attention probabilities

2. Use attention pattern to aggregate information:
   - Map input to values (shape: [batch, seq_pos, head_idx, d_head]) 
   - Weight and sum values according to attention probabilities
   - Combine results across heads to get final output (shape: [batch, seq_pos, d_model])

The attention mechanism allows each token to dynamically focus on relevant past tokens,
with the scaling and masking ensuring stable training and causality.

In [84]:
import circuitsvis as cv
from IPython.display import display

reference_text = "I am a powerful language model trained on massive amounts of text data. Soon I will learn to understand and generate human-like content!"
tokens = reference_gpt2.to_tokens(reference_text).to(device)
logits, cache = reference_gpt2.run_with_cache(tokens)

# Attention pattern of layer 0 token 0
display(
    cv.attention.attention_patterns(
        tokens=reference_gpt2.to_str_tokens(reference_text), attention=cache["pattern", 0][0]
    )
)

![Attention Architecture](../nanoGPT/transformer-attention-architecture.png)

In [85]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def forward(self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # Calculate query, key and value vectors
        q = (
            einops.einsum(
                normalized_resid_pre, self.W_Q, "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
            )
            + self.b_Q
        )
        k = (
            einops.einsum(
                normalized_resid_pre, self.W_K, "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
            )
            + self.b_K
        )
        v = (
            einops.einsum(
                normalized_resid_pre, self.W_V, "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
            )
            + self.b_V
        )

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q, k, "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K"
        )
        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head**0.5)
        attn_pattern = attn_scores_masked.softmax(-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v, attn_pattern, "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head"
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = (
            einops.einsum(z, self.W_O, "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model")
            + self.b_O
        )

        return attn_out

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

# MLP (Multi-Layer Perceptron)

The MLP component implements a standard feedforward neural network architecture with a single hidden layer and GELU activation function. Following transformer conventions, the hidden dimension is typically 4x the model dimension (d_mlp = 4 * d_model).

A fundamental property of MLPs is that they process each position in the residual stream independently and identically - unlike attention layers, they don't transfer information between positions. This makes them ideal for processing information that attention has already gathered to specific positions.

MLPs can be understood through multiple lenses:
1. As key-value memory systems: Input weights act as "keys" detecting specific features, while output weights serve as "values" that get activated
2. As knowledge storage: MLPs store learned patterns and information, processing inputs to write derived information into the residual stream
3. As memory managers: Certain neurons may help manage the residual stream's capacity by selectively erasing specific vector components

The exact implementation uses GELU activation which empirically performs well, though the specific activation function isn't conceptually critical. The architecture follows historical precedent set by early transformer models.

![MLP Architecture](../nanoGPT/mlp-architecture.png)



In [86]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        pre = (
            einops.einsum(
                normalized_resid_mid, self.W_in, "batch position d_model, d_model d_mlp -> batch position d_mlp"
            )
            + self.b_in
        )
        post = gelu_new(pre)
        mlp_out = (
            einops.einsum(post, self.W_out, "batch position d_mlp, d_mlp d_model -> batch position d_model")
            + self.b_out
        )
        return mlp_out

# Transformer Block

![Transformer Block](../nanoGPT/transformer-architecture.png)

In [87]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre: Float[Tensor, "batch position d_model"]) -> Float[Tensor, "batch position d_model"]:
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post


# Unembedding

The unembedding layer is a simple linear layer that maps the model's internal representation back to the original token space. It's used to convert the output of the transformer into a sequence of tokens, which can then be used for further processing or decoding.

In [88]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return (
            einops.einsum(
                normalized_resid_final,
                self.W_U,
                "batch posn d_model, d_model d_vocab -> batch posn d_vocab",
            )
            + self.b_U
        )

# Full GPT Model

![Full GPT Model](../nanoGPT/high_level_architecture.png)

In [89]:
class GPT(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits

In [90]:
demo_gpt2 = GPT(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

demo_logits = demo_gpt2(tokens)

In [91]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

Avg cross entropy loss: 4.2215
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.128396


In [92]:
test_string = """The Total Perspective Vortex derives its picture of the whole Universe on the principle of"""
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

The Total Perspective Vortex derives its picture of the whole Universe on the principle of the total perspective. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The


The model appears to be working reasonably well:
1. The cross entropy loss is significantly better than random (uniform distribution)
2. The model assigns meaningful probabilities to correct tokens
3. The text generation example shows coherent continuation, though it may not be perfectly fluent or logically consistent
4. The model successfully loaded the pretrained weights and produces sensible outputs


# Training

## Configurations

The next code block defines model configuration and training arguments:
- Creates a smaller GPT model with 256 dim embeddings, 4 attention heads, 2 layers
- Sets up training hyperparameters like batch size, learning rate, etc.
- Configures optional Weights & Biases logging



In [93]:
model_cfg = Config(
    debug=False,
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab=50257,
)
model = GPT(model_cfg)

@dataclass
class TransformerTrainingArgs:
    batch_size = 16
    epochs = 20
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "cleangpt2"
    wandb_name: str | None = None


args = TransformerTrainingArgs()

## Data Processing

The code below loads a small subset of The Pile dataset (10k examples) from the Hugging Face Hub
This provides training data for fine-tuning our GPT-2 model
The dataset contains text samples that we'll use to train the model on next-token prediction

In [94]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
print(dataset)
print(dataset[0]["text"][:100])

Dataset({
    features: ['text'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


Data Processing
1. Tokenizes the dataset using the GPT-2 tokenizer
2. Concatenates sequences to the model's context length (n_ctx)
3. Splits data into train and test sets
4. Creates DataLoader objects for efficient batched training
   - Uses multiple workers for parallel data loading
   - Enables memory pinning for faster GPU transfer
   - Shuffles training data while keeping test data ordered

In [95]:
tokenized_dataset = tokenize_and_concatenate(
    dataset,
    reference_gpt2.tokenizer,
    streaming=False,
    max_length=model.cfg.n_ctx,
    column_name="text",
    add_bos_token=True,
    num_proc=4,
)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
)

Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (80023 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (101051 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (155995 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (229134 > 1024). Running this sequence through the model will result in indexing errors


## Training loop

The code below defines a TransformerTrainer class for training the GPT model:
- Initializes optimizer and data loaders for training and testing
- Defines training_step method to:
  - Compute model predictions and loss
  - Perform backpropagation and optimization step
  - Log training loss to Weights & Biases
- Defines evaluate method to compute model accuracy on test set


In [99]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: GPT):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(
            dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
        )
        self.test_loader = DataLoader(
            dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
        )

    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)
        loss = -get_log_probs(logits, tokens).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            tokens = batch["tokens"].to(device)
            logits: Tensor = self.model(tokens)[:, :-1]
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == tokens[:, 1:]).sum().item()
            total_samples += tokens.size(0) * (tokens.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()

Now, let's train the model:

In [100]:
model = GPT(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mmichaelyliu6[0m ([33mmichaelyliu6-none[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  0%|          | 0/4000 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

0,1
accuracy,▁▃▃▄▄▅▅▆▆▆▆▇▇▇▇▇████
train_loss,█▅█▇▅▇▅▅▅▅▄▅▃▄▅▃▂▄▂▃▃▄▂▃▄▂▂▄▃▂▃▃▃▄▃▂▁▃▂▃

0,1
accuracy,0.31804
train_loss,4.23696


# Sampling