# Recurrent neural network (RNN) architectures

(Built on section 8 of Zhang, A., Lipton, Z. C., Li, M. & Smola, A. J. Dive into Deep Learning. 2021. https://d2l.ai/)


Latent variable models for sequences

$$
P(x_t | x_{t-1}, \ldots, x_1) \approx P(x_t | h_{t-1}), \qquad
h_t = f(x_t, h_{t-1})
$$

$h_t$: hidden state

## MLP

$\mathbf{X} \in \mathbb{R}^{n \times d}$: minibatch of $n$ examples (instances) of d-dimensional iputs

$\mathbf{H} \in \mathbb{R}^{n \times m}$: hidden layer for minibatch of $n$ instances with m dimensions

$\mathbf{O} \in \mathbb{R}^{n \times q}$: output layer for minibatch of $n$ instances with q dimensions (e.g. regression $q=1$, 10-wise classifiction $q=10$)

$$\mathbf{H} = \phi(\mathbf{XW}_{dm} + \mathbf{b}_m)$$

$\mathbf{W}_{dm}$: weight matrix, $\mathbf{b}_m$: bias vector

$$\mathbf{O} = \mathbf{HW}_{mq} + \mathbf{b}_q$$

$\mathbf{W}_{mq}$: weight matrix, $\mathbf{b}_q$: bias vector

![Mlp](mlp.png)


## RNN

$\mathbf{X}_t \in \mathbb{R}^{n \times d}$: minibatch of $n$ examples (instances) of d-dimensional iputs **at time step $t$**

$\mathbf{H}_t \in \mathbb{R}^{n \times m}$: hidden variable (state) for minibatch of $n$ instances with m dimensions **at time step $t$**

$\mathbf{O}_t \in \mathbb{R}^{n \times q}$: output layer for minibatch of $n$ instances with q dimensions  **at time step $t$** (e.g. regression $q=1$, 10-wise classifiction $q=10$)

$$\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{dm} + \mathbf{H}_{t-1} \mathbf{W}_{mm} + \mathbf{b}_m)$$

$\mathbf{W}_{dm}$, $\mathbf{W}_{mm}$: weight matices, $\mathbf{b}_m$: bias vector

$$\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{mq} + \mathbf{b}_q$$

$\mathbf{W}_{mq}$: weight matrix, $\mathbf{b}_q$: bias vector

![Rnn](rnn.png)


## Character prediction

Predict the next character based on the previous.

Batch size $n=1$, inputs sequence "machine" tokenized to *characters* (26-dimensional one-hot vecgtors).


![Rnn-character](rnn-train.png)

## Implementation

* organize **inputs** $\mathbf{X}$ as (number of sequence steps, batch size, vocabulary size) - easy to loop over sequence steps
* initiate shared **parameters** for hidden layer $\mathbf{W}_{dm}, \mathbf{W}_{mm}, \mathbf{b}_m$ and output $\mathbf{W}_{mq}, \mathbf{b}_q$
* initiate **hidden state** $H_0$ as zeros
* **rnn** - loop over steps
    * update hidden state $\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{dm} + \mathbf{H}_{t-1} \mathbf{W}_{mm} + \mathbf{b}_m)$
    * get output $\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{mq} + \mathbf{b}_q$
* **loss**: cross entropy averaged over all steps of sequence

## Few practical tricks

* **prefix** insert few initial tokens to *warm-up* hidden state
* **clip gradients** to avoid exploding gradient due to multiplications in backpropagation
$\mathbf{g} \leftarrow \min (1, \frac{\theta}{||\mathbf{g}||}\mathbf{g})$
* training over ordered batches - keep states from prevoius batch as $H_0$ but detach to avoid backpropagation
* training with suffled batches - initiate $H_0$ for each batch


In [1]:
# pytorch implementation
import torch
from torch import nn as nn
import torch.nn.functional as F

# rnn layer
hidden_dim = 256
vocab = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
input_dim = len(vocab)
rnn_layer = nn.RNN(input_dim, hidden_dim)

# init hidden state
rnn_layers = 1
batch_size = 32
state = torch.zeros((rnn_layers, batch_size, hidden_dim))
print(f'State: {state.shape}')

State: torch.Size([1, 32, 256])


In [2]:
# pass data through rnn layer

# get random X
num_steps = 10
X = torch.randint(len(vocab), size=(num_steps, batch_size))
X = F.one_hot(X).float()
print(f'X: {X.shape}')

# pass X through rnn
out, state_new = rnn_layer(X, state)  # out same as state, can be passed to another rnn layer or to output func
print(f'out: {out.shape}, state: {state.shape}')

X: torch.Size([10, 32, 8])
out: torch.Size([10, 32, 256]), state: torch.Size([1, 32, 256])


## Backpropagation through time

* Backpropagation as you know it through sequential model
* Expand computational graph of RNN through the sequence steps

### Unrolling computational graph

![Rnn-character](rnn-bptt.png)

## Backpropagation through time - math

Simplified notation:
$$h_t = f(x_t, h_{t-1}, w_h) \qquad o_t = g(h_t, w_o)$$

After running RNN forward we have:
$$[(x_1, h_1, o_1), (x_2, h_2, o_2), \ldots, (x_T, h_T, o_T)]$$

Objective function:
$$L(x_1, \ldots, x_T, y_1, \ldots, y_T, o_1, \ldots, o_T, w_h, w_o) = \frac{1}{T}\sum_{t=1}^{T} l(y_t, o_t)$$

Backpropagation via chain rule:
$$\frac{\partial L}{\partial w_h} = \frac{1}{T}\sum_{t=1}^{T} \frac{\partial l(y_t, o_t)}{\partial w_h}
= \frac{1}{T}\sum_{t=1}^{T} \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t}
\frac{\partial h_t}{\partial w_h}
$$

**However**:
$$\frac{\partial h_t}{\partial w_h} = \frac{\partial f(x_t, h_{t-1}, w_h)}{\partial w_h} + \frac{\partial f(x_t, h_{t-1}, w_h)}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial w_h} \\
\frac{\partial h_{t-1}}{\partial w_h} = \frac{\partial f(x_t, h_{t-2}, w_h)}{\partial w_h} + \frac{\partial f(x_t, h_{t-2}, w_h)}{\partial h_{t-2}}\frac{\partial h_{t-2}}{\partial w_h} \\
\ldots
$$

**Full recurrence**:
$$\frac{\partial h_t}{\partial w_h} = \frac{\partial f(x_t, h_{t-1}, w_h)}{\partial w_h} +
\sum_{i=1}^{t-1} \left( \prod_{j=i+1}^t \frac{\partial f(x_j, h_{j-1}, w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_i, h_{i-1}, w_h)}{\partial w_h}
$$




## Vanishing / exploding gradients

* products of many terms (many gradients): if terms $<1$ then product $\to 0$; if terms $>1$ then product $\to \infty$;

### Truncated backpropagation through time

* truncate backdward grad calculation to just a few steps (terminate with $\partial h_{t-\tau} / \partial w_h$)


# Modern RNN architectures (LSTM/GRU)

* **RNN major issue: numerical instability of bptt (vanishing / exploding gradients)**
* more sophisticated design

### Wish list
* important piece of info in the beginning of sequence:
    * a) large gradient to impact all future (effect on multiplication)
    * b) **memory cell** to store vital info for later

* some tokens carry no info
    * a) small gradient (effect on multiplication)
    * b) **skipping mechanism**
    
* logical break in sequence
    * a) prevent passing of gradient
    * b) **resetting mechanism**

## Gated Recurrent Units (GRU)

Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine
translation: encoder-decoder approaches. arXiv preprint arXiv:1409.1259

### 1) Reset and update gates

![Rnn-character](gru-1.png)

minibatch of inputs $\mathbf{X}_t \in \mathbb{R}^{n \times d}$, previous hidden state $\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}$, *reset gate* $\mathbf{R}_t \in \mathbb{R^{n \times h}}$, *update gate* $\mathbf{Z}_t \in \mathbb{R^{n \times h}}$ 

$$
\begin{aligned}
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{dh}^{(r)} + \mathbf{H}_{t-1} \mathbf{W}_{hh}^{(r)} + \mathbf{b}_h^{(r)}),\\
\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{dh}^{(z)} + \mathbf{H}_{t-1} \mathbf{W}_{hh}^{(z)} + \mathbf{b}_h^{(z)}),
\end{aligned}
$$

### sigmoid nonlinearity $\to \mathbf{R}_t, \mathbf{Z}_t \in (0, 1)$ !


### 2) Candidate hidden state

![Rnn-character](gru-2.png)

*candidate hidden state*
$\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}$ at time step $t$ - *effect of reset on hidden state $\mathbf{R}_t$*

$$\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{dh}^{(h)} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh}^{(h)} + \mathbf{b}_h^{(h)}),$$

$\odot$ - elementwise product

### tanh nonlinearity $\to \tilde{\mathbf{H}}_t \in (-1, 1)$!

### 3) Hidden state

![Rnn-character](gru-3.png)

*effect of update gate $\mathbf{Z}_t$* - what comes from the *old* $\mathbf{H}_t$ and what from the new *new candidate* $\tilde{\mathbf{H}}_t$ - elementwise convex combination of the two.

$$\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1}  + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.$$


* when update gate $\mathbf{Z}_t$ close to 1: retain old state $\mathbf{H}_t$ and ingnore new token $\mathbf{X}_t$
* when $\mathbf{Z}_t$ close to 0: new hidden $\mathbf{H}_t$ mainly uses candidate state $\tilde{\mathbf{H}}_t$

In [5]:
# pytorch imlementation - trivial :)

nn.GRU(input_dim, hidden_dim)

GRU(8, 256)

## Long Short-Term Memory (LSTM)

Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735–1780.

Much older then GRUs a quite a bit more complicated

**memory cell**: ouptut gate, input gate, foreget gate

![lstm](lstm-3.png)