In [1]:
import os; os.environ['ACCELERATE_DISABLE_RICH'] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
from pathlib import Path
import webbrowser

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part1_transformer_from_scratch").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
# import part1_transformer_from_scratch.solutions as solutions

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == '__main__'

reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

  warn(f"Failed to load image Python extension: {e}")


Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Input and output

## Tokenisation

In GPT-2, the End of Sequence (EOS), Beginning of Sequence (BOS) and Padding (PAD) tokens are all the same, <|endoftext|> with index 50256.

In [5]:
vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])

In [7]:
reference_gpt2("Hello World")

tensor([[[-43.4317, -39.8365, -43.0660,  ..., -54.0877, -54.3452, -42.3645],
         [-61.0647, -68.0792, -72.3322,  ..., -73.9381, -74.3353, -68.3130],
         [-43.5787, -50.2103, -52.9083,  ..., -57.6991, -57.5098, -49.6761]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [8]:
reference_gpt2.to_tokens("Hello World")

tensor([[50256, 15496,  2159]], device='cuda:0')

In [9]:
reference_gpt2.to_tokens("Hello World", prepend_bos=False)

tensor([[15496,  2159]], device='cuda:0')

In [11]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))
print(reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']
['<|endoftext|>', '568', '73', '+', '318', '46', '23', '=', '123', '45', '67', '89', '-', '1', '000000', '000']


## Text generation

In [24]:
reference_text = "I am an amazing transformer."
tokens = reference_gpt2.to_tokens(reference_text).to(device)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998, 47385,    13]], device='cuda:0')
torch.Size([1, 7])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' transformer', '.']


In [25]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 7, 50257])


In [26]:
probs = logits.softmax(dim=-1)
print(probs.shape)

torch.Size([1, 7, 50257])


In [27]:
next_token = logits[0, -1].argmax(dim=-1)
next_char = reference_gpt2.to_string(next_token)
print(repr(next_char))

' I'


In [28]:
next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

In [29]:
list(zip(reference_gpt2.to_str_tokens(tokens), next_tokens))

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' transformer', '.'),
 ('.', ' I')]

## Generate output

In [33]:
next_token

tensor(314, device='cuda:0')

In [58]:
from einops import pack, rearrange

prompt = "import numpy as"
tokens = reference_gpt2.to_tokens(prompt).to(device)
print(prompt)

for i in range(10):
    logits: Float[Tensor, "1 seq_len n_toks"] = reference_gpt2(tokens)
    next_token = logits[0, -1].argmax(-1)
    prompt += reference_gpt2.to_string(next_token)
    print(prompt)
    tokens, _ = pack([tokens, rearrange(next_token, "-> () ()")], "b *")

import numpy as
import numpy as np
import numpy as np import
import numpy as np import time
import numpy as np import time import
import numpy as np import time import time
import numpy as np import time import time.
import numpy as np import time import time.time
import numpy as np import time import time.time import
import numpy as np import time import time.time import n
import numpy as np import time import time.time import numpy


In [41]:
tokens

tensor([[50256,    40,   716,   281,  4998, 47385,    13]], device='cuda:0')

In [42]:
rearrange(next_token, "-> () ()")

tensor([[314]], device='cuda:0')

# Intuitions / notes

- MLP as key-value pairs: each neuron fires if dot product of its 'key' with its input is >0, and returns a 'value' vector scaled to match the 'key'
- MLP as knowledge storage
- Unembedding: sometimes we set $W_E = W_U^T$ -- seems OK at first, until you realise that the direct path involving embedding and unembedding should approximate _bigram_ frequencies
    - Note: should this necessarily hold?
- Positional encoding
    - We add rather than concatenate... residual stream is shared memory and under significant superposition [??]

In [60]:
batch = 1
position = 35
d_model = 768
n_heads = 12
n_layers = 12
d_mlp = 3072 # (= 4 * d_model)
d_head = 64 # (= d_model / n_heads)

In [61]:
logits, cache = reference_gpt2.run_with_cache("I am an amazing transformer!")

In [65]:
for activation_name, activation in cache.items():
    print(f"{activation_name:30}", tuple(activation.shape))

hook_embed                     (1, 7, 768)
hook_pos_embed                 (1, 7, 768)
blocks.0.hook_resid_pre        (1, 7, 768)
blocks.0.ln1.hook_scale        (1, 7, 1)
blocks.0.ln1.hook_normalized   (1, 7, 768)
blocks.0.attn.hook_q           (1, 7, 12, 64)
blocks.0.attn.hook_k           (1, 7, 12, 64)
blocks.0.attn.hook_v           (1, 7, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 7, 7)
blocks.0.attn.hook_pattern     (1, 12, 7, 7)
blocks.0.attn.hook_z           (1, 7, 12, 64)
blocks.0.hook_attn_out         (1, 7, 768)
blocks.0.hook_resid_mid        (1, 7, 768)
blocks.0.hook_mlp_in           (1, 7, 768)
blocks.0.ln2.hook_scale        (1, 7, 1)
blocks.0.ln2.hook_normalized   (1, 7, 768)
blocks.0.mlp.hook_pre          (1, 7, 3072)
blocks.0.mlp.hook_post         (1, 7, 3072)
blocks.0.hook_mlp_out          (1, 7, 768)
blocks.0.hook_resid_post       (1, 7, 768)
blocks.1.hook_resid_pre        (1, 7, 768)
blocks.1.ln1.hook_scale        (1, 7, 1)
blocks.1.ln1.hook_normalized   (1, 7, 768)

In [67]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(f"{name:18} {tuple(param.shape)}")

embed.W_E          (50257, 768)
pos_embed.W_pos    (1024, 768)
blocks.0.ln1.w     (768,)
blocks.0.ln1.b     (768,)
blocks.0.ln2.w     (768,)
blocks.0.ln2.b     (768,)
blocks.0.attn.W_Q  (12, 768, 64)
blocks.0.attn.W_K  (12, 768, 64)
blocks.0.attn.W_V  (12, 768, 64)
blocks.0.attn.W_O  (12, 64, 768)
blocks.0.attn.b_Q  (12, 64)
blocks.0.attn.b_K  (12, 64)
blocks.0.attn.b_V  (12, 64)
blocks.0.attn.b_O  (768,)
blocks.0.mlp.W_in  (768, 3072)
blocks.0.mlp.b_in  (3072,)
blocks.0.mlp.W_out (3072, 768)
blocks.0.mlp.b_out (768,)
ln_final.w         (768,)
ln_final.b         (768,)
unembed.W_U        (768, 50257)
unembed.b_U        (50257,)


In [69]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12


cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


# Implementation

In [70]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape)
    try: reference_output = gpt2_layer(input)
    except: reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = t.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

### LayerNorm

In [89]:
from einops import reduce
import torch


class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(
        self, residual: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # Each d_model vector x normalised to mean 0 and variance 1
        # Then we apply elementwise affine scaling
        mean = reduce(residual, "b n h -> b n ()", reduction="mean")
        var = rearrange(torch.var(residual, dim=-1, unbiased=False), "b n -> b n ()")
        return (residual - mean) / torch.sqrt(var + self.cfg.layer_norm_eps) * self.w + self.b


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])


Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 7, 768])
Output shape: torch.Size([1, 7, 768])
Reference output shape: torch.Size([1, 7, 768]) 

100.00% of the values are correct



### Embedding

In [105]:
import torch.nn.functional as F

class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]


rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 15])
Output shape: torch.Size([1, 15, 768])
Reference output shape: torch.Size([1, 15, 768]) 

100.00% of the values are correct



### Positional encoding

In [109]:
from einops import repeat


class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        B, N = tokens.shape
        return repeat(self.W_pos[torch.arange(0,N)], "n h -> b n h", b=B)


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 15])
Output shape: torch.Size([1, 15, 768])
Reference output shape: torch.Size([1, 15, 768]) 

100.00% of the values are correct



### Attention

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-attn-30.png" width="800">

In [122]:
from einops import einsum


class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # Compute attn mask
        queries: Float[Tensor, "b q_pos n_h d_h"] = (
            einsum(
                self.W_Q,
                normalized_resid_pre,
                "n_h d_m d_h, b q_pos d_m -> b q_pos n_h d_h",
            )
            + self.b_Q
        )
        keys: Float[Tensor, "b k_pos n_h d_h"] = (
            einsum(
                self.W_K,
                normalized_resid_pre,
                "n_h d_m d_h, b k_pos d_m -> b k_pos n_h d_h",
            )
            + self.b_K
        )
        attn_scores: Float[Tensor, "b n_h q_pos k_pos"] = einsum(
            queries, keys, "b q_pos n_h d_h, b k_pos n_h d_h -> b n_h q_pos k_pos"
        )
        scaled_masked_attn_scores = self.apply_causal_mask(
            attn_scores / math.sqrt(self.cfg.d_head)
        )
        attn_weights: Float[Tensor, "b n_h q_pos k_pos"] = torch.softmax(scaled_masked_attn_scores, dim=-1)

        # Compute values
        values: Float[Tensor, "b v_pos n_h d_h"] = (
            einsum(
                self.W_V,
                normalized_resid_pre,
                "n_h d_m d_h, b q_pos d_m -> b q_pos n_h d_h",
            )
            + self.b_V
        )
        weighted_sum: Float[Tensor, "b v_pos n_h d_h"] = einsum(
            attn_weights, values, "b n_h q_pos v_pos, b v_pos n_h d_h -> b q_pos n_h d_h"
        )
        per_head_outputs: Float[Tensor, "b v_pos n_h d_m"] = einsum(
            self.W_O, weighted_sum, "n_h d_h d_m, b v_pos n_h d_h -> b v_pos n_h d_m"
        )
        outputs: Float[Tensor, "b v_pos d_m"] = (
            reduce(per_head_outputs, "b v_pos n_h d_m -> b v_pos d_m", reduction="sum")
            + self.b_O
        )
        return outputs

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        b, n, q, k = attn_scores.shape
        mask = torch.tril(torch.ones(q, k, device=attn_scores.device))
        return mask * attn_scores + (1.0 - mask) * self.IGNORE


rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"])


Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 7, 768])
Output shape: torch.Size([1, 7, 768])
Reference output shape: torch.Size([1, 7, 768]) 

100.00% of the values are correct



In [111]:
import circuitsvis as cv
from IPython.display import display

html = cv.attention.attention_patterns(
    tokens=reference_gpt2.to_str_tokens(reference_text), 
    attention=cache["pattern", 0][0]
)
display(html)

### MLP

In [124]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        hid: Float[Tensor, "b p_n d_h"] = gelu_new(
            einsum(self.W_in, normalized_resid_mid, "d_m d_h, b p_n d_m -> b p_n d_h")
            + self.b_in
        )
        result: Float[Tensor, "b p_n d_m"] = (
            einsum(self.W_out, hid, "d_h d_m, b p_n d_h -> b p_n d_m")
            + self.b_out
        )
        return result


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"])


Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 7, 768])
Output shape: torch.Size([1, 7, 768])
Reference output shape: torch.Size([1, 7, 768]) 

100.00% of the values are correct



### Transformer block

In [125]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        attn_out = self.attn(self.ln1(resid_pre))
        resid_mid = resid_pre + attn_out
        mlp_out = self.mlp(self.ln2(resid_mid))
        resid_post = resid_mid + mlp_out
        return resid_post

rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 7, 768])
Output shape: torch.Size([1, 7, 768])
Reference output shape: torch.Size([1, 7, 768]) 

100.00% of the values are correct



### Unembed

In [126]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return einsum(self.W_U, normalized_resid_final, "d_m d_v, b p d_m -> b p d_v") + self.b_U


rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 7, 768])
Output shape: torch.Size([1, 7, 50257])
Reference output shape: torch.Size([1, 7, 50257]) 

100.00% of the values are correct



### Demo transformer

In [129]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        embeddings = self.embed(tokens)
        pos_embs = self.pos_embed(tokens)
        residual_stream = embeddings + pos_embs
        for block in self.blocks:
            residual_stream = block(residual_stream)
        return self.unembed(self.ln_final(residual_stream))

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 15])
Output shape: torch.Size([1, 15, 50257])
Reference output shape: torch.Size([1, 15, 50257]) 

100.00% of the values are correct



In [130]:
demo_gpt2 = DemoTransformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

demo_logits = demo_gpt2(tokens)

In [131]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], 
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

Avg cross entropy loss: 1.7251
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.472316


In [132]:
test_string = '''The Total Perspective Vortex derives its picture of the whole Universe on the principle of'''
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

The Total Perspective Vortex derives its picture of the whole Universe on the principle of the total perspective. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The
