In [1]:
import torch
import torch.nn as nn
import einops
from dataclasses import dataclass
import math
from easy_transformer import EasyTransformer

: 

In [118]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


In [119]:
@dataclass
class Config:
    d_model: int = 768
    vocab_size: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    ln_eps: float = 1e-5
    n_heads: int = 12
    d_head: int = 64
    d_mlp: int = 3072
    n_layers: int = 12
    debug: bool = True

cfg = Config()
print(cfg)

Config(d_model=768, vocab_size=50257, init_range=0.02, n_ctx=1024, ln_eps=1e-05, n_heads=12, d_head=64, d_mlp=3072, n_layers=12, debug=True)


In [120]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
print(tokens.shape)

torch.Size([1, 35])


In [121]:
logits_reference = reference_gpt2(tokens)
print(logits_reference.shape)


torch.Size([1, 35, 50257])


In [122]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 35, 50257])


In [123]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randn(shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randint(100, 1000, shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str):
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

Transformes building blocks are:

* Embeddings
* Positional embeddings
* LayerNorm
* Attention
* MLP
* Unembed

## Tokens and Embeddings

### Tokens
* Tokens are subunits of text that are used to build the vocabulary of the model. Tokens are built with tokenizers, which are usually a BPE tokenizer.
* BPE or byte pair encoding is an algorithm that builds a vocab based on the most frequent pairs of tokens in the text. It starts with a vocab of single characters and then merges the most frequent pairs of tokens until the desired vocab size is reached.
* After the vocab is built, you'll have a mapping from tokens to integers.

So given a text: "Let's play this new game"

The tokenizer will output a list of integers: [1, 4, 39, 2, 45]

Where 1 is the token for "Let's", 4 is the token for "play", 39 is the token for "this", 2 is the token for "new" and 45 is the token for "game". We're using a simple tokenizer ignoring punctuation, subwords and spaces. In a more realistic example, the word "Let's" would be tokenized as "Let" and "'s".

### Embeddings
* Embeddings are a lookup table that maps tokens to vectors in the embedding space
* We have a vector in the embedding space for each token. So the embedding lookup table is a matrix $E \in \mathbb{R}^{V \times d_{model}}$ where $V$ is the vocabulary size and $d_{model}$ is the embedding dimension.

So in the example above we have a list of integers which we want to map to a list of vectors in the embedding space. Therefore tokens = [1, 4, 39, 2, 45] becomes embeddings = [e1, e4, e39, e2, e45] where $e_i \in \mathbb{R}^{d_{model}}$, where embeddings should have a shape of (sequence_length, d_model).



In [124]:
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.W_E = nn.Parameter(torch.empty((config.vocab_size, config.d_model)))
        nn.init.normal_(self.W_E, std=config.init_range)
    
    def forward(self, tokens):
        embeddings = self.W_E[tokens]
        if self.config.debug: print("Embeddings:", embeddings.shape)
        return embeddings

embeddings_layer = Embeddings(Config())
tokens = torch.tensor([[1, 4, 39, 2, 45]])
embeddings = embeddings_layer(tokens)
print(embeddings.shape)


Embeddings: torch.Size([1, 5, 768])
torch.Size([1, 5, 768])


In [125]:
rand_int_test(Embeddings, [2, 4])
load_gpt2_test(Embeddings, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Embeddings: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 5])
Embeddings: torch.Size([1, 5, 768])
Output shape: torch.Size([1, 5, 768])
Reference output shape: torch.Size([1, 5, 768])
100.00% of the values are correct


tensor([[[ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
         [-0.0506, -0.1111,  0.1058,  ..., -0.1149,  0.0664,  0.0574],
         [ 0.0746,  0.0573,  0.2295,  ..., -0.1254, -0.0506, -0.1640],
         [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
         [-0.0240, -0.0773,  0.1979,  ..., -0.0814, -0.0701,  0.0289]]],
       grad_fn=<IndexBackward0>)

### Positional Embeddings
* Positional embeddings are meant to give the model information about the order of the tokens in the sequence. 
* Intuitively, nearby tokens are more likely to be related to each other than distant tokens.
* Positional embeddings are a lookup table that maps positions to vectors in the embedding space
* This lookup table is a matrix $P \in \mathbb{R}^{n_{ctx} \times d_{model}}$ where $n_{ctx}$ is the maximum sequence length and $d_{model}$ is the embedding dimension.
* There are many ways to define positional embeddings:
    * Learned embeddings: the positional embeddings are learned during training.
    * Sine and cosine embeddings: the positional embeddings are defined as a function of the position index.


In [126]:
class PositionalEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_pos = nn.Parameter(torch.empty(config.n_ctx, config.d_model))
        nn.init.normal_(self.W_pos, std=config.init_range)
    
    def forward(self, tokens):
        pos_embeddings = self.W_pos[:tokens.size(1)]
        pos_embeddings = einops.repeat(
            pos_embeddings,
            'position d_model -> batch position d_model',
            batch=tokens.size(0)
        )
        return pos_embeddings

rand_int_test(PositionalEmbeddings, [2, 4])
load_gpt2_test(PositionalEmbeddings, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 5])
Output shape: torch.Size([1, 5, 768])
Reference output shape: torch.Size([1, 5, 768])
100.00% of the values are correct


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         [-2.8337e-04, -7.3803e-02,  1.0553e-01,  ...,  1.0157e-02,
           1.7659e-02, -7.0854e-03],
         [ 7.6374e-03, -2.5090e-02,  1.2696e-01,  ...,  8.4643e-03,
           9.8542e-03, -7.0117e-03]]], grad_fn=<ExpandBackward0>)

### LayerNorm
* LayerNorm is a normalization layer that normalizes each vector in the input on the embedding dimension to have zero mean and unit variance
* LayerNorm also applies a learnable scaling factor and bias to the normalized vector, enabling the model to learn scaling and shifting best suited for the data it's processing
* LayerNorm helps with the stability of the training process and with the convergence of the model.
    * To-do: Explore LayerNorm in more detail.

In [127]:
class LayerNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.w = nn.Parameter(torch.ones(config.d_model))
        self.b = nn.Parameter(torch.zeros(config.d_model))
    
    def forward(self, residual):
        residual -= einops.reduce(
            residual, 
            'batch position d_model -> batch position 1',
            'mean'
        )
        std = (
            einops.reduce(
                residual.pow(2),
                "batch position d_model -> batch position 1", "mean"
            ) + self.config.ln_eps
        ).sqrt()
        residual = residual / std
        residual = residual * self.w + self.b
        if self.config.debug: print("Normalized:", residual.shape)
        return residual

In [128]:
layer_norm_layer = LayerNorm(Config())
residual = torch.randn(1, 1024, 768)
normalized_residual = layer_norm_layer(residual)
print(normalized_residual.shape)
print(normalized_residual.mean(dim=-1), normalized_residual.std(dim=-1))

Normalized: torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
tensor([[-4.9671e-09, -2.3594e-08, -8.6923e-09,  ...,  1.1176e-08,
          9.9341e-09,  1.3659e-08]], grad_fn=<MeanBackward1>) tensor([[1.0006, 1.0006, 1.0006,  ..., 1.0006, 1.0006, 1.0006]],
       grad_fn=<StdBackward0>)


In [129]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


### Attention
* The attention is a mechanism that allows each position in the sequence to consider the relevance of every other position in the sequence when computing a representation of that sequence. In other words, it allows the model to focus on the most relevant parts of the input sequence to predict the next token.
* It does this by:
    * KQ - Here we're generating the attention scores, which encode what source tokens are relevant to the destination tokens.
    * We do this by applying a linear map from the input (embedding space) to the key space (K) and a linear map from the input (embedding space) to the query space (Q). Where the key space encodes the information we have and the query space encodes the information we're looking for.
    * After generating K and Q, we compute the attention scores by taking the dot product of K and Q. Which measures how much the source tokens are relevant to the destination tokens.
    * The next step is to apply a softmax function to the attention scores to get the attention weights. However, before applying softmax, since we're working with sequences, we need to mask the attention scores for the "future" tokens. Making sure that the model is causal.
    * attn_scores @ V - Finally, we apply a linear map from the attention scores to the value space. Where the value space encodes the information we want to communicate to each destination token.
        * K = Key -> Key is what do we contain
            * Key is a linear map from the input space to the key space
        * Q = Query -> Query is what we're looking for
            * Query is a linear map from the input space to the query space
        * V = Value -> Value is what we're communicating to each destination token
            * Value is a linear map from the input space to the value space


In [130]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.W_K = nn.Parameter(torch.empty(config.n_heads, config.d_model, config.d_head))
        nn.init.normal_(self.W_K, std=config.init_range)
        self.b_K = nn.Parameter(torch.zeros(config.n_heads, config.d_head))

        self.W_Q = nn.Parameter(torch.empty(config.n_heads, config.d_model, config.d_head))
        nn.init.normal_(self.W_Q, std=config.init_range)
        self.b_Q = nn.Parameter(torch.zeros(config.n_heads, config.d_head))

        self.W_V = nn.Parameter(torch.empty(config.n_heads, config.d_model, config.d_head))
        nn.init.normal_(self.W_V, std=config.init_range)
        self.b_V = nn.Parameter(torch.zeros(config.n_heads, config.d_head))

        self.W_O = nn.Parameter(torch.empty(config.n_heads, config.d_head, config.d_model))
        nn.init.normal_(self.W_O, std=config.init_range)
        self.b_O = nn.Parameter(torch.zeros(config.d_model))

    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        # linear map from batch, position, d_model -> batch, position, n_heads, d_head
        q = einops.einsum(
            normalized_resid_pre, self.W_Q,
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head'
        ) + self.b_Q
        k = einops.einsum(
            normalized_resid_pre, self.W_K,
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head'
        ) + self.b_K
        # compute the attention scores
        # batch, query_pos, n_heads, d_head, batch, key_pos, n_heads, d_head -> batch, query_pos, n_heads, key_pos
        attn_scores = einops.einsum(
            q, k,
            'batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos'
        )
        attn_scores = attn_scores / math.sqrt(self.config.d_head)
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1)), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(mask, -float('inf'))
        attn_weights = attn_scores.softmax(dim=-1)

        v = einops.einsum(
            normalized_resid_pre, self.W_V,
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head'
        ) + self.b_V
        z = einops.einsum(
            attn_weights, v,
            'batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head'
        )
        attn_out = einops.einsum(
            z, self.W_O,
            'batch key_pos n_heads d_head, n_heads d_head d_model -> batch key_pos d_model'
        ) + self.b_O
        return attn_out

attention_layer = Attention(Config())
normalized_resid_pre = torch.randn(1, 5, 768)
attn_weights = attention_layer(normalized_resid_pre)
print(attn_weights.shape)

torch.Size([1, 5, 768])


In [131]:
rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.ln1.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 7.9663e-01,  1.6985e-02,  3.4781e-02,  ...,  3.3119e-02,
          -2.3129e-02,  1.8103e-01],
         [ 1.3179e-03,  1.5750e-01, -1.4059e-01,  ..., -8.1998e-03,
           5.3074e-03,  1.3511e-01],
         [ 8.9738e-02, -7.2411e-01, -6.9866e-01,  ...,  5.5321e-02,
           2.7958e-03,  9.0785e-02],
         ...,
         [-3.0286e-01,  4.9638e-02, -6.0990e-01,  ..., -3.7084e-02,
          -4.9527e-04, -8.6008e-03],
         [-1.0844e+00, -6.1457e-02,  2.2966e-01,  ..., -2.6688e-02,
          -1.4368e-02,  3.3245e-02],
         [ 3.7947e-01, -4.9886e-01,  2.6434e-01,  ..., -2.7894e-02,
          -8.9027e-03,  4.8796e-02]]], grad_fn=<AddBackward0>)

### MLP - Multi Layer Perceptron
* The MLP is a standard feedforward neural network, that performs a linear map, followed by a non-linearity, followed by another linear map.
* So it's a single hidden layer neural network with a GeLu activation function. 
* Mapping from embedding space -> hidden space -> embedding space with the non-linearity in the middle.

In [163]:
from easy_transformer.utils import gelu_new

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.W_in = nn.Parameter(torch.empty((config.d_model, config.d_mlp)))
        nn.init.normal_(self.W_in, std=config.init_range)
        self.b_in = nn.Parameter(torch.zeros((config.d_mlp)))

        self.W_out = nn.Parameter(torch.empty((config.d_mlp, config.d_model)))
        nn.init.normal_(self.W_out, std=config.init_range)
        self.b_out = nn.Parameter(torch.zeros((config.d_model)))
    
    def forward(self, normalized_resid_pre):
        mlp_in = einops.einsum(
            normalized_resid_pre, self.W_in,
            'batch position d_model, d_model d_mlp -> batch position d_mlp'
        ) + self.b_in
        mlp_in = gelu_new(mlp_in)
        mlp_out = einops.einsum(
            mlp_in, self.W_out, 
            'batch position d_mlp, d_mlp d_model -> batch position d_model'
        ) + self.b_out
        return mlp_out

mlp_layer = MLP(Config())
normalized_resid_pre = torch.randn(1, 5, 768)
mlp_out = mlp_layer(normalized_resid_pre)
print(mlp_out.shape)

torch.Size([1, 5, 768])


In [164]:
rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.4380,  0.3624,  0.5117,  ...,  1.7227,  1.5761,  0.0368],
         [-1.0766, -0.0438,  0.3276,  ..., -0.5437,  0.4033,  0.3717],
         [-1.2182, -1.5481, -0.9702,  ...,  1.0737,  0.7199,  0.5080],
         ...,
         [-0.4004,  0.8475,  0.2047,  ...,  0.3789,  0.0455, -0.4744],
         [-0.0862,  0.7839,  0.9046,  ..., -0.2174, -0.5953,  0.8555],
         [ 0.8448, -0.3743,  1.0397,  ...,  0.0296,  0.3405,  0.3585]]],
       grad_fn=<AddBackward0>)

### Unembed
* The unembed layer is a linear map from the embedding space to the logits space.
* The logits space is the space of the unnormalized probabilities of the next token.
* So the unembed layer outputs the logits for each token in the sequence.

In [150]:
class Unembed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_U = nn.Parameter(torch.empty((config.d_model, config.vocab_size)))
        nn.init.normal_(self.W_U, std=config.init_range)
        self.b_U = nn.Parameter(torch.zeros((config.vocab_size), requires_grad=False))
    
    def forward(self, normalized_resid_post):
        logits = einops.einsum(
            normalized_resid_post, self.W_U,
            'batch position d_model, d_model vocab_size -> batch position vocab_size'
        ) + self.b_U
        return logits

unembed_layer = Unembed(Config())
normalized_resid_post = torch.randn(1, 5, 768)
logits = unembed_layer(normalized_resid_post)
print(logits.shape)


torch.Size([1, 5, 50257])


In [151]:
rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0877,  -54.3452,
           -42.3644],
         [-128.0392, -127.9936, -130.7011,  ..., -136.7121, -129.9261,
          -129.3965],
         [-119.8521, -121.0064, -123.8820,  ..., -128.5181, -126.6028,
          -121.9061],
         ...,
         [-112.9815, -112.7750, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6725, -104.4889, -108.7361,  ..., -118.3552, -113.8766,
          -106.3604],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1969, -138.5882,
          -122.3697]]], grad_fn=<AddBackward0>)

In [166]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config)
        self.attn = Attention(config)
        self.ln_2 = LayerNorm(config)
        self.mlp = MLP(config)
    
    def forward(self, resid_pre):
        normalized_resid_pre = self.ln_1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out
        normalized_resid_mid = self.ln_2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post

transformer_block = TransformerBlock(Config())
resid_pre = torch.randn(1, 5, 768)
resid_post = transformer_block(resid_pre)
print(resid_post.shape)

Normalized: torch.Size([1, 5, 768])
Normalized: torch.Size([1, 5, 768])
torch.Size([1, 5, 768])


In [167]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed = Embeddings(config)
        self.pos_embed = PositionalEmbeddings(config)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.ln_final = LayerNorm(config)
        self.unembed = Unembed(config)
    
    def forward(self, tokens):
        embeds = self.embed(tokens)
        pos_embeds = self.pos_embed(tokens)
        resid_pre = embeds + pos_embeds
        for block in self.blocks:
            resid_pre = block(resid_pre)
        normalized_resid_final = self.ln_final(resid_pre)
        out = self.unembed(normalized_resid_final)
        return out

rand_int_test(Transformer, [2, 4])
load_gpt2_test(Transformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Embeddings: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 

tensor([[[-64.2467, -62.7545, -65.4784,  ..., -72.5686, -70.3090, -64.5301],
         [-21.9314, -21.9484, -23.8217,  ..., -28.2983, -27.8309, -22.8705],
         [-47.7269, -47.1128, -50.2303,  ..., -55.1695, -53.4897, -47.8364],
         ...,
         [-88.1346, -86.2741, -90.8221,  ..., -97.1845, -94.9073, -89.3229],
         [-60.6255, -59.4294, -63.5714,  ..., -68.8870, -67.3016, -61.2635],
         [-74.9402, -73.2983, -77.2700,  ..., -83.5924, -80.9157, -75.2907]]],
       grad_fn=<AddBackward0>)

In [168]:
my_transformer = Transformer(Config(debug=False))
my_transformer.load_state_dict(reference_gpt2.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['blocks.0.ln_1.w', 'blocks.0.ln_1.b', 'blocks.0.ln_2.w', 'blocks.0.ln_2.b', 'blocks.1.ln_1.w', 'blocks.1.ln_1.b', 'blocks.1.ln_2.w', 'blocks.1.ln_2.b', 'blocks.2.ln_1.w', 'blocks.2.ln_1.b', 'blocks.2.ln_2.w', 'blocks.2.ln_2.b', 'blocks.3.ln_1.w', 'blocks.3.ln_1.b', 'blocks.3.ln_2.w', 'blocks.3.ln_2.b', 'blocks.4.ln_1.w', 'blocks.4.ln_1.b', 'blocks.4.ln_2.w', 'blocks.4.ln_2.b', 'blocks.5.ln_1.w', 'blocks.5.ln_1.b', 'blocks.5.ln_2.w', 'blocks.5.ln_2.b', 'blocks.6.ln_1.w', 'blocks.6.ln_1.b', 'blocks.6.ln_2.w', 'blocks.6.ln_2.b', 'blocks.7.ln_1.w', 'blocks.7.ln_1.b', 'blocks.7.ln_2.w', 'blocks.7.ln_2.b', 'blocks.8.ln_1.w', 'blocks.8.ln_1.b', 'blocks.8.ln_2.w', 'blocks.8.ln_2.b', 'blocks.9.ln_1.w', 'blocks.9.ln_1.b', 'blocks.9.ln_2.w', 'blocks.9.ln_2.b', 'blocks.10.ln_1.w', 'blocks.10.ln_1.b', 'blocks.10.ln_2.w', 'blocks.10.ln_2.b', 'blocks.11.ln_1.w', 'blocks.11.ln_1.b', 'blocks.11.ln_2.w', 'blocks.11.ln_2.b'], unexpected_keys=['blocks.0.ln1.w', 'blocks.0.ln

In [170]:
text = "Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on"
for i in range(100):
    tokens = reference_gpt2.to_tokens(text)
    logits_my_transformer = my_transformer(tokens)
    next_token = logits_my_transformer[-1, -1].argmax()
    text += reference_gpt2.tokenizer.decode(next_token)
print(text)


Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on the the the the the the the the the the the the the un, the un the un, the un, the un, the un, the un a the the un, the un the un the un the the the the the un the the the the the the a the the the the the the the the def a the un a the the the the the un the the the the the the un the the the the the the a the a a the a a a a a a a a


In [139]:
print(text)

Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on the the, the the the the the the the


In [140]:
print(my_transformer)

Transformer(
  (embeddings): Embeddings()
  (pos_embeddings): PositionalEmbeddings()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln_1): LayerNorm()
      (attn): Attention()
      (ln_2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)
