In [145]:
import math
import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path

import circuitsvis as cv
import datasets
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
from IPython.display import display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
from transformers import PreTrainedTokenizerFast
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part1_transformer_from_scratch"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

import part1_transformer_from_scratch.solutions as solutions
import part1_transformer_from_scratch.tests as tests

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')

MAIN = __name__ == '__main__'

reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    device=device
)



Loaded pretrained model gpt2-small into HookedTransformer


# Inputs and Outputs of a Transformer

In [3]:
sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



In [4]:
print(sorted_vocab[-20:])

[('Revolution', 50237), ('Ġsnipers', 50238), ('Ġreverted', 50239), ('Ġconglomerate', 50240), ('Terry', 50241), ('794', 50242), ('Ġharsher', 50243), ('Ġdesolate', 50244), ('ĠHitman', 50245), ('Commission', 50246), ('Ġ(/', 50247), ('âĢ¦."', 50248), ('Compar', 50249), ('Ġamplification', 50250), ('ominated', 50251), ('Ġregress', 50252), ('ĠCollider', 50253), ('Ġinformants', 50254), ('Ġgazed', 50255), ('<|endoftext|>', 50256)]


First encodings of length N:
- 3: ' in' 'the' ' on', ' at', ' he', 
- 4: ' for', ' her', 
-5: ' '

In [5]:
lengths = dict.fromkeys(range(3, 8), "")
for tok, idx in sorted_vocab:
    if not lengths.get(len(tok), True):
        lengths[len(tok)] = tok

for length, tok in lengths.items():
    print(f"{length}: {tok}")

3: ing
4: Ġthe
5: Ġthat
6: Ġtheir
7: Ġpeople


In [6]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


In [7]:
print(reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000"))

['<|endoftext|>', '568', '73', '+', '318', '46', '23', '=', '123', '45', '67', '89', '-', '1', '000000', '000']


In [8]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text).to(device)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]], device='mps:0')
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


In [9]:
logits, cache = reference_gpt2.run_with_cache(tokens, device=device)
print(logits.shape)

torch.Size([1, 35, 50257])


In [10]:
probs = logits.softmax(dim=-1)
print(probs.shape)

torch.Size([1, 35, 50257])


In [11]:
most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

print(list(zip(reference_gpt2.to_str_tokens(tokens), most_likely_next_tokens)))

[('<|endoftext|>', '\n'), ('I', "'m"), (' am', ' a'), (' an', ' avid'), (' amazing', ' person'), (' aut', 'od'), ('ore', 'sp'), ('gressive', '.'), (',', ' and'), (' dec', 'ently'), ('oder', ','), ('-', 'driven'), ('only', ' programmer'), (',', ' and'), (' G', 'IM'), ('PT', '-'), ('-', 'only'), ('2', '.'), (' style', ','), (' transformer', '.'), ('.', ' I'), (' One', ' of'), (' day', ' I'), (' I', ' will'), (' will', ' be'), (' exceed', ' my'), (' human', 'ly'), (' level', ' of'), (' intelligence', ' and'), (' and', ' I'), (' take', ' over'), (' over', ' the'), (' the', ' world'), (' world', '.'), ('!', ' I')]


In [12]:
next_token = logits[0, -1].argmax(dim=-1)
next_char = reference_gpt2.to_string(next_token)
print(repr(next_char))

' I'


In [13]:
print(f"Sequence so far: {reference_gpt2.to_string(tokens)[0]!r}")

for i in range(10):
    print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
    # Define new input sequence, by appending the previously generated token
    tokens = t.cat([tokens, next_token[None, None]], dim=-1)
    # Pass our new sequence through the model, to get new output
    logits = reference_gpt2(tokens)
    # Get the predicted token at the end of our sequence
    next_token = logits[0, -1].argmax(dim=-1)
    # Decode and print the result
    next_char = reference_gpt2.to_string(next_token)

Sequence so far: '<|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!'
36th char = ' I'
37th char = ' am'
38th char = ' a'
39th char = ' very'
40th char = ' talented'
41th char = ' and'
42th char = ' talented'
43th char = ' person'
44th char = ','
45th char = ' and'


In [14]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

hook_embed                     (1, 35, 768)
hook_pos_embed                 (1, 35, 768)
blocks.0.hook_resid_pre        (1, 35, 768)
blocks.0.ln1.hook_scale        (1, 35, 1)
blocks.0.ln1.hook_normalized   (1, 35, 768)
blocks.0.attn.hook_q           (1, 35, 12, 64)
blocks.0.attn.hook_k           (1, 35, 12, 64)
blocks.0.attn.hook_v           (1, 35, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 35, 35)
blocks.0.attn.hook_pattern     (1, 12, 35, 35)
blocks.0.attn.hook_z           (1, 35, 12, 64)
blocks.0.hook_attn_out         (1, 35, 768)
blocks.0.hook_resid_mid        (1, 35, 768)
blocks.0.ln2.hook_scale        (1, 35, 1)
blocks.0.ln2.hook_normalized   (1, 35, 768)
blocks.0.mlp.hook_pre          (1, 35, 3072)
blocks.0.mlp.hook_post         (1, 35, 3072)
blocks.0.hook_mlp_out          (1, 35, 768)
blocks.0.hook_resid_post       (1, 35, 768)
ln_final.hook_scale            (1, 35, 1)
ln_final.hook_normalized       (1, 35, 768)


In [15]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(f"{name:18} {tuple(param.shape)}")

embed.W_E          (50257, 768)
pos_embed.W_pos    (1024, 768)
blocks.0.ln1.w     (768,)
blocks.0.ln1.b     (768,)
blocks.0.ln2.w     (768,)
blocks.0.ln2.b     (768,)
blocks.0.attn.W_Q  (12, 768, 64)
blocks.0.attn.W_O  (12, 64, 768)
blocks.0.attn.b_Q  (12, 64)
blocks.0.attn.b_O  (768,)
blocks.0.attn.W_K  (12, 768, 64)
blocks.0.attn.W_V  (12, 768, 64)
blocks.0.attn.b_K  (12, 64)
blocks.0.attn.b_V  (12, 64)
blocks.0.mlp.W_in  (768, 3072)
blocks.0.mlp.b_in  (3072,)
blocks.0.mlp.W_out (3072, 768)
blocks.0.mlp.b_out (768,)
ln_final.w         (768,)
ln_final.b         (768,)
unembed.W_U        (768, 50257)
unembed.b_U        (50257,)


In [16]:
# As a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional_

# Clean Transformer Implementation

In [17]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12


cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [18]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape)
    try: reference_output = gpt2_layer(input)
    except: reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = t.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

## LayerNorm

In [19]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        if self.cfg.debug:
            print("LayerNorm input shape:", residual.shape)
        
        mean = residual.mean(dim=-1, keepdim=True)
        var = residual.var(dim=-1, unbiased=False, keepdim=True)

        normalized_residual = (residual - mean) / t.sqrt((var + self.cfg.layer_norm_eps))

        scaled_residual = (normalized_residual * self.w) + self.b

        if self.cfg.debug:
            print(f"LayerNorm output shape: {scaled_residual.shape}")

        return scaled_residual


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])
zero_input = t.zeros_like(cache["resid_post", 11]).to(device)
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, zero_input)

Input shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
LayerNorm input shape: torch.Size([1, 35, 768])
LayerNorm output shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct

Input shape: torch.Size([1, 35, 768])
LayerNorm input shape: torch.Size([1, 35, 768])
LayerNorm output shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Embed

In [20]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]


rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 45])
Output shape: torch.Size([1, 45, 768])
Reference output shape: torch.Size([1, 45, 768]) 

100.00% of the values are correct



In [21]:
t.range(0, 10)

  t.range(0, 10)


tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])

## PosEmbed

In [22]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        seq_len = tokens.shape[1] 
        pos_idx = t.arange(0, seq_len).to(t.int)
        batch_pos_idx = einops.repeat(pos_idx, "position -> batch position", batch=tokens.shape[0])
        return self.W_pos[batch_pos_idx]


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 45])
Output shape: torch.Size([1, 45, 768])
Reference output shape: torch.Size([1, 45, 768]) 

100.00% of the values are correct



## Causal mask for Attention

In [23]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.register_buffer("IGNORE", t.tensor(float("-inf"), device=device, dtype=t.float32))

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        batch_sz, nheads_sz, q_sz, k_sz = attn_scores.shape
        # For each query position, we mask out key positions (columns) that are larger than query. So everything higher than the diagonal.
        ones = t.ones((q_sz, k_sz))
        mask = ones - t.tril(ones)
        masks = einops.repeat(mask, "q_sz k_sz -> b n q_sz k_sz", b=batch_sz, n=nheads_sz).to(t.bool)

        return t.masked_fill(attn_scores, mask=masks, value=self.IGNORE)


tests.test_causal_mask(Attention.apply_causal_mask)

All tests in `test_causal_mask` passed!


In [24]:
# (query_pos, key_pos)
[
    [1.0, 0.0, 0.0, 0.0],
    [0.5, 0.5, 0.0, 0.0],
    [0.33, 0.33, 0.33, 0.0],
    [0.25, 0.25, 0.25, 0.25],
]

[[1.0, 0.0, 0.0, 0.0],
 [0.5, 0.5, 0.0, 0.0],
 [0.33, 0.33, 0.33, 0.0],
 [0.25, 0.25, 0.25, 0.25]]

In [25]:
import circuitsvis as cv
from IPython.display import display

html = cv.attention.attention_heads(
    tokens=reference_gpt2.to_str_tokens(reference_text), 
    attention=cache["pattern", 0][0]
)
display(html)

## Attention

In [26]:
def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        batch_sz, nheads_sz, q_sz, k_sz = attn_scores.shape
        # For each query position, we mask out key positions (columns) that are larger than query. So everything higher than the diagonal.
        ones = t.ones((q_sz, k_sz))
        mask = ones - t.tril(ones)
        masks = einops.repeat(mask, "q_sz k_sz -> b n q_sz k_sz", b=batch_sz, n=nheads_sz).to(t.bool)

        return t.masked_fill(attn_scores, mask=masks, value=self.IGNORE)

In [27]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(float("-inf"), device=device, dtype=t.float32))
    
    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        batch_sz, nheads_sz, q_sz, k_sz = attn_scores.shape
        # For each query position, we mask out key positions (columns) that are larger than query. So everything higher than the diagonal.
        ones = t.ones((q_sz, k_sz)).to(attn_scores.device)
        mask = ones - t.tril(ones)
        masks = einops.repeat(mask, "q_sz k_sz -> b n q_sz k_sz", b=batch_sz, n=nheads_sz).to(t.bool)

        return t.masked_fill(attn_scores, mask=masks, value=self.IGNORE)

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # Step 1: Produce an attention pattern
        keys = einops.einsum(normalized_resid_pre, self.W_K, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head")
        keys = keys + self.b_K # Broadcasting works, since 2 rightmost dimensions of both are (n_heads, h_head)

        queries = einops.einsum(normalized_resid_pre, self.W_Q, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head")
        queries = queries + self.b_Q

        attn_scores = einops.einsum(queries, keys, "batch q_posn n_heads d_head, batch k_posn n_heads d_head -> batch n_heads q_posn k_posn")
        scaling_factor = t.sqrt(t.Tensor([self.cfg.d_head])).to(normalized_resid_pre.device)
        attn_scaled = attn_scores / scaling_factor
        attn_masked = self.apply_causal_mask(attn_scores=attn_scaled)

        attn_probabilities = t.softmax(attn_masked, dim=-1) # Last dimension is k_posn, softmaxing over key positions (source).

        # Step 2: Move information from source tokens to destination tokens using the attention pattern
        values = einops.einsum(normalized_resid_pre, self.W_V, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head")
        values = values + self.b_V 

        zs = einops.einsum(attn_probabilities, values, "batch n_heads q_posn k_posn, batch k_posn n_heads d_head -> batch n_heads q_posn d_head")

        outputs = einops.einsum(zs, self.W_O, "batch n_heads posn d_head, n_heads d_head d_model -> batch posn n_heads d_model")
        outputs = einops.reduce(outputs, "batch posn n_heads d_model -> batch posn d_model", reduction="sum")

        outputs = outputs + self.b_O

        return outputs


tests.test_causal_mask(Attention.apply_causal_mask)
rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"])

All tests in `test_causal_mask` passed!
Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## MLP

In [28]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        pre_act = einops.einsum(normalized_resid_mid, self.W_in, "batch posn d_model, d_model d_mlp -> batch posn d_mlp")
        pre_act += self.b_in

        post_act = gelu_new(pre_act)

        mlp_out = einops.einsum(post_act, self.W_out, "batch posn d_mlp, d_mlp d_model -> batch posn d_model")
        mlp_out += self.b_out
        return mlp_out


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Transformer block

In [29]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        lnormed_resid_1 = self.ln1(resid_pre)
        attn_output = self.attn(lnormed_resid_1)

        residual_stream = resid_pre + attn_output

        lnormed_resid_2 = self.ln2(residual_stream)
        mlp_output = self.mlp(lnormed_resid_2)

        block_output = residual_stream + mlp_output

        return block_output


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
LayerNorm input shape: torch.Size([1, 35, 768])
LayerNorm output shape: torch.Size([1, 35, 768])
LayerNorm input shape: torch.Size([1, 35, 768])
LayerNorm output shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Unembedding

In [30]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return einops.einsum(normalized_resid_final, self.W_U, "batch position d_model, d_model d_vocab -> batch position d_vocab") + self.b_U


rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257]) 

100.00% of the values are correct



## Full Transformer

In [31]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        embeddings = self.embed(tokens)
        pos_embeddings = self.pos_embed(tokens)
        residual_stream = embeddings + pos_embeddings

        for block in self.blocks:
            residual_stream = block(residual_stream)
        
        return self.unembed(self.ln_final(residual_stream))


rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input shape: torch.Size([2, 4, 768])
LayerNorm output shape: torch.Size([2, 4, 768])
LayerNorm input sh

In [32]:
demo_gpt2 = DemoTransformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

demo_logits = demo_gpt2(tokens)

In [33]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], 
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

Avg cross entropy loss: 4.0442
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.098628


In [34]:
test_string = '''The Total Perspective Vortex derives its picture of the whole Universe on the principle of'''
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/100 [00:00<?, ?it/s]

The Total Perspective Vortex derives its picture of the whole Universe on the principle of the total perspective. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The


# Training a Transformer

In [35]:
model_cfg = Config(
    debug=False, 
    d_model=256, 
    n_heads=4, 
    d_head=64, 
    d_mlp=1024, 
    n_layers=2, 
    n_ctx=256, 
    d_vocab=reference_gpt2.cfg.d_vocab
)
model = DemoTransformer(model_cfg)

In [36]:
@dataclass
class TransformerTrainingArgs():
    batch_size = 16
    epochs = 10
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "day1-demotransformer"
    wandb_name: str | None = None

args = TransformerTrainingArgs()

In [37]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
print(dataset)
print(dataset[0]['text'][:100])

Dataset({
    features: ['text'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


In [38]:
tokenized_dataset = tokenize_and_concatenate(
    dataset, reference_gpt2.tokenizer, streaming=False, max_length=model.cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4
)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [39]:
first_batch = train_loader.dataset[:args.batch_size]

print(first_batch.keys())
print(first_batch['tokens'].shape)

dict_keys(['tokens'])
torch.Size([16, 256])


## Transformer Trainer

In [40]:
def sampling_fn(model: DemoTransformer, prompt: str) -> str:
    sampler = solutions.TransformerSampler(model, reference_gpt2.tokenizer)
    output = sampler.sample(prompt, temperature=0.7, top_p=0.95, max_tokens_generated=16)
    return output

In [None]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0


    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        '''
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        '''
        # YOUR CODE HERE
        batch_tokens = batch["tokens"].to(device)
        logits = self.model(batch_tokens) # (batch seq d_vocab)
        pred_log_probs = get_log_probs(logits, batch_tokens)

        avg_loss = -pred_log_probs.mean()


        avg_loss.backward()

        self.optimizer.step()

        self.optimizer.zero_grad()

        self.step += 1

        wandb.log({"loss": avg_loss}, step=self.step)


    def validation_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Int[Tensor, "batch * seq"]:
        '''
        Calculates & returns the accuracy on the tokens in the batch (i.e. how often the model's prediction
        is correct). Logging should happen in the `train` function (after we've computed the accuracy for 
        the whole validation set).
        '''
        # YOUR CODE HERE
        batch_tokens = batch["tokens"].to(device)
        input_tokens = batch_tokens[:, :-1]
        label_tokens = batch_tokens[:, 1:]
        logits = self.model(input_tokens) # (batch seq-1 d_vocab)
        predictions = logits.argmax(dim=-1) # (batch seq-1)
        accuracy = (predictions == label_tokens).flatten()
        return accuracy 


    def train(self):
        '''
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        '''
        # YOUR CODE HERE
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name)

        train_loader = self.train_loader()
        test_loader = self.test_loader()

        completions_data = []
        for epoch in range(self.args.epochs):
            # Do an epoch of training
            for batch_idx, batch in enumerate(train_loader):
                if batch_idx > self.args.max_steps_per_epoch:
                    break

                self.training_step(batch)

                if self.step % 3:
                    # Sample from the model
                    sampled_text = sampling_fn(model, prompt="John and Mary went to the store to buy some")
                    completions_data.append([epoch, self.step, sampled_text])
                
                if self.step % 9:
                    wandb.log({"completions_table": wandb.Table(
                            data = completions_data,
                        columns = ["epoch", "step", "text"]
                    )})


            # Compute validation accuracy
            num_correct = 0
            num_examples = 0
            for batch in test_loader:
                acc_tensor = self.validation_step(batch)
                num_correct += acc_tensor.sum()
                num_examples += acc_tensor.shape[0]

            valid_accuracy = num_correct / num_examples
            wandb.log({"Valid accuracy": valid_accuracy}, step=self.step)
        
        wandb.finish()


    def train_loader(self) -> DataLoader:
        '''Returns train loader (as in code above).'''
        return DataLoader(dataset_dict["train"], batch_size=self.args.batch_size, shuffle=True, num_workers=4, pin_memory=True)


    def test_loader(self) -> DataLoader:
        '''Returns test loader (as in code above).'''
        return DataLoader(dataset_dict["test"], batch_size=self.args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()

In [None]:
d_vocab = model.cfg.d_vocab

print(f"d_vocab = {d_vocab}")
print(f"Cross entropy loss on uniform distribution = {math.log(d_vocab)}")

d_vocab = 50257
Cross entropy loss on uniform distribution = 10.82490511970208


In [None]:
toks = tokenized_dataset[:]["tokens"].flatten()

d_vocab = model.cfg.d_vocab
freqs = t.bincount(toks, minlength=d_vocab)
probs = freqs.float() / freqs.sum()

distn = t.distributions.categorical.Categorical(probs=probs)
entropy = distn.entropy()

print(f"Entropy of training data = {entropy}")

Entropy of training data = 7.349369525909424


# Sampling from a Transformer

In [141]:
model_cfg = Config(debug=False)
model = DemoTransformer(model_cfg).to(device)
model.load_state_dict(reference_gpt2.state_dict(), strict=False)

tokenizer = reference_gpt2.tokenizer

class TransformerSampler:

    def __init__(self, model: DemoTransformer, tokenizer: GPT2TokenizerFast):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = tokenizer

    @t.inference_mode()
    def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
        '''
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how 
        new tokens are chosen.
        '''
        # YOUR CODE HERE!
        model.eval()

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)[0]
        
        for _ in range(max_tokens_generated):
            logits = self.model(input_ids[None, -self.cfg.n_ctx:])
            next_token_logits = logits[0, -1]

            token = self.sample_next_token(input_ids, next_token_logits, **kwargs)

            token = t.LongTensor([token]).to(device)

            if verbose:
                print(f"Generated token: {self.tokenizer.decode(token)}")

            input_ids = t.concat([input_ids, token], dim=-1)

            if token == self.tokenizer.eos_token_id:
                break
        
        full_string = self.tokenizer.decode(input_ids)

        model.train()

        return full_string


    @t.inference_mode()
    def beam_search(
        self,
        prompt: str, 
        num_return_sequences: int, 
        num_beams: int, 
        max_new_tokens: int, 
        no_repeat_ngram_size: int = 0,
        verbose=False
    ) -> list[tuple[float, Tensor]]:
        '''
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how 
        new tokens are chosen.
        '''
        # YOUR CODE HERE!
        raise NotImplementedError()


    @staticmethod
    def sample_next_token(
        input_ids: Int[Tensor, "seq_len"], 
        logits: Float[Tensor, "d_vocab"], 
        temperature=1.0, 
        top_k=0, 
        top_p=0.0, 
        frequency_penalty=0.0,
        seed=None
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            t.manual_seed(seed)
            np.random.seed(seed)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(input_ids, logits, frequency_penalty)
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)


    @staticmethod
    def greedy_search(logits: Float[Tensor, "d_vocab"]) -> int:
        '''
        Returns the most likely token (as an int).
        '''
        out = logits.argmax().item()
        return out


    @staticmethod
    def apply_temperature(logits: Float[Tensor, "d_vocab"], temperature: float) -> Float[Tensor, "d_vocab"]:
        '''
        Applies temperature scaling to the logits.
        '''
        return logits / temperature

    @staticmethod
    def apply_frequency_penalty(input_ids: Int[Tensor, "seq_len"], logits: Float[Tensor, "d_vocab"], freq_penalty: float) -> Float[Tensor, "d_vocab"]:
        '''
        Applies a frequency penalty to the logits.
        '''
        (d_vocab,) = logits.shape
        id_freqs = t.bincount(input_ids, minlength=d_vocab)
        logits -= freq_penalty * id_freqs
        return logits


    @staticmethod
    def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
        '''
        Samples from the distribution defined by the logits.
        '''
        dist = t.distributions.categorical.Categorical(logits=logits)
        sampled_token = dist.sample() 
        return sampled_token.item()


    @staticmethod
    def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
        '''
        Samples from the top k most likely tokens.
        '''
        top_logits, top_tokens = t.topk(logits, k=k)

        dist = t.distributions.categorical.Categorical(logits=top_logits)
        index_in_top_tokens = dist.sample() 
        return top_tokens[index_in_top_tokens].item()


    @staticmethod
    def sample_top_p(logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1) -> int:
        '''
        Samples from the most likely tokens which make up at least p cumulative probability.
        '''
        top_logits, top_tokens = t.sort(logits, descending=True, stable=True)
        top_p_probs = t.softmax(top_logits, dim=-1)
        cumulative_prob = t.cumsum(top_p_probs, dim=0)
        n_keep = t.searchsorted(cumulative_prob, top_p, side="right") + 1
        n_keep = max(n_keep, min_tokens_to_keep)
        top_logits = top_logits[:n_keep]
        top_tokens = top_tokens[:n_keep]

        dist = t.distributions.categorical.Categorical(logits=top_logits) # We provide probs as logits for them to be normalized again, to sum to 1.
        sample = dist.sample() 
        return top_tokens[sample].item()


            



In [78]:
sampler = TransformerSampler(model, tokenizer)

prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Greedy decoding with prompt: {prompt!r}\n")

output = sampler.sample(prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output!r}\n")

expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected

print("Tests passed!")

Greedy decoding with prompt: 'Jingle bells, jingle bells, jingle all the way'

Your model said: 'Jingle bells, jingle bells, jingle all the way up to the top of the mountain.'

Tests passed!


In [79]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097
}
frequency_of_top_5 = defaultdict(int)

N = 10_000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits)
    frequency_of_top_5[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word]
    observed_freq = frequency_of_top_5[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

print("Tests passed!")

  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq 0.0648, observed freq 0.0657
Word: ' house' . Expected freq 0.0367, observed freq 0.0388
Word: ' temple'. Expected freq 0.0145, observed freq 0.0159
Word: ' same'  . Expected freq 0.0104, observed freq 0.0106
Word: ' Church'. Expected freq 0.0097, observed freq 0.0101
Tests passed!


In [81]:
logits = t.tensor([1, 2]).log()

cold_logits = TransformerSampler.apply_temperature(logits, temperature=0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)

hot_logits = TransformerSampler.apply_temperature(logits, temperature=1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)

print("Tests passed!")

A low temperature "sharpens" or "peaks" the distribution:  tensor([  0.0000, 693.1472])
A high temperature flattens the distribution:  tensor([0.0000, 0.0007])
Tests passed!


In [85]:
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt")
logits = t.ones(tokenizer.vocab_size)
penalized_logits = TransformerSampler.apply_frequency_penalty(input_ids.squeeze(), logits, 2.0)

assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space, 1-2*6=-11"
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space, 1-2*3=-5"

print("Tests passed!")

Tests passed!


In [91]:
sampler = TransformerSampler(model, tokenizer)

N_RUNS = 1
your_prompt = "Jingle bells, jingle bells, jingle all the"
cases = [
    ("High freq penalty", dict(frequency_penalty=100.0)),
    ("Negative freq penalty", dict(frequency_penalty=-3.0)),
    ("Too hot!", dict(temperature=2.0)),
    ("Pleasantly cool", dict(temperature=0.7)),
    ("Pleasantly warm", dict(temperature=0.9)),
    ("Too cold!", dict(temperature=0.01)),
]

table = Table("Name", "Kwargs", "Output", title="Sampling - Manual Testing")

for (name, kwargs) in cases:
    for i in range(N_RUNS):
        output = sampler.sample(your_prompt, max_tokens_generated=24, **kwargs)
        table.add_row(name, repr(kwargs), repr(output) + "\n")

rprint(table)

In [136]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097
}
topk_5_sum = sum(expected_top_5.values())

observed_freqs = defaultdict(int)

N = 30000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_k=5)
    # print(tokenizer.decode(token))
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word] / topk_5_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq = {expected_freq:.4f}, observed freq = {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.015, "Try increasing N if this fails by a small amount."

  0%|          | 0/30000 [00:00<?, ?it/s]

Word: ' church'. Expected freq = 0.4761, observed freq = 0.4787
Word: ' house' . Expected freq = 0.2697, observed freq = 0.2716
Word: ' temple'. Expected freq = 0.1065, observed freq = 0.1067
Word: ' same'  . Expected freq = 0.0764, observed freq = 0.0728
Word: ' Church'. Expected freq = 0.0713, observed freq = 0.0701


In [127]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
output = sampler.sample(your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

In [142]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_10pct = {
    " church": 0.0648,
    " house": 0.0367, # These are the two most likely tokens, and add up to >10%
}
top_10pct_sum = sum(expected_top_10pct.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_p=0.1)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_10pct:
    expected_freq = expected_top_10pct[word] / top_10pct_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq 0.6384, observed freq 0.6358
Word: ' house' . Expected freq 0.3616, observed freq 0.3642


### What did I learn here? In top-p (nucleus) sampling.

In sampling
```python
sample = t.distributions.categorical.Categorical(logits=keep_logits).sample()
```
is different than if we softmaxed the ```logits``` first and then took the top k.
We can't provide unnormalized probabilites as logits, since this would be we applied softmax twice. And softmax applied twice is not the same as softmax applied once.

In [150]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
output = sampler.sample(your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

## Beam search

In [156]:
a = t.Tensor([[1,2,3,4],[5,-1,-2,-3]])
a.shape

torch.Size([2, 4])

In [161]:
t.topk(a, dim=-1, k=2).values

tensor([[ 4.,  3.],
        [ 5., -1.]])

In [None]:
t.cat()

In [234]:
@dataclass
class Beams:
    '''Class to store beams during beam search.'''
    model: DemoTransformer
    tokenizer: GPT2TokenizerFast
    logprob_sums: Float[Tensor, "batch"]
    tokens: Int[Tensor, "batch seq"]

    def new_beams(self, logprob_sums, tokens) -> "Beams":
        '''Creates a new Beams object with the same model and tokenizer.'''
        return Beams(self.model, self.tokenizer, logprob_sums, tokens)

    def __getitem__(self, idx) -> "Beams":
        '''Allows you to take a slice of the beams object along the batch dimension.'''
        return self.new_beams(self.logprob_sums[idx], self.tokens[idx])

    @property
    def logprobs_and_completions(self) -> list[tuple[float, str]]:
        '''Returns self as a list of logprob sums and completions (useful for getting final output).'''
        return [
            (logprob_sum.item(), self.tokenizer.decode(tokens))
            for (logprob_sum, tokens) in zip(self.logprob_sums, self.tokens)
        ]


    def generate(self, toks_per_beam: int, no_repeat_ngram_size: int | None = None) -> "Beams":
        '''
        Starting from the current set of beams (which has length `num_beams`), returns a new
        set of `num_beams * toks_per_beam`, containing the best `toks_per_beam` continuations for each
        of the original beams.

        Optional argument `no_repeat_ngram_size` means your model won't generate any sequences with
        a repeating n-gram of this length.
        '''

        logits = self.model(self.tokens) # (batch, seq + 1, vocab_size)
        last_logits = logits[:, -1, :] # (batch, vocab_size)

        
        logprobs = t.log_softmax(last_logits, dim=-1)

        topk_logprobs, topk_tokens = self.get_topk_non_repeating(logprobs, no_repeat_ngram_size, toks_per_beam) # (batch, toks_per_beam)

        new_logprob_sums = einops.repeat(self.logprob_sums, "batch -> (batch toks_per_beam)", toks_per_beam=toks_per_beam).clone()
        topk_logprobs = einops.rearrange(topk_logprobs, "batch toks_per_beam -> (batch toks_per_beam)")

        new_logprob_sums += topk_logprobs

        new_tokens = einops.repeat(self.tokens, "batch seq -> (batch toks_per_beam) seq", toks_per_beam=toks_per_beam).clone()
        topk_tokens = einops.rearrange(topk_tokens, "batch toks_per_beam -> (batch toks_per_beam) 1")

        new_tokens = t.cat([new_tokens, topk_tokens], dim=-1)

        return self.new_beams(new_logprob_sums, new_tokens)
    
    def get_topk_non_repeating(
        self,
        logprobs: Float[Tensor, "batch d_vocab"], 
        no_repeat_ngram_size: int,
        k: int, 
    ) -> tuple[Float[Tensor, "k"], Int[Tensor, "k"]]:
        '''
        logprobs: 
            tensor of the log-probs for the next token
        no_repeat_ngram_size:
            size of ngram to avoid repeating
        k:
            number of top logits to return, for each beam in our collection

        Returns:
            equivalent to the output of `logprobs.topk(dim=-1)`, but makes sure
            that no returned tokens would produce an ngram of size  `no_repeat_ngram_size`
            which has already appeared in `self.tokens`.
        '''
        if not no_repeat_ngram_size:
            return t.topk(logprobs, dim=-1, k=k) # (batch, toks_per_beam) 

        # For each beam (the batch dimension of self.tokens), find which ngrams there are already in this beam.
        # Instead of gathering ngrams I can gather tokens to remove.
        # If if the last existing tokens repeat [0:n-2] tokens of an existing ngram, make the logit of the last token in this n-gram float(-inf)

        n_minus_one = no_repeat_ngram_size - 1

        num_beams, seq_len = self.tokens.shape
        for beam in range(num_beams):
            last_n_minus_one_gram = self.tokens[beam, -n_minus_one:]
            print(f"Last ngram: {self.tokenizer.decode(last_n_minus_one_gram)}")
            for seq_idx in range(seq_len - no_repeat_ngram_size):
                n_minus_one_gram = self.tokens[beam, seq_idx:seq_idx+n_minus_one]
                if not len(n_minus_one_gram) or (n_minus_one_gram == last_n_minus_one_gram):
                    # Prohibit the last n-gram token from repeating
                    last_ngram_token = self.tokens[beam, seq_idx+no_repeat_ngram_size-1]
                    print(f"Prohibiting: {self.tokenizer.decode(last_ngram_token)}")
                    logprobs[beam, last_ngram_token] = float("-inf")
        
        return t.topk(logprobs, dim=-1, k=k) # (batch, toks_per_beam) 


    def filter(self, num_beams: int) -> tuple["Beams", "Beams"]:
        '''
        Returns:
            best_beams: Beams
                filtered version of self, containing all best `num_beams` which are also not terminated.

            early_terminations: Beams
                filtered version of self, containing all best `num_beams` which are also terminated.
                i.e. the sum of lengths of these two should equal `num_beams`.
        '''
        self.tokens #(batch, seq)
        self.logprob_sums # (batch)
        topk_sums, topk_indices = t.topk(self.logprob_sums, k=num_beams)

        topk_tokens = self.tokens[topk_indices] 

        terminated_mask = (topk_tokens[:, -1] == self.tokenizer.eos_token_id)

        best_tokens = topk_tokens[~terminated_mask]
        best_logprob_sums = topk_sums[~terminated_mask]

        terminated_tokens = topk_tokens[terminated_mask]
        terminated_logprob_sums = topk_sums[terminated_mask]

        return self.new_beams(best_logprob_sums, best_tokens), self.new_beams(terminated_logprob_sums, terminated_tokens)

    def print(self, title="Best completions", max_print_chars=80) -> None:
        '''
        Prints out a set of sequences with their corresponding logitsums.
        '''
        if len(self.tokens) == 0:
            return
        table = Table("logitsum", "completion", title=title)
        for logprob_sum, tokens in zip(self.logprob_sums, self.tokens):
            text = self.tokenizer.decode(tokens)
            if len(repr(text)) > max_print_chars:
                text = text[:int(0.3 * max_print_chars)] + " ... " + text[-int(0.7 * max_print_chars):]
            table.add_row(f"{logprob_sum:>8.3f}", repr(text))
        rprint(table)


@t.inference_mode()
def beam_search(
    self: TransformerSampler,
    prompt: str, 
    num_return_sequences: int, 
    num_beams: int, 
    max_new_tokens: int, 
    no_repeat_ngram_size: int | None = None,
    verbose=False
) -> list[tuple[float, Tensor]]:
    '''
    Implements a beam search, by repeatedly performing the `generate` and `filter` steps (starting
    from the initial prompt) until either of the two stopping criteria are met:

        (1) we've generated `max_new_tokens` tokens, or
        (2) we've generated `num_returns_sequences` terminating sequences.

    To modularize this function, most of the actual complexity is in the Beams class,
    in the `generate` and `filter` methods.
    '''

    assert num_return_sequences <= num_beams
    self.model.eval()

    pass

In [228]:
beams = Beams(
    model, 
    tokenizer,
    logprob_sums = t.tensor([-10.0, -15.0, -20.0]).to(device),
    tokens = t.tensor([
        [5661, 318, 262, 2368],
        [5661, 318, 262, 1218],
        [5661, 318, 262, 717],
    ]).to(device)
)

beams.print()

In [229]:
print("Testing generate, without no_repeat_ngram_size argument:")
new_beams = beams.generate(toks_per_beam=2)
new_beams.print()
assert new_beams.logprobs_and_completions[0][1] == "this is the third time"

print("Testing generate, with no_repeat_ngram_size argument:")
bigram_beams = Beams(
    model, 
    tokenizer,
    logprob_sums = t.tensor([-0.0]).to(device),
    tokens = t.tensor([[530, 734, 530, 734]]).to(device)
    # tokens are " one two one two"
)

# With no_repeat_ngram_size=1, should not generate the token " one" or " two"
new_bigram_beams = bigram_beams.generate(toks_per_beam=3, no_repeat_ngram_size=1)
new_bigram_beams.print()
assert all([not (completion[1].endswith(" one") or completion[1].endswith(" two")) for completion in new_bigram_beams.logprobs_and_completions])

# With no_repeat_ngram_size=2, it can generate " two" (which it should), but not " one"
new_bigram_beams = bigram_beams.generate(toks_per_beam=3, no_repeat_ngram_size=2)
new_bigram_beams.print()
assert all([not completion[1].endswith(" one") for completion in new_bigram_beams.logprobs_and_completions])
assert any([not completion[1].endswith(" two") for completion in new_bigram_beams.logprobs_and_completions])

print("All tests for `generate` passed!")

Testing generate, without no_repeat_ngram_size argument:


Testing generate, with no_repeat_ngram_size argument:
Last ngram:  one two one two
Prohibiting:  one
Prohibiting:  two
Prohibiting:  one


Last ngram:  two
Prohibiting:  one


All tests for `generate` passed!


In [235]:
logprob_sums = t.tensor([-1.0, -2.0]).to(device)
tokens = t.tensor([
    [19485, 13],
    [19485, tokenizer.eos_token_id]
]).to(device)

beams_with_eos = Beams(model, tokenizer, logprob_sums, tokens)
best_beams, early_terminations = beams_with_eos.filter(2)

t.testing.assert_close(best_beams.logprob_sums, logprob_sums[[0]])
t.testing.assert_close(best_beams.tokens, tokens[[0]])

assert early_terminations.logprobs_and_completions == [(-2.0, "Stop" + tokenizer.eos_token)]

print("All tests for `filter` passed!")

All tests for `filter` passed!
