<a href="https://colab.research.google.com/github/jcheigh/mech-interp/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer

### Background:

The transformer architecture is fundamental to understanding deep learning. Here I have code for a basic GPT-2 style transformer, following [Neel Nanda's GPT-2 From Scratch](https://www.youtube.com/watch?v=dsjUDacBw8o&list=PL7m7hLIqA0hoIUPhC26ASCVs_VrqcDpAz&index=2).

Most of this is really an unvectorized step by step implementation. Rather than a bunch of matrix multiplications, there are more loops and dictionaries. My opinion is that this is more intuitive, but that this helps me differentiate between the fundamental part of the model (attention as a way to move info from one position to another etc.) from efficient steps (the fancy attention formula you see).

### What is Language Modeling?

Language models predict the next token. That is, given some history $h$ and a (fixed) vocabulary $V$, a language model $P_\theta$ models the distribution over $V$, conditioned on $h$. One can use this to generate text by making it *autoregressive*: 1) start with some prompt, 2) sample from the language model, 3) append this to the prompt, 4) repeat until sampling something akin to a <EOS> (end of sentence) token.

### Why Language Modeling?

Next token prediction seems like a relatively simple task, and on the surface seems there is no way for any model trained in this way to have superhuman performance. However, one shouldn't underestimate this task. There are perhaps linguistic or even philisophical arguments one can give for this, but I think [Ilya's](https://www.youtube.com/watch?v=YEUclZdj_Sc) statement here is well put.

In [None]:
### code from Nanda
%pip install git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
%pip install git+https://gith"ub.com/neelnanda-io/PySvelte.git
%pip install fancy_einsum
%pip install einops


In [2]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

### Part 1- Understanding Inputs and Outputs

Transformers (in our case) are language models. They input tokens and output logits. So, before we deal with an input sequence like "roses are red, violets are ", we first must create a *vocabulary*, i.e. a finite set of tokens. From there, we can *tokenize* our sequence, converting it from a string to a list of tokens, which map to a vector of token indices. We can then run our model to generate logits, from which case we can use softmax to turn this into a distribution over the vocabulary.

In [None]:
gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

Step 1: sequence -> tokens (tokenize)


In [4]:
input_sequence = 'hi, my name is justin'         ### 1- input sequence
tokens         = gpt2.to_tokens(input_sequence)  ### 2- to batch x pos tensor of tokens indices

print(f'Input Sequence: {input_sequence}')
print(f'Tokens: {tokens}')
print(f'Tokens Shape: {tokens.shape} = batch x position')
print(f'Tokens (str): {gpt2.to_str_tokens(tokens)}')

Input Sequence: hi, my name is justin
Tokens: tensor([[50256,  5303,    11,   616,  1438,   318,   655,   259]])
Tokens Shape: torch.Size([1, 8]) = batch x position
Tokens (str): ['<|endoftext|>', 'hi', ',', ' my', ' name', ' is', ' just', 'in']


Note the shape of the tokens tensor is batch x position. One can see this more clearly if we input multiple sequences, i.e. process in parallel.

In [5]:
sequences  = ['hi my name is justin', 'hi']
batch_tkns = gpt2.to_tokens(sequences)

print(f'Sequences: {sequences}')
print(f'Tokens: {batch_tkns}')
print(f'Tokens Shape: {batch_tkns.shape}')
print(f"Tokens for string 'hi' : {gpt2.to_str_tokens(batch_tkns[1])}")

Sequences: ['hi my name is justin', 'hi']
Tokens: tensor([[50256,  5303,   616,  1438,   318,   655,   259],
        [50256,  5303, 50256, 50256, 50256, 50256, 50256]])
Tokens Shape: torch.Size([2, 7])
Tokens for string 'hi' : ['<|endoftext|>', 'hi', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>']


Step 2: tokens -> logits (this is the transformer!)

In [6]:
logits, cache = gpt2.run_with_cache(tokens) ### 3- run model, return logits
print(logits.shape)                         # batch x position x vocab, i.e. for *each* token will be a dstn over vocab

torch.Size([1, 8, 50257])


Step 3: logits -> prediction (sample then decode)

In [7]:
log_probs  = logits.log_softmax(dim=-1)     ### 4- softmax to product prod distribution
next_token = log_probs[0,-1].argmax(dim=-1) ### 5- take argmax to produce next token
print(f'Next Token- Index: {next_token}, Token: {gpt2.tokenizer.decode(next_token)}')

next_tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)

Next Token- Index: 11, Token: ,


  next_tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)


With the above basic loop, we can generate using GPT-2.

In [8]:
def generate(input_sequence, n_gen, model=gpt2):
  print(f'Generating {n_gen} tokens, starting with: {input_sequence}')
  tokens = model.to_tokens(input_sequence)
  for _ in range(n_gen):
    sequence   = model.to_str_tokens(tokens)
    logits, _  = model.run_with_cache(tokens)
    log_probs  = logits.log_softmax(dim=-1)
    next_token = log_probs[0,-1].argmax(dim=-1)
    print(f"{''.join(sequence)} {model.tokenizer.decode(next_token)}")
    tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)

  return model.to_str_tokens(tokens)

results = generate(input_sequence, 10)

Generating 10 tokens, starting with: hi, my name is justin
<|endoftext|>hi, my name is justin ,


  tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)


<|endoftext|>hi, my name is justin,  i
<|endoftext|>hi, my name is justin, i  am
<|endoftext|>hi, my name is justin, i am  a
<|endoftext|>hi, my name is justin, i am a  student
<|endoftext|>hi, my name is justin, i am a student  at
<|endoftext|>hi, my name is justin, i am a student at  the
<|endoftext|>hi, my name is justin, i am a student at the  university
<|endoftext|>hi, my name is justin, i am a student at the university  of
<|endoftext|>hi, my name is justin, i am a student at the university of  u


### Aside on Tokenization

We've gone over the high level generation process, but this requires us to have access to a vocabulary. The typical method for this is Byte Pair Encoding (BPE).


At a high level, we may use UTF-8 to map each Unicode codepoint to 1-4 bytes. To further enable us to specify a vocab size, we can iteratively merge the most frequent token pair, appending this new pair to our vocabulary. For many reasons this is a highly flawed process.

In [9]:
from collections import Counter

def find_most_common_pair(tokens):
    return Counter(zip(tokens, tokens[1:])).most_common(1)[0][0]

def merge(ids, pair, idx):
    ### list of ints (ids), replace(pair, idx)
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i +=1
    return newids

def get_vocab(text, vocab_size):
  ### BPE implementation to get vocab using 256 ascii char as initial token set
  num_merges = vocab_size - 256
  tokens     = text.encode('utf-8')
  ids        = list(tokens)
  merges     = {}
  for i in range(num_merges):
    pair = find_most_common_pair(ids)
    idx = 256 + i
    print(f'merging {pair} into token {idx}')
    ids = merge(ids, pair, idx)
    merges[pair] = idx
  return merges

### Part 2- Step by Step Unvectorized Implementation

To be clear this is not actually a full implementation, as we skip over things like positional embeddings, layernorm, etc.

For a more high level discussion of transformer check out Neel Nanda's implementation/notes.

### Step 1- Preprocessing (tokenization)

1. Preprocessing: an input prompt to a batch x position x d_model tensor, where each token (each batch x position) has associated with it a residual stream that starts with it being purely the embedding.
  - Break string into a tensor of tokens (batch = 1 for this analysis)
  - Map tokens ('a') to indices ({a:0,...}) to get a tensor of indices
  - Use embedding lookup table ({0 : [...]}) to get a tensor of embeddings

We'll ignore the BPE stuff and just deal with character level tokenization for now.

In [36]:
print_ln = lambda n=20 : print(f'=' * n)
printT   = lambda matrix : np.round(matrix, 2)

### 1- start with input
print_ln()
input_sequence = 'hello, there.'
print(f'Input sequence: {input_sequence}')

### 2- standard preprocessing (for us this is really simple!)
print_ln()
input_sequence = input_sequence.lower()
input_sequence = input_sequence.replace('.', ' .')
input_sequence = input_sequence.replace(',', ' ,')
print(f'Processed input sequence: {input_sequence}')

### 3- define vocab and tokenize
print_ln()
vocab = list('abcdefghijklmnopqrstuvwxyz., ')
vocab.append('<EOS>') ### either eos/sos
indices = list(range(len(vocab)))
tkn_to_idx = {t:i for i, t in enumerate(vocab)}
idx_to_tkn = {i:t for i, t in enumerate(vocab)}
print(f'Vocabulary: {vocab}')

def tokenize(sequence):
  tokens = []
  for t in sequence:
    if t not in tkn_to_idx:
      raise Exception(f'Token {t} not in vocab idiot (jk...)')

    tokens.append(tkn_to_idx[t])
  return tokens + [tkn_to_idx['<EOS>']]

def to_str(tokens):
  return ''.join([idx_to_tkn[i] for i in tokens])

tokens   = tokenize(input_sequence)
sequence = to_str(tokens)
print_ln()
print(f'Tokens: {tokens}')
print(f'Num Tokens: {len(tokens)}')
assert sequence == input_sequence + '<EOS>'

d_vocab  = len(vocab)
d_model  = 6
position = len(tokens)

print_ln()
print(f'd_vocab: {d_vocab}')
print(f'd_model: {d_model}')
print(f'position: {position}')

embedding_map = {i : np.random.standard_normal(d_model) for i in indices}

for i in range(4):
  print_ln(20)
  print(f'Token {i}: {idx_to_tkn[i]}')
  print(f'Embedding: {embedding_map[i][:5]}')

### 4- embed using lookup table
def embed(sequence):
  tokens = tokenize(sequence)
  return np.array([embedding_map[i] for i in tokens]), tokens

res_stream_tensor, _ = embed(input_sequence)
print_ln()
print(f'Transformer Input:\n {printT(res_stream_tensor)}')
print(f'Transformer Input Shape: {res_stream_tensor.shape} (position x d_model)')
assert res_stream_tensor.shape == (len(tokens), d_model)

Input sequence: hello, there.
Processed input sequence: hello , there .
Vocabulary: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', ' ', '<EOS>']
Tokens: [7, 4, 11, 11, 14, 28, 27, 28, 19, 7, 4, 17, 4, 28, 26, 29]
Num Tokens: 16
d_vocab: 30
d_model: 6
position: 16
Token 0: a
Embedding: [-1.83168225 -1.05130991  1.49743168  1.85770234 -0.10343122]
Token 1: b
Embedding: [ 2.0958435   1.59444239  0.67894699 -0.81240235 -0.04902934]
Token 2: c
Embedding: [ 0.3308513   1.45100144  0.87941704 -1.07737608  1.37649645]
Token 3: d
Embedding: [ 0.6870658   1.46654086 -1.11158025 -0.03582558 -0.53145455]
Transformer Input:
 [[-0.3   1.31  0.32  0.19 -1.27  0.29]
 [ 0.35  2.51 -1.84 -0.03  0.64  0.12]
 [ 1.26  1.15 -0.97  1.01  0.34 -0.23]
 [ 1.26  1.15 -0.97  1.01  0.34 -0.23]
 [-0.18  0.4   0.69  0.73 -0.99 -0.28]
 [-0.57  0.33  0.93 -0.22  1.07  1.45]
 [ 0.53  0.84  1.22  1.06  0.87 -0.32]
 [-0.57  0.33

### Step 2- Transformer Block

At this point, we have a residual stream (some sort of global memory) for each token in the sequence. However, this residual stream is initialized to simply be the embedding, i.e. no information about the past context. We therefore need to
1. Move information from each prior token to the current token (attention)
2. Perform computation on this aggregate of information (MLP)

Let's dive a bit deeper into attention

### Attention

Here are basically Neel's notes. Attention moves information from prior positions in the sequence to the current token.
- We do this for each token in parallel with the same parameters, with the only difference being we look backwards only to avoid cheating
- This is the only part of the transformer that moves info between positions
- We do this with $n_{\text{heads}}$ heads, each with their own parameter, attention pattern (will discuss later), and own way of copying information from src/past token to destination/current token
  - heads act independently and additively, we add their outputs back into the stream
- Each head:
  - Produces an attention pattern for each destination token, that is a probability distribution of prior source tokens (including the current one), weighing how much information to copy
    - Does this for each pair of tokens
    - Copies info in the same way from each source token
      - Note the info we copy depends on the source token's residual stream, which can be more than simply the info of what text token is at the source token's position
    - Note which source token to copy info from (query/key dot product) is distinct from how to copy that info (value mlp)
    - $d_\text{head} = \frac{d_{\text{model}}}{n_{\text{heads}}}$

### Mathematically

Input: X = residual stream tensor with shape batch x position x d_model

Define $W_Q, W_K$ with shape n_heads x d_model x d_head, then multiply to get query tensor $Q$, key tensor $K$, where both are shape batch x position x n_heads x d_head. Here each token has n_heads, each of which has a d_head vector. The query vector is asking for information, the key vector is telling what info the residual stream of that token contains.

We then run essentially the following loop: for each batch, for each head, for each query vector, we compute the dot product of that query vector with each key vector, and this attention score is aggregated in a matrix. So we have a batch x n_head x position x position, where the position x position matrix is a query_pos x key_pos matrix, where the (i,j)th entry is the similarity between the ith token's query and the jth token's key.

To make this causal, the ith token should only be able to query token j with $j \le i$, therefore we will mask out anything where $j > i$. We want to make this into a probability distribution, so we apply softmax, which gives us the attention pattern (some normalization here also).

We then get a value tensor $V$ with shape batch, position, n_heads, d_head (the position here is key_position) and multiply along the key_position dimension. This means that, for each (destination) token there is an attention pattern over source tokens, as well as a value vector indicating, and the dot product of these is basically the new info to copy, giving us a batch x query_pos x n_heads x d_head tensor.

We then apply a linear map to get a batch x position x n_head x d_model tensor and sum over all heads to get our result.

In [38]:
### map tkn to residual stream
res_stream_map = {tkn : res_stream_tensor[i] for i, tkn in enumerate(tokens)}

In [40]:
tkn = tokens[2]
print(f'Token {tkn} = {idx_to_tkn[tkn]}')
residual_stream = res_stream_map[tkn]
print(f'Initial Residual Stream: {residual_stream}')

### residual stream starts as embedding
assert np.all(residual_stream == embedding_map[tkn])

Token 11 = l
Initial Residual Stream: [ 1.25923324  1.14644087 -0.97337884  1.00713334  0.33614496 -0.22747626]


In [41]:
n_heads = 2
d_head  = int(d_model / n_heads)
print(f'n_heads: {n_heads}')
print(f'd_head: {d_head}')

### 1- get query, key vector for each token
query_vec_map = {(tkn, head) : np.random.standard_normal(d_head) for tkn in tokens for head in range(n_heads)}
key_vec_map   = {(tkn, head) : np.random.standard_normal(d_head) for tkn in tokens for head in range(n_heads)}
print_ln()
head = 1
print(f'Query Vector for Token {tkn}, Head {head}: {query_vec_map[(tkn, head)]}')
print(f'Key Vector for Token {tkn}, Head {head}: {key_vec_map[(tkn, head)]}')

n_heads: 2
d_head: 3
Query Vector for Token 11, Head 1: [-2.31614386  0.8580142  -0.82819627]
Key Vector for Token 11, Head 1: [ 1.17171791 -1.17617334  0.4875603 ]


In [42]:
### 2- get attn scores- for each head, query_pos x key_pos matrix, where entry (i,j) = dot(query(i), key(j))
### or, for each head, for each (destination) token, there is a vector of scores with each position
### then, scale by 1/sqrt(d_head), and mask where j > i
attn_score_map = {
    head : np.zeros((position, position))
    for head in range(n_heads)
    }

print_rng = 2

print_ln(50)
print(f'Computing Attention Scores')

for head in range(n_heads):
  if head == 0:
    print_ln(40)
    print(f'Head {head}')
  for i, dest_tkn in enumerate(tokens):
    query = query_vec_map[(dest_tkn, head)]
    if head == 0 and i < print_rng:
      print(f'  ===========================')
      print(f'  Destination Token {dest_tkn}')
      print(f'  Query Vector: {printT(query)}')

    for j, src_tkn in enumerate(tokens):

      key = key_vec_map[(src_tkn, head)]
      if head == 0 and max(i,j) < print_rng:
        print(f'    ================')
        print(f'    Source Token {src_tkn}')
        print(f'    Key Vector: {printT(key)}')

      attn_score           = np.dot(query, key)
      ### mask out with j > i and normalize
      processed_attn_score = attn_score / np.sqrt(d_head) if j <= i else -1e8

      if head == 0 and max(i,j) < print_rng:
        print(f'    ================')
        print(f'    Attention Score for between Dest {dest_tkn}, Src {src_tkn}: {round(attn_score,2)}')
        print(f'    Processed Attention Score: {round(processed_attn_score, 2)}')

      attn_score_map[head][i][j] = processed_attn_score

Computing Attention Scores
Head 0
  Destination Token 7
  Query Vector: [-0.81  0.82  0.23]
    Source Token 7
    Key Vector: [-0.72 -0.2  -0.45]
    Attention Score for between Dest 7, Src 7: 0.33
    Processed Attention Score: 0.19
    Source Token 4
    Key Vector: [ 0.95  0.24 -1.53]
    Attention Score for between Dest 7, Src 4: -0.92
    Processed Attention Score: -100000000.0
  Destination Token 4
  Query Vector: [-0.15 -0.33 -1.04]
    Source Token 7
    Key Vector: [-0.72 -0.2  -0.45]
    Attention Score for between Dest 4, Src 7: 0.64
    Processed Attention Score: 0.37
    Source Token 4
    Key Vector: [ 0.95  0.24 -1.53]
    Attention Score for between Dest 4, Src 4: 1.38
    Processed Attention Score: 0.79


In [43]:
for head in range(n_heads):
  print_ln()
  attn_scores = attn_score_map[head]
  if_mask     = (attn_scores + 1e8) < 0.01
  to_print    = np.where(if_mask, "___", np.round(attn_scores, 2))[:7,:7]
  print(f'Attention Score Matrix for Head {head}: \n{to_print}')

Attention Score Matrix for Head 0: 
[['0.19' '___' '___' '___' '___' '___' '___']
 ['0.37' '0.79' '___' '___' '___' '___' '___']
 ['-0.02' '1.48' '0.56' '___' '___' '___' '___']
 ['-0.02' '1.48' '0.56' '0.56' '___' '___' '___']
 ['0.62' '-1.79' '-0.24' '-0.24' '-1.43' '___' '___']
 ['-0.08' '-0.11' '-0.17' '-0.17' '-0.05' '0.11' '___']
 ['1.0' '1.36' '0.85' '0.85' '-0.47' '1.23' '-0.8']]
Attention Score Matrix for Head 1: 
[['-0.17' '___' '___' '___' '___' '___' '___']
 ['0.6' '0.16' '___' '___' '___' '___' '___']
 ['1.16' '0.33' '-2.38' '___' '___' '___' '___']
 ['1.16' '0.33' '-2.38' '-2.38' '___' '___' '___']
 ['0.08' '-0.08' '0.18' '0.18' '-0.11' '___' '___']
 ['-0.54' '-0.09' '0.37' '0.37' '0.29' '-0.3' '___']
 ['0.1' '0.08' '-1.26' '-1.26' '2.01' '-1.62' '0.48']]


In [44]:
### 3- apply softmax row-wise to each attn_score matrix to get attn pattern
softmax = lambda vec : np.exp(vec) / np.sum(np.exp(vec))
attn_pattern_map = {
    head : np.zeros((position, position))
    for head in range(n_heads)
    }

for head in range(n_heads):
  attn_score_mat = attn_score_map[head]
  for i, dest_tkn in enumerate(tokens):
    attn_score_vec   = attn_score_mat[i]
    attn_pattern_vec = softmax(attn_score_vec)
    attn_pattern_map[head][i] = attn_pattern_vec

print_ln(50)
for head in range(n_heads):
  print_ln()
  attn_pattern = attn_pattern_map[head]
  print(f'Attention Pattern Matrix for Head {head}: \n{printT(attn_pattern[:7, :7])}')

Attention Pattern Matrix for Head 0: 
[[1.   0.   0.   0.   0.   0.   0.  ]
 [0.4  0.6  0.   0.   0.   0.   0.  ]
 [0.14 0.62 0.25 0.   0.   0.   0.  ]
 [0.11 0.5  0.2  0.2  0.   0.   0.  ]
 [0.48 0.04 0.2  0.2  0.06 0.   0.  ]
 [0.17 0.16 0.15 0.15 0.17 0.2  0.  ]
 [0.17 0.25 0.15 0.15 0.04 0.22 0.03]]
Attention Pattern Matrix for Head 1: 
[[1.   0.   0.   0.   0.   0.   0.  ]
 [0.61 0.39 0.   0.   0.   0.   0.  ]
 [0.68 0.3  0.02 0.   0.   0.   0.  ]
 [0.67 0.29 0.02 0.02 0.   0.   0.  ]
 [0.2  0.17 0.23 0.23 0.17 0.   0.  ]
 [0.09 0.14 0.22 0.22 0.21 0.11 0.  ]
 [0.09 0.09 0.02 0.02 0.62 0.02 0.13]]


In [45]:
### 4- get value vector for each token (specifically each src token)
### attn_pattern.shape = query_pos x key_pos, value_mat.shape = key_pos x d_head, z = query_pos x d_head
value_map = {
    head : np.random.standard_normal((position, d_head))
    for head in range(n_heads)
    }
z = {
    head : np.matmul(attn_pattern_map[head], value_map[head])
    for head in range(n_heads)
    }

print_ln(50)
head = 0
print(f'Z matrix for Head {head}: \n{printT(z[head])}, shape: {z[head].shape}')

Z matrix for Head 0: 
[[-0.56 -0.13  1.67]
 [-0.79  0.92  0.47]
 [-0.33  0.63  0.18]
 [-0.41  0.23  0.18]
 [-0.13 -0.62  0.87]
 [-0.   -0.35  0.2 ]
 [-0.19 -0.1   0.44]
 [ 0.2  -0.3   0.24]
 [ 0.21 -0.45 -0.03]
 [ 0.18 -0.48  0.24]
 [ 0.09 -0.06  0.29]
 [ 0.31 -0.15  0.22]
 [ 0.29 -0.18  0.23]
 [ 0.19 -0.24  0.43]
 [ 0.19 -0.37  0.29]
 [ 0.11 -0.16  0.31]], shape: (16, 3)


In [46]:
### 5- take z (query_pos x d_head) and apply W_O (d_head, d_model) to get result (pos x d_model)
W_O = {head : np.random.standard_normal((d_head, d_model)) for head in range(n_heads)}
result = {
    head : np.matmul(z[head], W_O[head])
    for head in range(n_heads)
    }

for head in range(n_heads):
  print_ln()
  print(f'Attention Result for Head {head}: \n{printT(result[head])}')
  assert result[head].shape == (position, d_model)

Attention Result for Head 0: 
[[ 0.34  1.17 -0.7  -2.62 -0.81  0.43]
 [ 1.43  0.98 -2.55 -0.39 -2.28 -0.46]
 [ 0.62  0.52 -1.41  0.01 -1.6  -0.35]
 [ 0.77  0.36 -0.99 -0.27 -0.44 -0.09]
 [-0.17  0.29  0.66 -1.63  0.94  0.57]
 [-0.1  -0.04  0.48 -0.49  0.75  0.26]
 [ 0.19  0.3  -0.18 -0.75 -0.    0.16]
 [-0.53 -0.05  0.74 -0.45  0.4   0.23]
 [-0.43 -0.29  0.95 -0.16  1.06  0.27]
 [-0.49 -0.13  0.95 -0.56  0.92  0.35]
 [-0.34  0.13  0.24 -0.41 -0.21  0.09]
 [-0.74 -0.03  0.71 -0.29 -0.08  0.13]
 [-0.72 -0.03  0.73 -0.33  0.02  0.15]
 [-0.61  0.09  0.66 -0.68  0.08  0.23]
 [-0.54 -0.05  0.83 -0.56  0.56  0.29]
 [-0.37  0.08  0.41 -0.49  0.05  0.16]]
Attention Result for Head 1: 
[[-3.16  1.85  3.33 -0.78  1.46 -0.3 ]
 [-0.22  0.04  0.96  0.14  0.67 -0.74]
 [-0.88  0.46  1.48 -0.08  0.84 -0.62]
 [-0.81  0.42  1.42 -0.06  0.81 -0.62]
 [ 1.06 -0.55 -0.41  0.27  0.09 -0.46]
 [ 1.24 -0.72 -0.66  0.38 -0.05 -0.45]
 [ 1.89 -1.23 -0.95  0.75 -0.06 -0.85]
 [ 0.66 -0.47 -0.28  0.31  0.01 -0.37]
 [-

In [47]:
### 6- sum over heads to get pos x d_model attn_out
attn_out = np.sum(list(result.values()), axis=0)
print_ln()
print(f'Attention Out: \n{printT(attn_out)}')

Attention Out: 
[[-2.82  3.02  2.63 -3.41  0.65  0.13]
 [ 1.21  1.03 -1.59 -0.25 -1.61 -1.21]
 [-0.26  0.97  0.07 -0.07 -0.76 -0.98]
 [-0.04  0.78  0.43 -0.33  0.37 -0.72]
 [ 0.89 -0.26  0.26 -1.37  1.03  0.11]
 [ 1.14 -0.76 -0.18 -0.11  0.7  -0.19]
 [ 2.08 -0.93 -1.13  0.   -0.06 -0.7 ]
 [ 0.13 -0.52  0.46 -0.13  0.41 -0.14]
 [-0.63 -0.46  1.59  0.21  1.45 -0.39]
 [-0.74 -0.24  1.19 -0.31  0.97  0.12]
 [-1.06 -0.02  1.69  0.21  0.58 -1.11]
 [-0.05 -0.84 -0.44  0.34 -0.79  0.19]
 [-0.97 -0.44  1.5   0.37  0.44 -0.82]
 [-0.4  -0.25  0.78 -0.31  0.22 -0.25]
 [-0.18 -0.51  0.6  -0.14  0.48 -0.05]
 [ 1.21 -1.17 -1.34  0.3  -0.79  0.08]]


In [48]:
### 7- add back to original residual stream
res_stream_tensor = res_stream_tensor + attn_out
print_ln()
print(f'Updated Residual Stream: \n{printT(res_stream_tensor)}, {res_stream_tensor.shape}')

Updated Residual Stream: 
[[-3.12  4.33  2.95 -3.21 -0.62  0.42]
 [ 1.56  3.54 -3.43 -0.28 -0.97 -1.08]
 [ 1.    2.12 -0.9   0.94 -0.42 -1.2 ]
 [ 1.22  1.92 -0.54  0.67  0.71 -0.94]
 [ 0.71  0.14  0.95 -0.63  0.05 -0.18]
 [ 0.57 -0.43  0.76 -0.33  1.76  1.26]
 [ 2.61 -0.1   0.09  1.06  0.81 -1.01]
 [-0.45 -0.18  1.39 -0.35  1.48  1.31]
 [-1.62 -0.5   0.78 -0.91  1.7  -0.18]
 [-1.04  1.08  1.51 -0.12 -0.3   0.41]
 [-0.71  2.49 -0.15  0.18  1.22 -0.99]
 [ 0.55 -1.47 -0.67  1.08 -0.37  0.46]
 [-0.63  2.07 -0.34  0.33  1.08 -0.7 ]
 [-0.97  0.09  1.72 -0.53  1.29  1.2 ]
 [ 0.43 -1.36  3.09  0.12  1.76 -0.2 ]
 [ 0.29 -2.01 -1.48  0.61 -1.32  1.43]], (16, 6)


Let's put this together.

In [60]:
def attention_is_all_you_needify(tokens, res_stream_tensor, n_heads=2):
  position = len(tokens)
  assert position == res_stream_tensor.shape[0]
  d_model  = res_stream_tensor.shape[1]
  d_head   = int(d_model / n_heads)

  query_vec_map = {(tkn, head) : np.random.standard_normal(d_head) for tkn in tokens for head in range(n_heads)}
  key_vec_map   = {(tkn, head) : np.random.standard_normal(d_head) for tkn in tokens for head in range(n_heads)}

  ### compute attn scores = CausalMask(QK^T/sqrt(d_head)), query_pos x key_pos mat for each head
  attn_score_map = {
    head : np.zeros((position, position))
    for head in range(n_heads)
    }

  for head in range(n_heads):
    for i, dest_tkn in enumerate(tokens):
      query = query_vec_map[(dest_tkn, head)]
      for j, src_tkn in enumerate(tokens):
        key = key_vec_map[(src_tkn, head)]
        attn_score = np.dot(query, key)
        processed_attn_score = attn_score / np.sqrt(d_head) if j <= i else -1e8
      attn_score_map[head][i][j] = processed_attn_score

  ## get attn pattern = Softmax(attn_scores)
  softmax = lambda vec : np.exp(vec) / np.sum(np.exp(vec))
  attn_pattern_map = {
      head : np.zeros((position, position))
      for head in range(n_heads)
      }

  for head in range(n_heads):
    attn_score_mat = attn_score_map[head]
    for i, dest_tkn in enumerate(tokens):
      attn_score_vec   = attn_score_mat[i]
      attn_pattern_vec = softmax(attn_score_vec)
      attn_pattern_map[head][i] = attn_pattern_vec

  ### compute z = AttnPattern @ value, shape query_pos x d_head
  value_map = {
      head : np.random.standard_normal((position, d_head))
      for head in range(n_heads)
      }
  z = {
      head : np.matmul(attn_pattern_map[head], value_map[head])
      for head in range(n_heads)
      }

  ### compute result = z @ W_O (d_head, d_model) to get result (pos x d_model)
  W_O = {head : np.random.standard_normal((d_head, d_model)) for head in range(n_heads)}
  result = {
      head : np.matmul(z[head], W_O[head])
      for head in range(n_heads)
      }
  ### sum over heads to get pos x d_model attn_out
  attn_out = np.sum(list(result.values()), axis=0)

  ### add back to original residual stream
  res_stream_tensor = res_stream_tensor + attn_out
  return res_stream_tensor

input_sequence = 'sup'
print(f'Input Sequence: {input_sequence}')

res_stream_tensor, tokens = embed(input_sequence)
print(f'Tokens: {tokens}')
print_ln()
print(f'Initial Residual Stream: \n{printT(res_stream_tensor)}')

res_stream_tensor = attention_is_all_you_needify(tokens, res_stream_tensor)
print_ln()
print(f'Residual Stream post Attn: \n{printT(res_stream_tensor)}')

Input Sequence: sup
Tokens: [18, 20, 15, 29]
Initial Residual Stream: 
[[-0.34  0.37 -0.12  0.46 -0.5  -0.18]
 [ 0.78 -0.43  0.41  0.55  0.44 -0.83]
 [ 1.19  0.89 -1.37  0.35 -1.98  0.05]
 [-0.92 -0.84 -0.14  0.31 -0.52  1.35]]
Residual Stream post Attn: 
[[ 1.34  0.46  0.07 -0.81  1.33 -1.16]
 [ 2.46 -0.34  0.6  -0.72  2.27 -1.8 ]
 [ 2.87  0.98 -1.18 -0.92 -0.15 -0.93]
 [-0.39 -1.45  0.25 -1.16  0.16  2.24]]


Great! So we've gone through taking the residual stream and applying attention, enabling us to move information from prior src tokens to the current destination token.

We'll now perform some computation on this, using an MLP layer.

### Step 3- MLP

For each batch x position, we now have a residual stream vector that includes meaningful information about past tokens. We wish to now perform computation on this meaningful residual stream.

Specifically, we have for each token a d_model residual stream. Our MLP will first project this vector into a dimension d_mlp = 4 * d_model. We'll then apply a ReLU (should be GeLU but who cares) and then project back down.

In [61]:
def mlpify(res_stream_tensor, d_mlp_multiplier=4, verbose=False):
  ### res_stream_tensor = position x d_model
  d_model = res_stream_tensor.shape[1]
  d_mlp   = d_mlp_multiplier * d_model

  W1 = np.random.standard_normal((d_model, d_mlp))
  linear_out = np.matmul(res_stream_tensor, W1)
  ### relu(x) = max(0, x)
  relu_out = np.maximum(0, linear_out)

  W2 = np.random.standard_normal((d_mlp, d_model))
  result = np.matmul(relu_out, W2)
  if verbose:
    print_ln()
    print(f'Res Stream Tensor Input: \n{printT(res_stream_tensor)}')
    print(f'Linear Out: \n{printT(linear_out)}')
    print(f'Post ReLU: \n{printT(relu_out)}')
    print(f'Result: \n{printT(result)}')
  return result

res_stream_tensor = mlpify(res_stream_tensor, d_mlp_multiplier=2, verbose=True)

Res Stream Tensor Input: 
[[ 1.34  0.46  0.07 -0.81  1.33 -1.16]
 [ 2.46 -0.34  0.6  -0.72  2.27 -1.8 ]
 [ 2.87  0.98 -1.18 -0.92 -0.15 -0.93]
 [-0.39 -1.45  0.25 -1.16  0.16  2.24]]
Linear Out: 
[[-2.19 -0.7   1.58 -3.76 -0.44  3.09  0.3   4.    1.71 -2.65 -2.4  -0.04]
 [-0.77  0.96  2.02 -5.22 -2.    7.14  1.94  4.81  1.64 -2.67 -4.18 -0.36]
 [-7.29  0.82  3.48 -4.32 -1.67  4.75  2.27  3.05  4.25 -6.1   2.25 -2.1 ]
 [ 8.47  3.17  4.1   0.45  2.87  3.91  5.34 -3.33 -3.98  1.51 -5.8  -1.81]]
Post ReLU: 
[[0.   0.   1.58 0.   0.   3.09 0.3  4.   1.71 0.   0.   0.  ]
 [0.   0.96 2.02 0.   0.   7.14 1.94 4.81 1.64 0.   0.   0.  ]
 [0.   0.82 3.48 0.   0.   4.75 2.27 3.05 4.25 0.   2.25 0.  ]
 [8.47 3.17 4.1  0.45 2.87 3.91 5.34 0.   0.   1.51 0.   0.  ]]
Result: 
[[  2.97  -6.26  -2.17   1.38  -2.56  -6.25]
 [  6.2   -7.44  -3.37  -0.64   3.25  -6.79]
 [ 11.22 -11.68   2.64   4.26  -2.76 -13.45]
 [  4.86  -8.05  -9.7   19.41  16.06  -5.96]]


This combination of self attention mechanism followed by a mlp layer is the essence of a *transformer block*.

In [76]:
def apply_transformer_block(tokens, res_stream_tensor):
  attn_out   = attention_is_all_you_needify(tokens, res_stream_tensor)
  resid_mid  = attn_out + res_stream_tensor
  mlp_out    = mlpify(resid_mid)
  resid_post = resid_mid + mlp_out
  return resid_post

input_sequence = 'sup'
res_stream_tensor, tokens = embed(input_sequence)
print(f'position = {res_stream_tensor.shape[0]}, d_model = {res_stream_tensor.shape[1]}')
output = apply_transformer_block(tokens, res_stream_tensor)
print_ln()
print(f'Output: \n{printT(output)}')
print(f'Output Shape: {output.shape}')

position = 4, d_model = 6
Output: 
[[ 16.87 -17.59  23.44   5.68 -16.58 -25.74]
 [ 12.89  -3.88  14.11   3.86 -11.4  -27.36]
 [ 23.3   -2.59  29.    -8.73  -9.69 -36.8 ]
 [  3.2  -12.66   2.28  -3.33  -6.05  -8.96]]
Output Shape: (4, 6)


Step 4- Unembed

After applying n_blocks transformer blocks, we presumably have a residual stream that is super meaningful at each token position. From here, we must unembed, i.e. turn into a probability distribution over tokens. That is, given our position x d_model tensor, we will apply a linear map to get to a position x d_vocab tensor. At this point we will have our logits, from which one can apply softmax and have the distribution over tokens for each position.

In [81]:
def apply_unembed(res_stream_tensor, d_vocab=d_vocab):
  W_out = np.random.standard_normal((d_model, d_vocab))
  unembed_tensor = np.matmul(res_stream_tensor, W_out)
  return unembed_tensor

def run_transformer(sequence, n_blocks=2, verbose=False):
  res_stream_tensor, tokens = embed(sequence)
  if verbose:
    print(f"d_vocab: {d_vocab}")
    print(f"d_model: {d_model}")
    print(f"position: {len(tokens)}")
    print(f'n_blocks: {n_blocks}')
    print(f'd_mlp: {4 * d_model}')

  for i in range(n_blocks):
    res_stream_tensor = apply_transformer_block(tokens, res_stream_tensor)

  return apply_unembed(res_stream_tensor)

input_sequence = 'wow we are done with this transformer implementation'
logits = run_transformer(input_sequence, verbose=True)
print_ln()
print(f'Logits: \n{printT(logits)}')
print(f'Logits Shape: {logits.shape}')

d_vocab: 30
d_model: 6
position: 53
n_blocks: 2
d_mlp: 24
Logits: 
[[ 106.86   37.81  -48.24 ...  334.94  434.98 -388.64]
 [-415.22  236.5    44.28 ... -261.52  -90.99  195.33]
 [ 106.86   37.81  -48.24 ...  334.94  434.98 -388.64]
 ...
 [-415.22  236.5    44.28 ... -261.52  -90.99  195.33]
 [ 218.37    6.04 -128.69 ...  371.32  547.3  -436.89]
 [-340.05  -45.77  163.73 ... -196.01  -76.     36.5 ]]
Logits Shape: (53, 30)
