# RNN LSTM Explainer

The goal is to walk through RNNs (recurrent neural networks), there are 2 flavors: GRUs and LSTMs. We'll focus on Long Short-Term Memory (LSTM). LSTM's have a unique structure where they process data sequentially which enables them to take in both the current and a prior state, making them great for sequential data like time series or streamed data.  The core unit of the LSTM takes in both the data you want to use to make an inference along with any previous context. LSTMs' memory cell uses three “gates”: an input gate, a forget gate, and an output gate. These gates allow the model to decide which information to keep, which to forget, and when to use it. GRU's only have the 2 gates for input and forget. While this is great there are a few dangers of RNNs: primarily that improper setup and training can lead to vanishing/exploding gradients and their sequential nature can consume a lot of resources.  The LSTM structure was built to fight the vanishing gradient problem. When coupled with grdient clipping it can help fight the vanishing/exploding gradient problem.

To help display how the LSTMs works, we'll use the first sentence from the [linear algebra wiki page](https://en.wikipedia.org/wiki/Linear_algebra) and [lu decomposition wiki page](https://en.wikipedia.org/wiki/LU_decomposition) as the topic is fitting and it shows us some non-standard patterns.  We'll take on a more simple task to many other notebooks in this repository where we have a few sentences of text and want to predict some other text, in this case the next token. 

## Text Prep/Tokenization

We'll start with a common preprocessing step of tokenizing the data.  This converts the string text into an array of numbers that can be used during the training loop.  I've built a very subtle byte-pair encoding that has each unique character that appears and the top 5 merges. This keeps our vocab size small and manageable for this example. Typically the vocab size is in the 100K+ range. A great library for this is `tiktoken`. Tokenization simply finds the longest pattern of characters that's in common with what was trained and replaces it with an integer that represents it.  This way we turn the text into a numeric array to simplify computing. import torch
from collections import Counter

In [1]:
import torch
from collections import Counter

In [2]:
class SimpleBPETokenizer:
    def __init__(self, num_merges=5, eot_token='<|endoftext|>'):
        self.num_merges = num_merges
        self.eot_token = eot_token
        self.eot_id = None
        self.merges = []
        self.pair_ranks = {}
        self.vocab = {}
        self.id_to_token = {}

    def _add_token(self, tok):
        if tok in self.vocab:
            return self.vocab[tok]
        i = len(self.vocab)
        self.vocab[tok] = i
        self.id_to_token[i] = tok
        return i

    def _get_bigrams(self, seq):
        for i in range(len(seq) - 1):
            yield (seq[i], seq[i + 1])

    def _merge_once(self, seq, pair):
        a, b = pair
        out = []
        i = 0
        while i < len(seq):
            if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        return out

    def train(self, corpus):
        # corpus: list[str]
        text = ''.join(corpus).lower()
        seq = list(text)
        merges = []
        for _ in range(self.num_merges):
            counts = Counter(self._get_bigrams(seq))
            if not counts: break
            best_pair, _ = counts.most_common(1)[0]
            merges.append(best_pair)
            seq = self._merge_once(seq, best_pair)
        self.merges = merges
        self.pair_ranks = {p: i for i, p in enumerate(self.merges)}

        self.vocab = {}
        self.id_to_token = {}
        for ch in sorted(set(text)):
            self._add_token(ch)
        for a, b in self.merges:
            self._add_token(a + b)
        self.eot_id = self._add_token(self.eot_token)

    def encode(self, text, force_last_eot=True):
        # treat literal eot marker as special; remove it from content
        if self.eot_token in text:
            text = text.replace(self.eot_token, '')
        seq = list(text)

        # make sure all seen base chars exist
        for ch in set(seq):
            if ch not in self.vocab:
                self._add_token(ch)

        # greedy BPE using learned pair ranks
        if self.merges:
            while True:
                best_pair, best_rank = None, None
                for p in self._get_bigrams(seq):
                    r = self.pair_ranks.get(p)
                    if r is not None and (best_rank is None or r < best_rank):
                        best_pair, best_rank = p, r
                if best_pair is None:
                    break
                seq = self._merge_once(seq, best_pair)

        # ensure all tokens in seq exist in vocab (e.g., if new chars appeared)
        for tok in seq:
            if tok not in self.vocab:
                self._add_token(tok)

        ids = [self.vocab[tok] for tok in seq]

        # FORCE: append EOT id if not already last
        if force_last_eot:
            if not ids or ids[-1] != self.eot_id:
                ids.append(self.eot_id)

        return ids

    def decode(self, ids):
        # drop trailing EOT if present
        if ids and self.eot_id is not None and ids[-1] == self.eot_id:
            ids = ids[:-1]
        toks = [self.id_to_token[i] for i in ids]
        return ''.join(toks)


In [3]:
raw_example_1 = r'''Linear algebra is central to almost all areas of mathematics. For instance, linear algebra is fundamental in modern presentations of geometry, including for defining basic objects such as lines, planes and rotations. Also, functional analysis, a branch of mathematical analysis, may be viewed as the application of linear algebra to function spaces.'''
raw_example_2 = r'''In numerical analysis and linear algebra, lower–upper (LU) decomposition or factorization factors a matrix as the product of a lower triangular matrix and an upper triangular matrix (see matrix multiplication and matrix decomposition).'''


In [4]:
tok = SimpleBPETokenizer(num_merges=5)
tok.train([raw_example_1,raw_example_2])
tok.merges

[(' ', 'a'), ('a', 't'), ('i', 'n'), (' ', 'm'), ('i', 'o')]

In [5]:
tok.vocab

{' ': 0,
 '(': 1,
 ')': 2,
 ',': 3,
 '.': 4,
 'a': 5,
 'b': 6,
 'c': 7,
 'd': 8,
 'e': 9,
 'f': 10,
 'g': 11,
 'h': 12,
 'i': 13,
 'j': 14,
 'l': 15,
 'm': 16,
 'n': 17,
 'o': 18,
 'p': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28,
 '–': 29,
 ' a': 30,
 'at': 31,
 'in': 32,
 ' m': 33,
 'io': 34,
 '<|endoftext|>': 35}

In [6]:
vocab_size = len(tok.vocab)
vocab_size

36

In [7]:
eot = tok.eot_id
tokens = []
for example in [raw_example_1, raw_example_2]:
    tokens.extend([eot])
    tokens.extend(tok.encode(example.lower()))
all_tokens = torch.tensor(tokens, dtype=torch.long)
all_tokens

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0,  7,
         9, 17, 22, 20,  5, 15,  0, 22, 18, 30, 15, 16, 18, 21, 22, 30, 15, 15,
        30, 20,  9,  5, 21,  0, 18, 10, 33, 31, 12,  9, 16, 31, 13,  7, 21,  4,
         0, 10, 18, 20,  0, 32, 21, 22,  5, 17,  7,  9,  3,  0, 15, 32,  9,  5,
        20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0, 10, 23, 17,  8,  5, 16,
         9, 17, 22,  5, 15,  0, 32, 33, 18,  8,  9, 20, 17,  0, 19, 20,  9, 21,
         9, 17, 22, 31, 34, 17, 21,  0, 18, 10,  0, 11,  9, 18, 16,  9, 22, 20,
        27,  3,  0, 32,  7, 15, 23,  8, 32, 11,  0, 10, 18, 20,  0,  8,  9, 10,
        32, 32, 11,  0,  6,  5, 21, 13,  7,  0, 18,  6, 14,  9,  7, 22, 21,  0,
        21, 23,  7, 12, 30, 21,  0, 15, 32,  9, 21,  3,  0, 19, 15,  5, 17,  9,
        21, 30, 17,  8,  0, 20, 18, 22, 31, 34, 17, 21,  4, 30, 15, 21, 18,  3,
         0, 10, 23, 17,  7, 22, 34, 17,  5, 15, 30, 17,  5, 15, 27, 21, 13, 21,
         3, 30,  0,  6, 20,  5, 17,  7, 

# Modeling

A machine learning model forward pass now uses the tokenization information, runs several layers of linear algebra on it, and then "predicts" the next token. When it is noisy (like you will see in this example), this process results in gibberish.  The training process changes the noise to pattern during the "backward pass" as you'll see.    We'll show 3 steps that are focused on training:
1. **Data Loading** `x, y = train_loader.next_batch()` - this step pulls from the raw data enough tokens to complete a forward and backward pass.  If the model is inference only, this step is replaced with taking in the inference input and preparing it similarly as the forward pass.
2. **Forward Pass** `logits, loss = model(x, y)` - using the data and the model architecture to predict the next token. When training we also compare against the expected to get loss, but in infrerence, we use the logits to complete the inference task.
3. **Back Propagation, aka Backward Pass & Training** `loss.backward(); optimizer.step()` - using differentials to understand what parameters most impact the forward pass' impact on its prediction, comparing that against what is actually right based on the data loading step, and then making very minor adjustments to the impactful parameters with the hope it improves future predictions.

The we'll show a final **Forward Pass** with the updated weights we did in #3. 

## Data Loading

To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to hold all at once in real practice, we would read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 
To start, we have to identify the batch size and the model context length to determine how much data we need.  Consequently, these dimensions also form 2 of the 3 dimensions in the initial matrix.
- **Batch Size (B)** - This is the number of examples you'll train on in a single pass. 
- **Context Length (T)** - This is the max number of tokens that a model can use in a single pass to generate the next token. If an example is below this length, it can be padded.
  
*Ideally both B and T are multiples of 2 to work nicely with chip architecture. This is a common theme across the board*

In [8]:
B = 2 # Batch
T = 8 # context length

To start, we need to pull from our long raw_token list enough tokens for the forward pass. To be able to satisfy training `B` Batches `T` Context length, we need to pull out `B*T` tokens to slide the context window across the examples enough to satisfy the batch size.  Since the training will attempt to predict the last token given the previous tokens in context, we also need 1 more token at the end so that the last training example in the last batch can have the next token to validate against. 

In [9]:
current_position = 0
tok_for_training = all_tokens[current_position:current_position + B*T +1 ]
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0])

Now that we have our initial tokens to train on, we now need to convert it to a matrix that's ready for training. In this step we'll need to create our batches and setup two different arrays: 1/ the input, `x`, tokens that will result in 2/ the output `y` tokens. To create each example in the batch, every `T` tokens will be placed into its own row. 

Recall that training takes in a string of tokens the length of the context and then predicts the next token. Recall that when we extracted `tok_for_training` we added 1 extra token so that we can evaluate the prediction for the last example. Because of this, the input, `x`, will be all of the tokens up to the second to last element `[:-1]`.  

It might be natural to think the output `y` would then just be the last token.But this is actually wasting valuable training loops.  Yes, there is the example that fills the context `T`, but we also have enough tokens in `tok_for_training` where any context length of `n` where `n<T` can also be used for inference since we have the `n+1` token available.  You can think of the following example:

sentence: `Hi I am learning`. This sentence contains the following "next tokens" that can be learned:
1. x: Hi I am  | y: learning
2. x: Hi I     | y: am
3. x: Hi       | y: I

Because we have this triangle to create, our `y` can be much larger.  We can start with the second token and, go all the way to the last element we added for the last example `[1:'`.   


We will now put this together and do the following:
1. Extract the input `x` and then split it into an example for each batch `B`
2. Extract the output `y` and then split it into an example for each batch `B`

*Note: View can take `-1` which allows the matrix to infer the dimension so we do not need to pass in `T`, but given how many matrices we'll work with we want to make sure we're controlling the dimensions or erroring out if they do not match our expectations.*

In [10]:
x=tok_for_training[:-1].view(B, T)
x

tensor([[35, 15, 32,  9,  5, 20, 30, 15],
        [11,  9,  6, 20,  5,  0, 13, 21]])

In [11]:
y=tok_for_training[1:].view(B, T)
y

tensor([[15, 32,  9,  5, 20, 30, 15, 11],
        [ 9,  6, 20,  5,  0, 13, 21,  0]])

## Forward pass

<img src="explainer_screenshots/lstm/full_network.png" width="200">

During training, the forward pass takes a string of tokens in and predicts the likelihood that each output token is the next "n" tokens.  This step as we'll look at it is focused on training where we'll pass in the input `x`, carry that input through the layers, and generate a matrix of the probability of each token being the next one, something we call `logits`. During the forward pass, since this is an RNN, we will actually pass each example through recursively based on the token length of available. For example, if we have 3 tokens in the example and we are deriving the 4th, we would hit our LSTM block 3 times through:
1. First with the embedding for the token `[0]` and no previous token
2. Then with the embedding for the token `[1]` and the output of the LSTM unit's processing of token `[0]` as the previous $H_{T-1}$
3. Finally with the embedding for the token `[2]` and the output of the LSTM unit's processing of token `[1]` as the previous $H_{T-1}$

At the end of the forward pass at the end we then compare the probability to the actual next token in `y` and calculate `loss` based on the difference. 

*Note that we will do some layer initialization to simplify following along.  In reality layers are often initialized to normal distribution with some adjustments made for parameter sizes to keep the weights properly noisy.  We will not cover initialization in this series*

We first rederive the batch size and context size based on the input to improve flexibility.

In [12]:
import torch.nn as nn

In [13]:
B_batch, T_context = x.size()
B_batch, T_context

(2, 8)

In [14]:
n_embd = 4 # level of embedding of input tokens
n_embd, vocab_size

(4, 36)

### Input Layer

<img src="explainer_screenshots/lstm/input_layer.png" width="200">

We'll first create an initial **embedding layer** for our input tokens.  Since the RNN iteratively processes input data, position embedding info is not needed, though, it can always be added to give the model more depth. The embedding weights are `vocab_size X n_embd` and simply store weights that correspond with each token.  The more embedding layers added the more complex data the model can learn. After the embedding layer we'll then perform dropout. The **dropout layer** simply applies 0's  randomly up to the defined level percentage and normalizes the remaining probabilities in the row. Since RNNs recycle the same parameters across timesteps, features can co-adapt and overfit quickly. Variational dropout breaks these temporal co-adaptations and adds needed noise where there’s little implicit regularization. Other model architectures like transformers have strong built-in stabilizers (LayerNorm, residuals, attention) and massive pretraining, so they generally need far less dropout. 



**Embedding** 

To start we'll initialize our embeddings with a weight of 1.000 so that all inputs are equally weighted. We'll also set our embedding dimension to 4 to allow for some levels of complexity. Since our initiation weights are equal, we expect that the output of embedding will for now be equal

In [15]:
wte = nn.Embedding(vocab_size, n_embd)
torch.nn.init.ones_(wte.weight)
wte.weight

Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)

In [16]:
x = wte(x)
x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]], grad_fn=<EmbeddingBackward0>))

### Dropout

Now, before our Recurrent unit, we'll perform dropout. Dropout will randomly zero out any value effectively removing that specific node from impacting prediction. Since this is Bernoulli based dropout, in addition to zeroing out weights the surviving entries are scaled by $1/(1-p)$. During training this helps with generalization and fights fixation. You can quickly see the dropout's impact on the embeddings. 

In [17]:
dropout = nn.Dropout(0.1)
dropout

Dropout(p=0.1, inplace=False)

In [18]:
x = dropout(x)
x

tensor([[[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 0.0000, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 0.0000, 0.0000],
         [1.1111, 1.1111, 1.1111, 1.1111]],

        [[1.1111, 0.0000, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [0.0000, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]]], grad_fn=<MulBackward0>)

### Recurrent Unit - LSTM

<img src="explainer_screenshots/lstm/lstm_details.png" width="600">

The recurrent unit we'll explore in this notebook is the Long-term Short-term memory unit, or LSTM.  Like other recurrent units, this acts iteratively on specific time points of X labeled $X_t$.  If there are more time points, for us meaning more tokens in the string, the unit then takes in that weight, represented as $H_{prev}$, and the next time point, represented again as $X_t$ and reruns the unit, until all time points are exhausted. 

Our recurrent unit in this case creates 2 outputs: the *memory cell* $C_t$ and the *hidden state* $H_t$.  During recursion, both are input in parallel with our input context $X_t$ and maintained with the memory cell acting as our long term memory but once recursion is complete only the hidden state moves forward.  

To manage the impact of the input context $X_t$, memory cell $C_t$ and hidden state $H_t$, LSTMs use 3 gates: input gate $I_t$, forget gate $F_t$, and output gate $O_t$. These gates created by simply applying a linearization and summation to the input context and incoming hidden state,  are wrapped in sigmoid function pushing their values to between $\left(0,1\right)$ act as learnable proportions that control how much of each of the input, memory, and hidden state are maintained in a recursion.  By having these 3 gates combined with the use of the Hadamard product to force gradients down to an element-wise level, during training RNNs can learn the impacts of a token and pattern whether its in the context window or if it's been previously seen. You can think of the gates as being a learnable Switch that says "when making a prediction of Y, how much of X should I use and how much of the history should I use".  With LSTMs, as the name implies, we have 2 types of history, the long term 

Mentally, the **Input Gate** $I_t$ determines how much a given input, we'll call the input state $\tilde{C}$, should be added to the memory cell $C_t$, and, by extension, the hidden state $H_t$.  The input state for us will be a combination of the input context $X_t$ and incoming hidden state $H_t$  as the point of RNNs is to use both the incoming context and the longer term context to make a prediction for the output.  We do this by doing a linearization and summation to the input context and incoming hidden state and then normalizing using $tanh$ to squeeze the values to $\left(-1,1\right)$. 

The remaining part of the memory cell  uses the **Forget Gate** $F_t$.  The forget gate controls how much of the current longer term memory is retained or forgotten during the memory cell update.  The forget gate is multiplied, using the Hadamard product, by the incoming memory cell $C_t$. This product is then summed with the Hadamard product of the input gate and input state to result in our updated memory cell $C_{t+1}$

The **Output Gate** $O_t$ then determines how much of the memory cell to use to  influence the current time steps updated hidden step $H_{t+1}$.  To get the hidden step update, we take the Hadamard product of the output gate with a $tanh$ normalization of the memory state which pulls the values to $\left(-1,1\right)$.

All together this results in the following:

$$
\begin{align}
I_t &= \sigma\!\left(X_t W_{xi} + H_{t} W_{hi} + b_i\right)\ \ \ \ \text{Input Gate},\\
F_t &= \sigma\!\left(X_t W_{xf} + H_{t} W_{hf} + b_f\right)\ \ \ \ \text{Forget Gate},\\
O_t &= \sigma\!\left(X_t W_{xo} + H_{t} W_{ho} + b_z\right)\ \ \ \ \text{Output Gate},\\
\\
\tilde{C} &= tanh\!\left(X_t W_{xc} + H_{t} W_{hc} + b_c\right),\\
C_{t+1} &= F_t\odot C_t + I_t\odot \tilde{C},\\
H_{t+1} &= O_t\odot tanh\left(C_{t+1}\right).\\
\end{align}
$$

Mathematically this gives us a nice simple set of formulas, but we need to actually write this as a program. If we look closely we have 4 sets of weights acting, 3 for the gates, and 1 for the input state:
* the input $X_t$: $\{W_{xi}, W_{xf}, W_{xo}, W_{xc}\}$. *Notice the subscript is "x\<gate or state>"*
* the previous $H_{prev}$: $\{W_{hi}, W_{hf}, W_{ho}, W_{hc}\}$.

To simplify this weight we'll start by creating a linear later that has 4x the hidden layers, separate it, and then complete the rest of the calculations. This allows us to gain efficiencies in learning and improve our code readability.  The algorithmic representation then looks as follows:

```
xi, xf, xo, xc = x2g(x_t).split(hidden_size, dim=-1)
hi, hf, ho, hc = h2g(h_prev).split(hidden_size, dim=-1)

#input, forget, and output gate 
input_gate =  torch.sigmoid(xi + hi)
forget_gate =  torch.sigmoid(xf + hf)
output_gate =  torch.sigmoid(xo + ho)

input_state = torch.tanh(xc + hc)

c_next = (forget_gate*c_t) + (input_gate*input_state)
h_next = output_gate * torch.tanh(c_next)
```
**Starting Example** While LSTMs are iterative, we'll first walk through 2 pass of the LSTM. After that we'll show the loop for the rest of the context window/examples. This will help us understand what is happening inside a pass and then how the impact of the weights adds up over the context length. To keep things simple we'll set the `hidden_size` to 3. Let's start by initiating our weights.  To visualize the impact of each set of weights, I'll initialize the current context and the previous context weights to two different values.

In [19]:
hidden_size = 3

In [20]:
x2g = nn.Linear(n_embd, 4 * hidden_size, bias=True)
torch.nn.init.constant_(x2g.weight, 0.500)
torch.nn.init.zeros_(x2g.bias)
x2g.weight.size(), x2g.weight, x2g.bias.size(), x2g.bias

(torch.Size([12, 4]),
 Parameter containing:
 tensor([[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]], requires_grad=True),
 torch.Size([12]),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True))

In [21]:
h2g = nn.Linear(hidden_size, 4 * hidden_size, bias=True)
torch.nn.init.constant_(h2g.weight, 0.250)
torch.nn.init.zeros_(h2g.bias)
h2g.weight.size(), h2g.weight, h2g.bias.size(), h2g.bias

(torch.Size([12, 3]),
 Parameter containing:
 tensor([[0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500]], requires_grad=True),
 torch.Size([12]),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True))

#### LSTM - Frist Pass

##### LSTM - Initialize the hidden state and memory cell $
Let's start first by setting up our starting hidden state and memory cell weights.  Since we are at 'time 0' in training, there are no previous states or memory yet. Because of this we'll actually set them to 0's so that there is no impact from a previous state and all of the inference prediction comes from `x_t`. 

In [22]:
h_next = torch.zeros(B_batch, hidden_size) 
c_next = torch.zeros(B_batch, hidden_size) 

##### LSTM - Weight calculation and incoming linearization

Now we'll calculate the impact of the weights by our input `x_t` and hidden layer `h_t` 
We'll then pass these variables into the linear layer to perform $Ax+b$ and split the result it up into our 4 different linear units. Notice that because our hidden state is currently 0, the weights returned are all 0. In future recursions, these weights will no longer be 0.  

Currently our incoming `x` has all 8 examples for our context windows across the 2 batches. With LSTM's we can only feed in 1 example at a time and, since we want sequential weights, 1 token at a time, though we can do multiple batches at a time.  Luckily our examples are incremental (example 1 is token 1, example 2 is tokens 1 and 2) so we can iteratively calculate our LSTM weights per example and use that collectively as the weight for each example.  But to start we'll need to extract just the example for the first time  period `t=0` into `x_t` for each batch and calculate our weights. Since our weights are 1/2 for each value of the cell, you can see that when we multiply that by the example each weight remains consistent. If dropout is observed, the values in each batch can be different, but they'll be the same across the example. 

Finally, we'll set our memory cell `c_t` to equal the `c_next` we initialized. As a reminder, we've set it to 0

In [23]:
t = 0
x_t = x[:, t, :]
x_t.size(), x_t

(torch.Size([2, 4]),
 tensor([[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 0.0000, 1.1111, 1.1111]], grad_fn=<SliceBackward0>))

In [24]:
h_t = h_next
h_t.size(), h_t

(torch.Size([2, 3]),
 tensor([[0., 0., 0.],
         [0., 0., 0.]]))

In [25]:
c_t = c_next
c_t.size(), h_t

(torch.Size([2, 3]),
 tensor([[0., 0., 0.],
         [0., 0., 0.]]))

**Linear weight calculation** Now we'll do our calcuation of all 4 weights for both the input and hidden state. Note that the memory cell does not receive a weight as it's impacted more by the input and forget gates

In [26]:
xi, xf, xo, xc = x2g(x_t).split(hidden_size, dim=-1)
'xi',xi,xi.size(),'xf',xf,xf.size(),'xo',xo,xo.size(),'xc',xc,xc.size()


('xi',
 tensor([[2.2222, 2.2222, 2.2222],
         [1.6667, 1.6667, 1.6667]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'xf',
 tensor([[2.2222, 2.2222, 2.2222],
         [1.6667, 1.6667, 1.6667]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'xo',
 tensor([[2.2222, 2.2222, 2.2222],
         [1.6667, 1.6667, 1.6667]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'xc',
 tensor([[2.2222, 2.2222, 2.2222],
         [1.6667, 1.6667, 1.6667]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]))

In [27]:
hi, hf, ho, hc = h2g(h_t).split(hidden_size, dim=-1)
'hi',hi,hi.size(),'hf',hf,hf.size(),'ho',ho,ho.size(),'hc',hc,hc.size()


('hi',
 tensor([[0., 0., 0.],
         [0., 0., 0.]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'hf',
 tensor([[0., 0., 0.],
         [0., 0., 0.]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'ho',
 tensor([[0., 0., 0.],
         [0., 0., 0.]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'hc',
 tensor([[0., 0., 0.],
         [0., 0., 0.]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]))

##### LSTM - Input $I_{t}$, Forget $F_{t}$, and Output $O_{t}$ Gates
Recall that the gates in an LSTM controls how much the input state and memory cell persist for the next memory cell and update state. We derive the gates using:
$$
\begin{align}
I_t = \sigma\!\left(X_t W_{xi} + H_{t} W_{hi} + b_i\right) \\
F_t = \sigma\!\left(X_t W_{xf} + H_{t} W_{hf} + b_f\right) \\
O_t = \sigma\!\left(X_t W_{xo} + H_{t} W_{ho} + b_z\right) \\
\end{align}
$$

These gate path will become the learnable weights that act as proportions since we use the sigmoid function to pull the values to between $\left(0,1\right)$ [source](https://docs.pytorch.org/docs/stable/generated/torch.sigmoid.html). This is achieved with the following function

$\sigma(x) = \frac{1}{1 + e^{(-x)}}$

You'll notice that all our gates have identical values currently.  This is because we used the consistent initiation on the weights. If there's variance in row values at this point it's due to dropout. 

In [28]:
input_gate =  torch.sigmoid(xi + hi)
input_gate.size(), input_gate

(torch.Size([2, 3]),
 tensor([[0.9022, 0.9022, 0.9022],
         [0.8411, 0.8411, 0.8411]], grad_fn=<SigmoidBackward0>))

In [29]:
forget_gate =  torch.sigmoid(xf + hf)
forget_gate.size(), forget_gate

(torch.Size([2, 3]),
 tensor([[0.9022, 0.9022, 0.9022],
         [0.8411, 0.8411, 0.8411]], grad_fn=<SigmoidBackward0>))

In [30]:
output_gate =  torch.sigmoid(xo + ho)
output_gate.size(), output_gate

(torch.Size([2, 3]),
 tensor([[0.9022, 0.9022, 0.9022],
         [0.8411, 0.8411, 0.8411]], grad_fn=<SigmoidBackward0>))

##### LSTM - Input State $\tilde{C}$

Next we calculate the input state.  This input state represents the impact of the timepoint's specific context with the impact of the incoming hidden state context. It's calculated as follows:

$\tilde{C} = tanh\!\left(X_t W_{xc} + H_{t} W_{hc} + b_c\right)$

Note that while this looks similar to our gates, we are using a tanh normalization instead of the sigmoid.  Tanh pulls the values to between $\left(-1,1\right)$ [source](https://docs.pytorch.org/docs/stable/generated/torch.tanh.html). This is achieved with the following function:

$\tanh(x) = \frac{e^x-e^{-x}}{e^x+e^{-x}}$

You can see that since the network is currently untrained since the values are the same across the board.  In fact, if we looked at the values currently without the tanh normalization, it would be identical to the gate raw values. With training these values will diverge. 

In [31]:
input_state = torch.tanh(xc + hc)
input_state.size(), input_state

(torch.Size([2, 3]),
 tensor([[0.9768, 0.9768, 0.9768],
         [0.9311, 0.9311, 0.9311]], grad_fn=<TanhBackward0>))

##### LSTM - Memory Cell $C_t$
Now we have all the values that are needed to calculate our updated memory cell. Recall that our memory cell acts as the long term memory and is updated by taking the product from forget gate the input memory cell and adding it with the input state. The calculation is as follows:

$C_{next} = F_t\odot C_t + I_t\odot \tilde{C}$

Recall that currently our memory cell is `0` since we are on the first iteration.  This means that despite having preset non-zero weights on our forget gate, we still won't have any value to remember, meaning that our memory cell becomes equal to our input state. 

In [32]:
i_in = input_gate*input_state
i_in.size(), i_in

(torch.Size([2, 3]),
 tensor([[0.8813, 0.8813, 0.8813],
         [0.7832, 0.7832, 0.7832]], grad_fn=<MulBackward0>))

In [33]:
f_c = forget_gate*c_t
f_c.size(), f_c

(torch.Size([2, 3]),
 tensor([[0., 0., 0.],
         [0., 0., 0.]], grad_fn=<MulBackward0>))

In [34]:
c_next = f_c + i_in	
c_next.size(), c_next

(torch.Size([2, 3]),
 tensor([[0.8813, 0.8813, 0.8813],
         [0.7832, 0.7832, 0.7832]], grad_fn=<AddBackward0>))

##### LSTM - Hidden State $H_t$ / Recursion output
We now have the updated memory cell that's a combination of the incoming memory cell with the incoming state, itself a combination of the incoming input context and hidden state. We've also already calculated the output gate which controlls how much of the memory cell impacts the next hidden state and recursion output. To generate the hidden state we do the following: 

$H_{next} = O_t\odot tanh\left(C_{next}\right)$


We'll first calcualte the tanh of the hidden state. Recall that tanh will squeeze the memory cell values to between $\left(-1,1\right)$ with the following function:

$\tanh(x) = \frac{e^x-e^{-x}}{e^x+e^{-x}}$

We'll then take the product of the output gate with that result to generate our recursion output. In this case, since this is not the final recursion this output becomes the next recursion's incoming hidden state. You'll notice again very consistent behavior since our weights are initialized consistently across all gates. 

In [35]:
tan_c_n = torch.tanh(c_next)
tan_c_n.size(), tan_c_n

(torch.Size([2, 3]),
 tensor([[0.7071, 0.7071, 0.7071],
         [0.6545, 0.6545, 0.6545]], grad_fn=<TanhBackward0>))

In [36]:
h_next = output_gate*tan_c_n
h_next.size(), h_next

(torch.Size([2, 3]),
 tensor([[0.6379, 0.6379, 0.6379],
         [0.5505, 0.5505, 0.5505]], grad_fn=<MulBackward0>))

#### LSTM -  Managing examples in a batch
Now that we have the weights from our recurrent unit for our first time period `t=0`, since we have more context, we actually have to recur through the rest of the context.  Before we do the next cycle recall that for training we want to take advantage of the fact that each of our batches has an example for each value up to our context length.  We want to train across inputs of various lengths and so, we'll store our current weights that signify an input of just a single token into our context.  As you'll see in our recursion, we'll store the weight at the end of each pass and build up an array of tensors for each example in the batch. to take to our output layer. 

To start we'll create an empty tensor of `(B x T x hidden_size)`. We'll then update the `[:,t,:]` entry with `h_next` linking back to each recursion. 

In [37]:
hs = x.new_empty(B_batch, T_context, hidden_size)
hs.size(), hs

(torch.Size([2, 8, 3]),
 tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]))

now we'll enter in our first entry at `t=0`

In [38]:
hs[:, 0, :] = h_next   
hs

tensor([[[0.6379, 0.6379, 0.6379],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.5505, 0.5505, 0.5505],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]], grad_fn=<CopySlices>)

#### LSTM - Second pass 

Now that we've done the first pass through, let's see how the second pass changes things. The main difference is that in the second pass we'll use the previous pass' outputs,`c_next` & `h_next`, to become the new input memory cell `c_t` and hidden state `h_t`.  This will quickly begin to show the impact now on the gates and calculations as the values are no longer 0 while our X weights are still consistent. We'll start by updating initializing our `h_t` and `w_t` and extracting the next set of inputs for `t=1`.

One thing to notice is that despite having a different time period for input `h_t`, the values remain consistent across examples since our initiation weights are uniform.  

In [39]:
t = 1
x_t = x[:, t, :]
x_t.size(), x_t

(torch.Size([2, 4]),
 tensor([[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]], grad_fn=<SliceBackward0>))

In [40]:
h_t = h_next
h_t.size(), h_t

(torch.Size([2, 3]),
 tensor([[0.6379, 0.6379, 0.6379],
         [0.5505, 0.5505, 0.5505]], grad_fn=<MulBackward0>))

In [41]:
c_t = c_next
c_t.size(), h_t

(torch.Size([2, 3]),
 tensor([[0.6379, 0.6379, 0.6379],
         [0.5505, 0.5505, 0.5505]], grad_fn=<MulBackward0>))

**LSTM second pass - recalculating weights**

Notice that with the updated non-zero hidden state we now see contributions from the dot product of the hidden state weights and hidden state. 

In [42]:
xi, xf, xo, xc = x2g(x_t).split(hidden_size, dim=-1)
'xi',xi,xi.size(),'xf',xf,xf.size(),'xo',xo,xo.size(),'xc',xc,xc.size()

('xi',
 tensor([[2.2222, 2.2222, 2.2222],
         [2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'xf',
 tensor([[2.2222, 2.2222, 2.2222],
         [2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'xo',
 tensor([[2.2222, 2.2222, 2.2222],
         [2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'xc',
 tensor([[2.2222, 2.2222, 2.2222],
         [2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]))

In [43]:
hi, hf, ho, hc = h2g(h_t).split(hidden_size, dim=-1)
'hi',hi,hi.size(),'hf',hf,hf.size(),'ho',ho,ho.size(),'hc',hc,hc.size()


('hi',
 tensor([[0.4784, 0.4784, 0.4784],
         [0.4129, 0.4129, 0.4129]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'hf',
 tensor([[0.4784, 0.4784, 0.4784],
         [0.4129, 0.4129, 0.4129]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'ho',
 tensor([[0.4784, 0.4784, 0.4784],
         [0.4129, 0.4129, 0.4129]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]),
 'hc',
 tensor([[0.4784, 0.4784, 0.4784],
         [0.4129, 0.4129, 0.4129]], grad_fn=<SplitBackward0>),
 torch.Size([2, 3]))

**LSTM second pass - Updated Input $I_{t}$, Forget $F_{t}$, and Output $O_{t}$ Gates**

Once again you'll notice the gates are all the same for an example (though may differ across the two batches). This is since we initilized with a consistent weight and we have not yet made changes to improve our prediction. Even with RNNs the changes happen during back-propogation. 

In [44]:
input_gate =  torch.sigmoid(xi + hi)
input_gate.size(), input_gate

(torch.Size([2, 3]),
 tensor([[0.9371, 0.9371, 0.9371],
         [0.9331, 0.9331, 0.9331]], grad_fn=<SigmoidBackward0>))

In [45]:
forget_gate =  torch.sigmoid(xf + hf)
forget_gate.size(), forget_gate

(torch.Size([2, 3]),
 tensor([[0.9371, 0.9371, 0.9371],
         [0.9331, 0.9331, 0.9331]], grad_fn=<SigmoidBackward0>))

In [46]:
output_gate =  torch.sigmoid(xo + ho)
output_gate.size(), output_gate

(torch.Size([2, 3]),
 tensor([[0.9371, 0.9371, 0.9371],
         [0.9331, 0.9331, 0.9331]], grad_fn=<SigmoidBackward0>))

**LSTM second pass - Input State $\tilde{C}$**

We'll continue the trend of consistent values for all the same reasons. 

In [47]:
input_state = torch.tanh(xc + hc)
input_state.size(), input_state

(torch.Size([2, 3]),
 tensor([[0.9910, 0.9910, 0.9910],
         [0.9898, 0.9898, 0.9898]], grad_fn=<TanhBackward0>))

**LSTM second pass - Memory Cell $C_t$**

This time through, we can see that the forget gate comes into play. Since our incoming memory cell is non-zero, we can now see that the updated memory cell is in part the input state and in part the previous memory cell, as determined by the forget gate.  With training this mix will be learned down to the entry level. 

In [48]:
i_in = input_gate*input_state
i_in.size(), i_in

(torch.Size([2, 3]),
 tensor([[0.9287, 0.9287, 0.9287],
         [0.9235, 0.9235, 0.9235]], grad_fn=<MulBackward0>))

In [49]:
f_c = forget_gate*c_t
f_c.size(), f_c

(torch.Size([2, 3]),
 tensor([[0.8258, 0.8258, 0.8258],
         [0.7308, 0.7308, 0.7308]], grad_fn=<MulBackward0>))

In [50]:
c_next = f_c + i_in	
c_next.size(), c_next

(torch.Size([2, 3]),
 tensor([[1.7545, 1.7545, 1.7545],
         [1.6543, 1.6543, 1.6543]], grad_fn=<AddBackward0>))

**LSTM second pass - Hidden State $H_t$ / Recursion output**

In [51]:
tan_c_n = torch.tanh(c_next)
tan_c_n.size(), tan_c_n

(torch.Size([2, 3]),
 tensor([[0.9419, 0.9419, 0.9419],
         [0.9294, 0.9294, 0.9294]], grad_fn=<TanhBackward0>))

In [52]:
h_next = output_gate*tan_c_n
h_next.size(), h_next

(torch.Size([2, 3]),
 tensor([[0.8826, 0.8826, 0.8826],
         [0.8673, 0.8673, 0.8673]], grad_fn=<MulBackward0>))

In [53]:
hs[:, 1, :] = h_next   
hs

tensor([[[0.6379, 0.6379, 0.6379],
         [0.8826, 0.8826, 0.8826],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.5505, 0.5505, 0.5505],
         [0.8673, 0.8673, 0.8673],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]], grad_fn=<CopySlices>)

#### LSTM - Recursion 
Now that we've seen how 2 passes through the LSTM look like, let's write a loop to do the remaining passes.  The loop in this case will start at `t=2` and run the rest of the context length.  We'll capture the predicted weights into `hs` so that we'll have each example's weights and the link back for the backward pass during back propagation. 

In [54]:
for t in range(2,T_context):
    
    # set input and previous
    h_t = h_next
    c_t = c_next
    x_t = x[:, t, :]

    # calculate weights
    xi, xf, xo, xc = x2g(x_t).split(hidden_size, dim=-1)
    hi, hf, ho, hc = h2g(h_t).split(hidden_size, dim=-1)
    
    #input, forget, and output gate 
    input_gate =  torch.sigmoid(xi + hi)
    forget_gate =  torch.sigmoid(xf + hf)
    output_gate =  torch.sigmoid(xo + ho)
    
    # input state
    input_state = torch.tanh(xc + hc)
    
    # memory cell
    c_next = (forget_gate*c_t) + (input_gate*input_state)	
    
    # hidden state / recursion output
    h_next = output_gate*torch.tanh(c_next)

    # save h_t
    hs[:, t, :] = h_next

hs

tensor([[[0.6379, 0.6379, 0.6379],
         [0.8826, 0.8826, 0.8826],
         [0.8988, 0.8988, 0.8988],
         [0.9451, 0.9451, 0.9451],
         [0.9488, 0.9488, 0.9488],
         [0.9494, 0.9494, 0.9494],
         [0.8609, 0.8609, 0.8609],
         [0.9462, 0.9462, 0.9462]],

        [[0.5505, 0.5505, 0.5505],
         [0.8673, 0.8673, 0.8673],
         [0.9340, 0.9340, 0.9340],
         [0.9465, 0.9465, 0.9465],
         [0.9489, 0.9489, 0.9489],
         [0.9150, 0.9150, 0.9150],
         [0.9482, 0.9482, 0.9482],
         [0.9495, 0.9495, 0.9495]]], grad_fn=<CopySlices>)

### Final Dropout 

After completing our recurrent unit, we'll perform another round of dropout. Dropout after the LSTMs regularizes the layer output / readout without corrupting the recurrent dynamics; injecting fresh noise inside the recurrence each step tends to destabilize long-range memory.  Often this step has “locked” dropout, meaning a single mask broadcast across time, which gives the same effect as per-step dropout with a fixed mask but avoids i.i.d. noise at every timestep.  We, however, will be doing default dropout that is not locked to simplify our code. 

As a reminder, we perform Bernoulli based dropout which both removes entries and adds p to the surviving entries. 

In [55]:
x = dropout(hs)
x.size(), x

(torch.Size([2, 8, 3]),
 tensor([[[0.7088, 0.7088, 0.0000],
          [0.9807, 0.9807, 0.9807],
          [0.9987, 0.9987, 0.9987],
          [1.0501, 1.0501, 1.0501],
          [1.0542, 1.0542, 1.0542],
          [0.0000, 1.0549, 1.0549],
          [0.9565, 0.9565, 0.9565],
          [1.0513, 1.0513, 1.0513]],
 
         [[0.6117, 0.6117, 0.6117],
          [0.9636, 0.9636, 0.9636],
          [0.0000, 1.0377, 0.0000],
          [1.0517, 1.0517, 1.0517],
          [1.0543, 1.0543, 1.0543],
          [1.0167, 1.0167, 1.0167],
          [1.0536, 1.0536, 1.0536],
          [1.0550, 1.0550, 1.0550]]], grad_fn=<MulBackward0>))

### Output Layers AKA Model Head.

The LSTM unit is our recurrent unit for this example. In practice, this unit can be scaled by increasing the hidden_size dimension or increasing the complexity and flavors of the recurring unit. Once those layers are complete and we've done dropout normalization, during the forward pass we then start the output process that results in `logits` which is a representation of the probability of each token being the next token given the input.  

This layer is also known as the model **head**. This layer is called this because it is a small, task-specific module attached to a model’s shared backbone that maps hidden features to the final outputs.  In our example case, this is a linear layer mapping the backbone to vocab logits. The benefit of this structure is that you can use the shared hidden features and train different heads for different tasks without starting from scratch. An example would be a classifier head, or policy head in RL.

<img src="explainer_screenshots/lstm/output_layer.png" width="200">

We'll use a similar head that we've used in other examples that focuses on predicting the next token. For our head we want to map to a predicted token which we'll look at as `logits`. In the process to generate `logits` we take the dropout output `x` of the GRU, then project, using a linear layer, to the vocabulary resulting in a `B, T, vocab_size` matrix known as `logits`.  

In training, the `logits` are then compared with `y` to see how close the  model is to predicting the correct next token. For inference, the `logits` are then used to drive sampling which is how the next token is then derived. 

Contrary to our transformer example, we won't be doing weight tying and instead use an initial linear layer.  We'll start initially by setting all of the values equal to 1 meaning they all have equal probability.  This will highlight just how impactful back-propagation and training is.  

In [56]:
lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
torch.nn.init.ones_(lm_head.weight)
lm_head.weight.size(), lm_head.weight

(torch.Size([36, 3]),
 Parameter containing:
 tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]], requires_grad=True))

#### Output layer - LM Head aka logits
We now project `x` onto the vocabulary resulting in a `B X T X vocab_size` final array `logits`.  This output correlates with the probability of each output token given the input context.  The best way to read  this is:

(dimension 0) we have 2 batches B, 
(dimension 1) each batch has an example for each value between 1 and context length T 
(dimension 2) for each example we see the probability of each token in our vocabulary

Since our output head had consistent weights across all the dimensions, we fully expect that our logits will have equal values across all positions.  In practice this means that our model will have the same probability of a token output as the 'next token' regardless of the preceding text, meaning it's shit. Luckily back-propagation has a way of updating this so that with enough data and time the probabilities change. 

Note that if we wanted to run inference, after calculating the logits we'd run a softmax to sample a token out of the probability distribution. 

In [57]:
logits = lm_head(x)

logits.shape, logits

(torch.Size([2, 8, 36]),
 tensor([[[1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
           1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
           1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
           1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
           1.4176, 1.4176, 1.4176, 1.4176],
          [2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
           2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
           2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
           2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
           2.9420, 2.9420, 2.9420, 2.9420],
          [2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,
           2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,
           2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,
           2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9

## Loss calculation

Now we have to see how good our ~shit~ prediction is.  Since we haven't done training, and we saw that regardless of example we had the same exact logit values, we can expect it's bad. That said, we need to know how bad. For this example we'll use cross entropy, also known as the negative log likelihood of the softmax.  Our loss calculates

$$
\ell_i=-\log\big(\mathrm{softmax}(z_i)\_{y_i}\big)
= -z_{i,y_i}+\log\!\sum_{c=1}^C e^{z_{i,c}},
$$


To calculate loss we'll pass in the calculated `logits` and our next tokens stored in `y`. The cross entropy function does not respect batches so we'll flatten the `B` dimension for both `logits` and `y`

In [58]:
import torch.nn.functional as F

In [59]:
y_flat = y.view(-1)
y_flat.shape, y_flat

(torch.Size([16]),
 tensor([15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0]))

In [60]:
logits_flat = logits.view(-1, logits.size(-1))
logits_flat.shape, logits_flat

(torch.Size([16, 36]),
 tensor([[1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
          1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
          1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176,
          1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176, 1.4176],
         [2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
          2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
          2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420,
          2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420, 2.9420],
         [2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,
          2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,
          2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,
          2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962, 2.9962,

In [61]:
loss = F.cross_entropy(logits_flat, y_flat)
loss.shape, loss

(torch.Size([]), tensor(3.5835, grad_fn=<NllLossBackward0>))

## Back Propagation

We now know just how ~well~ terribly the current model, with its weights and biases, predicts the next token given the input context. We now need to know how to change the different weights and biases to improve the formula.  We could do this by guessing through making minor changes and seeing what improves, or we can think through this more critically.

If you review the chain of layers above, you can see that it's a series of formulas.  We can think of this as $f(g(x))$, except with many many more layers and complexities.  Since this is a formula, we can dig into our math toolbox and find a better way to determine what parts need to update.  Recall that in our calculus we learned that differentiation tells us the rate of change in a graph.  So if we treat the loss function $\mathcal{L}$ as $\mathcal{L}(f(g(x)))$ taking the partial differential 

$\delta=\partial \mathcal{L}/\partial h$

at each layer will give us the impact of each weight/bias on our final out (albeit the inverse since our loss function is the negative log likelihood). 

Lucky for us, each layer of our model already has a placeholder for the partial differential called the **Gradient**. We'll use this field to store it.  We'll start by first zeroing out the gradients. We do this because of the nature of handling partial differentials for multiple dependencies. Recall that in multiple places we had a formula structure of 

$a+b=c ; a+c= d$

In this case $a$ has 2 dependencies and determining the partial derivative of $\partial d / \partial a$ requires understanding both the path from $d$ and $c$.  To determine the true impact of a we would sum both partial derivatives together.  Because of this property, the tool we use, the built in `.backwards()` automatically sums gradients, `+=`, so if we do not set the gradient to `0` we then end up with erroneous gradients. 

Finally, we start `.backwards` from the `loss`, not `logits` as our goal is to minimize loss, we need to ensure we are looking at the calculations that impact loss which requires the whole forward pass to be able to generate the prediction `logits_flat`.  If we think of it as $\mathcal{L}(f(x))$ where $f(x)$ is the forward pass to generate logits, then a simple chain rule is applied:

${\partial}/{\partial x} =  \mathcal{L}'(f(x)) f'(x)$

Lets start by zeroing the gradients and leaning on pytorch to calculate the gradients for us. We'll also validate the gradients were `none`.

In [62]:
lm_head.zero_grad()
h2g.zero_grad()
x2g.zero_grad()
wte.zero_grad()


# validate gradients
lm_head.weight.grad, wte.weight.grad

(None, None)

**Auto-Diff** - Now let's see the magic of the gradients populate.  This magic is called auto-differentiation, or auto-diff for short. This allows us to not have to write many layers of nasty code to do the differentiation for us, but, if you're a sadist, you can surely find people who have written out that code (it's not too bad since you just do one layer at a time). 

In [63]:
loss.backward()

In [64]:
lm_head.weight.grad, wte.weight.grad

(tensor([[-0.1082, -0.1046, -0.1076],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [-0.1077, -0.1041, -0.1071],
         [-0.0366, -0.0330, -0.0360],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0770, -0.0734, -0.0764],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0421, -0.0385, -0.0415],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0399, -0.0363, -0.0393],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0805, -0.0768, -0.0356],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0423, -0.1035, -0.0417],
         [-0.0422, -0.0386, -0.0416],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0

## Gradient Clipping

Recall that we discussed at the start that there was an issue called exploding gradients. Exploding gradients occurs because back-propagation through time multiplies gradients by a chain of recurrent Jacobians as follows:

$\nabla_{h_{t-k}} \mathcal{L} = \left(\prod_{i=1}^{k} \frac{\partial h_{t-i+1}}{\partial h_{t-i}}\right)\nabla_{h_t}\mathcal{L}$. 

This means that any spectral norm (>1) in those factors causes the gradient norm to grow roughly exponentially with sequence length. Even with LSTM gates, poorly conditioned hidden-to-hidden dynamics or large activations can push singular values WELL above 1 where even sigmoid and tanh layers will keep the values at 1. To combat this, we apply gradient clipping. Gradient clipping caps the global gradient, in our case to `1.0`, to prevent runaway steps, numerical overflow, and unstable updates.

*Note that this will only solve exploding gradients and does not address vanishing gradients. The LSTM structure and use of memory helps fight vanishing gradients*

Since none of our gradients are currently above 1, we will not see an impact of gradient clipping but I did want to introduce the concept as it is an important component. 

In [65]:
nn.utils.clip_grad_norm_(lm_head.parameters(), 1.0)
nn.utils.clip_grad_norm_(h2g.parameters(), 1.0)
nn.utils.clip_grad_norm_(x2g.parameters(), 1.0)
nn.utils.clip_grad_norm_(wte.parameters(), 1.0)
lm_head.weight.grad, wte.weight.grad

(tensor([[-0.1082, -0.1046, -0.1076],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [-0.1077, -0.1041, -0.1071],
         [-0.0366, -0.0330, -0.0360],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0770, -0.0734, -0.0764],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0421, -0.0385, -0.0415],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0399, -0.0363, -0.0393],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0805, -0.0768, -0.0356],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [-0.0423, -0.1035, -0.0417],
         [-0.0422, -0.0386, -0.0416],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0.0236,  0.0273,  0.0242],
         [ 0

## Learning

The process of learning now requires us to update our weights based on this gradient. To really feel the "back propagation" we'll start with the last layer and work backwards, though, since we have all of the gradients calculated already, the order does not matter. Recall that our loss function is the negative log likelihood ratio so our gradient signs are flipped.  If a parameter is important, the gradient will be more negative, and vice versa. The gradients are a ratio of importance of each parameter and we need to know how much of that gradient to apply to our weights. This "how much" is referred to as the *learning rate*. In modern training learning rate schedulers and optimizers are used to vary the rate and application by layer and by training round.  We'll keep it simple and use an astronomically high learning rate of `5.000` which applies the gradient directly to the weights via a `-=`. Gradient for the weights and the biases is different as the partial differential with respect to each is different. We need to remember in the layers with bias to apply it to both.

*Note that since our vocab is very small, our context is small, and our batch is small, relatively our model is very deep so we will see a lot of exceptionally small gradients*

In [66]:
## Huge learning rate to emphasize
learning_rate = 5.000

### Output Layer
Let's start with our output layer.  Recall that we initialized the weights to `1.000` so we can quickly see the impact of the gradient update on the weights. Most notable that we'll see is that the entries corresponding with tokens that are present are up weighted and others are downweighted.  Additionally you'll notice that the gradient uses all of the `n_embd` dimensions with different updates on each dimension. Finally, since this is the output layer, we did not include any bias 

In [67]:
with torch.no_grad():
    lm_head.weight -= learning_rate * lm_head.weight.grad
lm_head.weight

Parameter containing:
tensor([[1.5410, 1.5229, 1.5380],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [1.5387, 1.5205, 1.5357],
        [1.1830, 1.1649, 1.1800],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [1.3851, 1.3670, 1.3821],
        [0.8819, 0.8637, 0.8789],
        [1.2104, 1.1923, 1.2074],
        [0.8819, 0.8637, 0.8789],
        [1.1996, 1.1814, 1.1966],
        [0.8819, 0.8637, 0.8789],
        [1.4023, 1.3841, 1.1778],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [1.2113, 1.5175, 1.2083],
        [1.2111, 1.1930, 1.2081],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0.8789],
        [0.8819, 0.8637, 0

### Recurrent Unit - LSTM
Next we'll update our two linear layers inside of the LSTM.  Recall that we initiated h2g to weights `0.2500`, x2g weights to `0.5000` and the biases to.  Since our gradients are so small (1e-9 or so), they won't show up on the weights when we print them, but we can see it in the biases given they were 0. 

In [68]:
with torch.no_grad():
    h2g.weight -= learning_rate * h2g.weight.grad
    h2g.bias -= learning_rate * h2g.bias.grad
h2g.weight, h2g.bias

(Parameter containing:
 tensor([[0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500]], requires_grad=True),
 Parameter containing:
 tensor([-1.0529e-08, -1.0816e-08, -8.5233e-09, -1.7035e-09, -1.8827e-09,
         -1.7154e-09, -6.5187e-08, -7.1495e-08, -6.5682e-08, -7.1856e-09,
         -7.3534e-09, -6.2211e-09], requires_grad=True))

In [69]:
with torch.no_grad():
    x2g.weight -= learning_rate * x2g.weight.grad
    x2g.bias -= learning_rate * x2g.bias.grad
x2g.weight, x2g.bias

(Parameter containing:
 tensor([[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]], requires_grad=True),
 Parameter containing:
 tensor([-1.0529e-08, -1.0816e-08, -8.5233e-09, -1.7035e-09, -1.8827e-09,
         -1.7154e-09, -6.5187e-08, -7.1495e-08, -6.5682e-08, -7.1856e-09,
         -7.3534e-09, -6.2211e-09], requires_grad=True))

### Input Layer
Finally we'll update our input layer. Recall that we initialized it to `1.000`. You can see that the gradients are once again extremely small.  Also note that the gradients in this case only flow back to the entries for the tokens that were present in our example. This differs from our LM_head since when we calculate back through the recursions we can see that our embeddings select only these rows out of the wte weights meaning our differentiation can only impact them.  

In [70]:
wte.weight.grad

tensor([[0.0000e+00, 1.3955e-09, 1.3955e-09, 1.3955e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.7494e-09, 2.7494e-09, 2.7494e-09, 2.7494e-09],
        [4.0139e-10, 4.0139e-10, 4.0139e-10, 4.0139e-10],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.4687e-09, 3.4687e-09, 3.4687e-09, 3.4687e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.8984e-09, 0.0000e+00, 5.8984e-09, 5.8984e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.1376e-10, 9.1376e-10, 9.1376e-10, 9.1376e-10],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.1085e-09, 3.1085e-09, 3.1085e-09, 3.1085e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.000

In [71]:
with torch.no_grad():
    wte.weight -= learning_rate * wte.weight.grad
wte.weight

Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)

## Forward Pass with Updated Weights

Now that we have the updated weights for each layer, let's do another forward pass and compare the loss. Since each layer was previously explained we will instead focus on just showing the outputs of the different layers and the final loss. If you want, you can check the previous outputs in the cached cell outputs above and compare them to see how the weight changes impacted the values at each layer. 

One key sign that our weights were updated is that you'll see quickly that the values at each layer are no longer as predictable and repetitive.  

### Data Re-loading
Repulling to a new `x_2`. We'll keep `y` to emphasize the same examples are being used. 

In [72]:
x_2 = tok_for_training[:-1].view(B, T)
x_2, y

(tensor([[35, 15, 32,  9,  5, 20, 30, 15],
         [11,  9,  6, 20,  5,  0, 13, 21]]),
 tensor([[15, 32,  9,  5, 20, 30, 15, 11],
         [ 9,  6, 20,  5,  0, 13, 21,  0]]))

### Input Layer
Note that in `wte` since the gradient impact was so small, you hardly see a difference when we have the tensor printed out. In fact, the value is likely so small that we wouldn't even see it unless we selected a value and printed out all of it's impact.  With RNNs, because of the recursion layers, we would need long training and long context to see changes in our input layer initiation. Alternatively, a common step to help improve the input embedding is to do weight-tying between the input and output layers. This ensures that the input and output layers embed in the same space and helps with learning stabilization, but weight tying goes in and out of flavor often.  Because of this we'll cover weight-tying in the GPT notebook I write. 

Since our weights were barely tweaked, we'll again see very flat values at 1.000

In [73]:
x = wte(x_2)
x.shape, x 

(torch.Size([2, 8, 4]),
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]], grad_fn=<EmbeddingBackward0>))

### Dropout 
Once again we'll perform dropout. Similar to the explanation above, the impacts of training are beyond the visible decimals 

In [74]:
x = dropout(x)
x

tensor([[[1.1111, 1.1111, 1.1111, 0.0000],
         [0.0000, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 0.0000],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]],

        [[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [0.0000, 1.1111, 1.1111, 0.0000],
         [0.0000, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]]], grad_fn=<MulBackward0>)

### Recurrent Unit - LSTM

This time through we'll run the recurrent layer for all time periods and not break out a specific time period.  We still have to reset our hidden layer and our memory cell to `0` since our training batch reset. Also, since the weight and bias updates were so small, we do expect that the output H will still show a lot of uniformity.  

In [75]:
h_next = torch.zeros(B_batch, hidden_size) 
c_next = torch.zeros(B_batch, hidden_size) 
hs = x.new_empty(B_batch, T_context, hidden_size)

In [76]:
for t in range(T_context):
    
    # set input and previous
    h_t = h_next
    c_t = c_next
    x_t = x[:, t, :]

    # calculate weights
    xi, xf, xo, xc = x2g(x_t).split(hidden_size, dim=-1)
    hi, hf, ho, hc = h2g(h_t).split(hidden_size, dim=-1)
    
    #input, forget, and output gate 
    input_gate =  torch.sigmoid(xi + hi)
    forget_gate =  torch.sigmoid(xf + hf)
    output_gate =  torch.sigmoid(xo + ho)

    # input state
    input_state = torch.tanh(xc + hc)

    # memory cell
    c_next = (forget_gate*c_t) + (input_gate*input_state)

    # hidden state / recursion output
    h_next = output_gate*torch.tanh(c_next)

    # save h_t
    hs[:, t, :] = h_next

hs

tensor([[[0.5505, 0.5505, 0.5505],
         [0.8134, 0.8134, 0.8134],
         [0.9292, 0.9292, 0.9292],
         [0.9103, 0.9103, 0.9103],
         [0.9473, 0.9473, 0.9473],
         [0.9493, 0.9493, 0.9493],
         [0.9495, 0.9495, 0.9495],
         [0.9495, 0.9495, 0.9495]],

        [[0.6379, 0.6379, 0.6379],
         [0.8826, 0.8826, 0.8826],
         [0.8380, 0.8380, 0.8380],
         [0.9039, 0.9039, 0.9039],
         [0.9468, 0.9468, 0.9468],
         [0.9492, 0.9492, 0.9492],
         [0.9495, 0.9495, 0.9495],
         [0.9495, 0.9495, 0.9495]]], grad_fn=<CopySlices>)

### Final Dropout
As with before, we'll run dropout again before running the model head.  We should similarly see the zeroing out of some values and the probability updates to others

In [77]:
x = dropout(hs)
x

tensor([[[0.6117, 0.6117, 0.0000],
         [0.9038, 0.9038, 0.9038],
         [1.0325, 1.0325, 1.0325],
         [1.0115, 1.0115, 1.0115],
         [1.0525, 1.0525, 1.0525],
         [1.0547, 1.0547, 1.0547],
         [1.0550, 1.0550, 1.0550],
         [1.0550, 1.0550, 1.0550]],

        [[0.7088, 0.0000, 0.7088],
         [0.9807, 0.9807, 0.0000],
         [0.9312, 0.0000, 0.9312],
         [1.0043, 1.0043, 1.0043],
         [1.0521, 1.0521, 1.0521],
         [1.0547, 1.0547, 1.0547],
         [1.0549, 1.0549, 1.0549],
         [1.0550, 1.0550, 1.0550]]], grad_fn=<MulBackward0>)

### Output Layers AKA Model Head.
This is the layer where we expect to see the most updates.  Recall that the weights on the model head had the most notable changes both shifting more positive the values for the token that were present in our training data and shifting more negative the value for the token entries that were not present.  As such, if you look at the entries for each example you should see that the values are no longer equal for each token.

In [78]:
logits = lm_head(x)
logits.shape, logits

(torch.Size([2, 8, 36]),
 tensor([[[1.8742, 1.0678, 1.0678, 1.0678, 1.0678, 1.8714, 1.4362, 1.0678,
           1.0678, 1.6835, 1.0678, 1.4698, 1.0678, 1.4565, 1.0678, 1.7045,
           1.0678, 1.0678, 1.0678, 1.0678, 1.6693, 1.4706, 1.0678, 1.0678,
           1.0678, 1.0678, 1.0678, 1.0678, 1.0678, 1.0678, 1.2695, 1.0678,
           1.4428, 1.0678, 1.0678, 1.0678],
          [4.1591, 2.3719, 2.3719, 2.3719, 2.3719, 4.1527, 3.1884, 2.3719,
           2.3719, 3.7364, 2.3719, 3.2627, 2.3719, 3.2333, 2.3719, 3.5827,
           2.3719, 2.3719, 2.3719, 2.3719, 3.5583, 3.2646, 2.3719, 2.3719,
           2.3719, 2.3719, 2.3719, 2.3719, 2.3719, 2.3719, 2.9678, 2.3719,
           3.2028, 2.3719, 2.3719, 2.3719],
          [4.7515, 2.7098, 2.7098, 2.7098, 2.7098, 4.7442, 3.6425, 2.7098,
           2.7098, 4.2686, 2.7098, 3.7274, 2.7098, 3.6939, 2.7098, 4.0931,
           2.7098, 2.7098, 2.7098, 2.7098, 4.0651, 3.7296, 2.7098, 2.7098,
           2.7098, 2.7098, 2.7098, 2.7098, 2.7098, 2.7098, 3.3

### Updated Loss calculation

Now we'll calculate the updated loss.  Our first pass's loss was 3.5835. Since we're passing through the same example and used a fairly high learning rate we should see a significant improvement with just 1 learning pass. 

In [79]:
loss

tensor(3.5835, grad_fn=<NllLossBackward0>)

In [80]:
y_flat = y.view(-1)
logits_flat = logits.view(-1, logits.size(-1))
updated_loss = F.cross_entropy(logits_flat, y_flat)
print(updated_loss.shape, updated_loss)

torch.Size([]) tensor(2.9012, grad_fn=<NllLossBackward0>)


In [81]:
f'1 round of training resulted in an loss improvment of {loss.item() - updated_loss.item():.4f}'

'1 round of training resulted in an loss improvment of 0.6823'

# SUCCESS!
Our training improved the loss by about **~17%** (amount may vary since we didn't set a seed). There are flaws with this, mainly passing the same example through a second time and a high learning rate, but this helps show the fundamentals of what learning does inside a LSTM style RNN model. 