# RNN LSTM Explainer

The goal is to walk through RNNs (recurring neural networks), there are 2 flavors: GRUs and LSTMs. We'll focus on Long Short-Term Memory (LSTM).

To help display the transformation, we'll use the first sentence from the [linear algebra wiki page](https://en.wikipedia.org/wiki/Linear_algebra) and [lu decomposition wiki page](https://en.wikipedia.org/wiki/LU_decomposition) as the topic is fitting and it shows us some non-standard patterns.  

## Text Prep/Tokenization

we'll start with common preprocessing step of tokenizing the data.  This converts the string text into an array of numbers that can be used during the training loop.  I've built a very subtle byte-pair encdoing that has each unique character that appears and the top 5 merges. This keeps our vocab size small and managable for this example. Typically the vocab size is in the 100K+ range. A great library for this is `tiktoken`. Tokenization simply finds the longest pattern of characters that's in common with what was trained and replaces it with an integer that represents it.  This way we turn the text into a numeric array to simplify computing. import torch
from collections import Counter

In [1]:
import torch
from collections import Counter

In [2]:
class SimpleBPETokenizer:
    def __init__(self, num_merges=5, eot_token='<|endoftext|>'):
        self.num_merges = num_merges
        self.eot_token = eot_token
        self.eot_id = None
        self.merges = []
        self.pair_ranks = {}
        self.vocab = {}
        self.id_to_token = {}

    def _add_token(self, tok):
        if tok in self.vocab:
            return self.vocab[tok]
        i = len(self.vocab)
        self.vocab[tok] = i
        self.id_to_token[i] = tok
        return i

    def _get_bigrams(self, seq):
        for i in range(len(seq) - 1):
            yield (seq[i], seq[i + 1])

    def _merge_once(self, seq, pair):
        a, b = pair
        out = []
        i = 0
        while i < len(seq):
            if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        return out

    def train(self, corpus):
        # corpus: list[str]
        text = ''.join(corpus).lower()
        seq = list(text)
        merges = []
        for _ in range(self.num_merges):
            counts = Counter(self._get_bigrams(seq))
            if not counts: break
            best_pair, _ = counts.most_common(1)[0]
            merges.append(best_pair)
            seq = self._merge_once(seq, best_pair)
        self.merges = merges
        self.pair_ranks = {p: i for i, p in enumerate(self.merges)}

        self.vocab = {}
        self.id_to_token = {}
        for ch in sorted(set(text)):
            self._add_token(ch)
        for a, b in self.merges:
            self._add_token(a + b)
        self.eot_id = self._add_token(self.eot_token)

    def encode(self, text, force_last_eot=True):
        # treat literal eot marker as special; remove it from content
        if self.eot_token in text:
            text = text.replace(self.eot_token, '')
        seq = list(text)

        # make sure all seen base chars exist
        for ch in set(seq):
            if ch not in self.vocab:
                self._add_token(ch)

        # greedy BPE using learned pair ranks
        if self.merges:
            while True:
                best_pair, best_rank = None, None
                for p in self._get_bigrams(seq):
                    r = self.pair_ranks.get(p)
                    if r is not None and (best_rank is None or r < best_rank):
                        best_pair, best_rank = p, r
                if best_pair is None:
                    break
                seq = self._merge_once(seq, best_pair)

        # ensure all tokens in seq exist in vocab (e.g., if new chars appeared)
        for tok in seq:
            if tok not in self.vocab:
                self._add_token(tok)

        ids = [self.vocab[tok] for tok in seq]

        # FORCE: append EOT id if not already last
        if force_last_eot:
            if not ids or ids[-1] != self.eot_id:
                ids.append(self.eot_id)

        return ids

    def decode(self, ids):
        # drop trailing EOT if present
        if ids and self.eot_id is not None and ids[-1] == self.eot_id:
            ids = ids[:-1]
        toks = [self.id_to_token[i] for i in ids]
        return ''.join(toks)


In [3]:
raw_example_1 = r'''Linear algebra is central to almost all areas of mathematics. For instance, linear algebra is fundamental in modern presentations of geometry, including for defining basic objects such as lines, planes and rotations. Also, functional analysis, a branch of mathematical analysis, may be viewed as the application of linear algebra to function spaces.'''
raw_example_2 = r'''In numerical analysis and linear algebra, lower–upper (LU) decomposition or factorization factors a matrix as the product of a lower triangular matrix and an upper triangular matrix (see matrix multiplication and matrix decomposition).'''


In [4]:
tok = SimpleBPETokenizer(num_merges=5)
tok.train([raw_example_1,raw_example_2])
tok.merges

[(' ', 'a'), ('a', 't'), ('i', 'n'), (' ', 'm'), ('i', 'o')]

In [5]:
tok.vocab

{' ': 0,
 '(': 1,
 ')': 2,
 ',': 3,
 '.': 4,
 'a': 5,
 'b': 6,
 'c': 7,
 'd': 8,
 'e': 9,
 'f': 10,
 'g': 11,
 'h': 12,
 'i': 13,
 'j': 14,
 'l': 15,
 'm': 16,
 'n': 17,
 'o': 18,
 'p': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28,
 '–': 29,
 ' a': 30,
 'at': 31,
 'in': 32,
 ' m': 33,
 'io': 34,
 '<|endoftext|>': 35}

In [6]:
vocab_size = len(tok.vocab)
vocab_size

36

In [7]:
eot = tok.eot_id
tokens = []
for example in [raw_example_1, raw_example_2]:
    tokens.extend([eot])
    tokens.extend(tok.encode(example.lower()))
all_tokens = torch.tensor(tokens, dtype=torch.long)
all_tokens

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0,  7,
         9, 17, 22, 20,  5, 15,  0, 22, 18, 30, 15, 16, 18, 21, 22, 30, 15, 15,
        30, 20,  9,  5, 21,  0, 18, 10, 33, 31, 12,  9, 16, 31, 13,  7, 21,  4,
         0, 10, 18, 20,  0, 32, 21, 22,  5, 17,  7,  9,  3,  0, 15, 32,  9,  5,
        20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0, 10, 23, 17,  8,  5, 16,
         9, 17, 22,  5, 15,  0, 32, 33, 18,  8,  9, 20, 17,  0, 19, 20,  9, 21,
         9, 17, 22, 31, 34, 17, 21,  0, 18, 10,  0, 11,  9, 18, 16,  9, 22, 20,
        27,  3,  0, 32,  7, 15, 23,  8, 32, 11,  0, 10, 18, 20,  0,  8,  9, 10,
        32, 32, 11,  0,  6,  5, 21, 13,  7,  0, 18,  6, 14,  9,  7, 22, 21,  0,
        21, 23,  7, 12, 30, 21,  0, 15, 32,  9, 21,  3,  0, 19, 15,  5, 17,  9,
        21, 30, 17,  8,  0, 20, 18, 22, 31, 34, 17, 21,  4, 30, 15, 21, 18,  3,
         0, 10, 23, 17,  7, 22, 34, 17,  5, 15, 30, 17,  5, 15, 27, 21, 13, 21,
         3, 30,  0,  6, 20,  5, 17,  7, 

# Modeling

A machine learning model forward pass now uses the tokenization information, runs several layers of linear algebra on it, and then "predicts" the next token. When it is noisy (like you will see in this example), this process results in gibberish.  The training process changes the noise to pattern during the "backward pass" as you'll see.    We'll show 3 steps that are focused on training:
1. **Data Loading** `x, y = train_loader.next_batch()` - this step pulls from the raw data enough tokens to complete a forward and backward pass.  If the model is inference only, this step is replaced with taking in the inference input and preparing it similarly as the forward pass.
2. **Forward Pass** `logits, loss = model(x, y)` - using the data and the model architecture to predict the next token. When training we also compare against the expected to get loss, but in infrerence, we use the logits to complete the inference task.
3. **Backward Pass & Training** `loss.backward(); optimizer.step()` - using differentials to understand what parameters most impact the forward pass' impact on its prediction, comparing that against what is actually right based on the data loading step, and then making very minor adjustments to the impactful parameters with the hope it improves future predictions.

The we'll show a final **Forward Pass** with the updated weights we did in #3. 

## Data Loading

To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to hold all at once in real practice, we would read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 
To start, we have to identify the batch size and the model context length to determine how much data we need.  Consequently, these dimensions also form 2 of the 3 dimensions in the initial matrix.
- **Batch Size (B)** - This is the number of examples you'll train on in a single pass. 
- **Context Length (T)** - This is the max number of tokens that a model can use in a single pass to generat the next token. If an example is below this length, it can be padded.
  
*Ideally both B and T are multiples of 2 to work nicely with chip architecture. This is a common theme across the board*

In [8]:
B = 2 # Batch
T = 8 # context length

In [9]:
current_position = 0
tok_for_training = all_tokens[current_position:current_position + B*T +1 ]
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0])

In [10]:
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0])

In [11]:
x=tok_for_training[:-1].view(B, T)
x

tensor([[35, 15, 32,  9,  5, 20, 30, 15],
        [11,  9,  6, 20,  5,  0, 13, 21]])

In [12]:
y=tok_for_training[1:].view(B, T)
y

tensor([[15, 32,  9,  5, 20, 30, 15, 11],
        [ 9,  6, 20,  5,  0, 13, 21,  0]])

## Forward pass

In [13]:
import torch.nn as nn

In [14]:
B_batch, T_context = x.size()
B_batch, T_context

(2, 8)

In [15]:
n_embd = 4 # level of embedding of input tokens
n_embd, vocab_size

(4, 36)

**Embedding input**

Same as with transformer

In [16]:
wte = nn.Embedding(vocab_size, n_embd)
torch.nn.init.ones_(wte.weight)
wte.weight

Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)

In [17]:
x = wte(x)
x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]], grad_fn=<EmbeddingBackward0>))

**Dropout**

In [18]:
dropout = nn.Dropout(0.1)
dropout

Dropout(p=0.1, inplace=False)

In [19]:
x = dropout(x)
x

tensor([[[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 0.0000, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]],

        [[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 0.0000],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [0.0000, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 0.0000, 1.1111, 1.1111]]], grad_fn=<MulBackward0>)

**Recurrent Block**



In an LSTM at time (t): (i_t) is the input gate, (f_t) is the forget gate, (o_t) is the output gate, and (g_t) is the candidate cell content (sometimes written $\tilde{c}_t)$. (c_t) is the cell state after update, and (h_t) is the hidden state (the exposed output).

Using the usual affine “gate packs,” $x2g$ denotes the learned input-to-gates projection $W_x x_t + b_x \in \mathbb{R}^{4H}$ and $h2g$ the hidden-to-gates projection $W_h h_{t-1} + b_h \in \mathbb{R}^{4H}$. Their sum is split into four (H)-wide chunks to get the preactivations for the four gate blocks.

The standard equations $with (x_t \in \mathbb{R}^{B\times d}), (h_{t-1},h_t,c_{t-1},c_t \in \mathbb{R}^{B\times H})$ are:
$
\begin{aligned}
[z_i, z_f, z_g, z_o] = x2g(x_t) + h2g(h_{t-1}) \quad\in \mathbb{R}^{B\times 4H},\
i_t &= \sigma(z_i),\qquad f_t=\sigma(z_f),\qquad g_t=\tanh(z_g),\qquad o_t=\sigma(z_o),\
c_t &= f_t \odot c_{t-1} + i_t \odot g_t,\
h_t &= o_t \odot \tanh(c_t).
\end{aligned}
$

Intuition: (i_t) decides how much new information (g_t) to write, (f_t) decides how much of the old cell (c_{t-1}) to keep, (o_t) decides how much of the updated cell to expose through (h_t).


In [20]:
hidden_size = 5

In [21]:
x2g = nn.Linear(n_embd, 4 * hidden_size, bias=True)
torch.nn.init.constant_(x2g.weight, 0.500)
torch.nn.init.zeros_(x2g.bias)
x2g.weight.size(), x2g.weight, x2g.bias.size(), x2g.bias

(torch.Size([20, 4]),
 Parameter containing:
 tensor([[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]], requires_grad=True),
 torch.Size([20]),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0.

In [22]:
h2g = nn.Linear(hidden_size, 4 * hidden_size, bias=True)
torch.nn.init.constant_(h2g.weight, 0.250)
torch.nn.init.zeros_(h2g.bias)
h2g.weight.size(), h2g.weight, h2g.bias.size(), h2g.bias

(torch.Size([20, 5]),
 Parameter containing:
 tensor([[0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2

since first pass, incoming weight (h_t) are seroed

first pass (future will loop

In [23]:
x_0 = x[:, 0, :]
x_0

tensor([[1.1111, 1.1111, 1.1111, 1.1111],
        [1.1111, 1.1111, 1.1111, 1.1111]], grad_fn=<SliceBackward0>)

In [24]:
c_t = torch.zeros(B_batch, hidden_size) 
h_t = torch.zeros(B_batch, hidden_size) 
c_t, h_t

(tensor([[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]),
 tensor([[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]))

In [25]:
h_prev, c_prev = h_t, c_t

In [26]:
gi = x2g(x_0)
gi.size(), gi

(torch.Size([2, 20]),
 tensor([[2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222,
          2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222,
          2.2222, 2.2222],
         [2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222,
          2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222, 2.2222,
          2.2222, 2.2222]], grad_fn=<AddmmBackward0>))

In [27]:
gh = h2g(h_prev)
gh.size(), gh

(torch.Size([2, 20]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
        grad_fn=<AddmmBackward0>))

In [28]:
g = gi + gh
i_t, f_t, g_t, o_t = g.chunk(4, dim=-1)

g.size()

torch.Size([2, 20])

In [29]:
print('i_t', i_t, i_t.size())
i_t = torch.sigmoid(i_t)
i_t

i_t tensor([[2.2222, 2.2222, 2.2222, 2.2222, 2.2222],
        [2.2222, 2.2222, 2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>) torch.Size([2, 5])


tensor([[0.9022, 0.9022, 0.9022, 0.9022, 0.9022],
        [0.9022, 0.9022, 0.9022, 0.9022, 0.9022]], grad_fn=<SigmoidBackward0>)

In [30]:
print('f_t', f_t, f_t.size())
f_t = torch.sigmoid(f_t)
f_t

f_t tensor([[2.2222, 2.2222, 2.2222, 2.2222, 2.2222],
        [2.2222, 2.2222, 2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>) torch.Size([2, 5])


tensor([[0.9022, 0.9022, 0.9022, 0.9022, 0.9022],
        [0.9022, 0.9022, 0.9022, 0.9022, 0.9022]], grad_fn=<SigmoidBackward0>)

In [31]:
print('g_t', g_t, g_t.size())
g_t = torch.tanh(g_t)
g_t

g_t tensor([[2.2222, 2.2222, 2.2222, 2.2222, 2.2222],
        [2.2222, 2.2222, 2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>) torch.Size([2, 5])


tensor([[0.9768, 0.9768, 0.9768, 0.9768, 0.9768],
        [0.9768, 0.9768, 0.9768, 0.9768, 0.9768]], grad_fn=<TanhBackward0>)

In [32]:
print('o_t', o_t, o_t.size())
o_t = torch.sigmoid(o_t)
o_t

o_t tensor([[2.2222, 2.2222, 2.2222, 2.2222, 2.2222],
        [2.2222, 2.2222, 2.2222, 2.2222, 2.2222]], grad_fn=<SplitBackward0>) torch.Size([2, 5])


tensor([[0.9022, 0.9022, 0.9022, 0.9022, 0.9022],
        [0.9022, 0.9022, 0.9022, 0.9022, 0.9022]], grad_fn=<SigmoidBackward0>)

In [33]:
c_t = (f_t * c_prev) + (i_t * g_t)
c_t.size(), c_t

(torch.Size([2, 5]),
 tensor([[0.8813, 0.8813, 0.8813, 0.8813, 0.8813],
         [0.8813, 0.8813, 0.8813, 0.8813, 0.8813]], grad_fn=<AddBackward0>))

In [34]:
c_t_tanh = torch.tanh(c_t)
c_t_tanh.size(), c_t_tanh

(torch.Size([2, 5]),
 tensor([[0.7071, 0.7071, 0.7071, 0.7071, 0.7071],
         [0.7071, 0.7071, 0.7071, 0.7071, 0.7071]], grad_fn=<TanhBackward0>))

In [35]:
h_t = o_t * c_t_tanh

h_t.size(), h_t

(torch.Size([2, 5]),
 tensor([[0.6379, 0.6379, 0.6379, 0.6379, 0.6379],
         [0.6379, 0.6379, 0.6379, 0.6379, 0.6379]], grad_fn=<MulBackward0>))

## now do it recursively for the rest of the batch size

first we'll need to keep track of h_t for each loop for training, so let's start collecting them. 

In [36]:
hs = []
hs.append(h_t.unsqueeze(1))
hs

[tensor([[[0.6379, 0.6379, 0.6379, 0.6379, 0.6379]],
 
         [[0.6379, 0.6379, 0.6379, 0.6379, 0.6379]]],
        grad_fn=<UnsqueezeBackward0>)]

start at 1 since we did the "0" pass

In [37]:
for t in range(1,T):
    x_t = x[:, t, :]
    h_prev, c_prev = h_t, c_t
    gi = x2g(x_t)
    gh = h2g(h_prev)
    g = gi + gh
    i_t, f_t, g_t, o_t = g.chunk(4, dim=-1)
    i_t = torch.sigmoid(i_t)
    f_t = torch.sigmoid(f_t)
    g_t = torch.tanh(g_t)
    o_t = torch.sigmoid(o_t)

    c_t = (f_t * c_prev) + (i_t * g_t)
    
    c_t_tanh = torch.tanh(c_t)
    h_t = o_t * c_t_tanh

    print(f't: {t}')
    print(c_t, h_t)
    hs.append(h_t.unsqueeze(1))

t: 1
tensor([[1.7892, 1.7892, 1.7892, 1.7892, 1.7892],
        [1.7205, 1.7205, 1.7205, 1.7205, 1.7205]], grad_fn=<AddBackward0>) tensor([[0.9017, 0.9017, 0.9017, 0.9017, 0.9017],
        [0.8644, 0.8644, 0.8644, 0.8644, 0.8644]], grad_fn=<MulBackward0>)
t: 2
tensor([[2.6213, 2.6213, 2.6213, 2.6213, 2.6213],
        [2.6214, 2.6214, 2.6214, 2.6214, 2.6214]], grad_fn=<AddBackward0>) tensor([[0.9324, 0.9324, 0.9324, 0.9324, 0.9324],
        [0.9544, 0.9544, 0.9544, 0.9544, 0.9544]], grad_fn=<MulBackward0>)
t: 3
tensor([[3.5008, 3.5008, 3.5008, 3.5008, 3.5008],
        [3.4190, 3.4190, 3.4190, 3.4190, 3.4190]], grad_fn=<AddBackward0>) tensor([[0.9656, 0.9656, 0.9656, 0.9656, 0.9656],
        [0.9438, 0.9438, 0.9438, 0.9438, 0.9438]], grad_fn=<MulBackward0>)
t: 4
tensor([[4.3574, 4.3574, 4.3574, 4.3574, 4.3574],
        [4.2744, 4.2744, 4.2744, 4.2744, 4.2744]], grad_fn=<AddBackward0>) tensor([[0.9683, 0.9683, 0.9683, 0.9683, 0.9683],
        [0.9674, 0.9674, 0.9674, 0.9674, 0.9674]], grad

In [38]:
hs

[tensor([[[0.6379, 0.6379, 0.6379, 0.6379, 0.6379]],
 
         [[0.6379, 0.6379, 0.6379, 0.6379, 0.6379]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9017, 0.9017, 0.9017, 0.9017, 0.9017]],
 
         [[0.8644, 0.8644, 0.8644, 0.8644, 0.8644]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9324, 0.9324, 0.9324, 0.9324, 0.9324]],
 
         [[0.9544, 0.9544, 0.9544, 0.9544, 0.9544]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9656, 0.9656, 0.9656, 0.9656, 0.9656]],
 
         [[0.9438, 0.9438, 0.9438, 0.9438, 0.9438]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9683, 0.9683, 0.9683, 0.9683, 0.9683]],
 
         [[0.9674, 0.9674, 0.9674, 0.9674, 0.9674]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9686, 0.9686, 0.9686, 0.9686, 0.9686]],
 
         [[0.9686, 0.9686, 0.9686, 0.9686, 0.9686]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9687, 0.9687, 0.9687, 0.9687, 0.9687]],
 
         [[0.9687, 0.9687, 0.9687, 0.9687, 0.9687]]],
   

Combine out recurring inputs into the 2 different batches

In [39]:
x = torch.cat(hs, dim=1) 
x

tensor([[[0.6379, 0.6379, 0.6379, 0.6379, 0.6379],
         [0.9017, 0.9017, 0.9017, 0.9017, 0.9017],
         [0.9324, 0.9324, 0.9324, 0.9324, 0.9324],
         [0.9656, 0.9656, 0.9656, 0.9656, 0.9656],
         [0.9683, 0.9683, 0.9683, 0.9683, 0.9683],
         [0.9686, 0.9686, 0.9686, 0.9686, 0.9686],
         [0.9687, 0.9687, 0.9687, 0.9687, 0.9687],
         [0.9687, 0.9687, 0.9687, 0.9687, 0.9687]],

        [[0.6379, 0.6379, 0.6379, 0.6379, 0.6379],
         [0.8644, 0.8644, 0.8644, 0.8644, 0.8644],
         [0.9544, 0.9544, 0.9544, 0.9544, 0.9544],
         [0.9438, 0.9438, 0.9438, 0.9438, 0.9438],
         [0.9674, 0.9674, 0.9674, 0.9674, 0.9674],
         [0.9686, 0.9686, 0.9686, 0.9686, 0.9686],
         [0.9687, 0.9687, 0.9687, 0.9687, 0.9687],
         [0.9467, 0.9467, 0.9467, 0.9467, 0.9467]]], grad_fn=<CatBackward0>)

**Dropout** to fight vanishing / exploding gradient

In [40]:
x = dropout(x)
x.size(), x

(torch.Size([2, 8, 5]),
 tensor([[[0.7088, 0.7088, 0.7088, 0.0000, 0.7088],
          [0.0000, 1.0018, 0.0000, 1.0018, 1.0018],
          [1.0360, 1.0360, 1.0360, 1.0360, 1.0360],
          [1.0728, 1.0728, 1.0728, 1.0728, 1.0728],
          [0.0000, 1.0759, 1.0759, 1.0759, 1.0759],
          [1.0763, 0.0000, 1.0763, 1.0763, 1.0763],
          [1.0763, 1.0763, 1.0763, 1.0763, 1.0763],
          [1.0764, 0.0000, 1.0764, 1.0764, 1.0764]],
 
         [[0.7088, 0.7088, 0.7088, 0.7088, 0.0000],
          [0.9604, 0.9604, 0.9604, 0.9604, 0.0000],
          [1.0604, 1.0604, 1.0604, 1.0604, 1.0604],
          [1.0487, 1.0487, 1.0487, 1.0487, 1.0487],
          [1.0749, 1.0749, 1.0749, 1.0749, 1.0749],
          [1.0762, 1.0762, 0.0000, 1.0762, 1.0762],
          [1.0763, 1.0763, 1.0763, 1.0763, 1.0763],
          [0.0000, 1.0519, 1.0519, 1.0519, 1.0519]]], grad_fn=<MulBackward0>))

**Output Head**
projects down from the hiden size to the vocab for us to get logits. 

In [41]:
lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
torch.nn.init.ones_(lm_head.weight)
lm_head.weight.size(), lm_head.weight

(torch.Size([36, 5]),
 Parameter containing:
 tensor([[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1.,

In [42]:
logits = lm_head(x)

logits.shape, logits

(torch.Size([2, 8, 36]),
 tensor([[[2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
           2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
           2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
           2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
           2.8352, 2.8352, 2.8352, 2.8352],
          [3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
           3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
           3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
           3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
           3.0055, 3.0055, 3.0055, 3.0055],
          [5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,
           5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,
           5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,
           5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1

**Loss**

In [43]:
import torch.nn.functional as F

In [44]:
y_flat = y.view(-1)
y_flat.shape, y_flat

(torch.Size([16]),
 tensor([15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0]))

In [45]:
logits_flat = logits.view(-1, logits.size(-1))
logits_flat.shape, logits_flat

(torch.Size([16, 36]),
 tensor([[2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
          2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
          2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352,
          2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352, 2.8352],
         [3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
          3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
          3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055,
          3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055, 3.0055],
         [5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,
          5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,
          5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,
          5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801, 5.1801,

In [46]:
loss = F.cross_entropy(logits_flat, y_flat)
loss.shape, loss

(torch.Size([]), tensor(3.5835, grad_fn=<NllLossBackward0>))

## Back Propogation

In [47]:
lm_head.zero_grad()
h2g.zero_grad()
x2g.zero_grad()
wte.zero_grad()


# validate gradients
lm_head.weight.grad, wte.weight.grad

(None, None)

In [48]:
loss.backward()

In [49]:
lm_head.weight.grad, wte.weight.grad

(tensor([[-0.0445, -0.1086, -0.1084, -0.1061, -0.1077],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.1099, -0.1082, -0.1081, -0.1057, -0.1074],
         [-0.0374, -0.0357, -0.0355, -0.0332,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0864, -0.0847, -0.0846, -0.0822, -0.0396],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0446,  0.0244, -0.0428, -0.0404, -0.0421],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0446, -0.0429,  0.0245, -0.0404, -0.0421],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0889, -0.0872, -0.0871, -0.0404, -0.0864],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0

**Gradient clipping**
While gradient clipping helps with exploding gradients,

In [51]:
nn.utils.clip_grad_norm_(lm_head.parameters(), 1.0)
nn.utils.clip_grad_norm_(h2g.parameters(), 1.0)
nn.utils.clip_grad_norm_(x2g.parameters(), 1.0)
nn.utils.clip_grad_norm_(wte.parameters(), 1.0)
lm_head.weight.grad, wte.weight.grad

(tensor([[-0.0445, -0.1086, -0.1084, -0.1061, -0.1077],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.1099, -0.1082, -0.1081, -0.1057, -0.1074],
         [-0.0374, -0.0357, -0.0355, -0.0332,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0864, -0.0847, -0.0846, -0.0822, -0.0396],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0446,  0.0244, -0.0428, -0.0404, -0.0421],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0446, -0.0429,  0.0245, -0.0404, -0.0421],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [-0.0889, -0.0872, -0.0871, -0.0404, -0.0864],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0.0252],
         [ 0.0227,  0.0244,  0.0245,  0.0269,  0

## Learning 

In [52]:
## Huge learning rate to emphasize
learning_rate = 5.000

In [53]:
with torch.no_grad():
    lm_head.weight -= learning_rate * lm_head.weight.grad
lm_head.weight

Parameter containing:
tensor([[1.2226, 1.5428, 1.5422, 1.5303, 1.5386],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [1.5497, 1.5412, 1.5405, 1.5287, 1.5370],
        [1.1868, 1.1783, 1.1777, 1.1658, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [1.4320, 1.4235, 1.4228, 1.4109, 1.1978],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [1.2231, 0.8782, 1.2139, 1.2020, 1.2104],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [1.2230, 1.2145, 0.8776, 1.2020, 1.2103],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [1.4446, 1.4361, 1.4354, 1.2020, 1.4319],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8776, 0.8657, 0.8740],
        [0.8867, 0.8782, 0.8

In [54]:
with torch.no_grad():
    h2g.weight -= learning_rate * h2g.weight.grad
    h2g.bias -= learning_rate * h2g.bias.grad
h2g.weight, h2g.bias

(Parameter containing:
 tensor([[0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
        

In [55]:
with torch.no_grad():
    x2g.weight -= learning_rate * x2g.weight.grad
    x2g.bias -= learning_rate * x2g.bias.grad
x2g.weight, x2g.bias

(Parameter containing:
 tensor([[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]], requires_grad=True),
 Parameter containing:
 tensor([-7.2200e-09, -8.3450e-09, -7.2256e-09, -6.3360e-09, -4.3217e-09

In [56]:
with torch.no_grad():
    wte.weight -= learning_rate * wte.weight.grad
wte.weight

Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)

## Forward Pass with Updated Weights

In [57]:
x_2 = tok_for_training[:-1].view(B, T)
x_2, y

(tensor([[35, 15, 32,  9,  5, 20, 30, 15],
         [11,  9,  6, 20,  5,  0, 13, 21]]),
 tensor([[15, 32,  9,  5, 20, 30, 15, 11],
         [ 9,  6, 20,  5,  0, 13, 21,  0]]))

## Input projection

In [58]:
x = wte(x_2)
x.shape, x

(torch.Size([2, 8, 4]),
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]], grad_fn=<EmbeddingBackward0>))

**Dropout**

In [59]:
x = dropout(x)
x


tensor([[[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [0.0000, 1.1111, 1.1111, 1.1111],
         [1.1111, 0.0000, 0.0000, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]],

        [[1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 0.0000, 1.1111],
         [1.1111, 1.1111, 1.1111, 0.0000],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 0.0000, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111],
         [1.1111, 1.1111, 1.1111, 1.1111]]], grad_fn=<MulBackward0>)

**Recurrent Block** Collapsed Together

h_t, c_t still resets to 0.

In [60]:
c_t = torch.zeros(B_batch, hidden_size) 
h_t = torch.zeros(B_batch, hidden_size) 
hs = []
c_t, h_t

(tensor([[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]),
 tensor([[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]))

In [61]:
for t in range(T):
    x_t = x[:, t, :]
    h_prev, c_prev = h_t, c_t
    gi = x2g(x_t)
    gh = h2g(h_prev)
    g = gi + gh
    i_t, f_t, g_t, o_t = g.chunk(4, dim=-1)
    i_t = torch.sigmoid(i_t)
    f_t = torch.sigmoid(f_t)
    g_t = torch.tanh(g_t)
    o_t = torch.sigmoid(o_t)

    c_t = (f_t * c_prev) + (i_t * g_t)
    
    c_t_tanh = torch.tanh(c_t)
    h_t = o_t * c_t_tanh

    print(f't: {t}')
    print(c_t, h_t)
    hs.append(h_t.unsqueeze(1))

t: 0
tensor([[0.8813, 0.8813, 0.8813, 0.8813, 0.8813],
        [0.8813, 0.8813, 0.8813, 0.8813, 0.8813]], grad_fn=<AddBackward0>) tensor([[0.6379, 0.6379, 0.6379, 0.6379, 0.6379],
        [0.6379, 0.6379, 0.6379, 0.6379, 0.6379]], grad_fn=<MulBackward0>)
t: 1
tensor([[1.7892, 1.7892, 1.7892, 1.7892, 1.7892],
        [1.7205, 1.7205, 1.7205, 1.7205, 1.7205]], grad_fn=<AddBackward0>) tensor([[0.9017, 0.9017, 0.9017, 0.9017, 0.9017],
        [0.8644, 0.8644, 0.8644, 0.8644, 0.8644]], grad_fn=<MulBackward0>)
t: 2
tensor([[2.6213, 2.6213, 2.6213, 2.6213, 2.6213],
        [2.5489, 2.5489, 2.5489, 2.5489, 2.5489]], grad_fn=<AddBackward0>) tensor([[0.9324, 0.9324, 0.9324, 0.9324, 0.9324],
        [0.9283, 0.9283, 0.9283, 0.9283, 0.9283]], grad_fn=<MulBackward0>)
t: 3
tensor([[3.2654, 3.2654, 3.2654, 3.2654, 3.2654],
        [3.4301, 3.4301, 3.4301, 3.4301, 3.4301]], grad_fn=<AddBackward0>) tensor([[0.9043, 0.9043, 0.9043, 0.9043, 0.9043],
        [0.9651, 0.9651, 0.9651, 0.9651, 0.9651]], grad

In [62]:
hs

[tensor([[[0.6379, 0.6379, 0.6379, 0.6379, 0.6379]],
 
         [[0.6379, 0.6379, 0.6379, 0.6379, 0.6379]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9017, 0.9017, 0.9017, 0.9017, 0.9017]],
 
         [[0.8644, 0.8644, 0.8644, 0.8644, 0.8644]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9324, 0.9324, 0.9324, 0.9324, 0.9324]],
 
         [[0.9283, 0.9283, 0.9283, 0.9283, 0.9283]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9043, 0.9043, 0.9043, 0.9043, 0.9043]],
 
         [[0.9651, 0.9651, 0.9651, 0.9651, 0.9651]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9657, 0.9657, 0.9657, 0.9657, 0.9657]],
 
         [[0.9682, 0.9682, 0.9682, 0.9682, 0.9682]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9685, 0.9685, 0.9685, 0.9685, 0.9685]],
 
         [[0.9466, 0.9466, 0.9466, 0.9466, 0.9466]]],
        grad_fn=<UnsqueezeBackward0>),
 tensor([[[0.9687, 0.9687, 0.9687, 0.9687, 0.9687]],
 
         [[0.9679, 0.9679, 0.9679, 0.9679, 0.9679]]],
   

combine weights back together

In [63]:
x = torch.cat(hs, dim=1) 
x.size(), x

(torch.Size([2, 8, 5]),
 tensor([[[0.6379, 0.6379, 0.6379, 0.6379, 0.6379],
          [0.9017, 0.9017, 0.9017, 0.9017, 0.9017],
          [0.9324, 0.9324, 0.9324, 0.9324, 0.9324],
          [0.9043, 0.9043, 0.9043, 0.9043, 0.9043],
          [0.9657, 0.9657, 0.9657, 0.9657, 0.9657],
          [0.9685, 0.9685, 0.9685, 0.9685, 0.9685],
          [0.9687, 0.9687, 0.9687, 0.9687, 0.9687],
          [0.9687, 0.9687, 0.9687, 0.9687, 0.9687]],
 
         [[0.6379, 0.6379, 0.6379, 0.6379, 0.6379],
          [0.8644, 0.8644, 0.8644, 0.8644, 0.8644],
          [0.9283, 0.9283, 0.9283, 0.9283, 0.9283],
          [0.9651, 0.9651, 0.9651, 0.9651, 0.9651],
          [0.9682, 0.9682, 0.9682, 0.9682, 0.9682],
          [0.9466, 0.9466, 0.9466, 0.9466, 0.9466],
          [0.9679, 0.9679, 0.9679, 0.9679, 0.9679],
          [0.9687, 0.9687, 0.9687, 0.9687, 0.9687]]], grad_fn=<CatBackward0>))

**Dropout**

In [64]:
x = dropout(x)
x


tensor([[[0.7088, 0.7088, 0.7088, 0.7088, 0.7088],
         [1.0018, 1.0018, 1.0018, 1.0018, 1.0018],
         [0.0000, 1.0360, 1.0360, 1.0360, 1.0360],
         [1.0048, 1.0048, 1.0048, 1.0048, 1.0048],
         [1.0730, 1.0730, 1.0730, 1.0730, 1.0730],
         [1.0761, 1.0761, 1.0761, 1.0761, 1.0761],
         [1.0763, 1.0763, 1.0763, 1.0763, 1.0763],
         [1.0764, 1.0764, 1.0764, 1.0764, 1.0764]],

        [[0.7088, 0.7088, 0.7088, 0.7088, 0.7088],
         [0.9604, 0.9604, 0.9604, 0.9604, 0.9604],
         [0.0000, 1.0315, 1.0315, 1.0315, 1.0315],
         [1.0724, 1.0724, 1.0724, 1.0724, 1.0724],
         [0.0000, 1.0758, 1.0758, 1.0758, 1.0758],
         [1.0518, 1.0518, 1.0518, 1.0518, 1.0518],
         [1.0754, 1.0754, 1.0754, 1.0754, 1.0754],
         [1.0763, 1.0763, 1.0763, 1.0763, 1.0763]]], grad_fn=<MulBackward0>)

**Head**

In [65]:
logits = lm_head(x)
logits.shape, logits

(torch.Size([2, 8, 36]),
 tensor([[[5.2286, 3.1061, 3.1061, 3.1061, 3.1061, 5.4557, 3.9571, 3.1061,
           3.1061, 4.8816, 3.1061, 4.0598, 3.1061, 4.0597, 3.1061, 4.9262,
           3.1061, 3.1061, 3.1061, 3.1061, 5.2338, 4.2982, 3.1061, 3.1061,
           3.1061, 3.1061, 3.1061, 3.1061, 3.1061, 3.1061, 4.0597, 3.1061,
           3.7719, 3.1061, 3.1061, 3.1061],
          [7.3902, 4.3903, 4.3903, 4.3903, 4.3903, 7.7112, 5.5930, 4.3903,
           4.3903, 6.8997, 4.3903, 5.7382, 4.3903, 5.7380, 4.3903, 6.9628,
           4.3903, 4.3903, 4.3903, 4.3903, 7.3975, 6.0751, 4.3903, 4.3903,
           4.3903, 4.3903, 4.3903, 4.3903, 4.3903, 4.3903, 5.7381, 4.3903,
           5.3312, 4.3903, 4.3903, 4.3903],
          [6.3757, 3.6214, 3.6214, 3.6214, 3.6214, 6.3688, 4.5543, 3.6214,
           3.6214, 5.6516, 3.6214, 4.6669, 3.6214, 4.6667, 3.6214, 5.7038,
           3.6214, 3.6214, 3.6214, 3.6214, 6.3880, 5.0153, 3.6214, 3.6214,
           3.6214, 3.6214, 3.6214, 3.6214, 3.6214, 3.6214, 4.6

### Updated Loss calculation

Now we'll calculate the updated loss.  Our first pass's loss was 3.5835. Since we're passing through the same example and used a fairly high learning rate we should see a significant improvement with just 1 learning pass. 

In [66]:
loss

tensor(3.5835, grad_fn=<NllLossBackward0>)

In [67]:
y_flat = y.view(-1)
logits_flat = logits.view(-1, logits.size(-1))
updated_loss = F.cross_entropy(logits_flat, y_flat)
print(updated_loss.shape, updated_loss)

torch.Size([]) tensor(2.6797, grad_fn=<NllLossBackward0>)


In [68]:
f'1 round of training resulted in an loss improvment of {loss.item() - updated_loss.item():.4f}'

'1 round of training resulted in an loss improvment of 0.9038'

# SUCCESS!
Our training improved the loss by about **~25%** (amount may vary since we didn't set a seed). There are flaws with this, mainly passing the same example through a second time, but this helps show the fundamentals of what learning does inside a GPT-2 style model. 