# GPT Explainer

the purpose of this is to walk through what happens during the forward pass and backward pass of GPT-2 (generative pretrained transformer) like models.  To help display the transformation, we'll use the first sentence from the [linear algebra wiki page](https://en.wikipedia.org/wiki/Linear_algebra) and [lu decomposition wiki page](https://en.wikipedia.org/wiki/LU_decomposition) as the topic is fitting and it shows us some non-standard patterns.  

## Text Prep/Tokenization

we'll start with common preprocessing step of tokenizing the data.  This converts the string text into an array of numbers that can be used during the training loop.  I've built a very subtle byte-pair encdoing that has each unique character that appears and the top 5 merges. This keeps our vocab size small and managable for this example. Typically the vocab size is in the 100K+ range. A great library for this is `tiktoken`. Tokenization simply finds the longest pattern of characters that's in common with what was trained and replaces it with an integer that represents it.  This way we turn the text into a numeric array to simplify computing. 

In [1]:
import torch
from collections import Counter

In [2]:
class SimpleBPETokenizer:
    def __init__(self, num_merges=5, eot_token='<|endoftext|>'):
        self.num_merges = num_merges
        self.eot_token = eot_token
        self.eot_id = None
        self.merges = []
        self.pair_ranks = {}
        self.vocab = {}
        self.id_to_token = {}

    def _add_token(self, tok):
        if tok in self.vocab:
            return self.vocab[tok]
        i = len(self.vocab)
        self.vocab[tok] = i
        self.id_to_token[i] = tok
        return i

    def _get_bigrams(self, seq):
        for i in range(len(seq) - 1):
            yield (seq[i], seq[i + 1])

    def _merge_once(self, seq, pair):
        a, b = pair
        out = []
        i = 0
        while i < len(seq):
            if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        return out

    def train(self, corpus):
        # corpus: list[str]
        text = ''.join(corpus).lower()
        seq = list(text)
        merges = []
        for _ in range(self.num_merges):
            counts = Counter(self._get_bigrams(seq))
            if not counts: break
            best_pair, _ = counts.most_common(1)[0]
            merges.append(best_pair)
            seq = self._merge_once(seq, best_pair)
        self.merges = merges
        self.pair_ranks = {p: i for i, p in enumerate(self.merges)}

        self.vocab = {}
        self.id_to_token = {}
        for ch in sorted(set(text)):
            self._add_token(ch)
        for a, b in self.merges:
            self._add_token(a + b)
        self.eot_id = self._add_token(self.eot_token)

    def encode(self, text, force_last_eot=True):
        # treat literal eot marker as special; remove it from content
        if self.eot_token in text:
            text = text.replace(self.eot_token, '')
        seq = list(text)

        # make sure all seen base chars exist
        for ch in set(seq):
            if ch not in self.vocab:
                self._add_token(ch)

        # greedy BPE using learned pair ranks
        if self.merges:
            while True:
                best_pair, best_rank = None, None
                for p in self._get_bigrams(seq):
                    r = self.pair_ranks.get(p)
                    if r is not None and (best_rank is None or r < best_rank):
                        best_pair, best_rank = p, r
                if best_pair is None:
                    break
                seq = self._merge_once(seq, best_pair)

        # ensure all tokens in seq exist in vocab (e.g., if new chars appeared)
        for tok in seq:
            if tok not in self.vocab:
                self._add_token(tok)

        ids = [self.vocab[tok] for tok in seq]

        # FORCE: append EOT id if not already last
        if force_last_eot:
            if not ids or ids[-1] != self.eot_id:
                ids.append(self.eot_id)

        return ids

    def decode(self, ids):
        # drop trailing EOT if present
        if ids and self.eot_id is not None and ids[-1] == self.eot_id:
            ids = ids[:-1]
        toks = [self.id_to_token[i] for i in ids]
        return ''.join(toks)


In [3]:
raw_example_1 = r'''Linear algebra is central to almost all areas of mathematics. For instance, linear algebra is fundamental in modern presentations of geometry, including for defining basic objects such as lines, planes and rotations. Also, functional analysis, a branch of mathematical analysis, may be viewed as the application of linear algebra to function spaces.'''
raw_example_2 = r'''In numerical analysis and linear algebra, lower–upper (LU) decomposition or factorization factors a matrix as the product of a lower triangular matrix and an upper triangular matrix (see matrix multiplication and matrix decomposition).'''


In [4]:
tok = SimpleBPETokenizer(num_merges=5)
tok.train([raw_example_1,raw_example_2])
tok.merges

[(' ', 'a'), ('a', 't'), ('i', 'n'), (' ', 'm'), ('i', 'o')]

In [5]:
tok.vocab

{' ': 0,
 '(': 1,
 ')': 2,
 ',': 3,
 '.': 4,
 'a': 5,
 'b': 6,
 'c': 7,
 'd': 8,
 'e': 9,
 'f': 10,
 'g': 11,
 'h': 12,
 'i': 13,
 'j': 14,
 'l': 15,
 'm': 16,
 'n': 17,
 'o': 18,
 'p': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28,
 '–': 29,
 ' a': 30,
 'at': 31,
 'in': 32,
 ' m': 33,
 'io': 34,
 '<|endoftext|>': 35}

In [6]:
vocab_size = len(tok.vocab)
vocab_size

36

In [7]:
eot = tok.eot_id
tokens = []
for example in [raw_example_1, raw_example_2]:
    tokens.extend([eot])
    tokens.extend(tok.encode(example.lower()))
all_tokens = torch.tensor(tokens, dtype=torch.long)
all_tokens

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0,  7,
         9, 17, 22, 20,  5, 15,  0, 22, 18, 30, 15, 16, 18, 21, 22, 30, 15, 15,
        30, 20,  9,  5, 21,  0, 18, 10, 33, 31, 12,  9, 16, 31, 13,  7, 21,  4,
         0, 10, 18, 20,  0, 32, 21, 22,  5, 17,  7,  9,  3,  0, 15, 32,  9,  5,
        20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0, 10, 23, 17,  8,  5, 16,
         9, 17, 22,  5, 15,  0, 32, 33, 18,  8,  9, 20, 17,  0, 19, 20,  9, 21,
         9, 17, 22, 31, 34, 17, 21,  0, 18, 10,  0, 11,  9, 18, 16,  9, 22, 20,
        27,  3,  0, 32,  7, 15, 23,  8, 32, 11,  0, 10, 18, 20,  0,  8,  9, 10,
        32, 32, 11,  0,  6,  5, 21, 13,  7,  0, 18,  6, 14,  9,  7, 22, 21,  0,
        21, 23,  7, 12, 30, 21,  0, 15, 32,  9, 21,  3,  0, 19, 15,  5, 17,  9,
        21, 30, 17,  8,  0, 20, 18, 22, 31, 34, 17, 21,  4, 30, 15, 21, 18,  3,
         0, 10, 23, 17,  7, 22, 34, 17,  5, 15, 30, 17,  5, 15, 27, 21, 13, 21,
         3, 30,  0,  6, 20,  5, 17,  7, 

# Modeling

A machine learning model forward pass now uses the tokenization information, runs several layers of linear algebra on it, and then "predicts" the next token. When it is noisy (like you will see in this example), this process results in gibberish.  The training process changes the noise to pattern during the "backward pass" as you'll see.    We'll show 3 steps that are focused on training:
1. **Data Loading** `x, y = train_loader.next_batch()` - this step pulls from the raw data enough tokens to complete a forward and backward pass.  If the model is inference only, this step is replaced with taking in the inference input and preparing it similarly as the forward pass.
2. **Forward Pass** `logits, loss = model(x, y)` - using the data and the model architecture to predict the next token. When training we also compare against the expected to get loss, but in infrerence, we use the logits to complete the inference task.
3. **Backward Pass & Training** `loss.backward(); optimizer.step()` - using differentials to understand what parameters most impact the forward pass' impact on its prediction, comparing that against what is actually right based on the data loading step, and then making very minor adjustments to the impactful parameters with the hope it improves future predictions.

The we'll show a final **Forward Pass** with the updated weights we did in #3. 

## Data Loading

To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to hold all at once in real practice, we would read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 
To start, we have to identify the batch size and the model context length to determine how much data we need.  Consequently, these dimensions also form 2 of the 3 dimensions in the initial matrix.
- **Batch Size (B)** - This is the number of examples you'll train on in a single pass. 
- **Context Length (T)** - This is the max number of tokens that a model can use in a single pass to generat the next token. If an example is below this length, it can be padded.
  
*Ideally both B and T are multiples of 2 to work nicely with chip architecture. This is a common theme across the board*

In [8]:
B = 2 # Batch
T = 8 # context length

To start, we need to pull from our long raw_token list enough tokens for the forward pass. To be able to satisfy training `B` Batches `T` Context length, we need to pull out `B*T` tokens to slide the context window across the examples enough to satisfy the batch size.  Since the training will attempt to predict the last token given the previous tokens in context, we also need 1 more token at the end so that the last training example in the last batch can have the next token to validate against. 

In [9]:
current_position = 0
tok_for_training = all_tokens[current_position:current_position + B*T +1 ]
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0])

Now that we have our initial tokens to train on, we now need to convert it to a matrix that's ready for training. In this step we'll need to create our batches and setup two different arrays: 1/ the input, `x`, tokens that will result in 2/ the output `y` tokens. To create each example in the batch, every `T` tokens will be placed into it's own row. 

Recall that training takes in a string of tokens the length of the context and then predicts the next token. Recall that when we extracted `tok_for_training` we added 1 extra token so that we can evaluate the prediction for the last example. Because of this, the input, `x`, will be all of the tokens up to the second to last element `[:-1]`.  

It might be natural to think the output `y` would then just be the last token.But this is actually wasting valuable training loops.  Yes, there is the example that fills the context `T`, but we also have enough tokens in `tok_for_training` where any context length of `n` where `n<T` can also be used for inference since we have the `n+1` token available.  You can think of the following example:

sentence: `Hi I am learning`. This sentence contains the following "next tokens" that can be learned:
1. x: Hi I am  | y: learning
2. x: Hi I     | y: am
3. x: Hi       | y: I

Because we have this triangle to create, our `y` can be much larger.  We can start with the second token and, go all the way to the last element we added for the last example `[1:'`.   


We will now put this together and do the following:
1. Extract the input `x` and then split it into an example for each batch `B`
2. Extract the output `y` and then split it into an example for each batch `B`

*Note: View can take `-1` which allows the matrix to infer the dimension so we do not need to pass in `T`, but given how many matrices we'll work with we want to make sure we're controlling the dimensions or erroring out if they do not match our expectations.*

In [10]:
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0])

In [11]:
x=tok_for_training[:-1].view(B, T)
x

tensor([[35, 15, 32,  9,  5, 20, 30, 15],
        [11,  9,  6, 20,  5,  0, 13, 21]])

In [12]:
y=tok_for_training[1:].view(B, T)
y

tensor([[15, 32,  9,  5, 20, 30, 15, 11],
        [ 9,  6, 20,  5,  0, 13, 21,  0]])

## Forward pass

<img src="explainer_screenshots/gpt/full_network.png" width="200">

The forward pass takes a string of tokens in and predicts the next "n" tokens.  This step as we'll look at it is focused on training where we'll pass in the input `x`, carry that input through the layers, and generate a matrix of the probability of each token being the next one, something we call `logits`. During the forward pass at the end we then compare the probability to the actual next token in `y` and calculate `loss` based on the difference. 

*Note that we will do some layer initialization to simplify following along.  In reality layers are often initialized to normal distribution with some adjustments made for parameter sizes to keep the weights properly noisy.  We will not cover initialization in this series*

We first rederive the batch size and context size based on the input to improve flexibility.

In [13]:
import torch.nn as nn

In [14]:
B, T = x.size()
B,T

(2, 8)

The first layer of our network creates an embedding representation of our input sequence. 

In [15]:
block_size = T # max sequence length/context
vocab_size = vocab_size # 36 
n_embd = 4 # 

### Input Layer

<img src="explainer_screenshots/gpt/input_layer.png" width="200">
We'll first create an initiation for 2 of our input matrices: position and token embedding.  Both of these are a table of weigths that have `n_embd` number of columns to store information about the position or token. The more columns you add, the more complex information can be stored but the more compute is needed.  For now we'll let each position or column store up to 4 channels of information.  Before starting though we need to initialize the layer with a set of data.

In [16]:
# weighted token embedding
wte = nn.Embedding(vocab_size, n_embd)
with torch.no_grad(): # initilize to W[i,j] = 0.001*(1+i+j) for easy following 
    vs, d = wte.num_embeddings, wte.embedding_dim
    rows = torch.arange(vs).unsqueeze(1)  # (vs,1)
    cols = torch.arange(d).unsqueeze(0)  # (1,d)
    pattern = 0.01*(1 + rows + cols)  # W[i,j] = 0.001*(1+i+j)
    wte.weight.copy_(pattern)
# weighted position embedding
wpe = nn.Embedding(block_size, n_embd)
with torch.no_grad(): # initilize to W[i,j] = 0.001*(1+i+j) for easy following 
    vs, d = wpe.num_embeddings, wpe.embedding_dim
    rows = torch.arange(vs).unsqueeze(1)  # (vs,1)
    cols = torch.arange(d).unsqueeze(0)  # (1,d)
    pattern = 0.01*(1 + rows + cols)  # W[i,j] = 0.001*(1+i+j)
    wpe.weight.copy_(pattern)
wte.weight, wpe.weight

(Parameter containing:
 tensor([[0.0100, 0.0200, 0.0300, 0.0400],
         [0.0200, 0.0300, 0.0400, 0.0500],
         [0.0300, 0.0400, 0.0500, 0.0600],
         [0.0400, 0.0500, 0.0600, 0.0700],
         [0.0500, 0.0600, 0.0700, 0.0800],
         [0.0600, 0.0700, 0.0800, 0.0900],
         [0.0700, 0.0800, 0.0900, 0.1000],
         [0.0800, 0.0900, 0.1000, 0.1100],
         [0.0900, 0.1000, 0.1100, 0.1200],
         [0.1000, 0.1100, 0.1200, 0.1300],
         [0.1100, 0.1200, 0.1300, 0.1400],
         [0.1200, 0.1300, 0.1400, 0.1500],
         [0.1300, 0.1400, 0.1500, 0.1600],
         [0.1400, 0.1500, 0.1600, 0.1700],
         [0.1500, 0.1600, 0.1700, 0.1800],
         [0.1600, 0.1700, 0.1800, 0.1900],
         [0.1700, 0.1800, 0.1900, 0.2000],
         [0.1800, 0.1900, 0.2000, 0.2100],
         [0.1900, 0.2000, 0.2100, 0.2200],
         [0.2000, 0.2100, 0.2200, 0.2300],
         [0.2100, 0.2200, 0.2300, 0.2400],
         [0.2200, 0.2300, 0.2400, 0.2500],
         [0.2300, 0.2400, 0.250

**Positional Embeddings** - Now we need to pluck the weight of each position out of the position embedding.  Since we are creating a simple left to right, position 1 to n, we can just create an array from 0 to n based on the context, `T`, then pluck those rows out.  The resulting matrix from this operation is a `T, n_embd` based vector. 

In [17]:
pos = torch.arange(0, T, dtype=torch.long)
pos

tensor([0, 1, 2, 3, 4, 5, 6, 7])

for each element, look up the row in `wpe` and pluck it out. Since the position is just `[0:T]` we can see we pluck out the position array.

In [18]:
pos_emb = wpe(pos)
pos_emb.shape, pos_emb

(torch.Size([8, 4]),
 tensor([[0.0100, 0.0200, 0.0300, 0.0400],
         [0.0200, 0.0300, 0.0400, 0.0500],
         [0.0300, 0.0400, 0.0500, 0.0600],
         [0.0400, 0.0500, 0.0600, 0.0700],
         [0.0500, 0.0600, 0.0700, 0.0800],
         [0.0600, 0.0700, 0.0800, 0.0900],
         [0.0700, 0.0800, 0.0900, 0.1000],
         [0.0800, 0.0900, 0.1000, 0.1100]], grad_fn=<EmbeddingBackward0>))

**Word Embeddings** - Similarly we need to pluck out the rows from the token table, `wte` for the tokens in our example. Since our example is already represented as indices, we can simple use `x` directly. The resulting matrix from this operation is a `B,T, n_embd` based vector since `x` is `B,T` and `tok_emb` is `vocab_size,n_embd` and when we index `wte` by `x` each entry in x replaces `n_embd` based vector at that position in `wte`.

In [19]:
x

tensor([[35, 15, 32,  9,  5, 20, 30, 15],
        [11,  9,  6, 20,  5,  0, 13, 21]])

for each position pull out the the row of weights that corresponds to the token. You can see in the print out that the rows are not in the same order as the layer is initiliazed as the token ids are not sequential.  

In [20]:
tok_emb = wte(x)
tok_emb.shape, tok_emb

(torch.Size([2, 8, 4]),
 tensor([[[0.3600, 0.3700, 0.3800, 0.3900],
          [0.1600, 0.1700, 0.1800, 0.1900],
          [0.3300, 0.3400, 0.3500, 0.3600],
          [0.1000, 0.1100, 0.1200, 0.1300],
          [0.0600, 0.0700, 0.0800, 0.0900],
          [0.2100, 0.2200, 0.2300, 0.2400],
          [0.3100, 0.3200, 0.3300, 0.3400],
          [0.1600, 0.1700, 0.1800, 0.1900]],
 
         [[0.1200, 0.1300, 0.1400, 0.1500],
          [0.1000, 0.1100, 0.1200, 0.1300],
          [0.0700, 0.0800, 0.0900, 0.1000],
          [0.2100, 0.2200, 0.2300, 0.2400],
          [0.0600, 0.0700, 0.0800, 0.0900],
          [0.0100, 0.0200, 0.0300, 0.0400],
          [0.1400, 0.1500, 0.1600, 0.1700],
          [0.2200, 0.2300, 0.2400, 0.2500]]], grad_fn=<EmbeddingBackward0>))

**Impact of the position and token together**
To ensure that the position and token together impact the next token prediction, we sum the two so that the weight of each token is impacted by the weight of its relative position. To do this we sum `tok_emb` and `pos_emb` together. Quickly we can see the dimensions don't match as 
* `tok_emb` > `B,T,n_embd`
* `pos_emb` >   `T,n_embd`

Since we have multiple examples with the same ordering, we simply add pos_emb at the same level to each entry on the `B` dimension, something that pytorch does for us automatically resulting in a `B,T,n_embd` output

In [21]:
x = tok_emb + pos_emb
x

tensor([[[0.3700, 0.3900, 0.4100, 0.4300],
         [0.1800, 0.2000, 0.2200, 0.2400],
         [0.3600, 0.3800, 0.4000, 0.4200],
         [0.1400, 0.1600, 0.1800, 0.2000],
         [0.1100, 0.1300, 0.1500, 0.1700],
         [0.2700, 0.2900, 0.3100, 0.3300],
         [0.3800, 0.4000, 0.4200, 0.4400],
         [0.2400, 0.2600, 0.2800, 0.3000]],

        [[0.1300, 0.1500, 0.1700, 0.1900],
         [0.1200, 0.1400, 0.1600, 0.1800],
         [0.1000, 0.1200, 0.1400, 0.1600],
         [0.2500, 0.2700, 0.2900, 0.3100],
         [0.1100, 0.1300, 0.1500, 0.1700],
         [0.0700, 0.0900, 0.1100, 0.1300],
         [0.2100, 0.2300, 0.2500, 0.2700],
         [0.3000, 0.3200, 0.3400, 0.3600]]], grad_fn=<AddBackward0>)

### Transformer Layers

<img src="explainer_screenshots/gpt/transformer.png" width="200">

The transformer block is multiple parallel repetitions of the same matrix operations done independently.  This adds both depth and breadth to the computation.  Each block is the same steps of
1. Layer normalization
2. Causal self attention
3. Layer normalization (again)
4. Multi-layer perceptron (MLP)

Both steps 2 and 4 are also multi-layered so we'll go through each layer independently.  You'll notice the arrows in the diagram bypassing the causal self attention and the MLP.  This is to ensure that the weights of any one layer do not get overweighted. We achieve this by simply adding the input with the layer's calculations together, as you'll see. 

#### Transformer - Layer Normalization

With Layer normalization, we review the row and adjust based on how far away it is from the mean. This means an array of `[1,2,3,4]` and `[2,4,6,8]` will actually have the same normalized entries after layer normalization.  This layer adds regularization which helps with overall learning speed. The formula applied is:

$y = \frac{x - \mathbb{E}[x]}{\sqrt{\operatorname{Var}[x] + \epsilon}}$

Layer normalization is applied on the input matrix in kind and creates default weights of 1's across the dimension to equally weight all values in the normalization. We'll keep this as is and not change the initialization.  

*Note that even though we will do layer normalization again, we keep this as a separate layer so that its impact can be adjusted independent of other normalization layers.*

**Example** Let's see a quick example of how layer normalizaiton operates with an array of `[1,2,3,4]` and `[2,4,6,8]` 

In [22]:
example_ln = nn.LayerNorm(n_embd)

## Example 
example_ln(torch.tensor(
    [
        [1.0,2.0,3.0,4.0],
        [2.0,4.0,6.0,8.0]
    ]))

tensor([[-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416]],
       grad_fn=<NativeLayerNormBackward0>)

**Normalize** Now let's apply it to x.  We'll save the output to a new variable `x_norm` so that we retain the input `x` for the *skip connection* (more details below).  Since our position embeddings are still in their initiaion step where each entries values are similarly distributed `[n,n+2,n+4,n+6]`  we'll see that with layer normalization the resulting matrix has all equal rows. 

In [23]:
ln_1 = nn.LayerNorm(n_embd)
x_norm = ln_1(x)
x_norm.shape, x_norm

(torch.Size([2, 8, 4]),
 tensor([[[-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284]],
 
         [[-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284]]],
        grad_fn=<NativeLayerNormBackward0>))

#### Transformer - Flash Attnetion / Causal Self Attention

<img src="explainer_screenshots/gpt/flash_attention.png" width="600">

Causal self-attention is self-attention with a strictly upper-triangular mask, so each position $i$ can attend only to positions $j≤i$ (no look-ahead). For each position in a sequence, the model decides which earlier tokens matter and blends their information to generate the upcoming token, but the model is not allowed to peek at tokens to the right. Recall that we created an input `x` that has all of the tokens in the sequence. With causal self attention we ensure that the prediction of `y` is not learned in an example from tokens to the right of it. 

Causal self attention uses 3 linear projections to create representations of 3 concepts to help derive the next token: query `Q`, key `K`, and value `V`. FlashAttention takes the input `x` of shape $(B,T,C)$, linearly projects the last dimension to `Q, K, V` by multiplying each token vector $x_{b,t,:}$ with the learned weight matrices `Q, K,V` (often via one combined QKV linear matrix) and adding biases. Then we compute the causal dot-product attention with dropout and softmax per head, concatenate the heads, and apply a final linear layer to produce the output.  Linear layers use the projection of $y=xA^\top+b$

The Q,K,V matrix do the following: 
* **Query** - the current position’s “search request.” Every layer and head issues queries that can look for many things (antecedent, rhythm, agreement, etc.).
* **Key** - a matching tag/address for each allowed position. It is compared with the query to produce relevance scores
* **Value** - the payload you actually mix in once something matches. It’s a learned projection of the token’s representation so the model can copy the right kind of information.

This layer starts by taking a linear layer that has `n_embd` rows and then  `n_embd` columns for each: the query, key and value.  This allows the projection of the `x_norm` to be split into the 3 components, selected from them. We then use `scaled_dot_product_attention` to provide the attention masking and normalization, including an upper triangle mask to prevent look-ahead, and final project that onto a `n_embd x n_embd` matrix to return the output. 

We start by creating the 2 default layers. Also we'll set heads to 2.  The heads allow the layer to specialize in different concepts since each head creates its own Q,K,V.

*for the linear layers we'll do some special initiations, mainly pyramidal so that we can see unique numbers. We also initialize the bias to 0 so there linear layer is now just $y=xA^\top$*

In [24]:
from torch.nn import functional as F
import math
n_head = 2

In [25]:
c_attn = nn.Linear(n_embd, 3 * n_embd)
with torch.no_grad():
    out, inp = c_attn.weight.shape  # (3*n_embd, n_embd)
    r = torch.arange(1, out + 1).unsqueeze(1)  # [out,1], 1-indexed
    c = torch.arange(1, inp + 1).unsqueeze(0)   # [1,inp], 1-indexed
    base = r * c                          # rc
    tri = r * (r - 1) / 2                 # T_{r-1} = r(r-1)/2, shape [out,1]
    mask = (c >= 2)           # add T_{r-1} only from column 2 onward
    pattern = 1e-3 * (base + tri * mask) # matches [[.001,.002,.003],[.002,.005,.007],[.003,.009,.012],...]
    c_attn.weight.copy_(pattern)

    c_attn.bias.zero_()
    
c_attn, c_attn.weight, c_attn.bias

(Linear(in_features=4, out_features=12, bias=True),
 Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040],
         [0.0020, 0.0050, 0.0070, 0.0090],
         [0.0030, 0.0090, 0.0120, 0.0150],
         [0.0040, 0.0140, 0.0180, 0.0220],
         [0.0050, 0.0200, 0.0250, 0.0300],
         [0.0060, 0.0270, 0.0330, 0.0390],
         [0.0070, 0.0350, 0.0420, 0.0490],
         [0.0080, 0.0440, 0.0520, 0.0600],
         [0.0090, 0.0540, 0.0630, 0.0720],
         [0.0100, 0.0650, 0.0750, 0.0850],
         [0.0110, 0.0770, 0.0880, 0.0990],
         [0.0120, 0.0900, 0.1020, 0.1140]], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True))

In [26]:
c_proj = nn.Linear(n_embd, n_embd)
with torch.no_grad():
    out, inp = c_proj.weight.shape  # (3*n_embd, n_embd)
    r = torch.arange(1, out + 1).unsqueeze(1)  # [out,1], 1-indexed
    c = torch.arange(1, inp + 1).unsqueeze(0)   # [1,inp], 1-indexed
    base = r * c                          # rc
    tri = r * (r - 1) / 2                 # T_{r-1} = r(r-1)/2, shape [out,1]
    mask = (c >= 2)           # add T_{r-1} only from column 2 onward
    pattern = 1e-3 * (base + tri * mask) # matches [[.001,.002,.003],[.002,.005,.007],[.003,.009,.012],...]
    c_proj.weight.copy_(pattern)
    
    c_proj.bias.zero_()
c_proj, c_proj.weight

(Linear(in_features=4, out_features=4, bias=True),
 Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040],
         [0.0020, 0.0050, 0.0070, 0.0090],
         [0.0030, 0.0090, 0.0120, 0.0150],
         [0.0040, 0.0140, 0.0180, 0.0220]], requires_grad=True))

**Flash Attention - Creating Query, Key, Value**

We'll now create the query, key, and value matrices. Let's also revisualize `x_norm` to make it easy to connect the dots in this complex layer.

In [27]:
B, T, C = x_norm.size()
B, T, C

(2, 8, 4)

In [28]:
x_norm.shape, x_norm

(torch.Size([2, 8, 4]),
 tensor([[[-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284]],
 
         [[-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284],
          [-1.3284, -0.4428,  0.4428,  1.3284]]],
        grad_fn=<NativeLayerNormBackward0>))

Now we'll take `x_norm` and the dot product of the weights in `c_attn` $y=x\_norm \cdot c\_attn^\top+bias$. This results in the the `qkv` combined matrix leading to a 3x size increase. Since every entry on the 3rd dimension in `x_norm` is the same, we will see that every row has the same repeated values given that's how dot products work. 

In [29]:
qkv = c_attn(x_norm)
qkv.shape, qkv

(torch.Size([2, 8, 12]),
 tensor([[[0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408],
          [0.0044, 0.0102, 0.0173, 0.0257, 0.0354, 0.0465, 0.0589, 0.0726,
           0.0877, 0.1041, 0.1218, 0.1408]],
 
         [[0.0044, 0

**Flash Attention - Attention Head**

Now, we will split up qkv to create the 3 separate matrices, one for the query, key, and value. Together they to work to create complex concept embeddings. We also then have to split up each matrix into its own heads (shown as columns `dim=3`). For each head we create its own Q,K,V, compute the causal dot-product attention with softmax per head and then eventually concatenate the heads. During back-propogation, different heads can get updates in different ways based on the examples and gradients.  This backprop and separation is what allows each attention head to specialize in different concepts.

In [30]:
q,k,v = qkv.split(n_embd, dim=2)
q = q.view(B, T, n_head, C // n_head).transpose(1, 2)
k = k.view(B, T, n_head, C // n_head).transpose(1, 2)
v = v.view(B, T, n_head, C // n_head).transpose(1, 2)
'q', q.shape, q, 'k', k.shape, k, 'v', v.shape, v

('q',
 torch.Size([2, 2, 8, 2]),
 tensor([[[[0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102]],
 
          [[0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257]]],
 
 
         [[[0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102],
           [0.0044, 0.0102]],
 
          [[0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257],
           [0.0173, 0.0257]]]], grad_fn=<TransposeBackwa

**Flash Attention - Cross Attention / Dot-product attention**

With the separated query, key, value, we now need to calculate the attention. The attention starts with the mutliplying query and key together to get `TxT` matrix showing how the context length embeddings interact, or the *unnormalized attention*. We also multiply by the $\sqrt{C // n\_head}$ since the dot product will grow in relation to the head widht and this scaling allows us to normalize the variance back down. After this we apply the **causal** filter that prevents the model from learning about a token from the future tokens beyond it.  We do this by using a upper triangle maskl, meaning we set all the values in the upper triangle to $-\infty$.  This means that when we got to the softmax, these values are set to 0 meaning there's no weight on the prediction.   After applying our max, we softmax by the last dimension.  Softmax turns a real-valued vector into a probability distribution by exponentiating each component and normalizing by the sum so all outputs are nonnegative and sum to 1.  It does this by applying the following: 

$\frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$

After softmax, we apply a dropout layer which simply applies 0's out randomly up to the defined level percentage. Dropout can be applied in a few places and mainly acts to prevent co-adaptation and improve generalization. Dropout is only used during training and would not be run during inference. 

Finally to get the weigthed sum of the values by the attention we take the dot product of the raw attention and the values to bring the Query, Key, and Value impact per head back together.  After this we'll then collapse the heads and project the attention one final time to a linear layer. 

Let's start by first initializing our scaling, dropout, our our mask.  You'll see the mask creates an upper triangle matrix. 

*Note that in practice this step is run as `F.scaled_dot_product_attention` to avoid materializing the TxT matrix*

In [31]:
scale = 1.0 / math.sqrt(C // n_head)
dropout = nn.Dropout(0.1)
mask = torch.tril(torch.ones(block_size, block_size, dtype=torch.bool))
scale, dropout, mask

(0.7071067811865475,
 Dropout(p=0.1, inplace=False),
 tensor([[ True, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True]]))

Now we'll compute the unnormalized attention per head.  This takes the matrix multiple of the query and key and uses the scale factor to manage the variance on the last dimension. You'll see that the `TxT` matrix for each example at this point have all the same values because the `Q` and `K` had uniform values in each head.  

In [32]:
att = q @ k.transpose(-2, -1) * scale
att.shape, att

(torch.Size([2, 2, 8, 8]),
 tensor([[[[0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004]],
 
          [[0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020],
           [0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020],
           [0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020],
           [0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020],
           [0.0020, 0.0020, 0.0020, 0.0020, 0.

Now we have to apply the triangular mask.  This is where you can see causal attention removing the ability of each example looking at tokens beyond it's context to make predictions.

In [33]:
att = att.masked_fill(mask[:T, :T] == 0, float('-inf'))
att.shape, att

(torch.Size([2, 2, 8, 8]),
 tensor([[[[0.0004,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
           [0.0004, 0.0004,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
           [0.0004, 0.0004, 0.0004,   -inf,   -inf,   -inf,   -inf,   -inf],
           [0.0004, 0.0004, 0.0004, 0.0004,   -inf,   -inf,   -inf,   -inf],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004,   -inf,   -inf,   -inf],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004,   -inf,   -inf],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004,   -inf],
           [0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004]],
 
          [[0.0020,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
           [0.0020, 0.0020,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
           [0.0020, 0.0020, 0.0020,   -inf,   -inf,   -inf,   -inf,   -inf],
           [0.0020, 0.0020, 0.0020, 0.0020,   -inf,   -inf,   -inf,   -inf],
           [0.0020, 0.0020, 0.0020, 0.0020, 0.

Now we'll apply the softmax. Recall that softmax normalizes the probability per row by running the following: 

$\frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$

Because of this we can expect that the $-\infty$ will become 0 due to their extreme weight, showing the power of the mask.  Also, now you'll see that each row changes in value since you have `n+1` elements in each row to distribute across. 

In [34]:
att = F.softmax(att, dim=-1)
att

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
          [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
          

After the softmax, we'll add dropout. This dropout will randomly zero out any value effectively removing that specific node from impacting prediction. To balance out the dropped out values, the surviving entries are scaled by $1/(1-p)$. During training this helps with generalizaiton and fights fixation. You'll need to look closely to identify the droped-out values. As a reminder, during inference we would skip dropout. 

In [35]:
att = dropout(att)
att.shape, att

(torch.Size([2, 2, 8, 8]),
 tensor([[[[1.1111, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.5556, 0.5556, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.3704, 0.3704, 0.3704, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.2778, 0.0000, 0.2778, 0.2778, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.2222, 0.0000, 0.2222, 0.2222, 0.2222, 0.0000, 0.0000, 0.0000],
           [0.1852, 0.1852, 0.1852, 0.1852, 0.1852, 0.1852, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.1587, 0.1587, 0.1587, 0.1587, 0.1587, 0.0000],
           [0.1389, 0.1389, 0.1389, 0.1389, 0.1389, 0.1389, 0.1389, 0.1389]],
 
          [[1.1111, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.5556, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.3704, 0.3704, 0.3704, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.2778, 0.2778, 0.2778, 0.2778, 0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.2222, 0.2222, 0.2222, 0.

Finally we now take the raw attention and use it to weight the values.  We'll take the matrix multiple of the attention with the values to create the flash-attention output. While we would expect the values for each exampl'es head to be the same given the input consistency of the query, key, and value, you can see that the dropout created enough noise to shift certain entries away from consistency. 

In [36]:
fa = att @ v
fa.shape, fa

(torch.Size([2, 2, 8, 2]),
 tensor([[[[0.0974, 0.1156],
           [0.0974, 0.1156],
           [0.0974, 0.1156],
           [0.0731, 0.0867],
           [0.0779, 0.0925],
           [0.0974, 0.1156],
           [0.0696, 0.0826],
           [0.0974, 0.1156]],
 
          [[0.1353, 0.1565],
           [0.0677, 0.0782],
           [0.1353, 0.1565],
           [0.1353, 0.1565],
           [0.1082, 0.1252],
           [0.1353, 0.1565],
           [0.1353, 0.1565],
           [0.1353, 0.1565]]],
 
 
         [[[0.0974, 0.1156],
           [0.0974, 0.1156],
           [0.0974, 0.1156],
           [0.0731, 0.0867],
           [0.0779, 0.0925],
           [0.0974, 0.1156],
           [0.0974, 0.1156],
           [0.0974, 0.1156]],
 
          [[0.0000, 0.0000],
           [0.1353, 0.1565],
           [0.1353, 0.1565],
           [0.0677, 0.0782],
           [0.1353, 0.1565],
           [0.1128, 0.1304],
           [0.0966, 0.1118],
           [0.1184, 0.1369]]]], grad_fn=<UnsafeViewBackward0>)

**Flash Attention - collapse heads**  Now we'll reshape back to remove the heads

*Note we use `contiguous` here to force `transpose()` to create a new matrix in memory. This allows the heads to learn independently*

In [37]:
fa = fa.transpose(1, 2).contiguous().view(B, T, C)
fa.shape, fa

(torch.Size([2, 8, 4]),
 tensor([[[0.0974, 0.1156, 0.1353, 0.1565],
          [0.0974, 0.1156, 0.0677, 0.0782],
          [0.0974, 0.1156, 0.1353, 0.1565],
          [0.0731, 0.0867, 0.1353, 0.1565],
          [0.0779, 0.0925, 0.1082, 0.1252],
          [0.0974, 0.1156, 0.1353, 0.1565],
          [0.0696, 0.0826, 0.1353, 0.1565],
          [0.0974, 0.1156, 0.1353, 0.1565]],
 
         [[0.0974, 0.1156, 0.0000, 0.0000],
          [0.0974, 0.1156, 0.1353, 0.1565],
          [0.0974, 0.1156, 0.1353, 0.1565],
          [0.0731, 0.0867, 0.0677, 0.0782],
          [0.0779, 0.0925, 0.1353, 0.1565],
          [0.0974, 0.1156, 0.1128, 0.1304],
          [0.0974, 0.1156, 0.0966, 0.1118],
          [0.0974, 0.1156, 0.1184, 0.1369]]], grad_fn=<ViewBackward0>))

**Flash Attention - final projection and output**
Finally, we will now project the cross attention matrix on another final linear layer `n_embd x n_embd` that we initialized as `c_proj`.  Once agian, because our rows are identical (where not impacted by dropout) and we use `c_attn` $y=x\_norm \cdot c\_proj^\top + b$ we will result in the same values, except for where dropout has impacted the values. Note that currently bias is 0 so it is not impacting the values. 

In [38]:
x_norm = c_proj(fa)
x_norm.shape, x_norm

(torch.Size([2, 8, 4]),
 tensor([[[0.0014, 0.0031, 0.0053, 0.0079],
          [0.0008, 0.0020, 0.0033, 0.0049],
          [0.0014, 0.0031, 0.0053, 0.0079],
          [0.0013, 0.0029, 0.0050, 0.0074],
          [0.0011, 0.0025, 0.0042, 0.0063],
          [0.0014, 0.0031, 0.0053, 0.0079],
          [0.0013, 0.0029, 0.0049, 0.0073],
          [0.0014, 0.0031, 0.0053, 0.0079]],
 
         [[0.0003, 0.0008, 0.0013, 0.0020],
          [0.0014, 0.0031, 0.0053, 0.0079],
          [0.0014, 0.0031, 0.0053, 0.0079],
          [0.0008, 0.0018, 0.0030, 0.0044],
          [0.0013, 0.0030, 0.0050, 0.0075],
          [0.0012, 0.0027, 0.0046, 0.0069],
          [0.0011, 0.0025, 0.0042, 0.0062],
          [0.0012, 0.0028, 0.0048, 0.0072]]], grad_fn=<ViewBackward0>))

#### Transformer - Residual (skip) connection

<img src="explainer_screenshots/gpt/skip_layer.png" width="200">
Modern networks also use skip connections, meaning they allow for pathways to bypass around "boxes", passing through gradients during the backward pass.  This attribute ensures that the impact of each layer and head is normalized against the input embeddings themselves. Recall in the diagram we had the arrow that bypassed "masked multiheaded attention".  Functionally this is represented as

$y = f(x) + x$

To achieve this we simply sum the projection matrix `x` with the flash attention output `x_norm`.  As a reminder we'll print out X.  As you can see, because `x` was based on the tokens, it has a different value per row, so even though `x_norm` has mostly the same value per row, we'll result in a diverse set of weights.  With this you can see the power of skip connections passing through weights.

In [39]:
x

tensor([[[0.3700, 0.3900, 0.4100, 0.4300],
         [0.1800, 0.2000, 0.2200, 0.2400],
         [0.3600, 0.3800, 0.4000, 0.4200],
         [0.1400, 0.1600, 0.1800, 0.2000],
         [0.1100, 0.1300, 0.1500, 0.1700],
         [0.2700, 0.2900, 0.3100, 0.3300],
         [0.3800, 0.4000, 0.4200, 0.4400],
         [0.2400, 0.2600, 0.2800, 0.3000]],

        [[0.1300, 0.1500, 0.1700, 0.1900],
         [0.1200, 0.1400, 0.1600, 0.1800],
         [0.1000, 0.1200, 0.1400, 0.1600],
         [0.2500, 0.2700, 0.2900, 0.3100],
         [0.1100, 0.1300, 0.1500, 0.1700],
         [0.0700, 0.0900, 0.1100, 0.1300],
         [0.2100, 0.2300, 0.2500, 0.2700],
         [0.3000, 0.3200, 0.3400, 0.3600]]], grad_fn=<AddBackward0>)

In [40]:
x = x + x_norm
x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[0.3714, 0.3931, 0.4153, 0.4379],
          [0.1808, 0.2020, 0.2233, 0.2449],
          [0.3614, 0.3831, 0.4053, 0.4279],
          [0.1413, 0.1629, 0.1850, 0.2074],
          [0.1111, 0.1325, 0.1542, 0.1763],
          [0.2714, 0.2931, 0.3153, 0.3379],
          [0.3813, 0.4029, 0.4249, 0.4473],
          [0.2414, 0.2631, 0.2853, 0.3079]],
 
         [[0.1303, 0.1508, 0.1713, 0.1920],
          [0.1214, 0.1431, 0.1653, 0.1879],
          [0.1014, 0.1231, 0.1453, 0.1679],
          [0.2508, 0.2718, 0.2930, 0.3144],
          [0.1113, 0.1330, 0.1550, 0.1775],
          [0.0712, 0.0927, 0.1146, 0.1369],
          [0.2111, 0.2325, 0.2542, 0.2762],
          [0.3012, 0.3228, 0.3448, 0.3672]]], grad_fn=<AddBackward0>))

#### Transformer - Layer Normalization 2

We'll run another round of normalization now on the outputs of the masked multi-head attention and skip connection to ensure our values are not too spread apart. This layer will run the same normalization formula as before, but is it's own independent layer as it has different inputs. Recall the formula is: 

$y = \frac{x - \mathbb{E}[x]}{\sqrt{\operatorname{Var}[x] + \epsilon}}\$

Because we will do a skip connection again for the next layer, MLP, we'll once again branch `x` for the normalization and MLP and then sum it back togehter with `x`. 

While before normalization brought the values back to the same identical values, now you'll see ever so slight variance.  This is the impact of dropout starting to show. 

In [41]:
ln_2 = nn.LayerNorm(n_embd)

In [42]:
x_norm_2 = ln_2(x)
x_norm_2.shape, x_norm_2

(torch.Size([2, 8, 4]),
 tensor([[[-1.3227, -0.4518,  0.4355,  1.3390],
          [-1.3246, -0.4488,  0.4379,  1.3354],
          [-1.3227, -0.4518,  0.4355,  1.3390],
          [-1.3231, -0.4512,  0.4359,  1.3383],
          [-1.3238, -0.4501,  0.4368,  1.3370],
          [-1.3227, -0.4518,  0.4355,  1.3390],
          [-1.3231, -0.4511,  0.4360,  1.3382],
          [-1.3227, -0.4518,  0.4355,  1.3390]],
 
         [[-1.3266, -0.4455,  0.4405,  1.3316],
          [-1.3227, -0.4518,  0.4355,  1.3390],
          [-1.3227, -0.4518,  0.4355,  1.3390],
          [-1.3250, -0.4481,  0.4384,  1.3347],
          [-1.3230, -0.4513,  0.4358,  1.3384],
          [-1.3233, -0.4508,  0.4363,  1.3378],
          [-1.3238, -0.4501,  0.4368,  1.3370],
          [-1.3231, -0.4510,  0.4361,  1.3381]]],
        grad_fn=<NativeLayerNormBackward0>))

#### Transformer - Feed Forward (aka Multi-layer Perceptron)

<img src="explainer_screenshots/gpt/mlp.png" width="400">

The feed-forward sublayer consists of two-layer mirroring a multi-layer perceptron, MLP.  These layers mix the features within each token vector but never across time. The output `x` from multi-headed attention starts as `B x T x C`.  Feed forward 
1. Calculates `4C` using a $XA^\top + B$ linear layer 
2. Normalizes the data using a `tanh` based GELU layer. This layer pushes extreme values to +/- 1
3. Projects back down to `C` with a final $XW^\top + B$ linear layer.

The MLP nonlinearly re-expresses each token's channel features before being aggregated across the hidden layers and passed to the output layer. 

We'll first start by creating the 3 different layers:
1. `mlp_fc` - Linear layer to project up to `4C`
2. `mlp_gelu` - gelu approximation using tanh layer
3. `mlp_proj` - Linear layer to project down to `C`

We'll do similar initiation to our linear layers as before. Since the tanh step is a calculation per row (similar to layer normalization), we will not do initiation for it.

In [43]:
mlp_fc = nn.Linear(n_embd, 4 * n_embd)

with torch.no_grad():
    out, inp = mlp_fc.weight.shape  # (3*n_embd, n_embd)
    r = torch.arange(1, out + 1).unsqueeze(1)  # [out,1], 1-indexed
    c = torch.arange(1, inp + 1).unsqueeze(0)   # [1,inp], 1-indexed
    base = r * c                          # rc
    tri = r * (r - 1) / 2                 # T_{r-1} = r(r-1)/2, shape [out,1]
    mask = (c >= 2)           # add T_{r-1} only from column 2 onward
    pattern = 1e-3 * (base + tri * mask) # matches [[.001,.002,.003],[.002,.005,.007],[.003,.009,.012],...]
    mlp_fc.weight.copy_(pattern)

mlp_fc.weight.shape, mlp_fc.weight

(torch.Size([16, 4]),
 Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040],
         [0.0020, 0.0050, 0.0070, 0.0090],
         [0.0030, 0.0090, 0.0120, 0.0150],
         [0.0040, 0.0140, 0.0180, 0.0220],
         [0.0050, 0.0200, 0.0250, 0.0300],
         [0.0060, 0.0270, 0.0330, 0.0390],
         [0.0070, 0.0350, 0.0420, 0.0490],
         [0.0080, 0.0440, 0.0520, 0.0600],
         [0.0090, 0.0540, 0.0630, 0.0720],
         [0.0100, 0.0650, 0.0750, 0.0850],
         [0.0110, 0.0770, 0.0880, 0.0990],
         [0.0120, 0.0900, 0.1020, 0.1140],
         [0.0130, 0.1040, 0.1170, 0.1300],
         [0.0140, 0.1190, 0.1330, 0.1470],
         [0.0150, 0.1350, 0.1500, 0.1650],
         [0.0160, 0.1520, 0.1680, 0.1840]], requires_grad=True))

In [44]:
mlp_gelu = nn.GELU(approximate='tanh')
mlp_gelu

GELU(approximate='tanh')

In [45]:
mlp_proj = nn.Linear(4 * n_embd, n_embd)

with torch.no_grad():
    out, inp = mlp_proj.weight.shape  # (3*n_embd, n_embd)
    r = torch.arange(1, out + 1).unsqueeze(1)  # [out,1], 1-indexed
    c = torch.arange(1, inp + 1).unsqueeze(0)   # [1,inp], 1-indexed
    base = r * c                          # rc
    tri = r * (r - 1) / 2                 # T_{r-1} = r(r-1)/2, shape [out,1]
    mask = (c >= 2)           # add T_{r-1} only from column 2 onward
    pattern = 1e-3 * (base + tri * mask) # matches [[.001,.002,.003],[.002,.005,.007],[.003,.009,.012],...]
    mlp_proj.weight.copy_(pattern)

mlp_proj.weight.shape, mlp_proj.weight

(torch.Size([4, 16]),
 Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080, 0.0090,
          0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160],
         [0.0020, 0.0050, 0.0070, 0.0090, 0.0110, 0.0130, 0.0150, 0.0170, 0.0190,
          0.0210, 0.0230, 0.0250, 0.0270, 0.0290, 0.0310, 0.0330],
         [0.0030, 0.0090, 0.0120, 0.0150, 0.0180, 0.0210, 0.0240, 0.0270, 0.0300,
          0.0330, 0.0360, 0.0390, 0.0420, 0.0450, 0.0480, 0.0510],
         [0.0040, 0.0140, 0.0180, 0.0220, 0.0260, 0.0300, 0.0340, 0.0380, 0.0420,
          0.0460, 0.0500, 0.0540, 0.0580, 0.0620, 0.0660, 0.0700]],
        requires_grad=True))

**MLP - 4c projection**  - Now we'll take `x_norm_2` and apply the first linear layer that projects upward to `4C`.  Reminder that the  calculation is $x\_mlp = x\_norm\_2 \cdot mlp\_fc^\top + b$. Once again we'll see a repetition of many values in `x_mlp` as the rows in `x_norm_2` are close to the same.

In [46]:
x_mlp = mlp_fc(x_norm_2)
x_mlp.shape, x_mlp

(torch.Size([2, 8, 16]),
 tensor([[[ 0.1248,  0.3857,  0.3960, -0.2152, -0.0158,  0.0739,  0.1779,
            0.2909,  0.5819, -0.1378, -0.0966,  0.4029,  0.1428,  0.0509,
            0.2433,  0.1538],
          [ 0.1248,  0.3857,  0.3960, -0.2152, -0.0158,  0.0739,  0.1779,
            0.2909,  0.5819, -0.1377, -0.0965,  0.4030,  0.1429,  0.0511,
            0.2434,  0.1540],
          [ 0.1248,  0.3857,  0.3960, -0.2152, -0.0158,  0.0739,  0.1779,
            0.2909,  0.5819, -0.1378, -0.0966,  0.4029,  0.1428,  0.0509,
            0.2433,  0.1538],
          [ 0.1248,  0.3857,  0.3960, -0.2152, -0.0158,  0.0739,  0.1779,
            0.2909,  0.5819, -0.1378, -0.0966,  0.4029,  0.1429,  0.0510,
            0.2433,  0.1538],
          [ 0.1248,  0.3857,  0.3960, -0.2152, -0.0158,  0.0739,  0.1779,
            0.2909,  0.5819, -0.1378, -0.0966,  0.4030,  0.1429,  0.0510,
            0.2434,  0.1539],
          [ 0.1248,  0.3857,  0.3960, -0.2152, -0.0158,  0.0739,  0.1779,
           

**MLP - tanh**  - Now we'll apply the tanh approximation (GELU - tanh) which smoothly gates each input `x_mlp`. The formula applied is 

$\tanh(x)=\frac{\exp(x)-\exp(-x)}{\exp(x)+\exp(-x)}$

The Tanh formula pushes large positive numbers to 1 and large negative numbers to -1.  `tanh` is applied element wise across the full `x_mlp`.  This will further smooth out the differences we've seen given how minor they are. 

In [47]:
x_mlp = mlp_gelu(x_mlp)
x_mlp.shape, x_mlp

(torch.Size([2, 8, 16]),
 tensor([[[ 0.0686,  0.2507,  0.2590, -0.0893, -0.0078,  0.0391,  0.1015,
            0.1787,  0.4187, -0.0613, -0.0446,  0.2645,  0.0795,  0.0265,
            0.1450,  0.0863],
          [ 0.0686,  0.2507,  0.2590, -0.0893, -0.0078,  0.0391,  0.1015,
            0.1788,  0.4188, -0.0613, -0.0446,  0.2646,  0.0796,  0.0266,
            0.1451,  0.0864],
          [ 0.0686,  0.2507,  0.2590, -0.0893, -0.0078,  0.0391,  0.1015,
            0.1787,  0.4187, -0.0613, -0.0446,  0.2645,  0.0795,  0.0265,
            0.1450,  0.0863],
          [ 0.0686,  0.2507,  0.2590, -0.0893, -0.0078,  0.0391,  0.1015,
            0.1788,  0.4187, -0.0613, -0.0446,  0.2645,  0.0795,  0.0265,
            0.1450,  0.0863],
          [ 0.0686,  0.2507,  0.2590, -0.0893, -0.0078,  0.0391,  0.1015,
            0.1788,  0.4188, -0.0613, -0.0446,  0.2645,  0.0796,  0.0265,
            0.1451,  0.0863],
          [ 0.0686,  0.2507,  0.2590, -0.0893, -0.0078,  0.0391,  0.1015,
           

**MLP - down projection**  - Finally we'll take `x_mlp` and project `4C` back down to `C` using the weights in `mlp_proj`. Recall that we apply $x\_mlp = x\_mlp\_2 \cdot mlp\_proj^\top$.  Even though the layer weights are `4x16` the transpose in $XW^\top + B$ allows to project back down. We'll recover a lot of the row repetion though this process as we get further from the dropout.

In [48]:
x_mlp = mlp_proj(x_mlp)
x_mlp.shape, x_mlp

(torch.Size([2, 8, 4]),
 tensor([[[0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1715, 0.1055, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140]],
 
         [[0.1032, 0.1715, 0.1055, 0.0141],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1715, 0.1055, 0.0141],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140],
          [0.1032, 0.1714, 0.1054, 0.0140]]], grad_fn=<ViewBackward0>))

#### Transformer - Residual (skip) connection 2

Once again our transformer uses a skip connection to allow for passing gradients through the Feed Forward, aka MLP, layer. Just like the first skip connection, functionally this is represented as

$y = f(x) + x$

To achieve this we simply sum the MLP input matrix `x` with the MLP output `x_mlp`.  As a reminder we'll print out `x`.  As you can see, because `x` was based on the tokens, it has a different value per row, so even though `x_mlp` has the same value per row, we'll result in a diverse set of weights.  With this you can see the power of skip connections passing through weights.

In [49]:
x

tensor([[[0.3714, 0.3931, 0.4153, 0.4379],
         [0.1808, 0.2020, 0.2233, 0.2449],
         [0.3614, 0.3831, 0.4053, 0.4279],
         [0.1413, 0.1629, 0.1850, 0.2074],
         [0.1111, 0.1325, 0.1542, 0.1763],
         [0.2714, 0.2931, 0.3153, 0.3379],
         [0.3813, 0.4029, 0.4249, 0.4473],
         [0.2414, 0.2631, 0.2853, 0.3079]],

        [[0.1303, 0.1508, 0.1713, 0.1920],
         [0.1214, 0.1431, 0.1653, 0.1879],
         [0.1014, 0.1231, 0.1453, 0.1679],
         [0.2508, 0.2718, 0.2930, 0.3144],
         [0.1113, 0.1330, 0.1550, 0.1775],
         [0.0712, 0.0927, 0.1146, 0.1369],
         [0.2111, 0.2325, 0.2542, 0.2762],
         [0.3012, 0.3228, 0.3448, 0.3672]]], grad_fn=<AddBackward0>)

In [50]:
x = x_mlp + x

x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[0.4746, 0.5646, 0.5207, 0.4519],
          [0.2841, 0.3734, 0.3288, 0.2590],
          [0.4646, 0.5546, 0.5107, 0.4419],
          [0.2445, 0.3344, 0.2904, 0.2214],
          [0.2143, 0.3040, 0.2597, 0.1903],
          [0.3746, 0.4646, 0.4207, 0.3519],
          [0.4845, 0.5744, 0.5304, 0.4613],
          [0.3446, 0.4346, 0.3907, 0.3219]],
 
         [[0.2336, 0.3222, 0.2768, 0.2061],
          [0.2246, 0.3146, 0.2707, 0.2019],
          [0.2046, 0.2946, 0.2507, 0.1819],
          [0.3540, 0.4432, 0.3984, 0.3285],
          [0.2145, 0.3044, 0.2605, 0.1915],
          [0.1744, 0.2642, 0.2201, 0.1509],
          [0.3143, 0.4039, 0.3596, 0.2902],
          [0.4045, 0.4943, 0.4502, 0.3812]]], grad_fn=<AddBackward0>))

#### Transformer - Final Layer Normalization 

The final step in the transformer is to aggregate and normalize before calculating the final projections. This layer is similar to the previous normalization layers. This layer will run the same normalization formula as before, but is its own independent layer as it has different inputs. Recall the formula is: 

$y = \frac{x - \mathbb{E}[x]}{\sqrt{\operatorname{Var}[x] + \epsilon}}$

Since this is the final layer, we will not have a residual connection so we do not need to branch `x`.  As with the previous normalization layers, we'll once again see the rows diverge in value as the dropout impact shown in `x` comes back.  

In [51]:
ln_f = nn.LayerNorm(n_embd)
ln_f

LayerNorm((4,), eps=1e-05, elementwise_affine=True)

In [52]:
x = ln_f(x)
x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[-0.6522,  1.4170,  0.4091, -1.1739],
          [-0.6216,  1.4171,  0.3985, -1.1941],
          [-0.6522,  1.4170,  0.4091, -1.1739],
          [-0.6469,  1.4171,  0.4073, -1.1775],
          [-0.6357,  1.4171,  0.4034, -1.1849],
          [-0.6522,  1.4170,  0.4091, -1.1739],
          [-0.6462,  1.4171,  0.4070, -1.1780],
          [-0.6522,  1.4170,  0.4091, -1.1739]],
 
         [[-0.5912,  1.4167,  0.3880, -1.2134],
          [-0.6522,  1.4170,  0.4091, -1.1739],
          [-0.6522,  1.4170,  0.4091, -1.1739],
          [-0.6163,  1.4170,  0.3968, -1.1975],
          [-0.6480,  1.4171,  0.4077, -1.1768],
          [-0.6420,  1.4171,  0.4056, -1.1807],
          [-0.6347,  1.4171,  0.4031, -1.1855],
          [-0.6445,  1.4171,  0.4065, -1.1790]]],
        grad_fn=<NativeLayerNormBackward0>))

### Output Layers AKA Model Head.

The combination of masked multi-head attention and feed forward, along with the normalization and residual connections is considered the "transformer".  In practice this layer is horizontally scaled to run many layers in parallel.  Once those layers are complete during the forward pass we then start the output process that results in `logits` which is a representation of the probability of each token being the next token given the input.  

This layer is also known as the model **head**, not to be confused with attention heads. This layer is called this because it is a small, task-specific module attached to a model’s shared backbone that maps hidden features to the final outputs.  In our example case, this is a linear layer mapping the backbone to vocab logits. The benefit of this structure is that you can use the shared hidden features and train different heads for different tasks without starting from scratch. An example would be a classifier head, or policy head in RL.

<img src="explainer_screenshots/gpt/output_layer.png" width="200">

For our head we want to map to a predicted token which we'll look at as `logits`. In the process to generate `logits` we take the normalized output `x` of the transformers, then project, using a linear layer, to the vocabulary resulting in a `B, T, vocab_size` matrix known as `logits`.  

In training, the `logits` are then compared with `y` to see how close the  model is to predicting the correct next token. For inference, the `logits` are then used to drive sampling which is how the next token is then derived. 


Instead of initializing weights this time around, we'll do **Weight Tying**.  Weight tying sets the output softmax matrix equal to the transpose of the input embedding matrix $W_{\text{out}} = W_e^\top$, forcing the model to “read” and “predict” in the same token space. This reduces parameters and acts as a useful prior, improving sample-efficiency and often perplexity by aligning input–output geometry. Modern LLMs have seemed to ditch this though to gain the extra capacity, but, for our example, we'll maintain it.





In [53]:
lm_head = nn.Linear(n_embd, vocab_size, bias=False)
# weight sharing scheme
lm_head.weight = wte.weight

lm_head, lm_head.weight


(Linear(in_features=4, out_features=36, bias=False),
 Parameter containing:
 tensor([[0.0100, 0.0200, 0.0300, 0.0400],
         [0.0200, 0.0300, 0.0400, 0.0500],
         [0.0300, 0.0400, 0.0500, 0.0600],
         [0.0400, 0.0500, 0.0600, 0.0700],
         [0.0500, 0.0600, 0.0700, 0.0800],
         [0.0600, 0.0700, 0.0800, 0.0900],
         [0.0700, 0.0800, 0.0900, 0.1000],
         [0.0800, 0.0900, 0.1000, 0.1100],
         [0.0900, 0.1000, 0.1100, 0.1200],
         [0.1000, 0.1100, 0.1200, 0.1300],
         [0.1100, 0.1200, 0.1300, 0.1400],
         [0.1200, 0.1300, 0.1400, 0.1500],
         [0.1300, 0.1400, 0.1500, 0.1600],
         [0.1400, 0.1500, 0.1600, 0.1700],
         [0.1500, 0.1600, 0.1700, 0.1800],
         [0.1600, 0.1700, 0.1800, 0.1900],
         [0.1700, 0.1800, 0.1900, 0.2000],
         [0.1800, 0.1900, 0.2000, 0.2100],
         [0.1900, 0.2000, 0.2100, 0.2200],
         [0.2000, 0.2100, 0.2200, 0.2300],
         [0.2100, 0.2200, 0.2300, 0.2400],
         [0.2200, 0.2

now let's check that the values are the same and that the underlying objects `data_ptr()` are

In [54]:
lm_head.weight is wte.weight, lm_head.weight.data_ptr() == wte.weight.data_ptr()

(True, True)

#### Output layer - LM Head aka logits
We now project `x` onto the vocabulary resulting in a `B X T X vocab_size` final array `logits`.  This output correlates with the 
probabilty of each output token given the input context.  The best way to read  this is:

(dimension 0) we have 2 batches B, 
(dimension 1) each batch has an example for each value between 1 and context length T 
(dimension 2) for each example we see the probability of each token in our vocabulary

Since our `x` at this point has similar values across the row, and we are using weight tying between the head and inputlayer, we fully expect that our logits will have similar values.  In practice this means that our model will have close to the same probability of a token output as the 'next token' regardless of the preceeding text, meaning it's shit. Luckily backpropogation has a way of updating this so that with enough data and time the probabilities change. 

In [55]:
x

tensor([[[-0.6522,  1.4170,  0.4091, -1.1739],
         [-0.6216,  1.4171,  0.3985, -1.1941],
         [-0.6522,  1.4170,  0.4091, -1.1739],
         [-0.6469,  1.4171,  0.4073, -1.1775],
         [-0.6357,  1.4171,  0.4034, -1.1849],
         [-0.6522,  1.4170,  0.4091, -1.1739],
         [-0.6462,  1.4171,  0.4070, -1.1780],
         [-0.6522,  1.4170,  0.4091, -1.1739]],

        [[-0.5912,  1.4167,  0.3880, -1.2134],
         [-0.6522,  1.4170,  0.4091, -1.1739],
         [-0.6522,  1.4170,  0.4091, -1.1739],
         [-0.6163,  1.4170,  0.3968, -1.1975],
         [-0.6480,  1.4171,  0.4077, -1.1768],
         [-0.6420,  1.4171,  0.4056, -1.1807],
         [-0.6347,  1.4171,  0.4031, -1.1855],
         [-0.6445,  1.4171,  0.4065, -1.1790]]],
       grad_fn=<NativeLayerNormBackward0>)

In [56]:
logits = lm_head(x)

logits.shape, logits

(torch.Size([2, 8, 36]),
 tensor([[[-0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129],
          [-0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
           -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
           -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
           -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
           -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
           -0.0137],
          [-0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
           -0.0129, -0.0129, -0.0129, -0.0129

## Loss calculation
Now we have to see how good our ~shit~ prediction is.  Since we haven't done training, and we saw that regardless of example we had the same exact logit values, we can expect it's bad. That said, we need to know how bad. For this example we'll use cross entropy, also known as the negative log likelihood of the softmax.  Our loss calculates

$$
\ell_i=-\log\big(\mathrm{softmax}(z_i)\_{y_i}\big)
= -z_{i,y_i}+\log\!\sum_{c=1}^C e^{z_{i,c}},
$$


To calculate loss we'll pass in the calculated `logits` and our next tokens stored in `y`. The cross entropy function does not respect batches so we'll flatten the `B` dimension for both `logits` and `y`

In [57]:
y_flat = y.view(-1)
y_flat.shape, y_flat

(torch.Size([16]),
 tensor([15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0]))

In [58]:
logits_flat = logits.view(-1, logits.size(-1))
logits_flat.shape, logits_flat

(torch.Size([16, 36]),
 tensor([[-0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
          -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
          -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
          -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
          -0.0129, -0.0129, -0.0129, -0.0129],
         [-0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
          -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
          -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
          -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137, -0.0137,
          -0.0137, -0.0137, -0.0137, -0.0137],
         [-0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
          -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,
          -0.0129, -0.0129, -0.0129, -0.0129, -0.0129, -0.0129,

In [59]:
loss = F.cross_entropy(logits_flat, y_flat)
loss.shape, loss

(torch.Size([]), tensor(3.5835, grad_fn=<NllLossBackward0>))

## Back Propagation

We now know just how well the current model, with its weights and biases, does predicting the next token given the input context. We now need to know how to change the different weights and biases to improve the formula.  We could do this by guessing through making minor changes and seeing what improves, or we can think through this more critically.

If you review the chain of layers above, you can see that it's a series of formulas.  We can think of this as $f(g(x))$, except with many many more layers and complexities.  Since this is a formula, we can dig into our math toolbox and find a better way to determine what parts need to update.  Recall that in our calculus we learned that differentiation tells us the rate of change in a graph.  So if we treat the loss function $\mathcal{L}$ as $\mathcal{L}(f(g(x)))$ taking the partial differential 

$\delta=\partial \mathcal{L}/\partial h$

at each layer will give us the impact of each weight/bias on our final output (albeit the inverse since our loss function is the negative log likelihood). 

Lucky for us, each layer of our model already has a placeholder for the partial differential called the **Gradient**. We'll use this field to store it.  We'll start by first zeroing out the gradients. We do this because of the nature of handling partial differentials for multiple dependencies. Recall that in multiple places we had a formula structure of 

$a+b=c ; a+c= d$

In this case $a$ has 2 dependencies and determining the partial derivative of $\partial d / \partial a$ requires understanding both the path from $d$ and $c$.  To determine the true impact of a we would sum both partial derivatives together.  Because of this property, the tool we use, the built in `.backwards()` automatically sums gradients, `+=`, so if we do not set the gradient to `0` we then end up with erroneous gradients. 

Finally, we start `.backwards` from the `loss`, not `logits` as our goal is to minimize loss, we need to ensure we are looking at the calculations that impact loss which requires the whole forward pass to be able to generate the prediction `logits_flat`.  If we think of it as $\mathcal{L}(f(x))$ where $f(x)$ is the forward pass to generate logits, then a simple chain rule is applied:

${\partial}/{\partial x} =  \mathcal{L}'(f(x)) f'(x)$

Lets start by zeroing the gradients and leaning on pytorch to calculate the gradients for us. We'll also validate the gradients were `none`.

In [60]:
lm_head.zero_grad()
ln_f.zero_grad()
mlp_proj.zero_grad()
mlp_fc.zero_grad()
ln_2.zero_grad()
c_proj.zero_grad()
c_attn.zero_grad()
ln_1.zero_grad()
wpe.zero_grad()
wte.zero_grad()

# validate gradients
lm_head.weight.grad, ln_f.weight.grad

(None, None)

In [61]:
loss.backward()

**Auto-Diff** - Now let's see the magic of the gradients populate.  This magic is called auto-differentiation, or auto-diff for short. This allows us to not have to write many layers of nasty code to do the differentiation for us, but, if you're a sadist, you can surely find people who have written out that code (it's not too bad since you just do one layer at a time). 

In [62]:
lm_head.weight.grad, ln_f.weight.grad

(tensor([[ 0.0630, -0.1378, -0.0396,  0.1144],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0612, -0.1378, -0.0390,  0.1156],
         [ 0.0230, -0.0492, -0.0143,  0.0405],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0599, -0.1377, -0.0386,  0.1164],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0230, -0.0492, -0.0143,  0.0405],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0223, -0.0492, -0.0141,  0.0410],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0634, -0.1378, -0.0398,  0.1142],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0627, -0.1378, -0.0395,  0.1146],
         [ 0.

## Learning
The process of learning now requires us to update our weights based on this gradient. To really feel the "back propagation" we'll start with the last layer and work backwards, though, since we have all of the gradients calculated already, the order does not matter. Recall that our loss function is the negative log likelihood ratio so our gradient signs are flipped.  If a parameter is important, the gradient will be more negative, and vice versa. The gradients are a ratio of importance of each parameter and we need to know how much of that gradient to apply to our weights. This "how much" is referred to as the *learning rate*. In modern training learning rate schedulers and optimizers are used to vary the rate and application by layer and by training round.  We'll keep it simple and use an astronomically high learning rate of `5.000` which applies the gradient directly to the weights via a `-=`. Gradient for the weights and the biases is different as the partial differential with respect to each is different. We need to remember in the layers with bias to apply it to both.

*Note that since our vocab is very small, our context is small, and our batch is small, relatively our model is very deep so we will see a lot of exceptionally small gradients*

In [63]:
## Huge learning rate to emphasize
learning_rate = 5.000 

### Output Layer
Let's start with our output layer.  If you recall we did *weight tying* of our head `lm_head` to our token projection `wte`.  Because of this applying loss to the `lm_head` will automatically apply it to `wte`. Since we're doing this manually and not through an optimizer we need to be careful not to apply it twice. Since we're starting with the last layer first, the head, we'll see that the gradients are equal and that updating just the `lm_head` weights updates `wte` weights also.

In [64]:
lm_head.weight.grad is wte.weight.grad, lm_head.weight.grad, wte.weight.grad

(True,
 tensor([[ 0.0630, -0.1378, -0.0396,  0.1144],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0612, -0.1378, -0.0390,  0.1156],
         [ 0.0230, -0.0492, -0.0143,  0.0405],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0599, -0.1377, -0.0386,  0.1164],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0230, -0.0492, -0.0143,  0.0405],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0223, -0.0492, -0.0141,  0.0410],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0634, -0.1378, -0.0398,  0.1142],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [-0.0178,  0.0394,  0.0112, -0.0328],
         [ 0.0627, -0.1378, -0.0395,  0.1146],
      

In [65]:
with torch.no_grad():
    lm_head.weight -= learning_rate * lm_head.weight.grad
lm_head.weight is wte.weight, lm_head.weight, wte.weight

(True,
 Parameter containing:
 tensor([[-0.3050,  0.7089,  0.2282, -0.5320],
         [ 0.1089, -0.1668, -0.0162,  0.2142],
         [ 0.1189, -0.1568, -0.0062,  0.2242],
         [ 0.1289, -0.1468,  0.0038,  0.2342],
         [ 0.1389, -0.1368,  0.0138,  0.2442],
         [-0.2459,  0.7589,  0.2750, -0.4880],
         [-0.0449,  0.3260,  0.1616, -0.1027],
         [ 0.1689, -0.1068,  0.0438,  0.2742],
         [ 0.1789, -0.0968,  0.0538,  0.2842],
         [-0.1997,  0.7987,  0.3129, -0.4519],
         [ 0.1989, -0.0768,  0.0738,  0.3042],
         [ 0.0051,  0.3760,  0.2116, -0.0527],
         [ 0.2189, -0.0568,  0.0938,  0.3242],
         [ 0.0283,  0.3960,  0.2305, -0.0348],
         [ 0.2389, -0.0368,  0.1138,  0.3442],
         [-0.1568,  0.8589,  0.3788, -0.3808],
         [ 0.2589, -0.0168,  0.1338,  0.3642],
         [ 0.2689, -0.0068,  0.1438,  0.3742],
         [ 0.2789,  0.0032,  0.1538,  0.3842],
         [ 0.2889,  0.0132,  0.1638,  0.3942],
         [-0.1036,  0.9089,  0

### Transformer - Final Layer Normalization
Now let's move to the layer normalization.  In our forward pass we did not change the weights, but in training we still will as the normalization can impact the outputs.  Additionally layer normalization has bias so we have to remember to add it. 

In [66]:
ln_f.weight, ln_f.bias

(Parameter containing:
 tensor([1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0.], requires_grad=True))

In [67]:
ln_f.weight.grad, ln_f.bias.grad

(tensor([-0.0277,  0.0611,  0.0175, -0.0509]),
 tensor([0.0431, 0.0431, 0.0431, 0.0431]))

In [68]:
with torch.no_grad():
    ln_f.weight -= learning_rate * ln_f.weight.grad
    ln_f.bias   -= learning_rate * ln_f.bias.grad
ln_f.weight, ln_f.bias

(Parameter containing:
 tensor([1.1383, 0.6945, 0.9126, 1.2547], requires_grad=True),
 Parameter containing:
 tensor([-0.2156, -0.2156, -0.2156, -0.2156], requires_grad=True))

### Tranformer - Feed Froward Updates
In feed forward we had 3 layers: the upward projection `mlp_fc`, the tanh layer `mlp_gelu`, and the downward projection `mlp_proj`.  The tanh layer normalizes each element using tanh so it has no trainable parameters, or weights, so there is no updates to be made, the gradient passes through it. But the other two layers have both weights and biases that we can update.  

One thing you may notice is that the gradients are exceptionally small in some cases, leading to minimal updates

In [69]:
mlp_proj.weight.grad, mlp_proj.bias.grad, mlp_fc.weight.grad,mlp_fc.bias.grad

(tensor([[ 2.3517e-08,  8.5932e-08,  8.8750e-08, -3.0591e-08, -2.6721e-09,
           1.3407e-08,  3.4789e-08,  6.1262e-08,  1.4351e-07, -2.1024e-08,
          -1.5280e-08,  9.0650e-08,  2.7256e-08,  9.0820e-09,  4.9701e-08,
           2.9570e-08],
         [-3.8855e-08, -1.4197e-07, -1.4663e-07,  5.0542e-08,  4.4148e-09,
          -2.2151e-08, -5.7478e-08, -1.0121e-07, -2.3711e-07,  3.4736e-08,
           2.5245e-08, -1.4977e-07, -4.5032e-08, -1.5005e-08, -8.2114e-08,
          -4.8855e-08],
         [-8.9469e-09, -3.2692e-08, -3.3764e-08,  1.1638e-08,  1.0166e-09,
          -5.1003e-09, -1.3235e-08, -2.3306e-08, -5.4596e-08,  7.9989e-09,
           5.8137e-09, -3.4485e-08, -1.0368e-08, -3.4536e-09, -1.8906e-08,
          -1.1247e-08],
         [ 3.2975e-08,  1.2049e-07,  1.2444e-07, -4.2894e-08, -3.7468e-09,
           1.8799e-08,  4.8780e-08,  8.5899e-08,  2.0123e-07, -2.9480e-08,
          -2.1426e-08,  1.2711e-07,  3.8217e-08,  1.2733e-08,  6.9687e-08,
           4.1460e-08]]),
 t

In [70]:
with torch.no_grad():
    mlp_proj.weight -= learning_rate * mlp_proj.weight.grad
    mlp_proj.bias -= learning_rate * mlp_proj.bias.grad
    mlp_fc.weight -= learning_rate * mlp_fc.weight.grad
    mlp_fc.bias -= learning_rate * mlp_fc.bias.grad
mlp_proj.weight, mlp_proj.bias, mlp_fc.weight,mlp_fc.bias

(Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080, 0.0090,
          0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160],
         [0.0020, 0.0050, 0.0070, 0.0090, 0.0110, 0.0130, 0.0150, 0.0170, 0.0190,
          0.0210, 0.0230, 0.0250, 0.0270, 0.0290, 0.0310, 0.0330],
         [0.0030, 0.0090, 0.0120, 0.0150, 0.0180, 0.0210, 0.0240, 0.0270, 0.0300,
          0.0330, 0.0360, 0.0390, 0.0420, 0.0450, 0.0480, 0.0510],
         [0.0040, 0.0140, 0.0180, 0.0220, 0.0260, 0.0300, 0.0340, 0.0380, 0.0420,
          0.0460, 0.0500, 0.0540, 0.0580, 0.0620, 0.0660, 0.0700]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0891,  0.1415,  0.0581, -0.0524], requires_grad=True),
 Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040],
         [0.0020, 0.0050, 0.0070, 0.0090],
         [0.0030, 0.0090, 0.0120, 0.0150],
         [0.0040, 0.0140, 0.0180, 0.0220],
         [0.0050, 0.0200, 0.0250, 0.0300],
         [0.0060, 0.027

### Tranformer - Layer Norm 2
Now on to the next layer normalization.  Similarly we have both weights and biases to update again.  We can see once again these are exceptionally small gradients. 

In [71]:
ln_2.weight.grad, ln_2.bias.grad

(tensor([-1.1258e-09, -2.9141e-09,  3.1795e-09,  1.0917e-08]),
 tensor([8.5118e-10, 6.4505e-09, 7.3016e-09, 8.1528e-09]))

In [72]:
with torch.no_grad():
    ln_2.weight -= learning_rate * ln_2.weight.grad
    ln_2.bias   -= learning_rate * ln_2.bias.grad
ln_2.weight, ln_2.bias

(Parameter containing:
 tensor([1.0000, 1.0000, 1.0000, 1.0000], requires_grad=True),
 Parameter containing:
 tensor([-4.2559e-09, -3.2252e-08, -3.6508e-08, -4.0764e-08],
        requires_grad=True))

### Tranformer - Flash Attention
Even though flash attention was excpetionally complicated with multiple attnetion heads and the query, key, and value, we only have 2 sets of weights, and biases, to update: `c_attn`, `c_proj`.  Back propogation aggregated the different impacts across on each.  

Similar to what we've seen this far, these gradiens are very small but we can see large order of magnitudedifferences within the values starting to suggest how they impact the ouputs. 

In [73]:
c_attn.weight.grad,c_attn.bias.grad, c_proj.weight.grad, c_proj.bias.grad,

(tensor([[ 5.9886e-19,  1.9962e-19, -1.9962e-19, -5.9886e-19],
         [ 5.6699e-19,  1.8900e-19, -1.8900e-19, -5.6699e-19],
         [-1.4730e-17, -4.9100e-18,  4.9100e-18,  1.4730e-17],
         [-1.8661e-17, -6.2203e-18,  6.2202e-18,  1.8661e-17],
         [-3.4765e-20,  3.3945e-20, -2.9754e-20, -3.3095e-20],
         [-4.8207e-20,  8.9532e-20, -8.4130e-20, -9.7029e-20],
         [ 1.3764e-18,  3.7843e-18,  1.9764e-18, -7.2143e-19],
         [ 1.9788e-18,  5.6329e-18,  2.8696e-18, -1.0349e-18],
         [-1.0604e-09, -3.5347e-10,  3.5347e-10,  1.0604e-09],
         [-4.8019e-09, -1.6006e-09,  1.6006e-09,  4.8019e-09],
         [-6.1121e-09, -2.0374e-09,  2.0374e-09,  6.1121e-09],
         [-7.2319e-09, -2.4106e-09,  2.4106e-09,  7.2319e-09]]),
 tensor([-4.5080e-19, -4.2681e-19,  1.1088e-17,  1.4047e-17, -2.8376e-20,
         -8.6397e-20, -4.3368e-19, -9.7578e-19,  7.9824e-10,  3.6148e-09,
          4.6010e-09,  5.4440e-09]),
 tensor([[ 2.6019e-08,  3.0881e-08,  3.8607e-08,  4.4644e

In [74]:
with torch.no_grad():
    c_attn.weight -= learning_rate * c_attn.weight.grad
    c_attn.bias -= learning_rate * c_attn.bias.grad
    c_proj.weight -= learning_rate * c_proj.weight.grad
    c_proj.bias -= learning_rate * c_proj.bias.grad
c_attn.weight,c_attn.bias, c_proj.weight, c_proj.bias,

(Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040],
         [0.0020, 0.0050, 0.0070, 0.0090],
         [0.0030, 0.0090, 0.0120, 0.0150],
         [0.0040, 0.0140, 0.0180, 0.0220],
         [0.0050, 0.0200, 0.0250, 0.0300],
         [0.0060, 0.0270, 0.0330, 0.0390],
         [0.0070, 0.0350, 0.0420, 0.0490],
         [0.0080, 0.0440, 0.0520, 0.0600],
         [0.0090, 0.0540, 0.0630, 0.0720],
         [0.0100, 0.0650, 0.0750, 0.0850],
         [0.0110, 0.0770, 0.0880, 0.0990],
         [0.0120, 0.0900, 0.1020, 0.1140]], requires_grad=True),
 Parameter containing:
 tensor([ 2.2540e-18,  2.1341e-18, -5.5442e-17, -7.0236e-17,  1.4188e-19,
          4.3199e-19,  2.1684e-18,  4.8789e-18, -3.9912e-09, -1.8074e-08,
         -2.3005e-08, -2.7220e-08], requires_grad=True),
 Parameter containing:
 tensor([[0.0010, 0.0020, 0.0030, 0.0040],
         [0.0020, 0.0050, 0.0070, 0.0090],
         [0.0030, 0.0090, 0.0120, 0.0150],
         [0.0040, 0.0140, 0.0180, 0.0220]], requires_grad=T

### Tranformer - Layer Norm 1
Now on to the first layer normalization in a transfomer.  Similarly to the other layer normalizations we have both weights and biases to update again.  We can see once again these are exceptionally small gradients. 

In [75]:
ln_1.weight.grad, ln_1.bias.grad

(tensor([-2.1158e-10, -4.9696e-10,  5.6749e-10,  1.9140e-09]),
 tensor([1.5927e-10, 1.1223e-09, 1.2816e-09, 1.4408e-09]))

In [76]:
with torch.no_grad():
    ln_1.weight -= learning_rate * ln_1.weight.grad
    ln_1.bias   -= learning_rate * ln_1.bias.grad
ln_1.weight, ln_1.bias

(Parameter containing:
 tensor([1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([-7.9635e-10, -5.6115e-09, -6.4079e-09, -7.2042e-09],
        requires_grad=True))

### Input Layer 
Now on to the input layer.  Recall that this layer had 2 components: the positional embeddings `wpe` and the token embeddings `wte`. We used *weight tying* to tie our token project `wte` with our head `lm_head` which has already been updated so we do not need to touch it.  We do however still need to update our positional embeddings.   Remember that `wpe` as an embedding has only weights and no biases. We'll see that once again we have relatively small gradients but we can see that they do very a few orders of magnitude showing that even with our small vocab and context, the position are starting to show some impact. 

*Note that this is the final layer of update we need to make.  Tokenization and data loading is separate from the model and does not impact loss*

In [77]:
wpe.weight.grad

tensor([[-3.6269e-08,  6.3913e-08,  4.5561e-08, -9.1831e-08],
        [ 1.3911e-07, -3.2269e-07, -9.0829e-08,  3.0421e-07],
        [-2.9739e-08,  6.3188e-08,  1.3497e-08, -8.4200e-08],
        [-2.0299e-08,  3.2760e-08,  1.1684e-08, -5.3948e-08],
        [ 7.8811e-08, -1.9212e-08, -2.3554e-08,  9.8066e-08],
        [ 1.6379e-07, -3.6291e-07, -8.4069e-08,  3.1299e-07],
        [ 2.1207e-08, -2.1866e-08, -8.8038e-10, -2.1867e-09],
        [-4.5591e-08,  9.0521e-08,  2.2656e-08, -4.5235e-08]])

In [78]:
with torch.no_grad():
    wpe.weight -= learning_rate * wpe.weight.grad
wpe.weight

Parameter containing:
tensor([[0.0100, 0.0200, 0.0300, 0.0400],
        [0.0200, 0.0300, 0.0400, 0.0500],
        [0.0300, 0.0400, 0.0500, 0.0600],
        [0.0400, 0.0500, 0.0600, 0.0700],
        [0.0500, 0.0600, 0.0700, 0.0800],
        [0.0600, 0.0700, 0.0800, 0.0900],
        [0.0700, 0.0800, 0.0900, 0.1000],
        [0.0800, 0.0900, 0.1000, 0.1100]], requires_grad=True)

## Forward Pass with Updated Weights

Now that we have the updated weights for each layer, let's do another forward pass and compare the loss. Since each layer was previously explained we will instead focus on just showing the outputs of the different layers and the final loss. If you want, you can check the previous outputs in the cached cell outputs above and compare them to see how the weight changes impacted the values at each layer. 

One key sign that our weights were updated is that you'll see quickly that the values at each layer are no longer repeated.  

### Data Re-loading
Repulling to a new `x_2`. We'll keep `y` to emphasize the same examples are being used. 

In [79]:
x_2 = tok_for_training[:-1].view(B, T)
x_2, y

(tensor([[35, 15, 32,  9,  5, 20, 30, 15],
         [11,  9,  6, 20,  5,  0, 13, 21]]),
 tensor([[15, 32,  9,  5, 20, 30, 15, 11],
         [ 9,  6, 20,  5,  0, 13, 21,  0]]))

### Input Layer
Note that in `wte` you can already see the impact of updated weights with the negative values. This was not originally present. 

In [80]:
pos = torch.arange(0, T, dtype=torch.long)
pos_emb = wpe(pos)
print(pos_emb.shape, pos_emb)

torch.Size([8, 4]) tensor([[0.0100, 0.0200, 0.0300, 0.0400],
        [0.0200, 0.0300, 0.0400, 0.0500],
        [0.0300, 0.0400, 0.0500, 0.0600],
        [0.0400, 0.0500, 0.0600, 0.0700],
        [0.0500, 0.0600, 0.0700, 0.0800],
        [0.0600, 0.0700, 0.0800, 0.0900],
        [0.0700, 0.0800, 0.0900, 0.1000],
        [0.0800, 0.0900, 0.1000, 0.1100]], grad_fn=<EmbeddingBackward0>)


In [81]:
tok_emb = wte(x_2)
print(tok_emb.shape, tok_emb)

torch.Size([2, 8, 4]) tensor([[[ 0.4489,  0.1732,  0.3238,  0.5542],
         [-0.1568,  0.8589,  0.3788, -0.3808],
         [ 0.2246,  0.5860,  0.4183,  0.1510],
         [-0.1997,  0.7987,  0.3129, -0.4519],
         [-0.2459,  0.7589,  0.2750, -0.4880],
         [-0.1036,  0.9089,  0.4277, -0.3330],
         [ 0.1951,  0.5660,  0.4016,  0.1373],
         [-0.1568,  0.8589,  0.3788, -0.3808]],

        [[ 0.0051,  0.3760,  0.2116, -0.0527],
         [-0.1997,  0.7987,  0.3129, -0.4519],
         [-0.0449,  0.3260,  0.1616, -0.1027],
         [-0.1036,  0.9089,  0.4277, -0.3330],
         [-0.2459,  0.7589,  0.2750, -0.4880],
         [-0.3050,  0.7089,  0.2282, -0.5320],
         [ 0.0283,  0.3960,  0.2305, -0.0348],
         [ 0.1106,  0.4760,  0.3097,  0.0437]]], grad_fn=<EmbeddingBackward0>)


In [82]:
x = tok_emb + pos_emb
print(x)

tensor([[[ 0.4589,  0.1932,  0.3538,  0.5942],
         [-0.1368,  0.8889,  0.4188, -0.3308],
         [ 0.2546,  0.6260,  0.4683,  0.2110],
         [-0.1597,  0.8487,  0.3729, -0.3819],
         [-0.1959,  0.8189,  0.3450, -0.4080],
         [-0.0436,  0.9789,  0.5077, -0.2430],
         [ 0.2651,  0.6460,  0.4916,  0.2373],
         [-0.0768,  0.9489,  0.4788, -0.2708]],

        [[ 0.0151,  0.3960,  0.2416, -0.0127],
         [-0.1797,  0.8287,  0.3529, -0.4019],
         [-0.0149,  0.3660,  0.2116, -0.0427],
         [-0.0636,  0.9589,  0.4877, -0.2630],
         [-0.1959,  0.8189,  0.3450, -0.4080],
         [-0.2450,  0.7789,  0.3082, -0.4420],
         [ 0.0983,  0.4760,  0.3205,  0.0652],
         [ 0.1906,  0.5660,  0.4097,  0.1537]]], grad_fn=<AddBackward0>)


### Transformer - Layer Normalization
If you recall, the values here were all uniform. The updated weights and changed input now result in changes, showing the impact of our learning.

In [83]:
x_norm = ln_1(x)
print(x_norm.shape, x_norm)

torch.Size([2, 8, 4]) tensor([[[ 0.4014, -1.4095, -0.3151,  1.3233],
         [-0.7243,  1.4176,  0.4360, -1.1294],
         [-0.8080,  1.4090,  0.4675, -1.0685],
         [-0.6893,  1.4190,  0.4241, -1.1538],
         [-0.7019,  1.4186,  0.4284, -1.1451],
         [-0.7176,  1.4179,  0.4338, -1.1341],
         [-0.8614,  1.4028,  0.4850, -1.0264],
         [-0.7243,  1.4176,  0.4360, -1.1294]],

        [[-0.8614,  1.4028,  0.4850, -1.0264],
         [-0.6893,  1.4190,  0.4241, -1.1538],
         [-0.8614,  1.4028,  0.4850, -1.0264],
         [-0.7176,  1.4179,  0.4338, -1.1341],
         [-0.7019,  1.4186,  0.4284, -1.1451],
         [-0.7206,  1.4178,  0.4348, -1.1320],
         [-0.8436,  1.4051,  0.4792, -1.0407],
         [-0.8310,  1.4066,  0.4751, -1.0507]]],
       grad_fn=<NativeLayerNormBackward0>)


### Transformer - Flash Attention

In [84]:
B, T, C = x_norm.size()
qkv = c_attn(x_norm)
print(qkv.shape, qkv)

torch.Size([2, 8, 12]) tensor([[[ 1.9300e-03,  3.4587e-03,  4.5860e-03,  5.3119e-03,  5.6364e-03,
           5.5596e-03,  5.0814e-03,  4.2018e-03,  2.9209e-03,  1.2385e-03,
          -8.4517e-04, -3.3303e-03],
         [-1.0984e-03, -1.4724e-03, -1.1222e-03, -4.7605e-05,  1.7513e-03,
           4.2744e-03,  7.5219e-03,  1.1494e-02,  1.6190e-02,  2.1610e-02,
           2.7755e-02,  3.4624e-02],
         [-8.6149e-04, -9.1494e-04, -1.6036e-04,  1.4023e-03,  3.7729e-03,
           6.9516e-03,  1.0938e-02,  1.5733e-02,  2.1336e-02,  2.7747e-02,
           3.4965e-02,  4.2992e-02],
         [-1.1943e-03, -1.6994e-03, -1.5151e-03, -6.4166e-04,  9.2108e-04,
           3.1731e-03,  6.1143e-03,  9.7448e-03,  1.4065e-02,  1.9074e-02,
           2.4772e-02,  3.1160e-02],
         [-1.1600e-03, -1.6181e-03, -1.3743e-03, -4.2868e-04,  1.2188e-03,
           3.5682e-03,  6.6194e-03,  1.0372e-02,  1.4827e-02,  1.9984e-02,
           2.5843e-02,  3.2403e-02],
         [-1.1168e-03, -1.5159e-03, -1.197

In [85]:
q,k,v = qkv.split(n_embd, dim=2)
q = q.view(B, T, n_head, C // n_head).transpose(1, 2)
k = k.view(B, T, n_head, C // n_head).transpose(1, 2)
v = v.view(B, T, n_head, C // n_head).transpose(1, 2)
att = att.masked_fill(mask[:T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = dropout(att)
fa = att @ v
fa = fa.transpose(1, 2).contiguous().view(B, T, C)
x_norm = c_proj(fa)
print(x_norm.shape, x_norm)

torch.Size([2, 8, 4]) tensor([[[0.0003, 0.0007, 0.0012, 0.0017],
         [0.0003, 0.0007, 0.0011, 0.0017],
         [0.0003, 0.0007, 0.0013, 0.0019],
         [0.0003, 0.0008, 0.0013, 0.0019],
         [0.0003, 0.0007, 0.0011, 0.0017],
         [0.0003, 0.0007, 0.0012, 0.0017],
         [0.0003, 0.0008, 0.0013, 0.0019],
         [0.0003, 0.0006, 0.0011, 0.0016]],

        [[0.0003, 0.0007, 0.0012, 0.0018],
         [0.0003, 0.0006, 0.0011, 0.0016],
         [0.0003, 0.0007, 0.0011, 0.0016],
         [0.0004, 0.0008, 0.0014, 0.0021],
         [0.0003, 0.0007, 0.0011, 0.0017],
         [0.0003, 0.0006, 0.0010, 0.0015],
         [0.0003, 0.0007, 0.0011, 0.0017],
         [0.0004, 0.0008, 0.0014, 0.0020]]], grad_fn=<ViewBackward0>)


### Transformer - Residual connection

In [86]:
x = x + x_norm
print(x.shape, x)

torch.Size([2, 8, 4]) tensor([[[ 0.4592,  0.1939,  0.3549,  0.5959],
         [-0.1366,  0.8895,  0.4199, -0.3292],
         [ 0.2550,  0.6268,  0.4696,  0.2129],
         [-0.1594,  0.8495,  0.3742, -0.3800],
         [-0.1956,  0.8195,  0.3462, -0.4063],
         [-0.0433,  0.9795,  0.5088, -0.2413],
         [ 0.2654,  0.6468,  0.4929,  0.2392],
         [-0.0766,  0.9495,  0.4799, -0.2692]],

        [[ 0.0154,  0.3967,  0.2428, -0.0109],
         [-0.1794,  0.8293,  0.3539, -0.4003],
         [-0.0146,  0.3667,  0.2127, -0.0411],
         [-0.0632,  0.9597,  0.4891, -0.2609],
         [-0.1956,  0.8195,  0.3462, -0.4063],
         [-0.2448,  0.7795,  0.3092, -0.4405],
         [ 0.0986,  0.4767,  0.3216,  0.0669],
         [ 0.1909,  0.5669,  0.4111,  0.1557]]], grad_fn=<AddBackward0>)


### Transformer - Layer Normalization 2

In [87]:
x_norm_2 = ln_2(x)
print(x_norm_2.shape, x_norm_2)

torch.Size([2, 8, 4]) tensor([[[ 0.3960, -1.4089, -0.3132,  1.3260],
         [-0.7259,  1.4175,  0.4366, -1.1282],
         [-0.8131,  1.4085,  0.4692, -1.0646],
         [-0.6911,  1.4189,  0.4248, -1.1526],
         [-0.7035,  1.4185,  0.4290, -1.1440],
         [-0.7193,  1.4179,  0.4343, -1.1329],
         [-0.8664,  1.4021,  0.4867, -1.0224],
         [-0.7258,  1.4175,  0.4366, -1.1283]],

        [[-0.8661,  1.4021,  0.4866, -1.0226],
         [-0.6908,  1.4189,  0.4246, -1.1528],
         [-0.8658,  1.4022,  0.4865, -1.0229],
         [-0.7196,  1.4178,  0.4344, -1.1327],
         [-0.7035,  1.4185,  0.4290, -1.1440],
         [-0.7220,  1.4177,  0.4352, -1.1310],
         [-0.8482,  1.4045,  0.4807, -1.0370],
         [-0.8365,  1.4059,  0.4769, -1.0463]]],
       grad_fn=<NativeLayerNormBackward0>)


### Transformer - Feed Forward

In [88]:
x_mlp = mlp_fc(x_norm_2)
x_mlp = mlp_gelu(x_mlp)
x_mlp = mlp_proj(x_mlp)
print(x_mlp.shape, x_mlp)

torch.Size([2, 8, 4]) tensor([[[ 0.0930,  0.1502,  0.0723, -0.0318],
         [ 0.0955,  0.1554,  0.0803, -0.0209],
         [ 0.0961,  0.1566,  0.0822, -0.0182],
         [ 0.0953,  0.1549,  0.0795, -0.0219],
         [ 0.0954,  0.1550,  0.0798, -0.0215],
         [ 0.0955,  0.1553,  0.0801, -0.0211],
         [ 0.0965,  0.1574,  0.0834, -0.0165],
         [ 0.0955,  0.1554,  0.0803, -0.0209]],

        [[ 0.0965,  0.1574,  0.0834, -0.0165],
         [ 0.0953,  0.1549,  0.0795, -0.0219],
         [ 0.0965,  0.1574,  0.0834, -0.0165],
         [ 0.0955,  0.1553,  0.0802, -0.0210],
         [ 0.0954,  0.1550,  0.0798, -0.0215],
         [ 0.0955,  0.1553,  0.0802, -0.0210],
         [ 0.0963,  0.1571,  0.0830, -0.0171],
         [ 0.0963,  0.1569,  0.0828, -0.0174]]], grad_fn=<ViewBackward0>)


### Transformer - Residual (skip) connection 2

In [89]:
x = x_mlp + x
print(x.shape, x)

torch.Size([2, 8, 4]) tensor([[[ 0.5522,  0.3441,  0.4272,  0.5640],
         [-0.0411,  1.0449,  0.5002, -0.3500],
         [ 0.3511,  0.7834,  0.5518,  0.1947],
         [-0.0641,  1.0044,  0.4537, -0.4019],
         [-0.1002,  0.9746,  0.4260, -0.4278],
         [ 0.0522,  1.1348,  0.5890, -0.2623],
         [ 0.3619,  0.8041,  0.5763,  0.2227],
         [ 0.0189,  1.1048,  0.5602, -0.2901]],

        [[ 0.1119,  0.5541,  0.3262, -0.0274],
         [-0.0841,  0.9842,  0.4335, -0.4222],
         [ 0.0818,  0.5240,  0.2961, -0.0576],
         [ 0.0322,  1.1149,  0.5692, -0.2820],
         [-0.1002,  0.9746,  0.4260, -0.4278],
         [-0.1493,  0.9348,  0.3894, -0.4615],
         [ 0.1949,  0.6338,  0.4047,  0.0498],
         [ 0.2872,  0.7238,  0.4939,  0.1383]]], grad_fn=<AddBackward0>)


### Transformer - Final Layer Normalization

In [90]:
x = ln_f(x)
print(x.shape, x)

torch.Size([2, 8, 4]) tensor([[[ 0.7862, -1.1882, -0.6622,  1.0513],
         [-0.9204,  0.7712,  0.1474, -1.7207],
         [-0.8302,  0.7696,  0.1216, -1.7819],
         [-0.8828,  0.7708,  0.1369, -1.7470],
         [-0.8962,  0.7710,  0.1406, -1.7377],
         [-0.9132,  0.7712,  0.1454, -1.7258],
         [-0.8842,  0.7709,  0.1368, -1.7456],
         [-0.9203,  0.7712,  0.1473, -1.7208]],

        [[-0.8839,  0.7709,  0.1367, -1.7458],
         [-0.8824,  0.7708,  0.1368, -1.7472],
         [-0.8836,  0.7709,  0.1366, -1.7460],
         [-0.9136,  0.7712,  0.1455, -1.7255],
         [-0.8962,  0.7710,  0.1406, -1.7377],
         [-0.9162,  0.7712,  0.1462, -1.7237],
         [-0.8657,  0.7706,  0.1316, -1.7582],
         [-0.8537,  0.7703,  0.1282, -1.7663]]],
       grad_fn=<NativeLayerNormBackward0>)


### Output Layers AKA Model Head

In [91]:
logits = lm_head(x)
print(logits.shape, logits)

torch.Size([2, 8, 36]) tensor([[[-1.7925,  0.5197,  0.5196,  0.5195,  0.5193, -1.7902, -0.6377,
           0.5189,  0.5188, -1.7883,  0.5185, -0.6383,  0.5183, -0.6376,
           0.5180, -1.7950,  0.5178,  0.5176,  0.5175,  0.5174, -1.7946,
          -0.6379,  0.5170,  0.5169,  0.5167,  0.5166,  0.5165,  0.5163,
           0.5162,  0.5161, -0.6408,  0.5158, -0.6380,  0.5156,  0.5154,
           0.5153],
         [ 1.7765, -0.5998, -0.6170, -0.6342, -0.6514,  1.6918,  0.4933,
          -0.7031, -0.7203,  1.6235, -0.7548,  0.4072, -0.7892,  0.3733,
          -0.8237,  1.5178, -0.8582, -0.8754, -0.8926, -0.9098,  1.4322,
           0.2359, -0.9615, -0.9787, -0.9960, -1.0132, -1.0304, -1.0476,
          -1.0649, -1.0821,  0.0799, -1.1165,  0.0470, -1.1510, -1.1682,
          -1.1854],
         [ 1.7745, -0.6024, -0.6196, -0.6368, -0.6540,  1.6912,  0.4908,
          -0.7056, -0.7228,  1.6238, -0.7572,  0.4048, -0.7917,  0.3714,
          -0.8261,  1.5158, -0.8605, -0.8777, -0.8949, -0.912

### Updated Loss calculation

Now we'll calculate the updated loss.  Our first pass's loss was 3.5835. Since we're passing through the same example and used a fairly high learning rate we should see a significant improvement with just 1 learning pass. 

In [92]:
loss

tensor(3.5835, grad_fn=<NllLossBackward0>)

In [93]:
y_flat = y.view(-1)
logits_flat = logits.view(-1, logits.size(-1))
updated_loss = F.cross_entropy(logits_flat, y_flat)
print(updated_loss.shape, updated_loss)

torch.Size([]) tensor(2.8740, grad_fn=<NllLossBackward0>)


In [94]:
f'1 round of training resulted in an loss improvment of {loss.item() - updated_loss.item():.4f}'

'1 round of training resulted in an loss improvment of 0.7095'

# SUCCESS!
Our training improved the loss by about **~26%** (amount may vary since we didn't set a seed). There are flaws with this, mainly passing the same example through a second time, but this helps show the fundamentals of what learning does inside a GPT-2 style model. 