# CNN ResNet Explainer

Convolutional Neural Nets, or CNNs, learn the pattern in data by sliding small learnable filters across the input data to spot local patterns, like short “n-gram” features, and turns them into higher-level signals. Because of this, CNNs are most commonly used for analyzing spatially structured data, like images or videos, because they can efficiently learn local patterns such as edges, textures, and shapes. They are also used in natural language processing and time-series tasks, where the same idea of sliding filters helps capture local dependencies in text or signal data. Modern architectures extend CNNs to higher-level tasks such as object detection, segmentation, and even audio or biological sequence modeling.

For our example we will be taking text and using our embedding layer to add a second dimension to it for the CNN to learn across.  Recall that a discrete convolution of 2 matrices results in summation of a series of element-wise dot products. 
$$
(a * b)_n = \sum_{\substack{i,j\\ i+j=n}} a_i \cdot b_j
$$

In our example, the embedding of the input sequence is, $A$. We pad $A$ so the output has the same length as the input. The learnable kernel weights are (B). Each output value is the dot product between a local patch of $A$ and $B$ running our discrete cross-correlation. We also include the stride controls how far the kernel window moves along $A$ to show how we can downsample A. 

Because of this, we actually run a different calculation, similar to a convolution called the 2-D discrete cross-correlation. With the input reshaped to $[B,C,1,T]$ and a $1\times k$ kernel, each output token index $t$ is
$$
y_{t}=\sum_{c=1}^{C}\sum_{u=0}^{k-1} W_{c,u} x_{c,,t+u}\quad
$$


To help display how the CNNs works, we'll actually use the c-major note letters for 4 popular songs: [Hot Cross Buns](https://en.wikipedia.org/wiki/Hot_Cross_Buns_(song)), [Twinkle Twinkle Little Star](https://en.wikipedia.org/wiki/Twinkle,_Twinkle,_Little_Star), and [Happy Birthday To You](https://en.wikipedia.org/wiki/Happy_Birthday_to_You), [Mary Had a Little Lamb](https://en.wikipedia.org/wiki/Mary_Had_a_Little_Lamb), and [Frère Jacques](https://en.wikipedia.org/wiki/Fr%C3%A8re_Jacques). 

In today's notebooks we'll take in 2 different examples and predict the next note from them. In other notebooks you might have seen that we predicted many examples in each batch during a loop. Since we are using our `input X embedding`, we'll just have a single example in each batch. 

## Text Prep/Tokenization

We'll start with a common preprocessing step of tokenizing the data.  This converts the string text into an array of numbers that can be used during the training loop.  I've built a very small byte-pair encoding that has each unique character that appears and the top 6 merges to give us a total of 15 tokens in our vocab. This keeps our vocab size small and manageable for this example. Typically the vocab size is in the 100K+ range. A great library for this is `tiktoken`. Tokenization simply finds the longest pattern of characters that's in common with what was trained and replaces it with an integer that represents it.  This way we turn the text into a numeric array to simplify computing. import torch
from collections import Counter

In [1]:
import torch
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SimpleBPETokenizer:
    def __init__(self, num_merges=5, eot_token='<|endoftext|>'):
        self.num_merges = num_merges
        self.eot_token = eot_token
        self.eot_id = None
        self.merges = []
        self.pair_ranks = {}
        self.vocab = {}
        self.id_to_token = {}

    def _add_token(self, tok):
        if tok in self.vocab:
            return self.vocab[tok]
        i = len(self.vocab)
        self.vocab[tok] = i
        self.id_to_token[i] = tok
        return i

    def _get_bigrams(self, seq):
        for i in range(len(seq) - 1):
            yield (seq[i], seq[i + 1])

    def _merge_once(self, seq, pair):
        a, b = pair
        out = []
        i = 0
        while i < len(seq):
            if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        return out

    def train(self, corpus):
        # corpus: list[str]
        text = ''.join(corpus).lower()
        seq = list(text)
        merges = []
        for _ in range(self.num_merges):
            counts = Counter(self._get_bigrams(seq))
            if not counts: break
            best_pair, _ = counts.most_common(1)[0]
            merges.append(best_pair)
            seq = self._merge_once(seq, best_pair)
        self.merges = merges
        self.pair_ranks = {p: i for i, p in enumerate(self.merges)}

        self.vocab = {}
        self.id_to_token = {}
        for ch in sorted(set(text)):
            self._add_token(ch)
        for a, b in self.merges:
            self._add_token(a + b)
        self.eot_id = self._add_token(self.eot_token)

    def encode(self, text, force_last_eot=True):
        # treat literal eot marker as special; remove it from content
        if self.eot_token in text:
            text = text.replace(self.eot_token, '')
        seq = list(text)

        # make sure all seen base chars exist
        for ch in set(seq):
            if ch not in self.vocab:
                self._add_token(ch)

        # greedy BPE using learned pair ranks
        if self.merges:
            while True:
                best_pair, best_rank = None, None
                for p in self._get_bigrams(seq):
                    r = self.pair_ranks.get(p)
                    if r is not None and (best_rank is None or r < best_rank):
                        best_pair, best_rank = p, r
                if best_pair is None:
                    break
                seq = self._merge_once(seq, best_pair)

        # ensure all tokens in seq exist in vocab (e.g., if new chars appeared)
        for tok in seq:
            if tok not in self.vocab:
                self._add_token(tok)

        ids = [self.vocab[tok] for tok in seq]

        # FORCE: append EOT id if not already last
        if force_last_eot:
            if not ids or ids[-1] != self.eot_id:
                ids.append(self.eot_id)

        return ids

    def decode(self, ids):
        # drop trailing EOT if present
        if ids and self.eot_id is not None and ids[-1] == self.eot_id:
            ids = ids[:-1]
        toks = [self.id_to_token[i] for i in ids]
        return ''.join(toks)


In [3]:
twinkle_twinkle = r'CCGGAAG,FFEEDDC,GGFFEED,GGFFEED,CCGGAAG,FFEEDDC'
hot_cross_buns = r'EDC,EDC,CCCC,DDDD,EDC'
happy_birthday = r'GGAGCB,GGAGDC,GGGECBA,FFECDC'
mary_had_a_little_lamb = r'EDCDEEE,DDD,EGG,EDCDEEE,EDD,EDC'
frere_jacques = r'CDEC,CDEC,EFG,EFG,GAGFEC,GAGFEC,CGC,CGC'

In [4]:
tok = SimpleBPETokenizer(num_merges=6)
examples = [twinkle_twinkle,hot_cross_buns, happy_birthday, mary_had_a_little_lamb, frere_jacques]
tok.train(examples)
tok.merges

[('e', 'd'), ('c', ','), ('g', 'g'), ('f', 'e'), ('a', 'g'), ('c', 'c')]

In [5]:
tok.vocab

{',': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'ed': 8,
 'c,': 9,
 'gg': 10,
 'fe': 11,
 'ag': 12,
 'cc': 13,
 '<|endoftext|>': 14}

In [6]:
vocab_size = len(tok.vocab)
vocab_size

15

In [7]:
eot = tok.eot_id
tokens = []
for example in examples:
    tokens.extend([eot])
    tokens.extend(tok.encode(example.lower()))
all_tokens = torch.tensor(tokens, dtype=torch.long)
all_tokens

tensor([14, 13, 10,  1, 12,  0,  6, 11,  8,  4,  9, 10,  6, 11,  8,  0, 10,  6,
        11,  8,  0, 13, 10,  1, 12,  0,  6, 11,  8,  4,  3, 14, 14,  8,  9,  8,
         9, 13,  3,  9,  4,  4,  4,  4,  0,  8,  3, 14, 14, 10, 12,  3,  2,  0,
        10, 12,  4,  9, 10,  7,  5,  3,  2,  1,  0,  6, 11,  3,  4,  3, 14, 14,
         8,  3,  4,  5,  5,  5,  0,  4,  4,  4,  0,  5, 10,  0,  8,  3,  4,  5,
         5,  5,  0,  8,  4,  0,  8,  3, 14, 14,  3,  4,  5,  9,  3,  4,  5,  9,
         5,  6,  7,  0,  5,  6,  7,  0,  7, 12, 11,  9,  7, 12, 11,  9,  3,  7,
         9,  3,  7,  3, 14])

# Modeling

A machine learning model forward pass now uses the tokenization information, runs several layers of linear algebra on it, and then "predicts" the probability that each token in the vocab is next. When it is noisy (like you will see in this example), this process results in gibberish.  The training process changes the noise to pattern during the "backward pass" as you'll see. We'll show 3 steps that are focused on training:
1. **Data Loading** `x, y = train_loader.next_batch()` - this step pulls from the raw data enough tokens to complete a forward pass and loss calculation.  If the model is inference only, this step is replaced with taking in the inference input and preparing it similarly as the forward pass.
2. **Forward Pass** `logits, loss = model(x, y)` - using the data and the model architecture we run a prediction for the tokens. When training we also compare against the expected to get loss, but in inference, we use the logits to complete the inference task.
3. **Back Propagation, aka Backward Pass & Training** `loss.backward(); optimizer.step()` - using differentials to understand what parameters most impact the forward pass' impact on its prediction, comparing that against what is actually right based on the data loading step, and then making very minor adjustments to the impactful parameters with the hope it improves future predictions.

The we'll show a final **Forward Pass** with the updated weights we did in #3. 

## Data Loading

To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to be held in memory all at once in real practice, we will read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 
To start, we have to identify the batch size and the model context length to determine how much data we need.  Consequently, these dimensions also form 2 of the 3 dimensions in the initial matrix.
- **Batch Size (B)** - This is the number of examples you'll train on in a single pass. 
- **Context Length (T)** - This is the max number of tokens that a model can use in a single pass to generate the next token. If an example is below this length, it can be padded.
  
*Ideally both B and T are multiples of 2 to work nicely with chip architecture. This is a common theme across the board*

In [8]:
B_batch = 2 # Batch
T_context = 8 # context length

To start, we need to pull from our long raw_token list enough tokens for the forward pass. To be able to satisfy training `B_batch` Batches `T_context` context length, we need to pull out `B*T` tokens to slide the context window across the examples enough to satisfy the batch size. Since our training will attempt to predict the next token after the context, we also need 1 more token at the end so that the last last batch can have the next token to validate against. 

In [9]:
current_position = 0
tok_for_training = all_tokens[current_position:current_position + B_batch*T_context +1 ]
tok_for_training

tensor([14, 13, 10,  1, 12,  0,  6, 11,  8,  4,  9, 10,  6, 11,  8,  0, 10])

Now that we have our initial tokens to train on, we now need to convert it to a matrix that's ready for training. In this step we'll need to create our batches and setup two different arrays: 1/ the input, `x`, tokens that will result in 2/ the output `y` tokens. To create each example in the batch, every `T` tokens will be placed into its own row. 

Recall that training takes in a string of tokens the length of the context and then predicts the next token. Recall that when we extracted `tok_for_training` we added 1 extra token so that we can evaluate the prediction for the last example. Because of this, the input, `x`, will be all of the tokens up to the second to last element `[:-1]`.  


Finally, for `y` we will need to extract a token for every batch. That token will be the one immediately following the context length or every token at positions `B*T_context +1` where B corresponds to a multiple of every batch. 

We will now put this together and do the following:
1. Extract the input `x` and then split it into an example for each batch `B`
2. Extract the output `y` and then split it into an example for each batch `B`

*Note: View can take `-1` which allows the matrix to infer the dimension so we do not need to pass in `T`, but given how many matrices we'll work with we want to make sure we're controlling the dimensions or erroring out if they do not match our expectations.*

In [10]:
x=tok_for_training[:-1].view(B_batch, T_context)
x.size(), x

(torch.Size([2, 8]),
 tensor([[14, 13, 10,  1, 12,  0,  6, 11],
         [ 8,  4,  9, 10,  6, 11,  8,  0]]))

In [11]:
tok_for_training

tensor([14, 13, 10,  1, 12,  0,  6, 11,  8,  4,  9, 10,  6, 11,  8,  0, 10])

In [12]:
y=tok_for_training[T_context::T_context].view(B_batch, 1)
y.size(), y

(torch.Size([2, 1]),
 tensor([[ 8],
         [10]]))

## Forward pass

<img src="explainer_screenshots/cnn/full_network.png" width="200">


During training, in the CNN we've built, the forward pass takes a string of tokens in and predicts the likelihood of the next token for each batch. This is different from the other models we've used as there's only a single example in each batch. This is mainly because CNNs do best with multi-dimension data and so we're hacking our text input for this explainer by using our `text x embedding` to be our 2 dimensions, instead of an image or other 2d data. 

This explainer for the forward pass is focused on training where we'll pass in the input `x`, carry that input through the layers, and generate a matrix of the probability of each token being the next one, something we call `logits`. During the forward pass, since this is an CNN, we will actually pass each example through different convolution layers and even show downsampling, which reduces our matrix size. 

At the end of the forward pass we then compare the probability in the logits to the actual next token in `y` and calculate `loss` based on the difference. This difference is what we'll then use in the backprop/training steps.  

*Note that we will do some layer initialization to simplify following along.  In reality, layers are often initialized to normal distribution with some adjustments made for parameter sizes, called Kaiming normal, to keep the weights properly noisy.  We will not cover initialization in this series*

In [13]:
B_batch, T_context

(2, 8)

### Input Layer

<img src="explainer_screenshots/cnn/input_layer.png" width="200">

We'll first create an initial **embedding layer** for our input tokens. Recall that this is the layer that will add the second dimension to our text examples. We start with only supplying our embedding positions, though, if we wanted to add more learning capability, we could also do position.  Since CNNs generally take in multi-dimension examples and then use multi-dimension patches for learning in the convolutional layers, position is generally avoided since the goal would be to learn patterns in the data regardless of the position. We will make sure that our embedding weights are larger than 1 to visualize the convolutions well.  The output becomes `vocab_size X n_embd` so that each position can store weights that correspond with each token.  The more embedding layers added the more complex data the model can learn. 

After the embedding layer we'll then insert in the fourth dimension of 1 to better suit our convolutional layers.

#### Input Layer - Embedding Projection

To start we'll initialize our embeddings with an iterative weight so that we can see how it changes through our convolutions.  
of 1.000 so that all inputs are equally weighted. We'll also set our embedding dimension to 6 to allow us to see how our convolution strides across the embedding dimension.  You'll see that because our `x` plucks our different embedding rows, we are quickly adjusting away from the nicely ordered initial embeddings.  

In [14]:
n_embd = 6 # level of embedding of input tokens
n_embd, vocab_size

(6, 15)

In [15]:
wte = nn.Embedding(vocab_size, n_embd)
with torch.no_grad(): # initilize to W[i,j] = 0.001*(1+i+j) for easy following 
    vs, d = wte.num_embeddings, wte.embedding_dim
    rows = torch.arange(vs).unsqueeze(1)  # (vs,1)
    cols = torch.arange(d).unsqueeze(0)  # (1,d)
    pattern = 0.01*(1 + rows + cols)  # W[i,j] = 0.001*(1+i+j)
    wte.weight.copy_(pattern)
wte.weight

Parameter containing:
tensor([[0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600],
        [0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700],
        [0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800],
        [0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900],
        [0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000],
        [0.0600, 0.0700, 0.0800, 0.0900, 0.1000, 0.1100],
        [0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200],
        [0.0800, 0.0900, 0.1000, 0.1100, 0.1200, 0.1300],
        [0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400],
        [0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500],
        [0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600],
        [0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700],
        [0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800],
        [0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900],
        [0.1500, 0.1600, 0.1700, 0.1800, 0.1900, 0.2000]], requires_grad=True)

In [16]:
x = wte(x)
x.shape, x

(torch.Size([2, 8, 6]),
 tensor([[[0.1500, 0.1600, 0.1700, 0.1800, 0.1900, 0.2000],
          [0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900],
          [0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600],
          [0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700],
          [0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800],
          [0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600],
          [0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200],
          [0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700]],
 
         [[0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400],
          [0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000],
          [0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500],
          [0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600],
          [0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200],
          [0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700],
          [0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400],
          [0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600]]],
        gra

#### Input Layer - Add Dimension

We projected our input tokens `x` that was `[B×T]` into the embedding to get `[B×T×C]` so that we now have our `T×C` for each batch. To run our convolution per batch, though, we also need a spatial dimension for the kernel to slide over. PyTorch-style convolution layers expect tensors in `[B, C, H, W]` (channels-first), where the kernel slides over `H,W` while mixing across `C`. Because of this we add a singleton spatial dimension and reorder axes. With this process, the embedding dimension `C` becomes the channels and the token axis `T` becomes the width to slide across:

`[B, T, C]  →  [B, C, T]  →  [B, C, 1, T]`

The convolution we show is a `1×k` convolution which slides only along our tokens `T`, and aggregates over all `C` channels at each position.

In [17]:
x = x.permute(0,2,1) # [B,C,T]
x = x.unsqueeze(2)  # [B,C,1,T]
x.size(), x

(torch.Size([2, 6, 1, 8]),
 tensor([[[[0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700, 0.1200]],
 
          [[0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800, 0.1300]],
 
          [[0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900, 0.1400]],
 
          [[0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000, 0.1500]],
 
          [[0.1900, 0.1800, 0.1500, 0.0600, 0.1700, 0.0500, 0.1100, 0.1600]],
 
          [[0.2000, 0.1900, 0.1600, 0.0700, 0.1800, 0.0600, 0.1200, 0.1700]]],
 
 
         [[[0.0900, 0.0500, 0.1000, 0.1100, 0.0700, 0.1200, 0.0900, 0.0100]],
 
          [[0.1000, 0.0600, 0.1100, 0.1200, 0.0800, 0.1300, 0.1000, 0.0200]],
 
          [[0.1100, 0.0700, 0.1200, 0.1300, 0.0900, 0.1400, 0.1100, 0.0300]],
 
          [[0.1200, 0.0800, 0.1300, 0.1400, 0.1000, 0.1500, 0.1200, 0.0400]],
 
          [[0.1300, 0.0900, 0.1400, 0.1500, 0.1100, 0.1600, 0.1300, 0.0500]],
 
          [[0.1400, 0.1000, 0.1500, 0.1600, 0.1200, 0.1700, 0.1400, 0.0600]]]],
        gr

### Convolution Block

<img src="explainer_screenshots/cnn/convolutional_layers.png" width="400">

As is common in CNNs, we use multiple convolution layers with normalization and nonlinearity to learn increasingly expressive features from the input. Each convolution “looks” at a local patch whose size and stride we choose; stacking layers (sequentially). We also use residual skips to let the model capture richer patterns and relationships.

In our model, our input to the convolution is $[B,C,1,T]$ with a $1\times k$ kernel. The convolution runs as 2-D discrete cross-correlation along the token axis. For output channel $m$,
$$
y^{m}_{t}=\sum_{c=1}^{C}\sum_{u=0}^{k-1} W^{m}_{c,u}x_{c,ts+u-p}+b^{m}.
$$

Under the hood we:

1. Build the matrix of local patches $P\in\mathbb{R}^{(Ck)\times L}$ by extracting all sliding $1\times k$ windows; $L$ is the number of output positions.
2. Flatten the kernel bank into $W_{\text{flat}}\in\mathbb{R}^{C_{\text{out}}\times (Ck)}$.
3. Compute all positions at once: $Y = W_{\text{flat}},P \in \mathbb{R}^{C_{\text{out}}\times L}$ independently for each batch element, then reshape back to $[B,C_{\text{out}},1,T_{\text{out}}]$.

We interleave batch normalization and ReLU to stabilize activations, improve gradient flow, and add nonlinearity. 

The second convolution in the block downsamples with stride 2, reducing the token length $T\to\lceil T/2\rceil$. This both cuts compute and expands the effective receptive field of subsequent layers, helping the model capture longer-range patterns over the sequence.


Finally, as a nod to ResNets, the convolutional block also uses a residual path. For this path we add a projected skip $S(x)$ to the main path $F(x)$, yielding $y=F(x)+S(x)$.  Since we used downsampling on our main path, the residual path also uses a $1\times 1$ projection with stride 2 downsample so dimensions of the residual path match that of the convolutional block output. 

#### Convolution Block - 1x3 Conv

##### 1x3 Conv - Initialize weights
Our first convolutional block uses a kernel width of `(1,3)`, a stride of `(1,1)` and padding both at the start and end of the token dimension so that we can slide across all entries. For this first convolution layer we'll go through step by step showing how the convolution is built.  

To start, we will configure our weights to be based on the channel dimension, currently equal to our embedding, and our kernel. By matching the kernel we allow the layer to learn what parts of the kernel are more important for our final prediction. 

We'll also initialize our weights to be iterative so that we can see the impact clearly as they interact with our input

In [18]:
c1_kernel_height = 1
c1_kernel_width = 3
c1_stride_height = 1
c1_stride_width = 1
c1_padding_height = 0
c1_padding_width = 1
{'conv 1 kernel': (c1_kernel_height, c1_kernel_width),
 'conv 1 stride': (c1_stride_height, c1_stride_width),
 'conv 1 padding': (c1_padding_height, c1_padding_width)}


{'conv 1 kernel': (1, 3), 'conv 1 stride': (1, 1), 'conv 1 padding': (0, 1)}

In [19]:
## weight layer for convolution (similar to linear, just more explicit)
conv1 = nn.Parameter(
    torch.empty(n_embd, n_embd, c1_kernel_height, c1_kernel_width), 
    requires_grad=True)

In [20]:
# iniate rows as 0.1, 0.2, and 0.3 for easier view of the weight impact
with torch.no_grad():
    c1_pattern = torch.tensor([0.001,0.002,0.001]).view(1,1,1,c1_kernel_width).expand(conv1.size()).clone()
    conv1.copy_(c1_pattern)
conv1.size(), conv1

(torch.Size([6, 6, 1, 3]),
 Parameter containing:
 tensor([[[[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]]],
 
 
         [[[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]]],
 
 
         [[[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]]],
 
 
         [[[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0.0020, 0.0010]],
 
          [[0.0010, 0

##### 1x3 Conv - Run convolution

Now we'll calculate the 2-D discrete cross-correlation for our weight and input `x`.  Since we know we have a residual connection we'll branch `x` and rejoin it after the convolutional block. For our convolutional layer, in our step by step guide we'll do the following: 
1. Since we have padding, pad our channel
2. Flattens, or **unfolds** each sliding kernel_size-sized block within the spatial dimensions of input into a column (i.e., last dimension) of a 3-D output tensor of shape $(N,C*k_h*k_w,L)$
3. Stack our weights so that it is reused across batches meaning our learning benefits from both. 
4. Take the dot product of the unstacked input and the stacked weights and reshape the result back to our batch and channels.

**1x3 Conv - Step-by-step unfolding**

In particular we'll focus on step #2, as this specifically creates a sliding view that extracts a kernel size view across our input. By converting them to columns, when we do $W_{flat} \cdot X_{unfolded}$ the result is a sum of the row in the weight times what was previously a row in the input. Mentally, **unfold** linearizes all local receptive fields so you can do per-patch operations with a single batched matrix multiply. Convolution is exactly this with shared weights, hence the name.  After walking through step by step, we'll show you `F.unfold` a function that does the padding and unfolding for you and use it from there on out. 

In [21]:
batch = B_batch
channel = n_embd
height = 1
width = T_context
x.size(),batch, channel, height, width, 

(torch.Size([2, 6, 1, 8]), 2, 6, 1, 8)

**Calculate expected unfolded dimensions**  

To do our proper reshaping, we need to calculate the expected dimensions for our loop.  
Recall that we expect to go from $(B,C,1,T)$ to $(B,C*k_h*k_w,L)$ where $L$ is a flattening or our output height and width as follows: 
$$
\begin{align}
height_{out} &= (height + 2*pad_h - 1*(kernel_h-1) -1)\ //\ stride_{h}\\ 
width_{out} &= (width + 2*pad_w - 1*(kernel_w-1) -1)\ //\ stride_{w}\\
L &= height_{out} * width_{out}
\end{align}
$$

In [22]:
c1_khw = channel*c1_kernel_height*c1_kernel_width

c1_height_out = (height + 2*c1_padding_height - 1*(c1_kernel_height-1) - 1)//c1_stride_height + 1   # = 1, 
c1_width_out = (width + 2*c1_padding_width - 1*(c1_kernel_width-1) - 1)//c1_stride_width + 1   # = 4
c1_L = c1_height_out * c1_width_out

print(f'width out {c1_width_out}, height out {c1_height_out}, final dimension ({batch},{c1_khw},{c1_L})')

width out 8, height out 1, final dimension (2,18,8)


**Padding** 

We first start by padding.  Since we're using a stride of `(1,3)` we need to pad both the start and end of the tokens so that we can slide across it without losing an increment on the dimension. Padding simply adds `0` though we can add other values if we wanted.  When we pad on both sides we get output of `[2, 6, 1+0, 8+2]`

In [23]:
# pad last dim by (width, width) and 2nd to last by (height, height). width = 1, height = 0
c1_x_pad = F.pad(x, pad=(c1_padding_width,c1_padding_width,c1_padding_height,c1_padding_height))

c1_x_pad.size(), c1_x_pad #total size and show first example in batch 

(torch.Size([2, 6, 1, 10]),
 tensor([[[[0.0000, 0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700,
            0.1200, 0.0000]],
 
          [[0.0000, 0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800,
            0.1300, 0.0000]],
 
          [[0.0000, 0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900,
            0.1400, 0.0000]],
 
          [[0.0000, 0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000,
            0.1500, 0.0000]],
 
          [[0.0000, 0.1900, 0.1800, 0.1500, 0.0600, 0.1700, 0.0500, 0.1100,
            0.1600, 0.0000]],
 
          [[0.0000, 0.2000, 0.1900, 0.1600, 0.0700, 0.1800, 0.0600, 0.1200,
            0.1700, 0.0000]]],
 
 
         [[[0.0000, 0.0900, 0.0500, 0.1000, 0.1100, 0.0700, 0.1200, 0.0900,
            0.0100, 0.0000]],
 
          [[0.0000, 0.1000, 0.0600, 0.1100, 0.1200, 0.0800, 0.1300, 0.1000,
            0.0200, 0.0000]],
 
          [[0.0000, 0.1100, 0.0700, 0.1200, 0.1300, 0.0900, 0.1400, 0.1100,
            0.0300, 0.0000]],

**Manual Unfolding - First Stride** 

Now we will manually unfold our padded input.  The process of unfolding flattens each sliding kernel-sized block within the spatial dimensions of input into a column (i.e., last dimension) of a 3-D output tensor of shape $(N,C*k_h*k_w,L)$ 

We'll first start by pulling out the first patch.  Since we have a kernel of `(1,3)` we pull out the first 3 tokens from the first spatial dimension for each channel in each batch. 

In [24]:
step = 0
patch = c1_x_pad[:, :, step:c1_kernel_height, step:step+c1_kernel_width]
patch.size(), patch

(torch.Size([2, 6, 1, 3]),
 tensor([[[[0.0000, 0.1500, 0.1400]],
 
          [[0.0000, 0.1600, 0.1500]],
 
          [[0.0000, 0.1700, 0.1600]],
 
          [[0.0000, 0.1800, 0.1700]],
 
          [[0.0000, 0.1900, 0.1800]],
 
          [[0.0000, 0.2000, 0.1900]]],
 
 
         [[[0.0000, 0.0900, 0.0500]],
 
          [[0.0000, 0.1000, 0.0600]],
 
          [[0.0000, 0.1100, 0.0700]],
 
          [[0.0000, 0.1200, 0.0800]],
 
          [[0.0000, 0.1300, 0.0900]],
 
          [[0.0000, 0.1400, 0.1000]]]], grad_fn=<SliceBackward0>))

Now we need to stack our channels together. Since we want to make sure that eventually we can do a dot product of the weight and input where the weight column multiplies by the entry row, flattening our patches into a single entry gives us that.  

In [25]:
col = patch.reshape(batch, c1_khw)
col.size(), col

(torch.Size([2, 18]),
 tensor([[0.0000, 0.1500, 0.1400, 0.0000, 0.1600, 0.1500, 0.0000, 0.1700, 0.1600,
          0.0000, 0.1800, 0.1700, 0.0000, 0.1900, 0.1800, 0.0000, 0.2000, 0.1900],
         [0.0000, 0.0900, 0.0500, 0.0000, 0.1000, 0.0600, 0.0000, 0.1100, 0.0700,
          0.0000, 0.1200, 0.0800, 0.0000, 0.1300, 0.0900, 0.0000, 0.1400, 0.1000]],
        grad_fn=<UnsafeViewBackward0>))

Finally we want to make sure to save this since this is just the first pass of the patch. Let's create a list for now and store them. After we complete all the strides we can reshape our final output of the unfolded step to make each entry a column. 

In [26]:
manual_cols = []
manual_cols.append(col)

**Manual Unfolding - Second Stride** 

We now need to move our patch by the stride amount, in this case `(1,1)`. Using a stride of 1 on both dimensions ensures that we continue covering every input token in the example. As you'll see in future convolutions, changing the stride can downsample an input.  Let's start by again extracting the patch. You'll see that we just shifted to the "left" by 1 and took the next 3 columns in our input

In [27]:
step = 1
patch = c1_x_pad[:, :, 0:c1_kernel_height, step:step+c1_kernel_width]
patch.size(), patch

(torch.Size([2, 6, 1, 3]),
 tensor([[[[0.1500, 0.1400, 0.1100]],
 
          [[0.1600, 0.1500, 0.1200]],
 
          [[0.1700, 0.1600, 0.1300]],
 
          [[0.1800, 0.1700, 0.1400]],
 
          [[0.1900, 0.1800, 0.1500]],
 
          [[0.2000, 0.1900, 0.1600]]],
 
 
         [[[0.0900, 0.0500, 0.1000]],
 
          [[0.1000, 0.0600, 0.1100]],
 
          [[0.1100, 0.0700, 0.1200]],
 
          [[0.1200, 0.0800, 0.1300]],
 
          [[0.1300, 0.0900, 0.1400]],
 
          [[0.1400, 0.1000, 0.1500]]]], grad_fn=<SliceBackward0>))

we'll again flatten this the same as before

In [28]:
col = patch.reshape(batch, c1_khw)
col.size(), col

(torch.Size([2, 18]),
 tensor([[0.1500, 0.1400, 0.1100, 0.1600, 0.1500, 0.1200, 0.1700, 0.1600, 0.1300,
          0.1800, 0.1700, 0.1400, 0.1900, 0.1800, 0.1500, 0.2000, 0.1900, 0.1600],
         [0.0900, 0.0500, 0.1000, 0.1000, 0.0600, 0.1100, 0.1100, 0.0700, 0.1200,
          0.1200, 0.0800, 0.1300, 0.1300, 0.0900, 0.1400, 0.1400, 0.1000, 0.1500]],
        grad_fn=<UnsafeViewBackward0>))

and now add it to our list.  We can now see that we have entries for our first 2 steps already in the list

In [29]:
manual_cols.append(col)
manual_cols

[tensor([[0.0000, 0.1500, 0.1400, 0.0000, 0.1600, 0.1500, 0.0000, 0.1700, 0.1600,
          0.0000, 0.1800, 0.1700, 0.0000, 0.1900, 0.1800, 0.0000, 0.2000, 0.1900],
         [0.0000, 0.0900, 0.0500, 0.0000, 0.1000, 0.0600, 0.0000, 0.1100, 0.0700,
          0.0000, 0.1200, 0.0800, 0.0000, 0.1300, 0.0900, 0.0000, 0.1400, 0.1000]],
        grad_fn=<UnsafeViewBackward0>),
 tensor([[0.1500, 0.1400, 0.1100, 0.1600, 0.1500, 0.1200, 0.1700, 0.1600, 0.1300,
          0.1800, 0.1700, 0.1400, 0.1900, 0.1800, 0.1500, 0.2000, 0.1900, 0.1600],
         [0.0900, 0.0500, 0.1000, 0.1000, 0.0600, 0.1100, 0.1100, 0.0700, 0.1200,
          0.1200, 0.0800, 0.1300, 0.1300, 0.0900, 0.1400, 0.1400, 0.1000, 0.1500]],
        grad_fn=<UnsafeViewBackward0>)]

**Manual Unfolding - Remaining Strides** 

We'll now loop through the remaining steps for the manual unfolding to fill in the rest of the list.  This is the same set of steps done before, just in a loop but appending to the same list.  We'll start from 2 onward since we already did steps 0 and 1. 

In [30]:
for step in range(2,c1_width_out): 
    print(f'execting stride {step}')
    # extract step
    patch = c1_x_pad[:, :, 0:c1_kernel_height, step:step+c1_kernel_width]        # (2,6,1,3)
    
    # stack the entries in each batch together into a row
    col = patch.reshape(batch, c1_khw) # shape to [2,18]

    manual_cols.append(col)

manual_cols

execting stride 2
execting stride 3
execting stride 4
execting stride 5
execting stride 6
execting stride 7


[tensor([[0.0000, 0.1500, 0.1400, 0.0000, 0.1600, 0.1500, 0.0000, 0.1700, 0.1600,
          0.0000, 0.1800, 0.1700, 0.0000, 0.1900, 0.1800, 0.0000, 0.2000, 0.1900],
         [0.0000, 0.0900, 0.0500, 0.0000, 0.1000, 0.0600, 0.0000, 0.1100, 0.0700,
          0.0000, 0.1200, 0.0800, 0.0000, 0.1300, 0.0900, 0.0000, 0.1400, 0.1000]],
        grad_fn=<UnsafeViewBackward0>),
 tensor([[0.1500, 0.1400, 0.1100, 0.1600, 0.1500, 0.1200, 0.1700, 0.1600, 0.1300,
          0.1800, 0.1700, 0.1400, 0.1900, 0.1800, 0.1500, 0.2000, 0.1900, 0.1600],
         [0.0900, 0.0500, 0.1000, 0.1000, 0.0600, 0.1100, 0.1100, 0.0700, 0.1200,
          0.1200, 0.0800, 0.1300, 0.1300, 0.0900, 0.1400, 0.1400, 0.1000, 0.1500]],
        grad_fn=<UnsafeViewBackward0>),
 tensor([[0.1400, 0.1100, 0.0200, 0.1500, 0.1200, 0.0300, 0.1600, 0.1300, 0.0400,
          0.1700, 0.1400, 0.0500, 0.1800, 0.1500, 0.0600, 0.1900, 0.1600, 0.0700],
         [0.0500, 0.1000, 0.1100, 0.0600, 0.1100, 0.1200, 0.0700, 0.1200, 0.1300,
          0

**Manual Unfolding - Flatten List** 

Now that we've completed the patch extraction we have a list of tensors. We want to create a new tensor where we maintain the batch of 2 but convert our row length of 18 into the column dimension. We'll use stack to complete this and result in a `(2,18,8)` tensor, just like we calculated. 

In [31]:
# turn all the rows in the list into columns while maintaining the batch
manual_unfold = torch.stack(manual_cols, dim=2)  # (N, 18, 8)
manual_unfold.size(), manual_unfold

(torch.Size([2, 18, 8]),
 tensor([[[0.0000, 0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700],
          [0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700, 0.1200],
          [0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700, 0.1200, 0.0000],
          [0.0000, 0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800],
          [0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800, 0.1300],
          [0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800, 0.1300, 0.0000],
          [0.0000, 0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900],
          [0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900, 0.1400],
          [0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900, 0.1400, 0.0000],
          [0.0000, 0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000],
          [0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000, 0.1500],
          [0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000, 0.1500, 0.0000],
          [0.0000, 0.1900, 0.1800, 0.1500, 0.0600, 0.1700, 0.05

**Unfolding - Efficiently**

While the above is great for demonstration purposes, it eats up a lot of time and code space.  Let's switch to the help of a pytorch function `F.unfold`.  This unfold function does the same steps as above: padding, patch extraction, reshaping, stacking. 

Let's setup our unfold of the original input `x`.  We'll also do a comparison of the previous output `manual_unfold` with this functions output to demonstrate that it is in fact equal and we can use it going forward

In [32]:
c1_unfolded = F.unfold(x, 
		kernel_size=(c1_kernel_height, c1_kernel_width),  # (1,3)
		padding=(c1_padding_height, c1_padding_width), #(0,1)
		stride=(c1_stride_height, c1_stride_width))#(1,1)

print("manual equals unfold:", torch.allclose(c1_unfolded, manual_unfold))
c1_unfolded.size() , c1_unfolded

manual equals unfold: True


(torch.Size([2, 18, 8]),
 tensor([[[0.0000, 0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700],
          [0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700, 0.1200],
          [0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700, 0.1200, 0.0000],
          [0.0000, 0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800],
          [0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800, 0.1300],
          [0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800, 0.1300, 0.0000],
          [0.0000, 0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900],
          [0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900, 0.1400],
          [0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900, 0.1400, 0.0000],
          [0.0000, 0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000],
          [0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000, 0.1500],
          [0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000, 0.1500, 0.0000],
          [0.0000, 0.1900, 0.1800, 0.1500, 0.0600, 0.1700, 0.05

##### 1x3 Conv - $W\cdot X_{unfolded}$

Now that we have our unstacked patches, we can then let our network decide how much of the patch, and which part of the patch, influences our output.  To do this we take the dot product of the weight with the unfolded input.  We do have an issue though since our weight is `[6,6,1,3]` but our input is `[2x18x8]`. We will solve this simply by squeezing the last two dimensions of our Weights together to result in a `[6,18]` tensor that we can multiply.  

You might be now asking "what about the batch dimensions of 2".  We do want to make sure the 2 different batches actually share the same weight so we don't actually want to increase our weight dimension.  Instead we rely on the pytorch which broadcasts the same matrix to each of the batches in the input automatically.  This allows the two batches to share the weights. 

In [33]:
conv1_weigth = conv1.view(n_embd, -1) # [6,6,1,3] > [6,18]
conv1_weigth.size(), conv1_weigth

(torch.Size([6, 18]),
 tensor([[0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
          0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
         [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
          0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
         [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
          0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
         [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
          0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
         [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
          0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
         [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
          0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.002

Now that we have the weights in the dimension we want them we're ready to multiply them with the unfolded input.  Because in our weight matrix each "row" is the same, each of our column entries in the result will be equal.  We'll also get a final output of `[2,6,8]` compressing the 18 down. Also note that the batch dimension is maintained as the weight is broadcast across the batches. 

In [34]:
out = conv1_weigth @ c1_unfolded
out.size(), out

(torch.Size([2, 6, 8]),
 tensor([[[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023],
          [0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023],
          [0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023],
          [0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023],
          [0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023],
          [0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
         [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011],
          [0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011],
          [0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011],
          [0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011],
          [0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011],
          [0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]]],
        grad_fn=<CloneBackward0>))

Finally we need to resize our last dimension back to our target channel height and width.  Since our height is 1 it will just insert in another dimension of 1 without looking significantly different. 

In [35]:
out = out.view(batch,n_embd, c1_height_out, c1_width_out)
out.size(), out

(torch.Size([2, 6, 1, 8]),
 tensor([[[[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]]],
 
 
         [[[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0025, 0.0011]]]],
        gr

#### Convolution Block - First Batch Norm

With batch normalization (2D), we normalize each channel using statistics computed over the current mini-batch and spatial positions. Concretely, for inputs of shape $[N,C,H,W]$, batch normalization computes a per-channel mean and variance across ${B,C,1}$ and standardizes that channel. This helps to stabilize activations, act as a mild regularizer, and speed up training. If an entire channel is scaled uniformly across the batch and spatial locations, that scale is largely removed by the normalization. This means an array of `[1,2,3,4]` and `[2,4,6,8]` will have the same normalized entries after batch normalization.

Batch normalization applies the following:
$$
y_{b,c,h,t}=\gamma_c \frac{x_{b,c,h,t}-\mu_c}{\sqrt{\sigma_c^2+\epsilon}}+\beta_c,
\quad
\mu_c=\mathbb{E}_{b,h,t}[x_{b,c,h,t}],\quad
\sigma_c^2=\operatorname{Var}_{b,h,t}[x_{b,c,h,t}].
$$

Batch normalization is applied per channel on the feature map. Initially, batch normalization creates affine parameters with $\gamma_c=1$ and $\beta_c=0$ so that all channels are scaled and shifted equally. We’ll keep this initialization unchanged. With training though, they adjust based on the gradient. Since initiation is still 1, we'll see that the values are repeated across the columns but the values themselves now span both positive and negative given they represent scaled distance from the mean of the batch. 

*Note that even though we may apply batch normalization in multiple places, we keep it as a separate layer so its effect can be tuned independently of other normalization layers.*

In [36]:
bn_a = nn.BatchNorm2d(n_embd)
bn_a.weight, bn_a.bias

(Parameter containing:
 tensor([1., 1., 1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0.], requires_grad=True))

In [37]:
out = bn_a(out)
out.size(), out

(torch.Size([2, 6, 1, 8]),
 tensor([[[[ 0.1805,  0.4135,  0.1153, -0.0711, -0.0524, -0.1829, -0.0897,
            -0.0617]],
 
          [[ 0.1805,  0.4135,  0.1153, -0.0711, -0.0524, -0.1829, -0.0897,
            -0.0617]],
 
          [[ 0.1805,  0.4135,  0.1153, -0.0711, -0.0524, -0.1829, -0.0897,
            -0.0617]],
 
          [[ 0.1805,  0.4135,  0.1153, -0.0711, -0.0524, -0.1829, -0.0897,
            -0.0617]],
 
          [[ 0.1805,  0.4135,  0.1153, -0.0711, -0.0524, -0.1829, -0.0897,
            -0.0617]],
 
          [[ 0.1805,  0.4135,  0.1153, -0.0711, -0.0524, -0.1829, -0.0897,
            -0.0617]]],
 
 
         [[[-0.2108, -0.0524,  0.0780,  0.1340,  0.0967,  0.1526, -0.0151,
            -0.4345]],
 
          [[-0.2108, -0.0524,  0.0780,  0.1340,  0.0967,  0.1526, -0.0151,
            -0.4345]],
 
          [[-0.2108, -0.0524,  0.0780,  0.1340,  0.0967,  0.1526, -0.0151,
            -0.4345]],
 
          [[-0.2108, -0.0524,  0.0780,  0.1340,  0.0967,  0.1526, -0.0

#### Convolution Block - ReLU

Next we apply, rectified linear units (**ReLU**), an element-wise nonlinearity that keeps only positive activations and zeroes out negatives. This operates on any input shape (e.g., $[B,C,H,T]$ and is commonly used after convolution and normalization to introduce sparsity and enable the network to model nonlinear relationships while maintaining strong gradient flow compared to saturating activations. Activation sparsity zeroes out uninformative responses (weight < 0), which acts as an implicit regularizer reducing co-adaptation and overfitting while making features more selective and improving generalization. It can also lower the effective compute/memory footprint since there are fewer nonzeros to propagate and accumulate. ReLU does not have any learnable parameters so you won't see us create a layer for it.

The formula applied is:
$$
y = \max(0, x),
\quad
\frac{\partial y}{\partial x}=
\begin{cases}
1,& x>0\\
0,& x\leq 0
\end{cases}
$$

In CNN blocks, ReLU is typically placed after BatchNorm (e.g., Conv $\rightarrow$ BN $\rightarrow$ ReLU). In pre-activation designs, you may see BN $\rightarrow$ ReLU $\rightarrow$ Conv. A practical caveat is the “dying ReLU” phenomenon (units stuck at zero); if this becomes an issue, consider alternatives like LeakyReLU, ELU, or GELU.

*We keep ReLU as a separate layer so it can be swapped or configured independently (e.g., in-place vs. out-of-place) without changing other components.*

In [38]:
out = F.relu(out) 
out.size(), out

(torch.Size([2, 6, 1, 8]),
 tensor([[[[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
 
 
         [[[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]]]],
        gr

#### Convolution Block - 1x3 Conv 1x2 Stride downsample

##### 1x3 Conv 2 Stride - Initialize weights
Our second convolutional block uses a kernel width of `(1,3)`, a stride of `(1,2)` and padding both at the start and end of the token dimension so that we can slide across all entries. The `(1,2)` stride will downsample our channel from 8 to 4 reducing the size of our output.  While our stride is `(1,2)` the kernel width still ensures that all input cells are seen during the downsampling and the learned weights will determine which portion of the patch should impact the layer output the most. 

To start, we will configure our weights to be based on the channel dimension, currently equal to our embedding, and our kernel. By matching the kernel we allow the layer to learn what parts of the kernel are more important for our final prediction. 

We'll also initialize our weights to be iterative so that we can see the impact clearly as they interact with our input

In [39]:
c2_kernel_height = 1
c2_kernel_width = 3
c2_stride_height = 1
c2_stride_width = 2
c2_padding_height = 0
c2_padding_width = 1
{'kernel': (c2_kernel_height, c2_kernel_width),
 'stride': (c2_stride_height, c2_stride_width),
 'padding': (c2_padding_height, c2_padding_width)}

{'kernel': (1, 3), 'stride': (1, 2), 'padding': (0, 1)}

In [40]:
## weight layer for convolution (similar to linear, just more explicit)
conv2 = nn.Parameter(
    torch.empty(n_embd, n_embd, c2_kernel_height, c2_kernel_width), 
    requires_grad=True)

In [41]:
# iniate rows as 0.002, 0.001, and 0.001 for easier view of the weight impact
with torch.no_grad():
    c2_pattern = torch.tensor([0.002,0.001,0.001]).view(1,1,1,c2_kernel_width).expand(conv2.size()).clone()
    conv2.copy_(c2_pattern)
conv2.size(), conv2

(torch.Size([6, 6, 1, 3]),
 Parameter containing:
 tensor([[[[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]]],
 
 
         [[[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]]],
 
 
         [[[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]]],
 
 
         [[[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0.0010, 0.0010]],
 
          [[0.0020, 0

##### 1x3 Conv 2 Stride - Run convolution

Now we'll calculate the 2-D discrete cross-correlation for our weight and ReLU output `out`. For our convolutional layer, we won't go through step by step and instead do the following: 
1. Unfold each sliding kernel_size-sized block within the spatial dimensions of input into a column (i.e., last dimension) of a 3-D output tensor of shape $(B,C*k_h*k_w,L)$.  This is the step that does our downsampling by using the `(1,2)` stride.
2. Reshape our weights to match our downsampled input and to allow it to be reused across batches. 
3. Take the dot product of the unstacked input and the stacked weights and reshape the result back to our batch and channels.


*Let's start by highlighting the dimensions of our input. We won't reinitialize as we want this step to error if there are dimensions not aligned with expectations.*

In [42]:
out.size(), batch, channel, height, width, 

(torch.Size([2, 6, 1, 8]), 2, 6, 1, 8)

**Calculate expected unfolded dimensions**  

To do our reshaping in this layer, we need to calculate the expected dimensions for our loop.  
Recall that we expect to go from $(B,C,1,T)$ to $(B,C*k_h*k_w,L)$ where $L$ is a flattening or our output height and width as follows: 
$$
\begin{align}
height_{out} &= (height + 2*pad_h - 1*(kernel_h-1) -1)\ //\ stride_{h}\\ 
width_{out} &= (width + 2*pad_w - 1*(kernel_w-1) -1)\ //\ stride_{w}\\
L &= height_{out} * width_{out}
\end{align}
$$

We'll see here that the main dimension that changes is $L$.  Review how L is determined above, you can see that the $width_{out}$ is a result of the $width //\ stride_{w}$ meaning that for this convolution where $stride_{w} = 2$, we cut the final width in half. This is our **downsampling** in action. 

In [43]:
c2_khw = channel*c2_kernel_height*c2_kernel_width

c2_height_out = (height + 2*c2_padding_height - 1*(c2_kernel_height-1) - 1)//c2_stride_height + 1   # = 1, 
c2_width_out = (width + 2*c2_padding_width - 1*(c2_kernel_width-1) - 1)//c2_stride_width + 1   # = 4
c2_L = c2_height_out * c2_width_out

print(f'First Conv: idth out {c1_width_out}, height out {c1_height_out}, final dimension ({batch},{c1_khw},{c1_L})')
print(f'This  Conv: width out {c2_width_out}, height out {c2_height_out}, final dimension ({batch},{c2_khw},{c2_L})')

First Conv: idth out 8, height out 1, final dimension (2,18,8)
This  Conv: width out 4, height out 1, final dimension (2,18,4)


In [44]:
out

tensor([[[[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.1805, 0.4135, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0780, 0.1340, 0.0967, 0.1526, 0.0000, 0.0000]]]],
       grad_fn=<ReluBackward0>)

**Unfold**  Now we'll do the actual unfolding.  In this case we'll run the single function and you can see that we get a "narrower" output. An interesting observation here is that we saw in the ReLU step we had multiple columns of zero since they were negative.  Even with our stride of $(1,2)$ those zeros are persisting as a full column. We may wonder: why is there a full column, after unfolding, of zero.  Let's think about our kernel, it has a width of 3.  This means that we would have to have 3 columns next to each other that are 0, and, if we check, that's exactly what happened.  Looking forward then we can expect that this channel will be 0 once we multiply with our weights. 

In [45]:
c2_unfolded = F.unfold(out, 
		kernel_size=(c2_kernel_height, c2_kernel_width),  # (1,3)
		padding=(c2_padding_height, c2_padding_width), #(0,1)
		stride=(c2_stride_height, c2_stride_width))#(1,2)
c2_unfolded.size() , c2_unfolded

(torch.Size([2, 18, 4]),
 tensor([[[0.0000, 0.4135, 0.0000, 0.0000],
          [0.1805, 0.1153, 0.0000, 0.0000],
          [0.4135, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4135, 0.0000, 0.0000],
          [0.1805, 0.1153, 0.0000, 0.0000],
          [0.4135, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4135, 0.0000, 0.0000],
          [0.1805, 0.1153, 0.0000, 0.0000],
          [0.4135, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4135, 0.0000, 0.0000],
          [0.1805, 0.1153, 0.0000, 0.0000],
          [0.4135, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4135, 0.0000, 0.0000],
          [0.1805, 0.1153, 0.0000, 0.0000],
          [0.4135, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4135, 0.0000, 0.0000],
          [0.1805, 0.1153, 0.0000, 0.0000],
          [0.4135, 0.0000, 0.0000, 0.0000]],
 
         [[0.0000, 0.0000, 0.1340, 0.1526],
          [0.0000, 0.0780, 0.0967, 0.0000],
          [0.0000, 0.1340, 0.1526, 0.0000],
          [0.0000, 0.0000, 0.1340, 0.1526],
    

##### 1x3 Conv 2 Stride - $W\cdot X_{unfolded}$

Now that we have our unstacked patches that are downsampled, we can then let our network decide how much of the patch, and which part of the patch, influences our output.  To do this we take the dot product of the weight with the unfolded input.  We do have an issue though since our weight is `[6,6,1,3]` but our input is `[2x18x4]`. We will solve this simply by squeezing the last two dimensions of our Weights together to result in a `[6,18]` tensor that we can multiply.  

You might be now asking "what about the batch dimensions of 2".  Just like in our first convolution, we do want to make sure the 2 different batches actually share the same weight so we don't actually want to increase our weight dimension.  Instead we rely on the pytorch which broadcasts the same matrix to each of the batches in the input automatically.  This allows the two batches to share the weights. 

In [46]:
# Stacks creates a 2-d matrix of `out_channelX rest` so `6*18` by stacking the weights we match the shape of 
conv2_weigth = conv2.view(n_embd, -1) # [6,6,1,3] > [6,18]
conv2_weigth.size(), conv2_weigth

(torch.Size([6, 18]),
 tensor([[0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010,
          0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010],
         [0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010,
          0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010],
         [0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010,
          0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010],
         [0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010,
          0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010],
         [0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010,
          0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010],
         [0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010,
          0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.001

Now that we have the weights in the dimension we want them we're ready to multiply them with the unfolded input.  Because in our weight matrix each "row" is the same, each of our column entries in the result will be equal. We can also see that our zero columns persist.  We'll also get a final output of `[2,6,4]` compressing the 18 down. Also note that the batch dimension is maintained as the weight is broadcast across the batches. 

In [47]:
out = conv2_weigth @ c2_unfolded
out.size(), out


(torch.Size([2, 6, 4]),
 tensor([[[0.0036, 0.0057, 0.0000, 0.0000],
          [0.0036, 0.0057, 0.0000, 0.0000],
          [0.0036, 0.0057, 0.0000, 0.0000],
          [0.0036, 0.0057, 0.0000, 0.0000],
          [0.0036, 0.0057, 0.0000, 0.0000],
          [0.0036, 0.0057, 0.0000, 0.0000]],
 
         [[0.0000, 0.0013, 0.0031, 0.0018],
          [0.0000, 0.0013, 0.0031, 0.0018],
          [0.0000, 0.0013, 0.0031, 0.0018],
          [0.0000, 0.0013, 0.0031, 0.0018],
          [0.0000, 0.0013, 0.0031, 0.0018],
          [0.0000, 0.0013, 0.0031, 0.0018]]], grad_fn=<CloneBackward0>))

Finally, like in the previous convolution, we need to resize our last dimension back to our target channel height and width. Since our height is 1 it will just insert in another dimension of 1 without looking significantly different.

In [48]:
# insert in the channel dimension to go back to 1/2 of [B,C,1,T] since we took a stride of 2
out = out.view(batch,n_embd, c2_height_out, c2_width_out)
out.size(), out

(torch.Size([2, 6, 1, 4]),
 tensor([[[[0.0036, 0.0057, 0.0000, 0.0000]],
 
          [[0.0036, 0.0057, 0.0000, 0.0000]],
 
          [[0.0036, 0.0057, 0.0000, 0.0000]],
 
          [[0.0036, 0.0057, 0.0000, 0.0000]],
 
          [[0.0036, 0.0057, 0.0000, 0.0000]],
 
          [[0.0036, 0.0057, 0.0000, 0.0000]]],
 
 
         [[[0.0000, 0.0013, 0.0031, 0.0018]],
 
          [[0.0000, 0.0013, 0.0031, 0.0018]],
 
          [[0.0000, 0.0013, 0.0031, 0.0018]],
 
          [[0.0000, 0.0013, 0.0031, 0.0018]],
 
          [[0.0000, 0.0013, 0.0031, 0.0018]],
 
          [[0.0000, 0.0013, 0.0031, 0.0018]]]], grad_fn=<ViewBackward0>))

#### Convolution Block - Second Batch Norm

Again, on the output of the convolutional layer we apply another 2D batch normalization. Batch normalization (2D), normalize each channel using statistics computed over the current mini-batch and spatial positions. Concretely, for inputs of shape $[B,C,H,T]$, batch normalization computes a per-channel mean and variance across ${B,C,1}$ and standardizes that channel. This helps to stabilize activations, act as a mild regularizer, and speed up training. If an entire channel is scaled uniformly across the batch and spatial locations, that scale is largely removed by the normalization. This means an array of `[1,2,3,4]` and `[2,4,6,8]` will have the same normalized entries after batch normalization.

Batch normalization applies the following:
$$
y_{b,c,h,t}=\gamma_c \frac{x_{b,c,h,t}-\mu_c}{\sqrt{\sigma_c^2+\epsilon}}+\beta_c,
\quad
\mu_c=\mathbb{E}_{b,h,t}[x_{b,c,h,t}],\quad
\sigma_c^2=\operatorname{Var}_{b,h,t}[x_{b,c,h,t}].
$$

Batch normalization is again applied per channel on the feature map. Initially, batch normalization creates affine parameters with $\gamma_c=1$ and $\beta_c=0$ so that all channels are scaled and shifted equally. We’ll keep this initialization unchanged. With training though, they adjust based on the gradient. Since initiation is still 1, we'll see that the values are repeated across the columns but the values themselves now span both positive and negative given they represent scaled distance from the mean of the batch. 

*Note that even though we may apply batch normalization in multiple places, we keep it as a separate layer so its effect can be tuned independently of other normalization layers.*

In [49]:
bn_b = nn.BatchNorm2d(n_embd)   
bn_b.weight, bn_b.bias

(Parameter containing:
 tensor([1., 1., 1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0.], requires_grad=True))

In [50]:
out = bn_b(out)
out.size(), out

(torch.Size([2, 6, 1, 4]),
 tensor([[[[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]]],
 
 
         [[[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]]]],
        grad_fn=<NativeBatchNormBackward0>))

### Residual Connection

<img src="explainer_screenshots/cnn/resdiual_connection.png" width="200">

ResNets are a popular flavor of CNN and ustilize a key component called a "residual connection".  To showcase it's value we'll also include it here. Residual connections, often also called "skip connections" allow for pathways to bypass around other layers, passing through gradients during the backward pass.  This attribute ensures that the impact of each layer is normalized against the embeddings that are passed around. Functionally this creates a connection represented by:

$$
y = f(x) + x
$$

In our case though, we have a problem simply summing the two layers together. Our current token embedding is `[2,6,1,8]` but because of our downsampling, the output from our batch normalization is `[2,6,1,4]`.   We do have a couple options: we could find a way to upsample our convolutional output through *dilation* or we could downsample our residual connection to match the dimensions.  We'll choose to do a downsampled projection of our residual connection. This approach preserves the intended downsampling in the main, convolution block, branch, doesn't distort its receptive field or batch-norm statistics, and keeps the main branch’s computation and optimization behavior unchanged. We are not choosing dilation as it would alter the kernel’s geometry and typically increase compute without clear benefits. 

To do our *downsampling*, it's actually similar to our convolution down-sampling, but we just just a simple `(1,1)` patch and a stride of `(1,2)`.  Unlike the convolution layers this far, this connection will prevent  every other entry in `x` from being able to impact our output through the residual connection.  Overall though this is not a major concern as we still maintain the benefit of the residual connection's ability to connect gradients around the convolution.   

#### Residual Connection - Downsampling 1x1 Convolution 2 Stride

##### 1x1 Conv 2 Stride - Initialize weights
Our residual connection's convolutional block uses a kernel width of `(1,1)`, a stride of `(1,2)` without padding. Padding is not needed because the kernel is small enough that with our stride we can still cover our input edge to edge. The `(1,2)` stride will downsample our channel from 8 to 4 reducing the size of our output to our desired target.  Because our stride is `(1,2)` and the kernel is `(1,1)` we will skip every other input cell during the downsampling.

To start, we will setup our weights to be based on the channel dimension, currently equal to our embedding, and our kernel. We'll also initialize our weights to be 1.0 as we want to demonstrate the benefit of the residual connection. 

In [51]:
res_kernel_height = 1
res_kernel_width = 1
res_stride_height = 1
res_stride_width = 2
res_padding_height = 0
res_padding_width = 0
{'kernel': (res_kernel_height, res_kernel_width),
 'stride': (res_stride_height, res_stride_width),
 'padding': (res_padding_height, res_padding_width)}

{'kernel': (1, 1), 'stride': (1, 2), 'padding': (0, 0)}

In [52]:
convRes = nn.Parameter(
    torch.empty(n_embd, n_embd, res_kernel_height, res_kernel_width), 
    requires_grad=True)

In [53]:
with torch.no_grad():
    torch.nn.init.ones_(convRes)
convRes.size(), convRes

(torch.Size([6, 6, 1, 1]),
 Parameter containing:
 tensor([[[[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]]],
 
 
         [[[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]]],
 
 
         [[[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]]],
 
 
         [[[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]]],
 
 
         [[[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]]],
 
 
         [[[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]],
 
          [[1.]]]], requires_grad=True))

##### 1x1 Conv 2 Stride - Run convolution

Now we'll calculate the 2-D discrete cross-correlation for our weight and residual connection input `x`. For our convolutional layer, we won't go through step by step and instead do the following: 
1. Unfold each sliding kernel_size-sized block within the spatial dimensions of input into a column (i.e., last dimension) of a 3-D output tensor of shape $(B,C*k_h*k_w,L)$.  This is the step that does our downsampling by using the `(1,2)` stride.
2. Reshape our weights to match our downsampled input and to allow it to be reused across batches. 
3. Take the dot product of the unstacked input and the stacked weights and reshape the result back to our batch and channels.


*Let's start by highlighting our input `x` and the dimensions of our input. We won't reinitialize as we want this step to error if there are dimensions not aligned with expectations.*

In [54]:
x

tensor([[[[0.1500, 0.1400, 0.1100, 0.0200, 0.1300, 0.0100, 0.0700, 0.1200]],

         [[0.1600, 0.1500, 0.1200, 0.0300, 0.1400, 0.0200, 0.0800, 0.1300]],

         [[0.1700, 0.1600, 0.1300, 0.0400, 0.1500, 0.0300, 0.0900, 0.1400]],

         [[0.1800, 0.1700, 0.1400, 0.0500, 0.1600, 0.0400, 0.1000, 0.1500]],

         [[0.1900, 0.1800, 0.1500, 0.0600, 0.1700, 0.0500, 0.1100, 0.1600]],

         [[0.2000, 0.1900, 0.1600, 0.0700, 0.1800, 0.0600, 0.1200, 0.1700]]],


        [[[0.0900, 0.0500, 0.1000, 0.1100, 0.0700, 0.1200, 0.0900, 0.0100]],

         [[0.1000, 0.0600, 0.1100, 0.1200, 0.0800, 0.1300, 0.1000, 0.0200]],

         [[0.1100, 0.0700, 0.1200, 0.1300, 0.0900, 0.1400, 0.1100, 0.0300]],

         [[0.1200, 0.0800, 0.1300, 0.1400, 0.1000, 0.1500, 0.1200, 0.0400]],

         [[0.1300, 0.0900, 0.1400, 0.1500, 0.1100, 0.1600, 0.1300, 0.0500]],

         [[0.1400, 0.1000, 0.1500, 0.1600, 0.1200, 0.1700, 0.1400, 0.0600]]]],
       grad_fn=<UnsqueezeBackward0>)

In [55]:
x.size(), batch, channel, height, width, 

(torch.Size([2, 6, 1, 8]), 2, 6, 1, 8)

**Calculate expected unfolded dimensions**  

To do our reshaping in this layer, we need to calculate the expected dimensions for our loop.  
Recall that we expect to go from $(B,C,1,T)$ to $(B,C*k_h*k_w,L)$ where $L$ is a flattening or our output height and width as follows: 
$$
\begin{align}
height_{out} &= (height + 2*pad_h - 1*(kernel_h-1) -1)\ //\ stride_{h}\\ 
width_{out} &= (width + 2*pad_w - 1*(kernel_w-1) -1)\ //\ stride_{w}\\
L &= height_{out} * width_{out}
\end{align}
$$

One big difference  you'll see here compared to previous convolution layers is that, because our stride is `(1,1)`, we do not actually change our channel dimension. Additionally, we expect that our output will be half of our `height*width` since we have a stride of `(1,2)` because we're downsampling. 

In [56]:
res_khw = channel*res_kernel_height*res_kernel_width

res_height_out = (height + 2*res_padding_height - 1*(res_kernel_height-1) - 1)//res_stride_height + 1   # = 1, 
res_width_out = (width + 2*res_padding_width - 1*(res_kernel_width-1) - 1)//res_stride_width + 1   # = 4
res_L = res_height_out * res_width_out

print(f'Second  Conv: width out {c2_width_out}, height out {c2_height_out}, final dimension ({batch},{c2_khw},{c2_L})')
print(f'This  Conv: width out {res_width_out}, height out {res_height_out}, final dimension ({batch},{res_khw},{res_L})')


Second  Conv: width out 4, height out 1, final dimension (2,18,4)
This  Conv: width out 4, height out 1, final dimension (2,6,4)


**Unfold**  Now we'll do the actual unfolding. This unfolding, contrary to previous unfolding, is just selecting every other "column" due to our kernel size and stride. If we go back to how we setup `x` it was a rearrangement that lead to `(B,C,1,T)` so by selecting every other "column" we're selecting the channels corresponding to every other token.

In [57]:
x_unfolded = F.unfold(x, 
		kernel_size=(res_kernel_height, res_kernel_width),  # (1,1)
		padding=(res_padding_height, res_padding_width), #(0,0)
		stride=(res_stride_height, res_stride_width))#(1,2)
x_unfolded.size() , x_unfolded

(torch.Size([2, 6, 4]),
 tensor([[[0.1500, 0.1100, 0.1300, 0.0700],
          [0.1600, 0.1200, 0.1400, 0.0800],
          [0.1700, 0.1300, 0.1500, 0.0900],
          [0.1800, 0.1400, 0.1600, 0.1000],
          [0.1900, 0.1500, 0.1700, 0.1100],
          [0.2000, 0.1600, 0.1800, 0.1200]],
 
         [[0.0900, 0.1000, 0.0700, 0.0900],
          [0.1000, 0.1100, 0.0800, 0.1000],
          [0.1100, 0.1200, 0.0900, 0.1100],
          [0.1200, 0.1300, 0.1000, 0.1200],
          [0.1300, 0.1400, 0.1100, 0.1300],
          [0.1400, 0.1500, 0.1200, 0.1400]]], grad_fn=<Im2ColBackward0>))

##### 1x1 Conv 2 Stride - $W\cdot X_{unfolded}$

Now that we have unfolded the patches that are downsampled, we can then let our network decide which patch influences our output through the residual connection.  To do this we take the dot product of the weight with the unfolded input.  We do have an issue though since our weight is `[6,6,1,1]` but our input is `[2x6x4]`. We will solve this simply by squeezing the last two dimensions of our weights together to result in a `[6,6]` tensor that we can multiply.  

By now you also hopefully know that this tensor will be shared across the two batches through broadcasting. 

In [58]:
convRes_weigth = convRes.view(n_embd, -1) # [6,6,1,1] > [6,6]
convRes_weigth.size(), convRes_weigth

(torch.Size([6, 6]),
 tensor([[1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]], grad_fn=<ViewBackward0>))

Now that we have the weights in the dimension we want them we're ready to multiply them with the unfolded input.  Because in our weight matrix each entry is 1 currently, each of our column entries in the result will be the sum of the column, or equal. We'll also get a final output of `[2,6,4]` matching our target dimensions.

In [59]:
identity = convRes_weigth @ x_unfolded
identity.size(), identity

(torch.Size([2, 6, 4]),
 tensor([[[1.0500, 0.8100, 0.9300, 0.5700],
          [1.0500, 0.8100, 0.9300, 0.5700],
          [1.0500, 0.8100, 0.9300, 0.5700],
          [1.0500, 0.8100, 0.9300, 0.5700],
          [1.0500, 0.8100, 0.9300, 0.5700],
          [1.0500, 0.8100, 0.9300, 0.5700]],
 
         [[0.6900, 0.7500, 0.5700, 0.6900],
          [0.6900, 0.7500, 0.5700, 0.6900],
          [0.6900, 0.7500, 0.5700, 0.6900],
          [0.6900, 0.7500, 0.5700, 0.6900],
          [0.6900, 0.7500, 0.5700, 0.6900],
          [0.6900, 0.7500, 0.5700, 0.6900]]], grad_fn=<CloneBackward0>))

Finally, like in the previous convolution, we need to resize our last dimension back to our target channel height and width. Since our height is 1 it will just insert in another dimension of 1 without looking significantly different.

In [60]:
identity = identity.view(batch,n_embd, res_height_out, res_width_out)
identity.size(), identity

(torch.Size([2, 6, 1, 4]),
 tensor([[[[1.0500, 0.8100, 0.9300, 0.5700]],
 
          [[1.0500, 0.8100, 0.9300, 0.5700]],
 
          [[1.0500, 0.8100, 0.9300, 0.5700]],
 
          [[1.0500, 0.8100, 0.9300, 0.5700]],
 
          [[1.0500, 0.8100, 0.9300, 0.5700]],
 
          [[1.0500, 0.8100, 0.9300, 0.5700]]],
 
 
         [[[0.6900, 0.7500, 0.5700, 0.6900]],
 
          [[0.6900, 0.7500, 0.5700, 0.6900]],
 
          [[0.6900, 0.7500, 0.5700, 0.6900]],
 
          [[0.6900, 0.7500, 0.5700, 0.6900]],
 
          [[0.6900, 0.7500, 0.5700, 0.6900]],
 
          [[0.6900, 0.7500, 0.5700, 0.6900]]]], grad_fn=<ViewBackward0>))

#### Residual Connection - Sum 

Now that we have our token projection downsampled to match our convolution block output dimension we're ready to connect the two weights together.  Recall that the residual connection performs:  

$y = f(x) + x$

Because of this join, during the backward pass we're now able to let gradients flow through the convolution layers and around them.  We'll start by reprinting the output of the convolution block and then showing the sum.  Notice that because of how we initialized our weights across the inputs, in both our residual chain and main chain the current calculated weights have uniformity per channel (visually column). Because of this, we then expect that our connection will maintain this behavior. 

In [61]:
out.size(), out

(torch.Size([2, 6, 1, 4]),
 tensor([[[[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]],
 
          [[ 0.4422,  1.0070, -0.5211, -0.5211]]],
 
 
         [[[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]],
 
          [[-0.5211, -0.1773,  0.3175, -0.0262]]]],
        grad_fn=<NativeBatchNormBackward0>))

In [62]:
x = out + identity
x.size(), x

(torch.Size([2, 6, 1, 4]),
 tensor([[[[1.4922, 1.8170, 0.4089, 0.0489]],
 
          [[1.4922, 1.8170, 0.4089, 0.0489]],
 
          [[1.4922, 1.8170, 0.4089, 0.0489]],
 
          [[1.4922, 1.8170, 0.4089, 0.0489]],
 
          [[1.4922, 1.8170, 0.4089, 0.0489]],
 
          [[1.4922, 1.8170, 0.4089, 0.0489]]],
 
 
         [[[0.1689, 0.5727, 0.8875, 0.6638]],
 
          [[0.1689, 0.5727, 0.8875, 0.6638]],
 
          [[0.1689, 0.5727, 0.8875, 0.6638]],
 
          [[0.1689, 0.5727, 0.8875, 0.6638]],
 
          [[0.1689, 0.5727, 0.8875, 0.6638]],
 
          [[0.1689, 0.5727, 0.8875, 0.6638]]]], grad_fn=<AddBackward0>))

### Output Layers AKA Model Head.

We've now shown a common pattern for a convolutional block  inside a downsampling  along with a residual connection that's common inside of ResNets. Once those layers are complete during the forward pass we then start the output process that results in `logits` which is a representation of the probability of each token being the next token given the input.  

<img src="explainer_screenshots/cnn/output_layer.png" width="200">

This layer is also known as the model **head** and the weights learned are more task specific than the general model itself.  Once a model is trained, there are processes to swap out the "head" and learn new tasks, e.g. go from next token prediction to classification.  In our example case, this is a linear layer mapping the backbone to vocab logits.

Recall in the beginning that we discussed that each batch had 1 example in it and our goal was to predict the next token for that 1 specific example. To predict the tokens we'll generate `logits`, or a pseudo-probability linked to each vocab entry to express how likely it is to be the next token. 

We do have an issue though.  Currently we still have an expanded T dimension: `[2,6,1,4]`.  If we did our usual route of linear projection, we'd end up with 4 sets of logits per example.  Luckily, this is a common occurrence in CNNs and one that does NOT need downsampling to resolve.  Instead we use a technique called pooling that reduces the dimension for us. After pooling we'll then undo our dimension adjustments we made at the start going from `[B,C,1,T]` $\rightarrow$ `[B,T,C]`. And then we'll do our final linear projection to create our logits. 

#### Output Layer - Adaptive Average Pooling

Pooling is a downsampling operation common in CNNs that reduces the spatial dimensions of a feature map by aggregating values within a sliding window. The difference compared to a convolution is that there's typically no overlapping of the sliding window like there may be with kernels and there are no weights to be learned. We'll use a common version of pulling called adaptive average pooling.  

Adaptive average pooling down-samples the input feature maps by dividing them into a grid of rectangular regions (sliding window) and replacing each region with its average value. The *adaptive* part is due to the fact that the function dynamically adjusts the sliding window size to allow for a fixed-size output regardless of the input feature map's dimensions. In CNNs this is nice since a traditional modality like images may have different resolutions, meaning different sized inputs, but the adaptive average pooling can ensure that they all output this step at the same size. 

In our example, each batch should include just 1 example, so we need to squeeze the output `x` from `[2,6,1,4]` to `[2,6,1,1]`. As a reminder, at this point while the column values are different, each "row" on the T dimension contains the same values, so we expect that the output of the average will be the same.  This is now a precursor to showing us that our output at this point won't be great. 

*Note that this function is the same as running x.mean(dim=(2,3), keepdim=True)*

In [63]:
avgPool = nn.AdaptiveAvgPool2d((1, 1))

In [64]:
x = avgPool(x)
x.size(), x

(torch.Size([2, 6, 1, 1]),
 tensor([[[[0.9418]],
 
          [[0.9418]],
 
          [[0.9418]],
 
          [[0.9418]],
 
          [[0.9418]],
 
          [[0.9418]]],
 
 
         [[[0.5732]],
 
          [[0.5732]],
 
          [[0.5732]],
 
          [[0.5732]],
 
          [[0.5732]],
 
          [[0.5732]]]], grad_fn=<MeanBackward1>))

#### Output Layer - Remove Dimension

Recall that during our pre-convolution steps we took our input embedding `[B, T, C]` and converted it to `[B, C, H, T]`.  Now that we're running our output we need to undo this dimension rearrangement so that we can properly project our logits as `[B,T,vocab_size]`. 

To do this we'll first remove the height dimension and then flip our channels back. 

**Remove our third dimension**

First we'll remove our height dimension, currently our third size.  Since this dimension is still 1 we won't see any shifts in the data, simply a reduction in size. 

In [65]:
x = x.squeeze(2)
x.size(), x

(torch.Size([2, 6, 1]),
 tensor([[[0.9418],
          [0.9418],
          [0.9418],
          [0.9418],
          [0.9418],
          [0.9418]],
 
         [[0.5732],
          [0.5732],
          [0.5732],
          [0.5732],
          [0.5732],
          [0.5732]]], grad_fn=<SqueezeBackward1>))

**Permute Channel and Token**

Now we need to shuffle our channel and token token dimensions back so that we have our output ready for the final linear projection. In this case we'll see that our columns become a single row by batch, matching our expectation to have a single example, projected over the input channels, for each batch. 

In [66]:
x = x.permute(0,2,1)
x.size(), x

(torch.Size([2, 1, 6]),
 tensor([[[0.9418, 0.9418, 0.9418, 0.9418, 0.9418, 0.9418]],
 
         [[0.5732, 0.5732, 0.5732, 0.5732, 0.5732, 0.5732]]],
        grad_fn=<PermuteBackward0>))

#### Output layer - LM Head aka logits

We can now project `x` onto the vocabulary resulting in a `[B,T,vocab_size]` final array `logits`. This output correlates with the probability of each output token given the input context.  The best way to read  this is:

(dimension 0) we have 2 batches B, 
(dimension 1) each batch has an example. 
(dimension 2) for each example we see the probability of each token in our vocabulary


**LM Head - Weight initialization**
We'll first initialize our weights.  We'll again initialize the weights as a consistent value that can be read as each token has an equal probability based on the weight. This should rapidly change through our back propagation. 

In [67]:
lm_head = nn.Linear(n_embd, vocab_size, bias=False)
torch.nn.init.constant_(lm_head.weight,0.01)
lm_head.weight.size(), lm_head.weight

(torch.Size([15, 6]),
 Parameter containing:
 tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100],
         [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100]], requires_grad=True))

**LM Head - Output Projection**

Now we're ready for our final projection. Since our `x` at this point actually has different values but the LM_Head weights are equal, we do expect that our final output has different values for each example, but the value for each logit in that example is the same.  This is best interpreted as the model having equal probability across all logits, or a random "next token", meaning it's shit. Luckily backpropagation has a way of updating this so that with enough data and time the probabilities change. 

In [68]:
logits = lm_head(x)

logits.shape, logits

(torch.Size([2, 1, 15]),
 tensor([[[0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565,
           0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565]],
 
         [[0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344,
           0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344]]],
        grad_fn=<UnsafeViewBackward0>))

## Loss calculation

Now we have to see how good our ~shit~ prediction is.  Since we haven't done training and we saw that regardless of example we had the same exact logit values across the vocab, we can expect it's bad, basically random. That said, we need to know how bad. For this example we'll use cross entropy, also known as the negative log likelihood of the softmax.  Our loss calculates

$$
\ell_i=-\log\big(\mathrm{softmax}(z_i)\_{y_i}\big)
= -z_{i,y_i}+\log\!\sum_{c=1}^C e^{z_{i,c}},
$$


To calculate loss we'll pass in the calculated `logits` and our next tokens stored in `y`. The cross entropy function does not respect batches so we'll flatten the `B` dimension for both `logits` and `y`

In [69]:
y_flat = y.view(-1)
y_flat.shape, y_flat

(torch.Size([2]), tensor([ 8, 10]))

In [70]:
logits_flat = logits.view(-1, logits.size(-1))
logits_flat.shape, logits_flat

(torch.Size([2, 15]),
 tensor([[0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565,
          0.0565, 0.0565, 0.0565, 0.0565, 0.0565, 0.0565],
         [0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344,
          0.0344, 0.0344, 0.0344, 0.0344, 0.0344, 0.0344]],
        grad_fn=<ViewBackward0>))

In [71]:
loss = F.cross_entropy(logits_flat, y_flat)
loss.shape, loss

(torch.Size([]), tensor(2.7081, grad_fn=<NllLossBackward0>))

## Back Propagation

We now know just how ~well~ terribly the current model, with its weights and biases, predicts the next token given the input context. We now need to know how to change the different weights and biases to improve the formula.  We could do this by guessing through making minor changes and seeing what improves, or we can think through this more critically.

If you review the chain of layers above, you can see that it's a series of formulas.  We can think of this as $f(g(x))$, except with many many more layers and complexities.  Since this is a formula, we can dig into our math toolbox and find a better way to determine what parts need to update.  Recall that in our calculus we learned that differentiation tells us the rate of change in a graph.  So if we treat the loss function $\mathcal{L}$ as $\mathcal{L}(f(g(x)))$ taking the partial differential 

$\delta=\partial \mathcal{L}/\partial h$

at each layer will give us the impact of each weight/bias on our final out (albeit the inverse since our loss function is the negative log likelihood). 

Lucky for us, each layer of our model already has a placeholder for the partial differential called the **Gradient**. We'll use this field to store it.  We'll start by first zeroing out the gradients. We do this because of the nature of handling partial differentials for multiple dependencies. Recall that in multiple places we had a formula structure of 

$a+b=c ; a+c= d$

In this case $a$ has 2 dependencies and determining the partial derivative of $\partial d / \partial a$ requires understanding both the path from $d$ and $c$.  To determine the true impact of a we would sum both partial derivatives together.  Because of this property, the tool we use, the built in `.backwards()` automatically sums gradients, `+=`, so if we do not set the gradient to `0` we then end up with erroneous gradients. 

Finally, we start `.backwards` from the `loss`, not `logits` as our goal is to minimize loss, we need to ensure we are looking at the calculations that impact loss which requires the whole forward pass to be able to generate the prediction `logits_flat`.  If we think of it as $\mathcal{L}(f(x))$ where $f(x)$ is the forward pass to generate logits, then a simple chain rule is applied:

${\partial}/{\partial x} =  \mathcal{L}'(f(x)) f'(x)$

Lets start by zeroing the gradients and leaning on pytorch to calculate the gradients for us. We'll also validate the gradients were `none`.

### Back Propagation - Zero out gradients

Before we start we have to zero our our gradients to make sure there's nothing in them.  Recall that when we run  `.backwards()` it automatically sums gradients if there are multiple paths so zeroing out ensure no erroneous measures.  

Two things to notice: 
1. Our pooling step is not included.  While pooling does pass through gradients, it is stateless, meaning there are no learnable parameters, so the layer does not have a gradient buffers to clear. The step itself is differentiable though so gradients flow through it back to upstream parameters.
2. Our convolutional layers use a special reset.  Since we manually built the layers using `nn.Parameter` they do not have the typical gradient and weight functions that layers have, but the layers are learnable.  Because of this they do have gradient buffers but no magic function to clear them so we have to clear them manually.  

In [72]:
lm_head.zero_grad()
convRes.grad = None  # slightly different given it's manually built and not a torch layer. 
#conv block
bn_b.zero_grad()
conv2.weight = None
bn_a.zero_grad()
conv1.weight = None

wte.zero_grad()


# validate gradients
lm_head.weight.grad,convRes.grad, wte.weight.grad

(None, None, None)

### Back Propagation - Auto Diff

Now let's see the magic of the gradients populate.  This magic is called auto-differentiation, or auto-diff for short. This allows us to not have to write many layers of nasty code to do the differentiation for us, but, if you're a sadist, you can surely find people who have written out that code (it's not too bad since you just do one layer at a time). 

In [73]:
loss.backward()

Now we will revisualize our gradients and see that they contain values, just like we wanted.

In [74]:
lm_head.weight.grad,convRes.grad, wte.weight.grad

(tensor([[ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [-0.4204, -0.4204, -0.4204, -0.4204, -0.4204, -0.4204],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [-0.2361, -0.2361, -0.2361, -0.2361, -0.2361, -0.2361],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505],
         [ 0.0505,  0.0505,  0.0505,  0.0505,  0.0505,  0.0505]]),
 tensor([[[[-1.2303e-11

## Learning

The process of learning now requires us to update our weights based on this gradient. To really feel the "back propagation" we'll start with the last layer and work backwards, though, since we have all of the gradients calculated already, the order does not matter. Recall that our loss function is the negative log likelihood ratio so our gradient signs are flipped.  If a parameter is important, the gradient will be more negative, and vice versa. The gradients are a ratio of importance of each parameter and we need to know how much of that gradient to apply to our weights. This "how much" is referred to as the *learning rate*. In modern training learning rate schedulers and optimizers are used to vary the rate and application by layer and by training round with learning rates that are small (e.g. 1e-3) and decaying. 

We however are trying to learn and if you look at the gradient above in most layers it's tiny (~1e-10).  If we used a typical learning rate scheduler, with our context, batch size, and just 1 pass, the second pass would just have the same values and we wouldn't learn anything new.  Because of this we'll do something super unorthodox.  Our LM_head will use a learning rate of `1.000` but for the rest of the layers we'll use a learning rate of `1e8`. This INSANE learning rate would absolutely be too noisy for real training as taking that "large" of steps would mean the model could never find a good fit.  But since we're doing 1 step we don't care. Also, we're rerunning on the same example we used in the first loop so even less of a care. But as a warning, DO NOT DO THIS IN REAL TRAINING. If you did your model would most likely not converge. 

In [75]:
## Huge learning rate to emphasize
head_learning_rate = 1.000
learning_rate = 1e8

### Output Layer
Let's start with our output layer.  Recall that we initialized the weights to `0.0100` so we can quickly see the impact of the gradient update on the weights. Most notable that we'll see is that the entries corresponding with the token that are present in `y`, our target, are up weighted and others are downweighted. This should already intuitively give you a sense as to 
1. The model will be improved
2. This is WAY too high of a learning rate if we wanted to generalize.

Additionally you'll notice that the gradient is consistent across all of the `n_embd` dimensions given our light learning. Finally, since this is the output layer, we did not include any bias 

In [76]:
with torch.no_grad():
    lm_head.weight -= head_learning_rate * lm_head.weight.grad
lm_head.weight

Parameter containing:
tensor([[-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [ 0.4304,  0.4304,  0.4304,  0.4304,  0.4304,  0.4304],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [ 0.2461,  0.2461,  0.2461,  0.2461,  0.2461,  0.2461],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405],
        [-0.0405, -0.0405, -0.0405, -0.0405, -0.0405, -0.0405]],
       requires_g

### Residual Connection
Next we'll update the convolutional layer inside of our residual connection.  Recall that we initialized the weights to `1.000` and did not include bias. Our gradients on this layer are extremely small (~1e-12) but some are positive and some negative across the channels.  This helps intuitively show that different channels impact our prediction.  

With this layer you can start seeing the beauty of our insanely high learning rate.  Even though our gradient impact was tiny, we can see the weight shift away from the initialization suggesting some level of adaptation. Even with our high rate though, we only see the negative impact though, if we expanded the decimal level, we'd see some positive impact also 

In [77]:
with torch.no_grad():
    convRes -= learning_rate * convRes.grad
convRes

Parameter containing:
tensor([[[[1.0012]],

         [[1.0013]],

         [[1.0014]],

         [[1.0015]],

         [[1.0017]],

         [[1.0018]]],


        [[[1.0012]],

         [[1.0013]],

         [[1.0014]],

         [[1.0015]],

         [[1.0017]],

         [[1.0018]]],


        [[[1.0012]],

         [[1.0013]],

         [[1.0014]],

         [[1.0015]],

         [[1.0017]],

         [[1.0018]]],


        [[[1.0012]],

         [[1.0013]],

         [[1.0014]],

         [[1.0015]],

         [[1.0017]],

         [[1.0018]]],


        [[[0.9973]],

         [[0.9971]],

         [[0.9969]],

         [[0.9966]],

         [[0.9964]],

         [[0.9962]]],


        [[[0.9973]],

         [[0.9971]],

         [[0.9969]],

         [[0.9966]],

         [[0.9964]],

         [[0.9962]]]], requires_grad=True)

### Convolutional Block
Next we'll update the layers in our convolutional block.  This includes our batch normalization layers and our two convolutional layers. Our ReLU layer is stateless so it does not have learnable parameters to update. Our batch normal layers include bias while our convolutional layers do not.

Recall that our batch norm layers initiate with weights of `1` and bias of `0`.  For our convolution layers we initiated with iterations of `conv2 = [0.0020, 0.0010, 0.0010]` and `conv1 = [0.001,0.002,0.001]`.  We can see that the gradients, thanks again to our large learning rate, slightly shifted our weights away from their initial values. Interestingly, in conv2 within a patch we see that we only have started learning different weights on a single dimension and have not yet seen different shifts based on token position like we see in conv1. 

In [78]:
with torch.no_grad():
    bn_b.weight -= learning_rate * bn_b.weight.grad
    bn_b.bias -= learning_rate * bn_b.bias.grad
bn_b.weight, bn_b.bias

(Parameter containing:
 tensor([1.0012, 1.0012, 1.0012, 1.0012, 0.9976, 0.9976], requires_grad=True),
 Parameter containing:
 tensor([ 0.0106,  0.0106,  0.0106,  0.0106, -0.0233, -0.0233],
        requires_grad=True))

In [79]:
with torch.no_grad():
    conv2 -= learning_rate * conv2.grad
conv2

Parameter containing:
tensor([[[[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]]],


        [[[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]]],


        [[[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]]],


        [[[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401,  0.0426]],

         [[ 0.0337,  0.0401, 

In [80]:
with torch.no_grad():
    bn_a.weight -= learning_rate * bn_a.weight.grad
    bn_a.bias -= learning_rate * bn_a.bias.grad
bn_a.weight, bn_a.bias

(Parameter containing:
 tensor([1., 1., 1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([6.5325e-11, 6.5325e-11, 6.5325e-11, 6.5325e-11, 6.5325e-11, 6.5325e-11],
        requires_grad=True))

In [81]:
with torch.no_grad():
    conv1 -= learning_rate * conv1.grad
conv1

Parameter containing:
tensor([[[[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]]],


        [[[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]]],


        [[[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]]],


        [[[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0.0020, 0.0010]]],


        [[[0.0010, 0.0020, 0.0010]],

         [[0.0010, 0

### Input Layer
Finally we'll update our input layer. Recall that we initialized it to increments of `0.0100`. Given our high learning rate, a keen eye can observe an interesting phenomenon, gradient updates only impacted the tokens that were used in our example. This differs from our other layers because our embedding layer projects the example into the embedding space only uses vectors for tokens present so those are the only ones that the gradient can trace back to, versus on other layers because of dot products and other operations there's impact across the whole tensor.    

In [82]:
with torch.no_grad():
    wte.weight -= learning_rate * wte.weight.grad
wte.weight

Parameter containing:
tensor([[0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600],
        [0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700],
        [0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800],
        [0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900],
        [0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000],
        [0.0600, 0.0700, 0.0800, 0.0900, 0.1000, 0.1100],
        [0.0689, 0.0789, 0.0889, 0.0989, 0.1089, 0.1189],
        [0.0800, 0.0900, 0.1000, 0.1100, 0.1200, 0.1300],
        [0.0889, 0.0989, 0.1089, 0.1189, 0.1289, 0.1389],
        [0.0995, 0.1095, 0.1195, 0.1295, 0.1395, 0.1495],
        [0.1095, 0.1195, 0.1295, 0.1395, 0.1495, 0.1595],
        [0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700],
        [0.1295, 0.1395, 0.1495, 0.1595, 0.1695, 0.1795],
        [0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900],
        [0.1495, 0.1595, 0.1695, 0.1795, 0.1895, 0.1995]], requires_grad=True)

## Forward Pass with Updated Weights

Now that we have the updated weights for each layer, let's do another forward pass and compare the loss. Since each layer was previously explained we will instead focus on just showing the outputs of the different layers and the final loss. If you want, you can check the previous outputs in the cached cell outputs above and compare them to see how the weight changes impacted the values at each layer. 

You'll notice that because our high learning rate we're able to see how each layer now shifts the embedding values as the input passes through them. 

### Data Re-loading
Repulling to a new `x_2`. We'll keep `y` to emphasize the same examples are being used. 

In [83]:
x_2 = tok_for_training[:-1].view(B_batch, T_context)
x_2, y

(tensor([[14, 13, 10,  1, 12,  0,  6, 11],
         [ 8,  4,  9, 10,  6, 11,  8,  0]]),
 tensor([[ 8],
         [10]]))

### Input Layer

Note that in `wte` since the gradient impact was only on the embeddings that mapped to our tokens. Lucky for us, we're passing the same example through so we'll pull those same weights.

In a typical learning rate, with our vocab size and context length, along with the many layers we added to our CNN, we would need an extended training length to see changes in this initiation layer. Alternatively, a common step to help improve the input embedding is to do weight-tying between the input and output layers. This would be tricky in our case given the dimensions of the input and output layers do not align and is less common in CNNs. 

#### Input Layer - Embedding Projection

We can quickly see that instead of our nice iterative layers we're getting weights that are adjusted.  We can see that we haven't yet learned much of a distinction across embedding layers yet as the increments between columns are consistent still. 

In [84]:
x = wte(x_2)
x.shape, x

(torch.Size([2, 8, 6]),
 tensor([[[0.1495, 0.1595, 0.1695, 0.1795, 0.1895, 0.1995],
          [0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900],
          [0.1095, 0.1195, 0.1295, 0.1395, 0.1495, 0.1595],
          [0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700],
          [0.1295, 0.1395, 0.1495, 0.1595, 0.1695, 0.1795],
          [0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600],
          [0.0689, 0.0789, 0.0889, 0.0989, 0.1089, 0.1189],
          [0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700]],
 
         [[0.0889, 0.0989, 0.1089, 0.1189, 0.1289, 0.1389],
          [0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000],
          [0.0995, 0.1095, 0.1195, 0.1295, 0.1395, 0.1495],
          [0.1095, 0.1195, 0.1295, 0.1395, 0.1495, 0.1595],
          [0.0689, 0.0789, 0.0889, 0.0989, 0.1089, 0.1189],
          [0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700],
          [0.0889, 0.0989, 0.1089, 0.1189, 0.1289, 0.1389],
          [0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600]]],
        gra

#### Input Layer - Add Dimension

We'll now do the permutation and add the height dimension into our matrix. This step does not change our values. 

In [85]:
x = x.permute(0,2,1) # [B,C,T]
x = x.unsqueeze(2)  # [B,C,1,T]
x.size(), x

(torch.Size([2, 6, 1, 8]),
 tensor([[[[0.1495, 0.1400, 0.1095, 0.0200, 0.1295, 0.0100, 0.0689, 0.1200]],
 
          [[0.1595, 0.1500, 0.1195, 0.0300, 0.1395, 0.0200, 0.0789, 0.1300]],
 
          [[0.1695, 0.1600, 0.1295, 0.0400, 0.1495, 0.0300, 0.0889, 0.1400]],
 
          [[0.1795, 0.1700, 0.1395, 0.0500, 0.1595, 0.0400, 0.0989, 0.1500]],
 
          [[0.1895, 0.1800, 0.1495, 0.0600, 0.1695, 0.0500, 0.1089, 0.1600]],
 
          [[0.1995, 0.1900, 0.1595, 0.0700, 0.1795, 0.0600, 0.1189, 0.1700]]],
 
 
         [[[0.0889, 0.0500, 0.0995, 0.1095, 0.0689, 0.1200, 0.0889, 0.0100]],
 
          [[0.0989, 0.0600, 0.1095, 0.1195, 0.0789, 0.1300, 0.0989, 0.0200]],
 
          [[0.1089, 0.0700, 0.1195, 0.1295, 0.0889, 0.1400, 0.1089, 0.0300]],
 
          [[0.1189, 0.0800, 0.1295, 0.1395, 0.0989, 0.1500, 0.1189, 0.0400]],
 
          [[0.1289, 0.0900, 0.1395, 0.1495, 0.1089, 0.1600, 0.1289, 0.0500]],
 
          [[0.1389, 0.1000, 0.1495, 0.1595, 0.1189, 0.1700, 0.1389, 0.0600]]]],
        gr

### Convolution Block

We'll go through the convolution layers again. Given the high learning rate we will now see how different weights impact our convolution. 

#### Convolution Block - 1x3 Conv

This time through we'll do all the steps sequentially and just print the output.  Even with the weight updates you can see that because the gradient was so small it did not impact our convolution output.  While our weights were updated, because our weights are still so repetitive, we can see that the values we have across the channel dimension remain consistent yet our token position differs. 

In [86]:
c1_unfolded = F.unfold(x, 
		kernel_size=(c1_kernel_height, c1_kernel_width),  # (1,3)
		padding=(c1_padding_height, c1_padding_width), #(0,1)
		stride=(c1_stride_height, c1_stride_width))#(1,1)
conv1_weigth = conv1.view(n_embd, -1) # [6,6,1,3] > [6,18]
conv1_weigth

tensor([[0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
         0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
        [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
         0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
        [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
         0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
        [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
         0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
        [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
         0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010],
        [0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010,
         0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010, 0.0010, 0.0020, 0.0010]],
       grad_fn=<ViewB

In [87]:
out = conv1_weigth @ c1_unfolded
out = out.view(batch,n_embd, c1_height_out, c1_width_out)
out.size(), out

(torch.Size([2, 6, 1, 8]),
 tensor([[[[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]],
 
          [[0.0031, 0.0038, 0.0029, 0.0023, 0.0023, 0.0019, 0.0022, 0.0023]]],
 
 
         [[[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0024, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0024, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0024, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0024, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0024, 0.0011]],
 
          [[0.0018, 0.0023, 0.0028, 0.0029, 0.0028, 0.0030, 0.0024, 0.0011]]]],
        gr

#### Convolution Block - First Batch Norm

Both the weight and bias was updated in the batch normalization but the updates were consistent across the channels for normalization so we don't expect too much of a deviation from an initialized batch norm. Interestingly though, we can see that we have a set of channels that becomes negative before our ReLU step so we'll again have a set of channels get clipped during dilation. 

In [88]:
out = bn_a(out)
out.size(), out

(torch.Size([2, 6, 1, 8]),
 tensor([[[[ 0.1816,  0.4146,  0.1164, -0.0700, -0.0514, -0.1828, -0.0906,
            -0.0607]],
 
          [[ 0.1816,  0.4146,  0.1164, -0.0700, -0.0514, -0.1828, -0.0906,
            -0.0607]],
 
          [[ 0.1816,  0.4146,  0.1164, -0.0700, -0.0514, -0.1828, -0.0906,
            -0.0607]],
 
          [[ 0.1816,  0.4146,  0.1164, -0.0700, -0.0514, -0.1828, -0.0906,
            -0.0607]],
 
          [[ 0.1816,  0.4146,  0.1164, -0.0700, -0.0514, -0.1828, -0.0906,
            -0.0607]],
 
          [[ 0.1816,  0.4146,  0.1164, -0.0700, -0.0514, -0.1828, -0.0906,
            -0.0607]]],
 
 
         [[[-0.2118, -0.0524,  0.0781,  0.1320,  0.0947,  0.1517, -0.0161,
            -0.4334]],
 
          [[-0.2118, -0.0524,  0.0781,  0.1320,  0.0947,  0.1517, -0.0161,
            -0.4334]],
 
          [[-0.2118, -0.0524,  0.0781,  0.1320,  0.0947,  0.1517, -0.0161,
            -0.4334]],
 
          [[-0.2118, -0.0524,  0.0781,  0.1320,  0.0947,  0.1517, -0.0

#### Convolution Block - ReLU

We can see a lot more exaggeration now within our Relu step.  We start pushing some values very high and still see a set of channels zero-out.  If we were training for a long time we'd need to keep an eye on this to make sure we don't end up with dead channels (this is also why different types of normalized initiation are used). 

In [89]:
out = F.relu(out) 
out.size(), out

(torch.Size([2, 6, 1, 8]),
 tensor([[[[0.1816, 0.4146, 0.1164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1816, 0.4146, 0.1164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1816, 0.4146, 0.1164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1816, 0.4146, 0.1164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1816, 0.4146, 0.1164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1816, 0.4146, 0.1164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
 
 
         [[[0.0000, 0.0000, 0.0781, 0.1320, 0.0947, 0.1517, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0781, 0.1320, 0.0947, 0.1517, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0781, 0.1320, 0.0947, 0.1517, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0781, 0.1320, 0.0947, 0.1517, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0781, 0.1320, 0.0947, 0.1517, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0781, 0.1320, 0.0947, 0.1517, 0.0000, 0.0000]]]],
        gr

#### Convolution Block - 1x3 Conv 1x2 Stride downsample

With our second convolution, you can see that we now have 2 channels that are no longer consistant.  While the last two channels are equal to each other, they differ from the first 4.  Because the channels and the token positions differ, we can expect a similar pattern to the weight in our convolution output where the first 4 channels have a consistent value per token but our last two channels have a different value. 

In [90]:
c2_unfolded = F.unfold(out, 
		kernel_size=(c2_kernel_height, c2_kernel_width),  # (1,3)
		padding=(c2_padding_height, c2_padding_width), #(0,1)
		stride=(c2_stride_height, c2_stride_width))#(1,2)

conv2_weigth = conv2.view(n_embd, -1) # [6,6,1,3] > [6,18]
conv2_weigth

tensor([[ 0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,
          0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,
          0.0401,  0.0426],
        [ 0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,
          0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,
          0.0401,  0.0426],
        [ 0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,
          0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,
          0.0401,  0.0426],
        [ 0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,
          0.0426,  0.0337,  0.0401,  0.0426,  0.0337,  0.0401,  0.0426,  0.0337,
          0.0401,  0.0426],
        [-0.0615, -0.0772, -0.0823, -0.0615, -0.0772, -0.0823, -0.0615, -0.0772,
         -0.0823, -0.0615, -0.0772, -0.0823, -0.0615, -0.0772, -0.0823, -0.0615,
         -0.0772, -0.0823],
        [-0.0615, -0.0772, -0.0823, -0.0615, -0.07

In [91]:
out = conv2_weigth @ c2_unfolded
out = out.view(batch,n_embd, c2_height_out, c2_width_out)
out.size(), out

(torch.Size([2, 6, 1, 4]),
 tensor([[[[ 0.1497,  0.1119,  0.0000,  0.0000]],
 
          [[ 0.1497,  0.1119,  0.0000,  0.0000]],
 
          [[ 0.1497,  0.1119,  0.0000,  0.0000]],
 
          [[ 0.1497,  0.1119,  0.0000,  0.0000]],
 
          [[-0.2888, -0.2068,  0.0000,  0.0000]],
 
          [[-0.2888, -0.2068,  0.0000,  0.0000]]],
 
 
         [[[ 0.0000,  0.0526,  0.0883,  0.0307]],
 
          [[ 0.0000,  0.0526,  0.0883,  0.0307]],
 
          [[ 0.0000,  0.0526,  0.0883,  0.0307]],
 
          [[ 0.0000,  0.0526,  0.0883,  0.0307]],
 
          [[ 0.0000, -0.1014, -0.1674, -0.0559]],
 
          [[ 0.0000, -0.1014, -0.1674, -0.0559]]]], grad_fn=<ViewBackward0>))

#### Convolution Block - Second Batch Norm

Our last batch normalization actually had a similar patter to our second convolution where we saw that it had values > 1.000 for the first 4 channels and values < 1.000 for the last two. These weights change how the normalization is done so that it's not consistently using all values for the mean.  That said we can see an interesting pattern start to emerge in our pass as to how the downsampled token position and the channel start to impact our final prediction. Notice though that despite these weights, we can still see a repeated pattern of consistency across the first 4 channels as well as the last two. 

In [92]:
bn_b.weight, bn_b.bias

(Parameter containing:
 tensor([1.0012, 1.0012, 1.0012, 1.0012, 0.9976, 0.9976], requires_grad=True),
 Parameter containing:
 tensor([ 0.0106,  0.0106,  0.0106,  0.0106, -0.0233, -0.0233],
        requires_grad=True))

In [93]:
out = bn_b(out)
out.size(), out

(torch.Size([2, 6, 1, 4]),
 tensor([[[[ 1.7928,  1.0873, -0.9990, -0.9990]],
 
          [[ 1.7928,  1.0873, -0.9990, -0.9990]],
 
          [[ 1.7928,  1.0873, -0.9990, -0.9990]],
 
          [[ 1.7928,  1.0873, -0.9990, -0.9990]],
 
          [[-1.8412, -1.0411,  0.9776,  0.9776]],
 
          [[-1.8412, -1.0411,  0.9776,  0.9776]]],
 
 
         [[[-0.9990, -0.0190,  0.6474, -0.4269]],
 
          [[-0.9990, -0.0190,  0.6474, -0.4269]],
 
          [[-0.9990, -0.0190,  0.6474, -0.4269]],
 
          [[-0.9990, -0.0190,  0.6474, -0.4269]],
 
          [[ 0.9776, -0.0118, -0.6568,  0.4318]],
 
          [[ 0.9776, -0.0118, -0.6568,  0.4318]]]],
        grad_fn=<NativeBatchNormBackward0>))

### Residual Connection
Now that we have updated our convolution block, we have to travel our residual connection.  The residual connection includes the learnable downsampling 1x1 convolution.  Interestingly we saw that this layer learned weights across the channel showing an emergence of the model learning higher level patterns, not just token frequencies. 

#### Residual Connection - Downsampling 1x1 Convolution 2 Stride
We'll first start with the downsampling convolution. We see that both channel dimensions on the weights are adjusted hinting at some interesting data patterns in our examples.  Notice though that despite these weights, we can still see a repeated pattern of consistency across the first 4 channels as well as the last two.  

In [94]:
x_unfolded = F.unfold(x, 
		kernel_size=(res_kernel_height, res_kernel_width),  # (1,1)
		padding=(res_padding_height, res_padding_width), #(0,0)
		stride=(res_stride_height, res_stride_width))#(1,2)

convRes_weigth = convRes.view(n_embd, -1) # [6,6,1,1] > [6,6]
convRes_weigth

tensor([[1.0012, 1.0013, 1.0014, 1.0015, 1.0017, 1.0018],
        [1.0012, 1.0013, 1.0014, 1.0015, 1.0017, 1.0018],
        [1.0012, 1.0013, 1.0014, 1.0015, 1.0017, 1.0018],
        [1.0012, 1.0013, 1.0014, 1.0015, 1.0017, 1.0018],
        [0.9973, 0.9971, 0.9969, 0.9966, 0.9964, 0.9962],
        [0.9973, 0.9971, 0.9969, 0.9966, 0.9964, 0.9962]],
       grad_fn=<ViewBackward0>)

In [95]:
identity = convRes_weigth @ x_unfolded
identity = identity.view(batch,n_embd, res_height_out, res_width_out)
identity.size(), identity

(torch.Size([2, 6, 1, 4]),
 tensor([[[[1.0484, 0.8080, 0.9282, 0.5644]],
 
          [[1.0484, 0.8080, 0.9282, 0.5644]],
 
          [[1.0484, 0.8080, 0.9282, 0.5644]],
 
          [[1.0484, 0.8080, 0.9282, 0.5644]],
 
          [[1.0433, 0.8041, 0.9237, 0.5617]],
 
          [[1.0433, 0.8041, 0.9237, 0.5617]]],
 
 
         [[[0.6846, 0.7479, 0.5644, 0.6846]],
 
          [[0.6846, 0.7479, 0.5644, 0.6846]],
 
          [[0.6846, 0.7479, 0.5644, 0.6846]],
 
          [[0.6846, 0.7479, 0.5644, 0.6846]],
 
          [[0.6813, 0.7443, 0.5617, 0.6813]],
 
          [[0.6813, 0.7443, 0.5617, 0.6813]]]], grad_fn=<ViewBackward0>))

#### Residual Connection - Sum

The residual connection output and our convolutional block outputs both had very different values and so when we combine them  we can even see certain entries becoming more muted, some flipping signs or becoming amplified in the direction one layer suggested. All together this shows the value of how the residual connections balances out the convolution layers and how the network can learn to use each for the type of outputs it's improving on.  Now since both the residual and convolution layers had the same patter: consistency across the first 4 channels as well as the last 2, we maintain that after the residual connection and convolution path merge.  

In [96]:
x = out + identity
x.size(), x

(torch.Size([2, 6, 1, 4]),
 tensor([[[[ 2.8411,  1.8953, -0.0708, -0.4346]],
 
          [[ 2.8411,  1.8953, -0.0708, -0.4346]],
 
          [[ 2.8411,  1.8953, -0.0708, -0.4346]],
 
          [[ 2.8411,  1.8953, -0.0708, -0.4346]],
 
          [[-0.7979, -0.2370,  1.9013,  1.5393]],
 
          [[-0.7979, -0.2370,  1.9013,  1.5393]]],
 
 
         [[[-0.3144,  0.7289,  1.2118,  0.2577]],
 
          [[-0.3144,  0.7289,  1.2118,  0.2577]],
 
          [[-0.3144,  0.7289,  1.2118,  0.2577]],
 
          [[-0.3144,  0.7289,  1.2118,  0.2577]],
 
          [[ 1.6589,  0.7325, -0.0951,  1.1131]],
 
          [[ 1.6589,  0.7325, -0.0951,  1.1131]]]], grad_fn=<AddBackward0>))

### Output Layer

The output layer saw the largest gradients and so we used a more normal, still large, learning rate. You'll notice that despite the inputs this layer will still dominate our logit prediction.

#### Output Layers - Adaptive Average Pooling
Before we do the final Logit calculation let's do our pooling.  You'll see that we have the same patter maintained across the channels of consistency on the first 4 and the last 2

In [97]:
x = avgPool(x)
x.size(), x

(torch.Size([2, 6, 1, 1]),
 tensor([[[[1.0578]],
 
          [[1.0578]],
 
          [[1.0578]],
 
          [[1.0578]],
 
          [[0.6015]],
 
          [[0.6015]]],
 
 
         [[[0.4710]],
 
          [[0.4710]],
 
          [[0.4710]],
 
          [[0.4710]],
 
          [[0.8524]],
 
          [[0.8524]]]], grad_fn=<MeanBackward1>))

#### Output Layers - Remove Dimension
We'll again reshape without changing our values. 

In [98]:
x = x.squeeze(2)
x = x.permute(0,2,1)
x.size(), x

(torch.Size([2, 1, 6]),
 tensor([[[1.0578, 1.0578, 1.0578, 1.0578, 0.6015, 0.6015]],
 
         [[0.4710, 0.4710, 0.4710, 0.4710, 0.8524, 0.8524]]],
        grad_fn=<PermuteBackward0>))

#### Output Layers - LM Head aka logits
Now we do our final projection against the LM head.  We know that our positions in Y are `[4,12]`.  We can see that the logit for the first example did very well and has by far the highest value for that position. At the same time though we can see that in the second example we did not do so well. A couple values got pulled negative to show they were not candidates but the remaining values are indistinguishable.  This might be because of our weird learning rates, but this is fine as we'd typically just run more training to fix this over time.  

In [99]:
logits = lm_head(x)
logits.shape, logits

(torch.Size([2, 1, 15]),
 tensor([[[-0.2201, -0.2201, -0.2201, -0.2201, -0.2201, -0.2201, -0.2201,
           -0.2201,  2.3387, -0.2201,  1.3374, -0.2201, -0.2201, -0.2201,
           -0.2201]],
 
         [[-0.1453, -0.1453, -0.1453, -0.1453, -0.1453, -0.1453, -0.1453,
           -0.1453,  1.5445, -0.1453,  0.8833, -0.1453, -0.1453, -0.1453,
           -0.1453]]], grad_fn=<UnsafeViewBackward0>))

### Updated Loss calculation

Now we'll calculate the updated loss.  Our first pass's loss was 2.7081, on par with random. Since we're passing through the same example and used a fairly high learning rate we should see a significant improvement with just 1 learning pass. 

In [100]:
loss

tensor(2.7081, grad_fn=<NllLossBackward0>)

In [101]:
y_flat = y.view(-1)
logits_flat = logits.view(-1, logits.size(-1))
updated_loss = F.cross_entropy(logits_flat, y_flat)
print(updated_loss.shape, updated_loss)

torch.Size([]) tensor(1.4453, grad_fn=<NllLossBackward0>)


In [102]:
f'1 round of training resulted in an loss improvment of {loss.item() - updated_loss.item():.4f}'

'1 round of training resulted in an loss improvment of 1.2628'

## Training SUCCESS!
Our training improved the loss by about **~41%** (amount may vary since we didn't set a seed). There are flaws with this, mainly passing the same example through a second time and a high learning rate, but this helps show the fundamentals of what learning does inside a ResNet style CNN. 

## Logit to Token

For our last piece of code on this notebook, we'll actually now convert our logits into actual "next tokens", the goal of this head.  To predict the next tokens we convert:
1. Convert our logits into probabilities
2. Sample from the logits based on the probabilities
3. Convert the token ids into tokens.

As you can see the sampling is non-deterministics. This allows the next token to flow in a more realistic  conversational manner since, if we sample many tokens, over time we won't always pick the highest probability id. Some people read this as creativity, or consciousness, others point to this being one of a few key reasons for hallucinations (which are also caused by the fact that models don't fully memorize every piece of training data), but in reality it's just statistics. 

### Logit and Input shaping
Similar to other steps, we first have to get our tensors into the right dimensions. We'll start by compressing out our batch dimension so we just have a a tensor that shows our `[B*T,logits]`, or otherwise, an entry for each example ignoring our batch.  We'll also reinitialize our input tokens and do the same transformation so that we have a tensor ready to append to. 

In [103]:
pred_logits = logits.flatten(0, 1) 
pred_logits.size(), pred_logits

(torch.Size([2, 15]),
 tensor([[-0.2201, -0.2201, -0.2201, -0.2201, -0.2201, -0.2201, -0.2201, -0.2201,
           2.3387, -0.2201,  1.3374, -0.2201, -0.2201, -0.2201, -0.2201],
         [-0.1453, -0.1453, -0.1453, -0.1453, -0.1453, -0.1453, -0.1453, -0.1453,
           1.5445, -0.1453,  0.8833, -0.1453, -0.1453, -0.1453, -0.1453]],
        grad_fn=<ViewBackward0>))

In [104]:
xgen = tok_for_training[:-1].view(B_batch, T_context)
xgen.size(), xgen

(torch.Size([2, 8]),
 tensor([[14, 13, 10,  1, 12,  0,  6, 11],
         [ 8,  4,  9, 10,  6, 11,  8,  0]]))

### Logit into probabilities. 
If you inspect the logits, you'll notice that they don't sum up to `1.000`. Our goal is to understand the probability of each token in our vocab as being the next one.  Because of this we need to convert our weight to probabilities.  The most common approach is to use Softmax, which rescales a tensor so that the elements of the n-dimensional output Tensor lie in the range `(0,1)` and sum to 1. This is done applying the following formula based on the dimension specified:

$$
\mathrm{Softmax}(x)_i = \frac{\exp(x_i)}{\sum_{j} \exp(x_j)}
$$

*Note that when the input Tensor is a sparse tensor then the unspecified values are treated as $-\infty$. This is handy in steps like attention masking.*

In [105]:
probs = F.softmax(pred_logits, dim=-1)
probs

tensor([[0.0326, 0.0326, 0.0326, 0.0326, 0.0326, 0.0326, 0.0326, 0.0326, 0.4213,
         0.0326, 0.1548, 0.0326, 0.0326, 0.0326, 0.0326],
        [0.0471, 0.0471, 0.0471, 0.0471, 0.0471, 0.0471, 0.0471, 0.0471, 0.2554,
         0.0471, 0.1318, 0.0471, 0.0471, 0.0471, 0.0471]],
       grad_fn=<SoftmaxBackward0>)

### Token Sampling 
Now that we have probabilities, we can now perform our sampling. Because next-token sampling is a categorical draw from the model’s softmax, `torch.multinomial` is the go-to choice for sampling from discrete distributions.  Multinomial takes non-negative weights (e.g., $\mathrm{softmax}(\ell/T)$ or $\exp(\ell/T))$ and returns indices without having to write complex loops. You can also tune its behavior with different temperatures and using top-k/top-p slicing. We'll keep things simple and just draw a single token. 

In cases where you'd want to generate a lot of text (e.g. chatbot use cases), instead of doing multi-sampling, you'd run the forward pass, sample, then using that sample, run another forward pass, iteratively. 

Once we sample the tokens we'll then append them to our input X `xgen` to generate `T_context+1` long examples. 

In [106]:
xcol = torch.multinomial(probs, 1) 
xcol

tensor([[10],
        [11]])

In [107]:
xgen = torch.cat((xgen, xcol), dim=1)
xgen

tensor([[14, 13, 10,  1, 12,  0,  6, 11, 10],
        [ 8,  4,  9, 10,  6, 11,  8,  0, 11]])

### Converting into tokens. 
Now that we have our tokens generated and appended we just have to convert them back into the original text using our tokenizer. We'll also print out our original text to see how close it gets. 

**Original Text**

In [108]:
#expected:
print(tok.decode(tok_for_training[:9].tolist()))
print(tok.decode(tok_for_training[8:].tolist()))

<|endoftext|>ccggaag,ffeed
eddc,ggffeed,gg


**Predicted text**

In [109]:
for i in range(xgen.size()[0]):
    tokens = xgen[i,:].tolist()
    decoded = tok.decode(tokens)
    print(f'batch {i}: {decoded}')

batch 0: <|endoftext|>ccggaag,ffegg
batch 1: eddc,ggffeed,fe


## Conclusion 
So you might look at this now and say "wait, the predicted text is different than our correct text, yet our loss decreased, what's going on?".  To this I'd point to our loss function, it compares all logits, so as long as the probability for the right logit increases and the wrong logits decrease, the loss will decrease, yet, unless it's 0, it's not perfect.  