# Chapter 10: Modern Recurrent Neural Networks

### Dhuvi karthikeyan

03/01/2023

* Long-short term memory (LSTM) architecture was first produced in 1997 and has been one of the two most important development in sequence modeling.
* Bidirectional recurrent neural networks also introduced in 1997 are the result of the second development which resulted in a huge boost in performance. 

## 10.1 Long Short-Term Memory (LSTM)

While gradient clipping helped control the issue of exploding gradients, the issue of vanishing gradients and the modeling of long-term dependencies in the sequence learning was addressed with a more complicated architectural variation. 

### 10.1.1 Gated Memory Cells

Core to the advancement is the memory cell which replaces the standard Recurrent node. The recurrent node being the unit of computation that takes in a hidden state from time step t-1 and input token at time t and passes the output and hidden state. The RNN has parameter sharing where the parameters are the same at every recurrent cell. 

Memory Cells consist of:
* Internal state C_t that handles information flow by allowing 1 of three things to happend to the input:
    1. Input gate: Allow input to enter internal state
    2. Forget gate: Reduce internal state to 0
    3. Output gate: Allow internal state to affect output


**Note:** Gating the hidden state allows us to explicitly allow for mechanisms to decide when the hidden state should be updated, remain the same, or reset.

#### Input Gate, Forget Gate, Output Gate:

With the input token and the hidden state of the previous timestep, the learnable Forget gate, input gate and output gate are all learnable. This means that theres associated weights and biases with this part of the model. In practice each of these gates is its own MLP with sigmoid activations. 

* Input node: a tanh activation function that learns how much of the inputs to incldue in the internal state.

$$ I_t = \sigma (X_tW_{xi} + H_{t-1}W_{hi} + b_i)$$
$$ F_t = \sigma (X_tW_{xf} + H_{t-1}W_{hf} + b_f)$$
$$ O_t = \sigma (X_tW_{xo} + H_{t-1}W_{ho} + b_o)$$

These gates all have dimensionality of $\mathbb{R}^{n x h}$ for batch size n and hidden dimension h. The input weight matrices are all $\in \mathbb{R}^{d x h}$ whereas the hidden state matrices are all $\in \mathbb{R}^{h x h}$. The sigmoid activation is used to squash values to 0,1. 

#### Input Node

The input node is defined as **C** and can be thought of as the context with a tanh activation function mapping values to (-1, 1) . 

$$ \hat{C_t} = tanh(X_t W_{xc} + H_{t-1}W_{hc} + b_c)$$

At this point the hidden state and input are passed as input to the input node, the forget gate, input gate, and output gate. Each of which computes its output through a different activation function but similar method of adding the input and hidden states in h dimensional tensors. The input node is designed to scale and flip the sign of the coming input. 

#### Memory Cell Internal State

$$ C_t = F_t \odot C_{t-1} + I_t \odot \hat{C_t}$$

Input gate $I_t$ dictates how much of the current input token to allow into the internal state and the forget gate controls how much of the previous internal state to use in computation of the current internal state. 

#### Hidden State

The output of the memory cell are the internal state of the memory cell along with the hidden state of the sequence. 

$$ H_t = O_t \odot tanh(C_t) $$

The tanh nonlinearity in conjunction wiht the sigmoid applied on the output gate ensures the values of H_t are between (-1,1). This flow of information results in the memory cell being able to selectively affect the network with the output gate being close to zero for much of the early timesteps and then flipping the switch to go from zero to one and include more of the internal state in the hidden state.

### 10.1.4 Summary

LSTMS are the archetypal latent variable autoregressive model with nontrivial state control. Although they do have costly training due to long range dependency. Proposed in 1997 but took off in the 2000s, becoming the de-facto sequence model for much of 2010s until transformers. While the internal state and hidden state are used by hidden layers, only the latter is passed to the output layers.

In [112]:
import torch.nn as nn
import torch
import torch.functional as F
import numpy as np


class LSTMcell(nn.Module):
    """
    Implement the LSTM cell from "scratch".
    s/o to Torch community for handling the 
    heavy lifting.
    """
    def __init__(self, input_dim, hidden_dim, sigma=.01):
        super(LSTMcell, self).__init__()
        # Save H-params
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sigma = sigma
        # Initialize Net Parameters
        self.W_xi, self.W_hi, self.b_i = self.gate_init()
        self.W_xf, self.W_hf, self.b_f = self.gate_init()
        self.W_xo, self.W_ho, self.b_o = self.gate_init()
        self.W_xc, self.W_hc, self.b_c = self.gate_init()
    
    def gate_init(self):
        gate_params =  (self.weight_init(self.input_dim, self.hidden_dim),
                self.weight_init(self.hidden_dim, self.hidden_dim),
                self.bias_init())
        return gate_params
    
    def weight_init(self, *shape):
        return nn.Parameter(torch.randn(shape)*self.sigma)
    
    def bias_init(self):
        return nn.Parameter(torch.zeros(self.hidden_dim))
    
    def init_hidden(self):
        return (torch.zeros(self.hidden_dim), torch.zeros(self.hidden_dim))
    
    def forward(self, x_t, state):
        h_t, c_t = state
        
        input_gate = torch.sigmoid(torch.matmul(x_t, self.W_xi)
                                  + torch.matmul(h_t, self.W_hi)
                                  + self.b_i)
        forget_gate = torch.sigmoid(torch.matmul(x_t, self.W_xf)
                                  + torch.matmul(h_t, self.W_hf)
                                  + self.b_f)
        output_gate = torch.sigmoid(torch.matmul(x_t, self.W_xo)
                                  + torch.matmul(h_t, self.W_ho)
                                  + self.b_o)
        input_node = torch.sigmoid(torch.matmul(x_t, self.W_xc)
                                  + torch.matmul(h_t, self.W_hc)
                                  + self.b_c)

        c_t = forget_gate * c_t + input_gate * input_node
        h_t = output_gate * torch.tanh(c_t)
        
        return (h_t, c_t)

In [129]:
class LSTM(nn.Module):
    """
    Implement a Deep LSTM that can take an arbitrary number of LSTM cells.
    """
    def __init__(self, input_dim, hidden_dim, m_layers):
        super(LSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.m_layers = m_layers
        self.layers = [LSTMcell(input_dim, hidden_dim)] + [LSTMcell(hidden_dim, hidden_dim)]*(m_layers-1)
        
        
    def forward(self, X):
        outs = []
        # Initialize the hidden states as a list of tuples
        hiddens = [(torch.zeros((X.shape[1], self.hidden_dim)), torch.zeros((X.shape[1], self.hidden_dim)))]*self.m_layers 
        
        # Perform recurrent pass
        for i, x_t in enumerate(X):
            for j, layer in enumerate(self.layers):
                h_t, c_t = layer(x_t, hiddens[j])
                # Update the hidden states for that layer @this timestep
                hiddens[j] = (h_t, c_t)
                # Update the input to the next layer @this timestep
                x_t = h_t
            # At the end of layer append the hidden state (output)
            outs.append(h_t)
        return torch.stack(outs, dim=0), (h_t, c_t)

In [108]:
# Test the LSTMCell
rnn = LSTMcell(10, 20) # (input_size, hidden_size)
inputs = torch.randn(5, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []

for i in range(inputs.size()[0]):
    hx, cx = rnn(inputs[i], (hx,cx))
    output.append(hx)
output = torch.stack(output, dim=0)



In [130]:
# Test the arbitrary LSTM
inputs = torch.randn(5, 3, 8)       
lm = LSTM(8, 16, 2)
lm(inputs)


(tensor([[[0.1214, 0.1232, 0.1220, 0.1228, 0.1231, 0.1231, 0.1215, 0.1219,
           0.1231, 0.1223, 0.1217, 0.1222, 0.1242, 0.1228, 0.1227, 0.1214],
          [0.1214, 0.1232, 0.1220, 0.1228, 0.1231, 0.1230, 0.1215, 0.1219,
           0.1231, 0.1223, 0.1218, 0.1223, 0.1242, 0.1228, 0.1227, 0.1214],
          [0.1214, 0.1232, 0.1220, 0.1228, 0.1231, 0.1230, 0.1215, 0.1219,
           0.1231, 0.1223, 0.1217, 0.1223, 0.1242, 0.1228, 0.1227, 0.1214]],
 
         [[0.1760, 0.1813, 0.1779, 0.1802, 0.1813, 0.1809, 0.1785, 0.1779,
           0.1792, 0.1782, 0.1774, 0.1785, 0.1815, 0.1793, 0.1798, 0.1772],
          [0.1760, 0.1813, 0.1780, 0.1802, 0.1813, 0.1809, 0.1786, 0.1779,
           0.1792, 0.1783, 0.1774, 0.1785, 0.1815, 0.1793, 0.1797, 0.1773],
          [0.1760, 0.1813, 0.1779, 0.1802, 0.1813, 0.1809, 0.1785, 0.1779,
           0.1792, 0.1783, 0.1774, 0.1785, 0.1815, 0.1793, 0.1798, 0.1772]],
 
         [[0.2008, 0.2089, 0.2040, 0.2073, 0.2091, 0.2086, 0.2052, 0.2038,
           0.

## 10.2 Gated Recurrence Units (GRU)

Proposed in (2014) the GRU was a simplification on the LSTM designed to speed up computation. 

### 10.2.1 Reset and Update Gates

3-1=2 gates compared to the LSTM. Each of the gates here have the same sigmoid function which restricts their values to (0,1) corresponding to how much the previous hidden state matters.

$$ R_t = \sigma (X_tW_{xr} + H_{t-1}W_{hr} + b_r)$$
$$ Z_t = \sigma (X_tW_{xz} + H_{t-1}W_{hz} + b_z)$$

### 10.2.2 Candidate Hidden State

$$ \hat{H_t} = tanh(X_tW_{xh} + (R_t \odot H_{t-1})W_{hh} + b_h) $$ 

This candidate hidden state is generated using just tihe information from the previous hidden state (scaled based on the reset gate) along with the input. The last part is to combine it wiht the update gate to get the actual hidden state of the GRU cell.

### 10.2.3 Hidden State

$$ H_t = Z_t \odot H_{t-1} + (1 - Z_t) \odot \hat{H_t} $$

We can see here that the elementwise convex combination of the previous hidden state and the current candidate hidden state via the z_t and 1-z_t update. Learning the fraction z_t is what the network must learn.

Resets capture the short term dependencies while updates capture the long term dependencies.


In [136]:
class GRUcell(nn.Module):
    """
    Implement the GRU cell from "scratch".
    s/o to Torch community for handling the 
    heavy lifting.
    """
    def __init__(self, input_dim, hidden_dim, sigma=.01):
        super(GRUcell, self).__init__()
        # Save H-params
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sigma = sigma
        # Initialize Net Parameters
        self.W_xr, self.W_hr, self.b_r = self.gate_init()
        self.W_xz, self.W_hz, self.b_z = self.gate_init()
        self.W_xh, self.W_hh, self.b_h = self.gate_init()
        
    def gate_init(self):
        gate_params =  (self.weight_init(self.input_dim, self.hidden_dim),
                self.weight_init(self.hidden_dim, self.hidden_dim),
                self.bias_init())
        return gate_params
    
    def weight_init(self, *shape):
        return nn.Parameter(torch.randn(shape)*self.sigma)
    
    def bias_init(self):
        return nn.Parameter(torch.zeros(self.hidden_dim))
    
    def init_hidden(self):
        return (torch.zeros(self.hidden_dim), torch.zeros(self.hidden_dim))
    
    def forward(self, x_t, h_t):
        
        reset_gate = torch.sigmoid(torch.matmul(x_t, self.W_xr)
                                  + torch.matmul(h_t, self.W_hr)
                                  + self.b_r)
        update_gate = torch.sigmoid(torch.matmul(x_t, self.W_xz)
                                  + torch.matmul(h_t, self.W_hz)
                                  + self.b_z)
        candidate_h = torch.tanh(torch.matmul(x_t, self.W_xh)
                                  + torch.matmul((update_gate * h_t), self.W_hh)
                                  + self.b_h)
        
        h_t = update_gate * h_t + (1-update_gate) * candidate_h
        
        return h_t

In [137]:
class GRU(nn.Module):
    """
    Implement a Deep GRU that can take an arbitrary number of GRU cells.
    """
    def __init__(self, input_dim, hidden_dim, m_layers):
        super(GRU, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.m_layers = m_layers
        self.layers = [GRUcell(input_dim, hidden_dim)] + [GRUcell(hidden_dim, hidden_dim)]*(m_layers-1)
        
        
    def forward(self, X):
        outs = []
        # Initialize the hidden states as a list of tuples
        hiddens = [torch.zeros((X.shape[1], self.hidden_dim))]*self.m_layers 
        
        # Perform recurrent pass
        for i, x_t in enumerate(X):
            for j, layer in enumerate(self.layers):
                h_t = layer(x_t, hiddens[j])
                # Update the hidden states for that layer @this timestep
                hiddens[j] = h_t
                # Update the input to the next layer @this timestep
                x_t = h_t
            # At the end of layer append the hidden state (output)
            outs.append(h_t)
        return torch.stack(outs, dim=0), h_t

In [138]:
# Test the arbitrary LSTM
inputs = torch.randn(5, 3, 8)       
lm = GRU(8, 16, 2)
lm(inputs)


(tensor([[[-6.1351e-04,  2.1541e-04, -1.8708e-04,  8.3099e-05, -2.9001e-04,
            2.7070e-04,  3.4477e-04,  1.5560e-04,  3.8333e-05,  4.6603e-04,
            6.3583e-04,  1.1855e-04, -8.6820e-05, -2.6103e-04, -4.8864e-04,
           -1.1058e-05],
          [-4.9974e-04,  1.8529e-04, -2.4628e-04,  3.4985e-04, -2.5507e-04,
            2.9341e-04,  6.7379e-05, -9.8408e-05,  1.6230e-04,  5.6795e-04,
            8.8003e-04, -1.1572e-04, -5.0920e-05, -2.6074e-04, -3.8427e-04,
           -9.8128e-05],
          [-1.1470e-04,  1.2553e-04, -4.1958e-05,  1.7969e-04, -1.4367e-04,
            1.0470e-04, -3.2690e-05,  6.4227e-05,  1.1511e-04,  1.7195e-04,
            2.5168e-04, -8.4800e-05, -8.9615e-05, -2.3025e-04,  3.6465e-05,
           -1.3526e-04]],
 
         [[-3.9307e-04,  1.8178e-04, -6.8292e-04,  4.4217e-04, -1.9730e-04,
           -2.7677e-05,  3.2114e-04,  8.6098e-04,  3.2028e-04,  1.4633e-04,
            1.6423e-04,  1.5344e-04, -4.0722e-04, -7.2992e-04, -1.7060e-04,
          

## 10.3 Deep Recurrent Neural Networks

Depth in recurrent neural networks is often thought of how long the unrolled RNN is O(seq_length). However, depth can also be added in terms of how long the chain from input to output is per timestep. In the latter case, the linear RNN becomes like a matrix and the hidden state from the current timestep gets passed to the next timestep as well as up to the next layer as input. This propagates through the depth of the layers at single timestep and the hidden states are aggregated before passing the vector of hidden states to the next timestep.


$$ H_t^{(l)} = \phi_l (H_t^{(l-1)}W_xh^{(l)} + H_t^{(l)}W_hh^{(l)} + b_h^{(l)})$$

Where $H_t^{(0)}$ is $X_t$.

## 10.4 Bidirectional Recurrent Neural Networks

To address the problem of unidirectionality and the limitations it poses, in 1997, the bidirectional RNN was born as an intuitive means of alleviating the aforementioned issue. The implementation of the bi-RNN is simply the combination of two uni-directional RNNs where the first RNN reads the sequence in the forwards direction and the second processes the sequence in the reverse direction. This allows for both contexts to be embedded in the final hidden representation of each NN. Instead of reducing the tensor outputs from each timestep the bi-RNN concates them which ensures that the learned vectors do not destructively interfere with each other. 


In [157]:
class BiDirectional(nn.Module):
    """
    Take any implementation of a deep RNN and return the concatenated outputs
    """
    def __init__(self, RNN, input_dim, hidden_dim, n_layers):
        super(BiDirectional, self).__init__()
        self.fc = RNN(input_dim, hidden_dim, n_layers)
        self.rc = RNN(input_dim, hidden_dim, n_layers)
        self.hidden_dim = hidden_dim
        
        
    def forward(self, X):
        forward_outs, H_f = self.fc(X)
        reverse_outs, H_r = self.rc(reversed(X))
        concat_outs = torch.cat((forward_outs, reversed(reverse_outs)), dim=2)
        return concat_outs, (H_f, H_r)

In [159]:
#biLSTM = BiDirectional(LSTM, 8, 16, 10)
#biLSTM(inputs)[0].shape

biGRU = BiDirectional(GRU, 8, 16, 10)
biGRU(inputs)[0].shape

torch.Size([5, 3, 32])

lm(inputs)

In [145]:
torch.cat((lm(inputs)[0], lm(inputs)[0]), dim=2).shape

torch.Size([5, 3, 32])

Dynamic algorithms to get the probilistic sequence:\
    * almost always get EOS\
    * are smaller TCR sequences more likely?\
    * what is the probability of TCR generation\
    * top-k