# LSTM

* LSTM network rely on a gated cell to track information throughout many time steps.
* Maintain a separate cell state from what is outputted.
* Use gates to control the flow of information
    - **Forget gate** gets rid of irrelevant information.
    - **Store** relevant information from curren input.
    - Selectively **update** cell state.
    - **Output gate** returns a filtered version of the cell state
* Backpropogation through time with uninterrupted gradient flow

In [1]:
import torch
import torch.nn as nn

# Step by Step LSTM

- <font size=3>Step 1: Using forget deciding what information we need to forget and what needs to be kept</font><br>
<img src = "https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-f.png" height="720" width="480"><br>
<br>
- <font size=3>Step 2: Deciding what new information we will store<br></font>
<br>
<p align="left">
    <img src = "https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-i.png" height="520" width="440" align="left">
    <img src = "https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-C.png" height="520" width="440">
    <em>First input gate decides which value will be updated.&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;<em>Update the cell state</em><br> 
        Then, a tanh layer using which we scales the input gate<br> values by how much we decide to update each state value</em>
</p><br>

- <font size=3>Step 3: Finally, we calculate the output and new hidden state</font><br>
<img src = "https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-o.png" height="720" width="480"><br>

### LSTM equations:
Input vector:&emsp;&emsp;&emsp;$\large {X_t} $
<br><br>
Forget Gate: &ensp;&nbsp;&emsp;&emsp;$\large f_t = \sigma(W_{xf} X_t + W_{hf} H_{t-1}) $
<br><br>
Input Gate:&emsp;&emsp;&emsp;&emsp; $ \large i_t = \sigma(W_{xi} X_t + W_{hi} H_{t-1})$
<br><br>
Candidate: &emsp;&emsp;&emsp;&emsp; $ \large \tilde c_t = tanh(W_{xc} X_t + W_{hc} H_{t-1})$
<br><br>
Udate Cell state: &emsp;&emsp; $ \large c_t = f_t*c_{t-1} + i_t*\tilde c_t$
<br><br>
Output: &emsp;&emsp;&emsp;&emsp;&emsp;&emsp; $ \large o_t = \sigma(W_{xo} X_t + W_{ho} H_{t-1})$
<br><br>
Update Hidden state: &emsp; $ \large h_t = o_t * tanh(c_t)$

In [19]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.w_xf = nn.Linear(input_size, hidden_size)
        self.w_hf = nn.Linear(hidden_size, hidden_size)
        
        self.w_xi = nn.Linear(input_size, hidden_size)
        self.w_hi = nn.Linear(hidden_size, hidden_size)
        
        self.w_xc = nn.Linear(input_size, hidden_size)
        self.w_hc = nn.Linear(hidden_size, hidden_size)
        
        self.w_xo = nn.Linear(input_size, hidden_size)
        self.w_ho = nn.Linear(hidden_size, hidden_size)
        
    def init_hidden_cell(self):
        #initializing hidden state as a tensor of zeros
        # shape -> num_layers, batch_size, hidden_size
        hidden = torch.zeros((1,1,hidden_size), dtype=torch.float32)
        cell_state = torch.zeros((1,1,hidden_size), dtype=torch.float32)
        return hidden, cell_state
        
    def forward(self, x, hidden, cell_state):   
        forget_gate = torch.sigmoid(self.w_xf(x) + self.w_hf(hidden))
        
        input_gate = torch.sigmoid(self.w_xi(x) + self.w_hi(hidden))
        
        candidate = torch.tanh(self.w_xc(x) + self.w_hc(hidden))
        
        cell_state = torch.mul(forget_gate, cell_state) + torch.mul(input_gate, candidate) #pointwise multiplication
        
        output_gate = torch.sigmoid(self.w_xo(x) + self.w_ho(hidden))
        
        hidden_state = torch.mul(output_gate, torch.tanh(cell_state))
        
        return hidden_state, hidden_state[-1], cell_state[-1]

Reference: [colah's blog](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
[MIT Lecture](https://www.youtube.com/watch?v=SEnXr6v2ifU&ab_channel=AlexanderAmini)