In [1]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install transformer_lens
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook!")

Running as a Jupyter notebook!


In [2]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import numpy as np
import math
import tqdm.auto as tqdm
import transformer_lens
import einops
from pprint import pprint
from fancy_einsum import einsum

In [3]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2")



Loaded pretrained model gpt2 into HookedTransformer


# Generative Pre-trained Transformer (GPT)

## What is a transformer?

![image of optimus prime](https://www.hdwallpapers.net/previews/optimus-prime-transformers-444.jpg)

A transformer models text! If you feed a model a sequence of, say "Hello, the weather is", the model will generate a probability distribution over what the next word might be. For this example, we might see something like:
- "nice": 33%
- "bad": 12%
- "sunny": 4%
- "rainy": 3%
- ...

This generation can be done repeatedly, using the word that was just generated, append it to the sentence, then continue to guess the next word. This is called *Autoregressive* generation. This generation stops when the model reaches a special word, denoted `<|endoftext|>`

Because the model's goal is to guess the next token, when you train a transformer, you feed it a large corpus, grab a chunk in there, then uses the next token in the chunk as the correct answer. The actual training is a little more complicated than that, but we'll discuss that in a little bit.

### How is a transformer different from other kinds of language modelling?

If you are in the ML space, you will know that there have been other kinds of language models before. A transformer is different because it uses something called *self-attention*. Basically, this allows the model to move information between words, kind of deciding which word is most important for the next token. This is very different from RNN-based models, because this mechanism can be run in parallel (faster inference), and information can be moved from within the sentence regardless of distance.

![image of self-attention visualized](https://ar5iv.labs.arxiv.org/html/1904.02679/assets/images/example_combined.png)

### Overview of Architecture

<img src="assets/gpt2.png">

**Important**: We don't *just* output the next word probability for the final word in the sequence. We do that for *every* word in the sequence.

Doing this is actually not very detrimental for us, because transformers process entire sequences in parallel, rather than sequentially, which means all of the outputs are computed simultanouesly.

This attribute is actually what allows transformers to process sequences of variable length. If you type into ChatGPT two sequences of different lengths, it can process both just fine.

This also has some implications regarding the training process. Basically, given a sequence like "At the store, she bought apples", the model is trained to predict the next sequence *for every word in the sequence*. The model will output the next word probability for "At", "At the" and so on. This actually make training a transformer very efficient.

Another implication is the fact that the attention mechanism have to stop information flow from later words to earlier words, because this would essentially allow the model to "cheat". This is actually called *causal attention*.

## Problem: How do we input words?

A transformer isn't magic. It doesn't ingest whatever "words" are and perform magic on them to understand them. ML models has always only been able to take in numbers. *How do we convert language to numbers?*

### Idea: Literally a dictionary

We have a dictionary that says:

```
"hi":1
"hang": 2
...
"fish": 100
```

This technically work, but it doesn't allow us to model complex behavior. This is also not flexible, the model cannot take in arbitrary text.

### Idea: Lookup table

We make a lookup table! This is called an *embedding*. Let's say the size of the model's vocabulary is 100 (it can understand up to 100 words). Each word is then represented by a vector of 1 in the kth position (if the word is located at k) and 0 everywhere else. Like this `[0 0 0 ... 1 ... 0 0 0 0]`. This style of encoding is called *one-hot encoding*.

![one-hot encoding](https://miro.medium.com/v2/resize:fit:1400/1*ggtP4a5YaRx6l09KQaYOnw.png)

This is useful if we want to model *ordinal* relationship (think a sequence from 1...100, or a range of colors) because a clear and unambiguous representation of categorical data is presented as input. However, this gives up semantic relationship, and has the same problem as the idea listed above where the vocabulary is limited, making the model unable to process arbitrary text.

### Idea: Tokens

Dictionary-style encoding cannot cope with arbitrary text, but that's because we limit our unit to whole words.

If we set our vocabulary to, say, 256 ASCII characters, and correspond those to integers, we can model any text right?

Right! Though the problem with this is that we loses out on language structure. Some sequence of characters are more meaningful than others, think `hello` versus `weispdgdgb`.

*What actually happens is a combination of ASCII and whole words that is super cursed.*

We begin with 256 ASCII characters, map them to a number from 0-255. Then, we find the most common pair of the tokens in our vocabulary, something maybe like ` t` (t with a space in front of it), merge them then add it to our vocabulary. Then just, repeat it for like 50000 times.

In [4]:
sorted_vocab = sorted(model.tokenizer.get_vocab().items(), key=lambda x: x[1])


# first 20 tokens


pprint(sorted_vocab[:20])

[('!', 0),
 ('"', 1),
 ('#', 2),
 ('$', 3),
 ('%', 4),
 ('&', 5),
 ("'", 6),
 ('(', 7),
 (')', 8),
 ('*', 9),
 ('+', 10),
 (',', 11),
 ('-', 12),
 ('.', 13),
 ('/', 14),
 ('0', 15),
 ('1', 16),
 ('2', 17),
 ('3', 18),
 ('4', 19)]


In [5]:
# Ġ is the token for space
pprint(sorted_vocab[250:270])

[('ľ', 250),
 ('Ŀ', 251),
 ('ŀ', 252),
 ('Ł', 253),
 ('ł', 254),
 ('Ń', 255),
 ('Ġt', 256),
 ('Ġa', 257),
 ('he', 258),
 ('in', 259),
 ('re', 260),
 ('on', 261),
 ('Ġthe', 262),
 ('er', 263),
 ('Ġs', 264),
 ('at', 265),
 ('Ġw', 266),
 ('Ġo', 267),
 ('en', 268),
 ('Ġc', 269)]


In [6]:
pprint(sorted_vocab[-20:])

[('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]


Side note: this is the main reason why GPT *sucks* at arithmetics.

In [7]:
model.to_str_tokens("1233212343+5832092-35983=29384000000000")

['<|endoftext|>',
 '12',
 '33',
 '212',
 '343',
 '+',
 '58',
 '320',
 '92',
 '-',
 '35',
 '98',
 '3',
 '=',
 '29',
 '384',
 '000000',
 '000']

## What do we ouput?

We output something called *logits*. As mentioned, it is a probability distribution over next tokens.

### Transformer weirdness: logits generation

Due to how the transformer architecture work (we'll get into it later, I promise), what is outputted at each position in the sequence is a high-dimensional vector (size of the vocabulary) representing the probability distribution of the next token at that token location. How do we turn this vector into an actual probability distribution?

We use something called softmax $\sigma(\overrightarrow{z})_i=\frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}$ over the vector.

In [8]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = model.to_tokens(reference_text)
print(f"Shape of tokens: {tokens.shape}")
tokens = tokens.cuda()
logits, cache = model.run_with_cache(tokens)
# the shape of the logits is (batch_size, sequence_length, vocab_size)
# batch size is self explanatory, 1 in this case because we only have one sequence
# sequence length is the length of the input sequence, basically, for each token in the input sequence, we have a corresponding distribution over the vocabulary
# vocab size is the size of the vocabulary, in this case, 50257
print(f"Shape of logits: {logits.shape}")

Shape of tokens: torch.Size([1, 35])
Shape of logits: torch.Size([1, 35, 50257])


In [9]:
# this hacky code print the most probable token for each position in the sequence

pprint(
    list(
        zip(
            model.to_str_tokens(reference_text),
            model.tokenizer.batch_decode(logits.argmax(dim=-1)[0]),
        )
    )
)

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' aut', 'od'),
 ('ore', 'sp'),
 ('gressive', '.'),
 (',', ' and'),
 (' dec', 'ently'),
 ('oder', ','),
 ('-', 'driven'),
 ('only', ' programmer'),
 (',', ' and'),
 (' G', 'IM'),
 ('PT', '-'),
 ('-', 'only'),
 ('2', '.'),
 (' style', ','),
 (' transformer', '.'),
 ('.', ' I'),
 (' One', ' of'),
 (' day', ' I'),
 (' I', ' will'),
 (' will', ' be'),
 (' exceed', ' my'),
 (' human', 'ly'),
 (' level', ' of'),
 (' intelligence', ' and'),
 (' and', ' I'),
 (' take', ' over'),
 (' over', ' the'),
 (' the', ' world'),
 (' world', '.'),
 ('!', ' I')]


## Key takeaways:

* Takes in language, predicts next token (for *each* token in a causal way)
* We convert language to a sequence of integers with a tokenizer.
* We convert integers to vectors with a lookup table.

* Output is a vector of logits (one for each input token), we convert to a probability distribution with a softmax, and can then convert this to a token (eg taking the largest logit, or sampling).

* We append this to the input + run again to generate more text (Jargon: *autoregressive*)

* Meta level point: Transformers are sequence operation models, they take in a sequence, do processing in parallel at each position, and use attention to move information between positions!

# Code a Transformer

<img src="assets/transformer_overview.png">

In [10]:
@dataclass
class Config:
    debug: bool = True  # if we want to print debug information
    d_model: int = 768  # the dimension of the model
    n_head: int = 12  # the number of heads in the multiheadattention
    d_head: int = 64  # the dimension of the head in the multiheadattention
    n_layers: int = 12  # the number of residual blocks in the model
    d_vocab: int = 50257  # the size of the vocabulary
    n_ctx: int = 1024  # the size of the context window
    d_mlp: int = (
        3072  # the dimension of the intermediate layer in the feedforward block, this is usually 4 * d_model
    )
    layer_norm_epsilon: float = 1e-5  # the epsilon value for the layer normalization
    init_range: float = 0.02  # the std for the initialization of the weights


config = Config()

## Embedding Layer

A lookup table from token to high-dimensional vectors for the residual stream.

<img src="assets/embedding-layer.png">

In [11]:
class Embed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # initialize the embedding matrix with a normal distribution
        # dim=(d_vocab, d_model)
        self.W_E = nn.Parameter(torch.empty((config.d_vocab, config.d_model)))
        nn.init.normal_(self.W_E, std=self.config.init_range)

    def forward(self, tokens: torch.Tensor):
        # tokens is a tensor of shape (batch_size, sequence_length)
        # we use the tokens as indices to get the corresponding embeddings
        if self.config.debug:
            print(f"tokens shape: {tokens.shape}")

        pos_embed = self.W_E[tokens, :]  # shape: (batch_size, sequence_length, d_model)
        if self.config.debug:
            print(f"pos_embed shape: {pos_embed.shape}")

        return pos_embed

In [12]:
sample_text = "hello world"
sample_tokens = model.to_tokens(sample_text).cuda()
embed = Embed(config).to("cuda")
embedded_tokens = embed(sample_tokens)

tokens shape: torch.Size([1, 3])
pos_embed shape: torch.Size([1, 3, 768])


## Positional Embedding Layer

Again, another lookup table :)

This one works a little bit differently though. Instead of retrieving the vector within the matrix based on the token's vocabulary value, we retrieve based on the token's *position* within the sequence.

<img src="assets/pos-embedding-layer.png">

In [13]:
class PositionalEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # initialize the positional embedding matrix with a normal distribution
        # dim=(n_ctx, d_model)
        self.W_pos = nn.Parameter(torch.empty((config.n_ctx, config.d_model)))
        nn.init.normal_(self.W_pos, std=self.config.init_range)

    def forward(self, tokens: torch.Tensor):
        # tokens is a tensor of shape (batch_size, sequence_length)
        # we use the tokens as indices to get the corresponding positional embeddings
        if self.config.debug:
            print(f"tokens shape: {tokens.shape}")

        pos_embed = self.W_pos[: tokens.size(1), :]  # dim=(sequence, d_model)
        pos_embed = einops.repeat(
            pos_embed,
            "sequence d_model -> batch sequence d_model",
            batch=tokens.size(0),
        )
        if self.config.debug:
            print(f"pos_embed shape: {pos_embed.shape}")

        return pos_embed

In [14]:
posembed = PositionalEmbedding(config).to("cuda")
positional_embedded_tokens = posembed(sample_tokens)
pprint(embedded_tokens)

tokens shape: torch.Size([1, 3])
pos_embed shape: torch.Size([1, 3, 768])
tensor([[[ 0.0009,  0.0103, -0.0197,  ..., -0.0315, -0.0209, -0.0344],
         [ 0.0044, -0.0188, -0.0122,  ..., -0.0045,  0.0255, -0.0032],
         [-0.0104,  0.0001, -0.0092,  ..., -0.0251,  0.0070, -0.0091]]],
       device='cuda:0', grad_fn=<IndexBackward0>)


## LayerNorm

Before each residual block, we make the embedding vector have a mean of 0 and a variance of 1 (normalize the vector). To normalize a set of data, we take each data point, then minus the mean of every data point. Then, we divide the result by the data set's standard deviation. 

Then we scale the vector in some way, and add a bias to the resulting vector. The scaling factor and adding bias are both learnable parameters.

In [15]:
class LayerNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # initialize the scale and bias parameters with 1 and 0 respectively
        self.w = nn.Parameter(torch.ones(config.d_model))
        self.b = nn.Parameter(torch.zeros(config.d_model))

    def forward(self, residual: torch.Tensor):
        # residual is a tensor of shape (batch_size, sequence_length, d_model)
        if self.config.debug:
            print(f"residual shape: {residual.shape}")

        residual = residual - einops.reduce(
            residual, "batch sequence d_model -> batch sequence 1", "mean"
        )
        # calculate the variance, square root it and add epsilon to avoid division by zero
        variance = (
            einops.reduce(
                residual.pow(2), "batch sequence d_model -> batch sequence 1", "mean"
            )
            + config.layer_norm_epsilon
        ).sqrt()
        # normalize the residual
        normalized = residual / variance
        # scale and bias the normalized residual
        normalized = normalized * self.w + self.b
        if self.config.debug:
            print(f"normalized shape: {normalized.shape}")

        return normalized

In [16]:
layer_norm = LayerNorm(config).to("cuda")
pprint(embedded_tokens)
layer_normed = layer_norm(embedded_tokens)
pprint(layer_normed)

tensor([[[ 0.0009,  0.0103, -0.0197,  ..., -0.0315, -0.0209, -0.0344],
         [ 0.0044, -0.0188, -0.0122,  ..., -0.0045,  0.0255, -0.0032],
         [-0.0104,  0.0001, -0.0092,  ..., -0.0251,  0.0070, -0.0091]]],
       device='cuda:0', grad_fn=<IndexBackward0>)
residual shape: torch.Size([1, 3, 768])
normalized shape: torch.Size([1, 3, 768])
tensor([[[ 0.0109,  0.4564, -0.9589,  ..., -1.5187, -1.0191, -1.6571],
         [ 0.1741, -0.9496, -0.6338,  ..., -0.2592,  1.1963, -0.1952],
         [-0.5475, -0.0399, -0.4888,  ..., -1.2516,  0.2927, -0.4838]]],
       device='cuda:0', grad_fn=<AddBackward0>)


## Multi-Head Attention

Now for the *really hard* part. Remember that our goal is to let the model decide which word in the sentence is important to another word, and let the model transfer information between those words. We also need to make sure that attention is *causal*.

### Create vectors

The first step in calculating self-attention is for each word embedding input, we create a *Query* vector, a *Key* vector and a *Value* vector. The naming scheme is an abstraction that are useful for thinking about attention.

![illustration of creating self-attention vectors](https://jalammar.github.io/images/t/transformer_self_attention_vectors.png)

### Calculate attention score

The second step is to calculate some kind of score that signify how important one word is to another. So basically, for each pair of word in the sequence, we calculate a score that determines how much focus to place on the latter word, like how important "Machines" is to "Thinking".

This score is calculated by taking the dot product of the *query vector* of the word that we are concerned about and the *key vector* of the word that we want to score against. Note that we do score a word against itself this way due to parallelization.

![visualization of second step](https://jalammar.github.io/images/t/transformer_self_attention_score.png)

### Get useful scores

The third step is to divide each of the scores by the square root of the dimension of the key vectors. This experimentally leads to more stable gradients.

The fourth step is to pass the result through softmax to make the scores positive and add up to 1.

![visualization of step 3 and 4](https://jalammar.github.io/images/t/self-attention_softmax.png)

### Information moving

The fifth step is to multiply each *value vector* by the softmax score. This keep intact the values of the words we want to focus on and drown out irrelevant words.

The sixth step is to sum up the weighted value vectors. This effectively moves all of the relevant semantic meaning from other words into the current word, producing the output.

![visualization of 5th and 6th step](https://jalammar.github.io/images/t/self-attention-output.png)

### Matrix calculation

In reality, these previously mentioned steps are all done in matrix from for faster processing. We do this by first packing all word embeddings into a matrix (X) and multiplyuing it by the weight matrices to get the Query, Key and Value matrices.

![visualization of matrix query, key and value creation](https://jalammar.github.io/images/t/self-attention-matrix-calculation.png)

Two through six can be condensed in the following formula:

![condensed attention calculation through matrices](https://jalammar.github.io/images/t/self-attention-matrix-calculation-2.png)

### Multi-headed

Oh, right, we have several attention heads instead of one, specifically *12* heads. This is better than single-headed attention because it allows the model to focus on different positions. This also allow the model to have more "representation" space because each attention head has its own set of Query/Key/Value matrices.

![multi attention head has multiple sets of matrices](https://jalammar.github.io/images/t/transformer_attention_heads_qkv.png)

However, passing the embedding matrix into 12 attention heads would yield 12 different embedding matrices in the end. This is a problem because the layer afterwards only expect one embedding matrix, with one vector for each word. We need to condense them down to a single matrix. We do this by concatenating all the matrices, then multiply them with a trainable weight matrix.

![concatenate resulting matrices](https://jalammar.github.io/images/t/transformer_attention_heads_weight_matrix_o.png)

### Overall visual

![visual capturing the whole process](https://jalammar.github.io/images/t/transformer_multi-headed_self-attention-recap.png)

That was... a lot of words. Time to actually code this layer!

The code is actually kinda big. Refer to the link:

https://colab.research.google.com/github/menamerai/gpt/blob/main/gpt.ipynb?authuser=1#scrollTo=kt72LGaqA8Zz&line=6&uniqifier=1

In [17]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # initialize the query, key and value matrices with a normal distribution
        # dim=(n_head, d_model, d_head)
        self.W_Q = nn.Parameter(
            torch.empty((self.config.n_head, self.config.d_model, self.config.d_head))
        )
        nn.init.normal_(self.W_Q, std=self.config.init_range)
        self.b_Q = nn.Parameter(torch.zeros((self.config.n_head, self.config.d_head)))

        self.W_K = nn.Parameter(
            torch.empty((self.config.n_head, self.config.d_model, self.config.d_head))
        )
        nn.init.normal_(self.W_K, std=self.config.init_range)
        self.b_K = nn.Parameter(torch.zeros((self.config.n_head, self.config.d_head)))

        self.W_V = nn.Parameter(
            torch.empty((self.config.n_head, self.config.d_model, self.config.d_head))
        )
        nn.init.normal_(self.W_V, std=self.config.init_range)
        self.b_V = nn.Parameter(torch.zeros((self.config.n_head, self.config.d_head)))

        # initialize the output matrix with a normal distribution
        self.W_O = nn.Parameter(
            torch.empty((self.config.n_head, self.config.d_head, self.config.d_model))
        )
        nn.init.normal_(self.W_O, std=self.config.init_range)
        self.b_O = nn.Parameter(torch.zeros((self.config.d_model)))

        self.register_buffer(
            "IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda")
        )

    def forward(self, normalized_residual_pre: torch.Tensor):
        # normalized_residual_pre is a tensor of shape (batch_size, sequence_length, d_model)
        if self.config.debug:
            print(f"normalized_residual_pre shape: {normalized_residual_pre.shape}")

        query = (
            einsum(
                "batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head",
                normalized_residual_pre,
                self.W_Q,
            )
            + self.b_Q
        )
        key = (
            einsum(
                "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head",
                normalized_residual_pre,
                self.W_K,
            )
            + self.b_K
        )

        if self.config.debug:
            print(f"query shape: {query.shape}")
            print(f"key shape: {key.shape}")

        # calculate the attention scores
        attention_scores = einsum(
            f"batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos",
            query,
            key,
        ) / math.sqrt(self.config.d_head)

        if self.config.debug:
            print(f"attention_scores shape: {attention_scores.shape}")

        # mask the attention scores
        attention_scores = self.apply_causal_mask(attention_scores)

        # apply the softmax function to the attention scores
        attention_scores = attention_scores.softmax(
            dim=-1
        )  # dim=(batch_size, n_heads, query_pos, key_pos)

        if self.config.debug:
            print(f"attention_scores shape: {attention_scores.shape}")

        # value still uses key_pos because it is all about the information from the source to the query position
        value = (
            einsum(
                "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head",
                normalized_residual_pre,
                self.W_V,
            )
            + self.b_V
        )

        if self.config.debug:
            print(f"value shape: {value.shape}")

        # calculate the weighted sum of the values
        weighted_sum = einsum(
            "batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head",
            attention_scores,
            value,
        )

        if self.config.debug:
            print(f"weighted_sum shape: {weighted_sum.shape}")

        # calculate the output of the attention block
        output = (
            einsum(
                "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model",
                weighted_sum,
                self.W_O,
            )
            + self.b_O
        )

        if self.config.debug:
            print(f"output shape: {output.shape}")

        return output

    def apply_causal_mask(self, attention_scores: torch.Tensor):
        # attention_scores is a tensor of shape (batch_size, n_heads, query_pos, key_pos)
        # uses triu to return the upper triangular part of the matrix
        mask = torch.triu(
            torch.ones(
                attention_scores.size(-2),
                attention_scores.size(-1),
                device=attention_scores.device,
            ),
            diagonal=1,
        ).bool()
        attention_scores.masked_fill_(mask, self.IGNORE)
        return attention_scores

In [18]:
attention = Attention(config).to("cuda")
pprint(layer_normed)
attention_output = attention(layer_normed)
pprint(attention_output)

tensor([[[ 0.0109,  0.4564, -0.9589,  ..., -1.5187, -1.0191, -1.6571],
         [ 0.1741, -0.9496, -0.6338,  ..., -0.2592,  1.1963, -0.1952],
         [-0.5475, -0.0399, -0.4888,  ..., -1.2516,  0.2927, -0.4838]]],
       device='cuda:0', grad_fn=<AddBackward0>)
normalized_residual_pre shape: torch.Size([1, 3, 768])
query shape: torch.Size([1, 3, 12, 64])
key shape: torch.Size([1, 3, 12, 64])
attention_scores shape: torch.Size([1, 12, 3, 3])
attention_scores shape: torch.Size([1, 12, 3, 3])
value shape: torch.Size([1, 3, 12, 64])
weighted_sum shape: torch.Size([1, 3, 12, 64])
output shape: torch.Size([1, 3, 768])
tensor([[[-0.1854,  0.0502,  0.0586,  ...,  0.3658,  0.0444, -0.0914],
         [ 0.0372,  0.1063,  0.2640,  ...,  0.1154,  0.0146, -0.2190],
         [ 0.0046,  0.2469,  0.1178,  ...,  0.0660,  0.0247, -0.1407]]],
       device='cuda:0', grad_fn=<AddBackward0>)


## Multi-Layer Perceptron

That's right. This is the good ol' neural network.

In [19]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # initialize the first linear layer with a normal distribution
        # dim=(batch_size, sequence_length, d_model)

        self.W_in = nn.Parameter(torch.empty((config.d_model, config.d_mlp)))
        nn.init.normal_(self.W_in, std=config.init_range)
        self.b_in = nn.Parameter(torch.zeros((config.d_mlp)))

        self.gelu = nn.GELU()

        self.W_out = nn.Parameter(torch.empty((config.d_mlp, config.d_model)))
        nn.init.normal_(self.W_out, std=config.init_range)
        self.b_out = nn.Parameter(torch.zeros((config.d_model)))

    def forward(self, attention_output: torch.Tensor):
        # attention_output is a tensor of shape (batch_size, sequence_length, d_model)
        if self.config.debug:
            print(f"attention_output shape: {attention_output.shape}")

        pre = (
            einsum(
                "batch sequence d_model, d_model d_mlp -> batch sequence d_mlp",
                attention_output,
                self.W_in,
            )
            + self.b_in
        )
        # apply gelu activation function
        post = self.gelu(pre)
        # post = gelu_new(pre)
        mlp_output = (
            einsum(
                "batch sequence d_mlp, d_mlp d_model -> batch sequence d_model",
                post,
                self.W_out,
            )
            + self.b_out
        )
        if self.config.debug:
            print(f"mlp_output shape: {mlp_output.shape}")
        return mlp_output

In [20]:
mlp = MLP(config).to("cuda")
pprint(attention_output)
mlp_output = mlp(attention_output)
pprint(mlp_output)

tensor([[[-0.1854,  0.0502,  0.0586,  ...,  0.3658,  0.0444, -0.0914],
         [ 0.0372,  0.1063,  0.2640,  ...,  0.1154,  0.0146, -0.2190],
         [ 0.0046,  0.2469,  0.1178,  ...,  0.0660,  0.0247, -0.1407]]],
       device='cuda:0', grad_fn=<AddBackward0>)
attention_output shape: torch.Size([1, 3, 768])
mlp_output shape: torch.Size([1, 3, 768])
tensor([[[-0.1582, -0.0102, -0.1213,  ..., -0.0452, -0.1728, -0.0291],
         [-0.1459, -0.0583, -0.0315,  ...,  0.0394, -0.0390, -0.0395],
         [-0.0443, -0.0507, -0.0194,  ...,  0.0140, -0.0490, -0.0632]]],
       device='cuda:0', grad_fn=<AddBackward0>)


## Transformer Block

In [21]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.ln1 = LayerNorm(self.config)
        self.attn = Attention(self.config)
        self.ln2 = LayerNorm(self.config)
        self.mlp = MLP(self.config)

    def forward(self, resid_pre: torch.Tensor):
        # resid_pre is a tensor of shape (batch_size, sequence_length, d_model)
        if self.config.debug:
            print(f"resid_pre shape: {resid_pre.shape}")

        # apply the first layer normalization
        norm_resid_pre = self.ln1(resid_pre)

        # apply the attention block
        attention_output = self.attn(norm_resid_pre)

        # add the residual connection
        resid_mid = resid_pre + attention_output

        if self.config.debug:
            print(f"resid_mid shape: {resid_mid.shape}")

        # apply the second layer normalization
        norm_resid_mid = self.ln2(resid_mid)

        # apply the mlp block
        mlp_output = self.mlp(norm_resid_mid)

        # add the residual connection
        resid_post = resid_mid + mlp_output

        if self.config.debug:
            print(f"resid_post shape: {resid_post.shape}")

        return resid_post

## Unembed

Another linear map creating the output.

In [22]:
class Unembed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # initialize the unembedding matrix with a normal distribution
        # dim=(d_model, d_vocab)
        self.W_U = nn.Parameter(torch.empty((config.d_model, config.d_vocab)))
        nn.init.normal_(self.W_U, std=config.init_range)
        self.b_U = nn.Parameter(torch.zeros((config.d_vocab)))

    def forward(self, normalized_result_fin: torch.Tensor):
        # normalized_result_fin is a tensor of shape (batch_size, sequence_length, d_model)
        if self.config.debug:
            print(f"normalized_result_fin shape: {normalized_result_fin.shape}")

        logits = (
            einsum(
                "batch sequence d_model, d_model d_vocab -> batch sequence d_vocab",
                normalized_result_fin,
                self.W_U,
            )
            + self.b_U
        )

        if self.config.debug:
            print(f"logits shape: {logits.shape}")

        return logits

## Full Transformer

In [23]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed = Embed(config)
        self.pos_embed = PositionalEmbedding(config)
        self.blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.n_layers)]
        )
        self.ln_final = LayerNorm(config)
        self.unembed = Unembed(config)

    def forward(self, tokens: torch.Tensor):
        # tokens is a tensor of shape (batch_size, sequence_length)
        if self.config.debug:
            print(f"tokens shape: {tokens.shape}")

        # get the embeddings
        embedded_tokens = self.embed(tokens)

        # get the positional embeddings
        positional_embedded_tokens = self.pos_embed(tokens)

        # add the embeddings and positional embeddings
        resid = embedded_tokens + positional_embedded_tokens

        if self.config.debug:
            print(f"resid shape: {resid.shape}")

        # apply the transformer blocks
        for block in self.blocks:
            resid = block(resid)

        # apply the final layer normalization
        norm_resid_fin = self.ln_final(resid)

        # get the logits
        logits = self.unembed(norm_resid_fin)

        if self.config.debug:
            print(f"logits shape: {logits.shape}")

        return logits

# Use the model

In [24]:
demo_config = Config(debug=False)
demo_model = Transformer(demo_config)
demo_model.load_state_dict(model.state_dict(), strict=False)
demo_model.cuda()

Transformer(
  (embed): Embed()
  (pos_embed): PositionalEmbedding()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP(
        (gelu): GELU(approximate='none')
      )
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

In [27]:
test_string = "Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on the"
for i in tqdm.tqdm(range(30)):
    test_tokens = model.to_tokens(test_string).cuda()
    test_logits = demo_model(test_tokens)
    test_string += model.tokenizer.decode(test_logits[-1, -1].argmax())

print(test_string)

  0%|          | 0/30 [00:00<?, ?it/s]

Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on the day of the vote.


The House of Representatives is expected to vote on the impeachment of President Trump on Tuesday.


The House


In [26]:
test_string = "That's what I am talking about,"
# print out top 10 predictions for the next token

test_tokens = model.to_tokens(test_string).cuda()
test_logits = demo_model(test_tokens)
top10 = torch.topk(test_logits[-1, -1], 10)
for token, prob in zip(top10.indices, top10.values):
    pprint(f"{model.tokenizer.decode(token)}: {prob.item()}")

' and: 13.326175689697266'
' right: 13.024477005004883'
' but: 12.671146392822266'
' because: 12.606410026550293'
' I: 12.379632949829102'
' not: 12.371942520141602'
' a: 12.048900604248047'
' the: 11.977973937988281'
' you: 11.909173011779785'
' so: 11.894983291625977'
