# RNN GRU Explainer

The goal is to walk through RNNs (recurrent neural networks), there are 2 flavors: GRUs and LSTMs. We'll focus on Gated Recurrent Unit (GRUs). RNN's have a unique structure where they process data sequentially which enables them to take in both the current and a prior state, making them great for sequential data like time series or streamed data.  The core unit of the RNN takes in both the data you want to use to make a inference along with any previous context. While this is great there are a few dangers of the RNN: primarily that improper setup and training can lead to vanishing/exploding gradients and their sequential nature can consume a lot of resources. As you go through this notebook you can hopefully see why this would be the case. 

To help display how the RNN works, we'll use the first sentence from the [linear algebra wiki page](https://en.wikipedia.org/wiki/Linear_algebra) and [lu decomposition wiki page](https://en.wikipedia.org/wiki/LU_decomposition) as the topic is fitting and it shows us some non-standard patterns.  We'll take on a simpilar task to many other notebooks in this repository where we have a few sentences of text and want to predict some other text, in this case the next token. 

## Text Prep/Tokenization

We'll start with common preprocessing step of tokenizing the data.  This converts the string text into an array of numbers that can be used during the training loop.  I've built a very subtle byte-pair encdoing that has each unique character that appears and the top 5 merges. This keeps our vocab size small and managable for this example. Typically the vocab size is in the 100K+ range. A great library for this is `tiktoken`. Tokenization simply finds the longest pattern of characters that's in common with what was trained and replaces it with an integer that represents it.  This way we turn the text into a numeric array to simplify computing. import torch
from collections import Counter

In [1]:
import torch
from collections import Counter

In [2]:
class SimpleBPETokenizer:
    def __init__(self, num_merges=5, eot_token='<|endoftext|>'):
        self.num_merges = num_merges
        self.eot_token = eot_token
        self.eot_id = None
        self.merges = []
        self.pair_ranks = {}
        self.vocab = {}
        self.id_to_token = {}

    def _add_token(self, tok):
        if tok in self.vocab:
            return self.vocab[tok]
        i = len(self.vocab)
        self.vocab[tok] = i
        self.id_to_token[i] = tok
        return i

    def _get_bigrams(self, seq):
        for i in range(len(seq) - 1):
            yield (seq[i], seq[i + 1])

    def _merge_once(self, seq, pair):
        a, b = pair
        out = []
        i = 0
        while i < len(seq):
            if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        return out

    def train(self, corpus):
        # corpus: list[str]
        text = ''.join(corpus).lower()
        seq = list(text)
        merges = []
        for _ in range(self.num_merges):
            counts = Counter(self._get_bigrams(seq))
            if not counts: break
            best_pair, _ = counts.most_common(1)[0]
            merges.append(best_pair)
            seq = self._merge_once(seq, best_pair)
        self.merges = merges
        self.pair_ranks = {p: i for i, p in enumerate(self.merges)}

        self.vocab = {}
        self.id_to_token = {}
        for ch in sorted(set(text)):
            self._add_token(ch)
        for a, b in self.merges:
            self._add_token(a + b)
        self.eot_id = self._add_token(self.eot_token)

    def encode(self, text, force_last_eot=True):
        # treat literal eot marker as special; remove it from content
        if self.eot_token in text:
            text = text.replace(self.eot_token, '')
        seq = list(text)

        # make sure all seen base chars exist
        for ch in set(seq):
            if ch not in self.vocab:
                self._add_token(ch)

        # greedy BPE using learned pair ranks
        if self.merges:
            while True:
                best_pair, best_rank = None, None
                for p in self._get_bigrams(seq):
                    r = self.pair_ranks.get(p)
                    if r is not None and (best_rank is None or r < best_rank):
                        best_pair, best_rank = p, r
                if best_pair is None:
                    break
                seq = self._merge_once(seq, best_pair)

        # ensure all tokens in seq exist in vocab (e.g., if new chars appeared)
        for tok in seq:
            if tok not in self.vocab:
                self._add_token(tok)

        ids = [self.vocab[tok] for tok in seq]

        # FORCE: append EOT id if not already last
        if force_last_eot:
            if not ids or ids[-1] != self.eot_id:
                ids.append(self.eot_id)

        return ids

    def decode(self, ids):
        # drop trailing EOT if present
        if ids and self.eot_id is not None and ids[-1] == self.eot_id:
            ids = ids[:-1]
        toks = [self.id_to_token[i] for i in ids]
        return ''.join(toks)


In [3]:
raw_example_1 = r'''Linear algebra is central to almost all areas of mathematics. For instance, linear algebra is fundamental in modern presentations of geometry, including for defining basic objects such as lines, planes and rotations. Also, functional analysis, a branch of mathematical analysis, may be viewed as the application of linear algebra to function spaces.'''
raw_example_2 = r'''In numerical analysis and linear algebra, lower–upper (LU) decomposition or factorization factors a matrix as the product of a lower triangular matrix and an upper triangular matrix (see matrix multiplication and matrix decomposition).'''


In [4]:
tok = SimpleBPETokenizer(num_merges=5)
tok.train([raw_example_1,raw_example_2])
tok.merges

[(' ', 'a'), ('a', 't'), ('i', 'n'), (' ', 'm'), ('i', 'o')]

In [5]:
tok.vocab

{' ': 0,
 '(': 1,
 ')': 2,
 ',': 3,
 '.': 4,
 'a': 5,
 'b': 6,
 'c': 7,
 'd': 8,
 'e': 9,
 'f': 10,
 'g': 11,
 'h': 12,
 'i': 13,
 'j': 14,
 'l': 15,
 'm': 16,
 'n': 17,
 'o': 18,
 'p': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28,
 '–': 29,
 ' a': 30,
 'at': 31,
 'in': 32,
 ' m': 33,
 'io': 34,
 '<|endoftext|>': 35}

In [6]:
vocab_size = len(tok.vocab)
vocab_size

36

In [7]:
eot = tok.eot_id
tokens = []
for example in [raw_example_1, raw_example_2]:
    tokens.extend([eot])
    tokens.extend(tok.encode(example.lower()))
all_tokens = torch.tensor(tokens, dtype=torch.long)
all_tokens

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0,  7,
         9, 17, 22, 20,  5, 15,  0, 22, 18, 30, 15, 16, 18, 21, 22, 30, 15, 15,
        30, 20,  9,  5, 21,  0, 18, 10, 33, 31, 12,  9, 16, 31, 13,  7, 21,  4,
         0, 10, 18, 20,  0, 32, 21, 22,  5, 17,  7,  9,  3,  0, 15, 32,  9,  5,
        20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0, 10, 23, 17,  8,  5, 16,
         9, 17, 22,  5, 15,  0, 32, 33, 18,  8,  9, 20, 17,  0, 19, 20,  9, 21,
         9, 17, 22, 31, 34, 17, 21,  0, 18, 10,  0, 11,  9, 18, 16,  9, 22, 20,
        27,  3,  0, 32,  7, 15, 23,  8, 32, 11,  0, 10, 18, 20,  0,  8,  9, 10,
        32, 32, 11,  0,  6,  5, 21, 13,  7,  0, 18,  6, 14,  9,  7, 22, 21,  0,
        21, 23,  7, 12, 30, 21,  0, 15, 32,  9, 21,  3,  0, 19, 15,  5, 17,  9,
        21, 30, 17,  8,  0, 20, 18, 22, 31, 34, 17, 21,  4, 30, 15, 21, 18,  3,
         0, 10, 23, 17,  7, 22, 34, 17,  5, 15, 30, 17,  5, 15, 27, 21, 13, 21,
         3, 30,  0,  6, 20,  5, 17,  7, 

# Modeling

A machine learning model forward pass now uses the tokenization information, runs several layers of linear algebra on it, and then "predicts" the next token. When it is noisy (like you will see in this example), this process results in gibberish.  The training process changes the noise to pattern during the "backward pass" as you'll see.    We'll show 3 steps that are focused on training:
1. **Data Loading** `x, y = train_loader.next_batch()` - this step pulls from the raw data enough tokens to complete a forward and backward pass.  If the model is inference only, this step is replaced with taking in the inference input and preparing it similarly as the forward pass.
2. **Forward Pass** `logits, loss = model(x, y)` - using the data and the model architecture to predict the next token. When training we also compare against the expected to get loss, but in infrerence, we use the logits to complete the inference task.
3. **Back Propogation, aka Backward Pass & Training** `loss.backward(); optimizer.step()` - using differentials to understand what parameters most impact the forward pass' impact on its prediction, comparing that against what is actually right based on the data loading step, and then making very minor adjustments to the impactful parameters with the hope it improves future predictions.

The we'll show a final **Forward Pass** with the updated weights we did in #3. 

## Data Loading

To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to hold all at once in real practice, we would read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 
To start, we have to identify the batch size and the model context length to determine how much data we need.  Consequently, these dimensions also form 2 of the 3 dimensions in the initial matrix.
- **Batch Size (B)** - This is the number of examples you'll train on in a single pass. 
- **Context Length (T)** - This is the max number of tokens that a model can use in a single pass to generat the next token. If an example is below this length, it can be padded.
  
*Ideally both B and T are multiples of 2 to work nicely with chip architecture. This is a common theme across the board*

In [8]:
B = 2 # Batch
T = 8 # context length

To start, we need to pull from our long raw_token list enough tokens for the forward pass. To be able to satisfy training `B` Batches `T` Context length, we need to pull out `B*T` tokens to slide the context window across the examples enough to satisfy the batch size.  Since the training will attempt to predict the last token given the previous tokens in context, we also need 1 more token at the end so that the last training example in the last batch can have the next token to validate against. 

In [9]:
current_position = 0
tok_for_training = all_tokens[current_position:current_position + B*T +1 ]
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0])

Now that we have our initial tokens to train on, we now need to convert it to a matrix that's ready for training. In this step we'll need to create our batches and setup two different arrays: 1/ the input, `x`, tokens that will result in 2/ the output `y` tokens. To create each example in the batch, every `T` tokens will be placed into it's own row. 

Recall that training takes in a string of tokens the length of the context and then predicts the next token. Recall that when we extracted `tok_for_training` we added 1 extra token so that we can evaluate the prediction for the last example. Because of this, the input, `x`, will be all of the tokens up to the second to last element `[:-1]`.  

It might be natural to think the output `y` would then just be the last token.But this is actually wasting valuable training loops.  Yes, there is the example that fills the context `T`, but we also have enough tokens in `tok_for_training` where any context length of `n` where `n<T` can also be used for inference since we have the `n+1` token available.  You can think of the following example:

sentence: `Hi I am learning`. This sentence contains the following "next tokens" that can be learned:
1. x: Hi I am  | y: learning
2. x: Hi I     | y: am
3. x: Hi       | y: I

Because we have this triangle to create, our `y` can be much larger.  We can start with the second token and, go all the way to the last element we added for the last example `[1:'`.   


We will now put this together and do the following:
1. Extract the input `x` and then split it into an example for each batch `B`
2. Extract the output `y` and then split it into an example for each batch `B`

*Note: View can take `-1` which allows the matrix to infer the dimension so we do not need to pass in `T`, but given how many matrices we'll work with we want to make sure we're controlling the dimensions or erroring out if they do not match our expectations.*

In [10]:
x=tok_for_training[:-1].view(B, T)
x

tensor([[35, 15, 32,  9,  5, 20, 30, 15],
        [11,  9,  6, 20,  5,  0, 13, 21]])

In [11]:
y=tok_for_training[1:].view(B, T)
y

tensor([[15, 32,  9,  5, 20, 30, 15, 11],
        [ 9,  6, 20,  5,  0, 13, 21,  0]])

## Forward pass

<img src="explainer_screenshots/gru/full_network_gru.png" width="200">

During training, the forward pass takes a string of tokens in and predicts the likelihood that each output token is the next "n" tokens.  This step as we'll look at it is focused on training where we'll pass in the input `x`, carry that input through the layers, and generate a matrix of the probability of each token being the next one, something we call `logits`. During the forward pass, since this is an RNN, we will actually pass each example through recursively based on the token length of available. For example, if we have 3 tokens in the example and we are deriving the 4th, we would hit our GRU block 3 times through:
1. First with the embedding for the token `[0]` and no previous token
2. Then with the embedding for the token `[1]` and the output of the GRU unit's processing of token `[0]` as the previous $H_{T-1}$
3. Finally with the embedding for the token `[2]` and the output of the GRU unit's processing of token `[1]` as the previous $H_{T-1}$

At the end of the forward pass at the end we then compare the probability to the actual next token in `y` and calculate `loss` based on the difference. 

*Note that we will do some layer initialization to simplify following along.  In reality layers are often initialized to normal distribution with some adjustments made for parameter sizes to keep the weights properly noisy.  We will not cover initialization in this series*

We first rederive the batch size and context size based on the input to improve flexibility.

In [12]:
import torch.nn as nn

In [13]:
B_batch, T_context = x.size()
B_batch, T_context

(2, 8)

In [14]:
n_embd = 4 # level of embedding of input tokens
n_embd, vocab_size

(4, 36)

### Input Layer

<img src="explainer_screenshots/gru/input_layer.png" width="200">

We'll first create an initial **embedding layer** for our input tokens.  Since the RNN iteratively processes input data, position embedding info is not needed, though, it can always be added to give the model more depth. The embedding weights are `vocab_size X n_embd` and simply store weights that correspond with each token.  The more embedding layers added the more complex data the model can learn. After the embedding layer we'll then perform dropout. The **dropout layer** simply applies 0's  randomly up to the defined level percentage and normalizes the remaining probabilities in the row. Since RNNs recycle the same parameters across timesteps, features can co-adapt and overfit quickly. Variational dropout breaks these temporal co-adaptations and adds needed noise where there’s little implicit regularization. Other model architectures like transformers have strong built-in stabilizers (LayerNorm, residuals, attention) and massive pretraining, so they generally need far less dropout. 



**Embedding** 

To start we'll initilize our embeddings with a weight of 1.000 so that all inputs are equally weighted. We'll also set our embedding dimension to 4 to allows for some levels of complexity. Since our initiation weights are equal, we expect that the output of embedding will for now be equal

In [15]:
wte = nn.Embedding(vocab_size, n_embd)
torch.nn.init.ones_(wte.weight)
wte.weight

Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)

In [16]:
x = wte(x)
x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]], grad_fn=<EmbeddingBackward0>))

**Dropout**

Now, before our Recurrent unit, we'll perform dropout. Dropout will randomly zero out any value effectively removing that specific node from impacting prediction. Since this is Bernoulli based dropout, in addition to zeroing out weights the surviving entries are scaled by $1/(1-p)$. During training this helps with generalizaiton and fights fixation. You can quickly see the dropout's impact on the embeddings. 

In [19]:
dropout = nn.Dropout(0.1)
dropout

Dropout(p=0.1, inplace=False)

In [20]:
x = dropout(x)
x

tensor([[[1.2346, 1.2346, 1.2346, 1.2346],
         [0.0000, 1.2346, 1.2346, 1.2346],
         [1.2346, 1.2346, 0.0000, 1.2346],
         [0.0000, 1.2346, 1.2346, 1.2346],
         [1.2346, 1.2346, 1.2346, 0.0000],
         [1.2346, 1.2346, 1.2346, 1.2346],
         [0.0000, 1.2346, 1.2346, 1.2346],
         [0.0000, 1.2346, 1.2346, 0.0000]],

        [[1.2346, 1.2346, 1.2346, 1.2346],
         [1.2346, 1.2346, 1.2346, 1.2346],
         [1.2346, 1.2346, 1.2346, 0.0000],
         [1.2346, 1.2346, 1.2346, 1.2346],
         [1.2346, 1.2346, 0.0000, 1.2346],
         [1.2346, 1.2346, 1.2346, 1.2346],
         [1.2346, 1.2346, 1.2346, 1.2346],
         [0.0000, 1.2346, 1.2346, 1.2346]]], grad_fn=<MulBackward0>)

### Recurrent Unit - GRU

**Recurrent Block**

<img src="explainer_screenshots/gru/gru_details.png" width="600">

https://d2l.ai/chapter_recurrent-modern/gru.html

Here, the LSTM’s three gates are replaced by two: the reset gate and the update gate. As with LSTMs, these gates are given sigmoid activations, forcing their values to lie in the interval . Intuitively, the reset gate controls how much of the previous state we might still want to remember. Likewise, an update gate would allow us to control how much of the new state is just a copy of the old one


z is the update gate, r the reset gate, and h~t is the candidate hidden

Short version:

* `x2g` = “input-to-gates” affine: maps the current input $(x_t\in\mathbb{R}^{B\times E})$ to a concatenated vector of length (3H). Splitting it gives $(z_x, r_x, n_x \in \mathbb{R}^{B\times H})$.
* `h2g` = “hidden-to-gates” affine: maps the previous hidden $(h_{t-1}\in\mathbb{R}^{B\times H})$ to another (3H) vector. Splitting it gives $(z_h, r_h, n_h \in \mathbb{R}^{B\times H})$.
* `zx`, `rx`, `nx` are the pre-activation contributions from the input to the GRU’s update, reset, and “new state” (candidate) parts, respectively.
* `zh`, `rh`, `nh` are the corresponding contributions from the previous hidden state.

Putting it together (your code’s math):

$$
\begin{align}
z_x, r_x, n_x &= \mathrm{x2g}(x_t)\ \text{chunked into 3 parts}, \\
z_h, r_h, n_h &= \mathrm{h2g}(h_{t-1})\ \text{chunked into 3 parts}, \\
z &= \sigma(z_x + z_h), \\
r &= \sigma(r_x + r_h), \\
r_{nh} &= r \odot n_h, \\
\tilde{h}_t &= \tanh(n_x + r_{nh}), \\
h_t &= (1 - z)\odot \tilde{h}_t + z\odot h_{t-1}.
\end{align}
$$

Notes: (z) is the update gate, (r) the reset gate, and $(\tilde{h}_t)$ is the candidate hidden. Each “chunk” has shape ((B,H)). Also, there’s no `z2g` in that snippet—you probably meant `x2g`; the “2g” just reads “to-gates.”


$$
\begin{align}
\mathbf{Z}_t &= \sigma\!\left(\mathbf{X}_t \odot \mathbf{W}_{xz} + \mathbf{H}_{t-1} \odot \mathbf{W}_{hz} + \mathbf{b}_z\right),\\
\mathbf{R}_t &= \sigma\!\left(\mathbf{X}_t \odot \mathbf{W}_{xr} + \mathbf{H}_{t-1}\odot \mathbf{W}_{hr} + \mathbf{b}_r\right),\\
\tilde{\mathbf{H}}_t &= \tanh\!\left(\mathbf{X}_t \odot \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \odot  \mathbf{W}_{hh} + \mathbf{b}_h\right),\\
\mathbf{H}_t &= \mathbf{Z}_t \odot \mathbf{H}_{t-1} + \left(\mathbf{1}-\mathbf{Z}_t\right)\odot \tilde{\mathbf{H}}_t .
\end{align}
$$

In [None]:
hidden_size = 2

In [None]:
x2g = nn.Linear(n_embd, 3 * hidden_size, bias=True)
torch.nn.init.constant_(x2g.weight, 0.500)
torch.nn.init.zeros_(x2g.bias)
x2g.weight.size(), x2g.weight, x2g.bias.size(), x2g.bias

In [None]:
h2g = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
torch.nn.init.constant_(h2g.weight, 0.250)
torch.nn.init.zeros_(h2g.bias)
h2g.weight.size(), h2g.weight, h2g.bias.size(), h2g.bias

since first pass, incoming weight (h_t) are seroed

first pass (future will loop

In [None]:
x_0 = x[:, 0, :]
x_0

In [None]:
x_t = x2g(x_0)
zx, rx, nx = x_t.chunk(3, dim=-1)

x_t.size(), 'zx', zx, zx.size(), 'rx',rx, rx.size(),'nx',nx, nx.size()

Now we have to do the incoming weights, since this is round 0, they'll be 0 since nothing preceeds the first token. we'll set `h_t` to 0

In [None]:
h_init = torch.zeros(B_batch, hidden_size) 
h_init

In [None]:
h_t = h2g(h_init)
zh, rh, nh = h_t.chunk(3, dim=-1)

h_t.size(), 'zh', zh, zh.size(), 'rh',rh, rh.size(),'nh',nh, nh.size()

combine the pieces of of the inputs and the previous:

https://docs.pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html

In [None]:
z = torch.sigmoid(zx + zh)
z.size(), z

In [None]:
r = torch.sigmoid(rx + rh)
r.size(), r

In [None]:
r_nh = r*nh
n = torch.tanh(nx + r_nh)

n.size(), n

In [None]:
h_t = (1 - z) * n + z * h_init
h_t.size(), h_t

## now do it recursively for the rest of the batch size

first we'll need to keep track of h_t for each loop for training, so let's start collecting them. 

In [None]:
hs = []
hs.append(h_t.unsqueeze(1))
hs

start at 1 since we did the "0" pass

In [21]:
for t in range(1,T):
    x_t = x[:, t, :]
    h_prev = h_t
    zx, rx, nx = x2g(x_t).chunk(3, dim=-1)
    zh, rh, nh = h2g(h_prev).chunk(3, dim=-1)
    z = torch.sigmoid(zx + zh)
    r = torch.sigmoid(rx + rh)
    n = torch.tanh(nx + r * nh)
    h_t = (1 - z) * n + z * h_prev
    
    print(f't: {t}')
    print(h_t)
    hs.append(h_t.unsqueeze(1))

NameError: name 'h_t' is not defined

In [None]:
hs

Combine out recurring inputs into the 2 different batches

In [None]:
x = torch.cat(hs, dim=1) 
x

**Dropout** to fight vanishing / exploding gradient

In [None]:
x = dropout(x)
x.size(), x

**Output Head**
projects down from the hiden size to the vocab for us to get logits. 

In [None]:
lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
torch.nn.init.ones_(lm_head.weight)
lm_head.weight.size(), lm_head.weight

In [None]:
logits = lm_head(x)

logits.shape, logits

**Loss**

In [None]:
import torch.nn.functional as F

In [None]:
y_flat = y.view(-1)
y_flat.shape, y_flat

In [None]:
logits_flat = logits.view(-1, logits.size(-1))
logits_flat.shape, logits_flat

In [None]:
loss = F.cross_entropy(logits_flat, y_flat)
loss.shape, loss

## Back Propogation

In [None]:
lm_head.zero_grad()
h2g.zero_grad()
x2g.zero_grad()
wte.zero_grad()


# validate gradients
lm_head.weight.grad, wte.weight.grad

In [None]:
loss.backward()

In [None]:
lm_head.weight.grad, wte.weight.grad

## Learning 

In [None]:
## Huge learning rate to emphasize
learning_rate = 5.000

In [None]:
with torch.no_grad():
    lm_head.weight -= learning_rate * lm_head.weight.grad
lm_head.weight

In [None]:
with torch.no_grad():
    h2g.weight -= learning_rate * h2g.weight.grad
    h2g.bias -= learning_rate * h2g.bias.grad
h2g.weight, h2g.bias

In [None]:
with torch.no_grad():
    x2g.weight -= learning_rate * x2g.weight.grad
    x2g.bias -= learning_rate * x2g.bias.grad
x2g.weight, x2g.bias

In [None]:
with torch.no_grad():
    wte.weight -= learning_rate * wte.weight.grad
wte.weight

## Forward Pass with Updated Weights

In [None]:
x_2 = tok_for_training[:-1].view(B, T)
x_2, y

## Input projection

In [None]:
x = wte(x_2)
x.shape, x

**Dropout**

In [None]:
x = dropout(x)
x


**Recurrent Block** Collapsed Together

h_t still resets to 0.

In [None]:
h_t = torch.zeros(B_batch, hidden_size) 
hs = []
h_t

In [None]:
for t in range(T):
    x_t = x[:, t, :]
    h_prev = h_t
    zx, rx, nx = x2g(x_t).chunk(3, dim=-1)
    zh, rh, nh = h2g(h_prev).chunk(3, dim=-1)
    z = torch.sigmoid(zx + zh)
    r = torch.sigmoid(rx + rh)
    n = torch.tanh(nx + r * nh)
    h_t = (1 - z) * n + z * h_prev
    
    print(f't: {t}')
    print(h_t)
    hs.append(h_t.unsqueeze(1))

hs

combine weights back together

In [None]:
x = torch.cat(hs, dim=1) 
x.size(), x

**Dropout**

In [None]:
x = dropout(x)
x


**Head**

In [None]:
logits = lm_head(x)
logits.shape, logits

### Updated Loss calculation

Now we'll calculate the updated loss.  Our first pass's loss was 3.5835. Since we're passing through the same example and used a fairly high learning rate we should see a significant improvement with just 1 learning pass. 

In [None]:
loss

In [None]:
y_flat = y.view(-1)
logits_flat = logits.view(-1, logits.size(-1))
updated_loss = F.cross_entropy(logits_flat, y_flat)
print(updated_loss.shape, updated_loss)

In [None]:
f'1 round of training resulted in an loss improvment of {loss.item() - updated_loss.item():.4f}'

# SUCCESS!
Our training improved the loss by about **~6%** (amount may vary since we didn't set a seed). There are flaws with this, mainly passing the same example through a second time, but this helps show the fundamentals of what learning does inside a GPT-2 style model. 