# Long Short-Term Memory
In this exercise, we will implement an LSTM. In the class, we have already seen the definition of the LSTM update rules at time step $t$:

$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o) \\
\tilde{c}_t &= \tanh(W_c h_{t-1} + U_c x_t + b_c) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
h_t &= o_t * \tanh(c_t)
\end{align}
$$

In [2]:
import torch
import torch.nn as nn

Implement this original version of the LSTM as an `LSTMCell`.

In [14]:
class LSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # hidden_state weights
        self.Wf = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        self.Wi = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        self.Wo = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        self.Wc = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        
        # input weights
        self.Uf = nn.Parameter(torch.zeros(hidden_dim, input_dim))
        self.Ui = nn.Parameter(torch.zeros(hidden_dim, input_dim))
        self.Uo = nn.Parameter(torch.zeros(hidden_dim, input_dim))
        self.Uc = nn.Parameter(torch.zeros(hidden_dim, input_dim))
        
        # bias
        self.bf = nn.Parameter(torch.zeros(hidden_dim))
        self.bi = nn.Parameter(torch.zeros(hidden_dim))  
        self.bo = nn.Parameter(torch.zeros(hidden_dim))  
        self.bc = nn.Parameter(torch.zeros(hidden_dim))  
        
    def forward(self, x, hidden_state, cell_state):
        ft = torch.sigmoid(self.Wf @ hidden_state + self.Uf @ x + self.bf)
        it = torch.sigmoid(self.Wi @ hidden_state + self.Ui @ x + self.bi)
        ot = torch.sigmoid(self.Wo @ hidden_state + self.Uo @ x + self.bo)
        c_t = torch.tanh(self.Wc @ hidden_state + self.Uc @ x + self.bc)
        ct = ft * cell_state + it * c_t
        ht = ot * torch.tanh(ct)
        return ht, ct
    
    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -1, 1)

Create a 2-layer LSTM from your LSTMCell base class and run a forward pass with a random input sequence to test that all your dimensions are correct.

In [15]:
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.l1 = LSTMCell(input_dim, hidden_dim)
        self.l2 = LSTMCell(hidden_dim, hidden_dim)
        
    def forward(self, x, h1, c1, h2, c2):
        h1, c1 = self.l1(x, h1, c1)
        h2, c2 = self.l2(h1, h2, c2)
        return h1, c1, h2, c2
    
    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -1, 1)

In [16]:
input_dim = 10
hidden_dim = 20
seq_length = 5

lstm = LSTM(input_dim, hidden_dim)
lstm.reset_parameters()
h1 = torch.randn(hidden_dim)
h2 = torch.randn(hidden_dim)
c1 = torch.randn(hidden_dim)
c2 = torch.randn(hidden_dim)
x = torch.randn(seq_length, input_dim)

result_hidden = []
result_cell = []

for i in range(seq_length):
    h1, c1, h2, c2 = lstm(x[i], h1, c1, h2, c2)
    result_hidden.append((h1, h2))
    result_cell.append((c1, c2))
    
print(len(result_hidden))
print(result_hidden[0][0].shape)
print(len(result_cell))
print(result_cell[0][0].shape)
    

5
torch.Size([20])
5
torch.Size([20])


Implement a subclass of your LSTM that uses a coupled forget and input gate, i.e. the cell state update becomes:

$$c_t = f_t * c_{t-1} + (1-f_t) * \tilde{c}_t$$

In [18]:
class CoupledLSTMCell(LSTMCell):
    def forward(self, x, hidden_state, cell_state):
        ft = torch.sigmoid(self.Wf @ hidden_state + self.Uf @ x + self.bf)
        ot = torch.sigmoid(self.Wo @ hidden_state + self.Uo @ x + self.bo)
        c_t = torch.tanh(self.Wc @ hidden_state + self.Uc @ x + self.bc)
        ct = ft * cell_state + (1 - ft) * c_t
        ht = ot * torch.tanh(ct)
        return ht, ct
    

In [19]:
class CoupledLSTM(LSTM):
    def __init__(self, input_dim, hidden_dim):
        nn.Module.__init__(self)
        self.l1 = CoupledLSTMCell(input_dim, hidden_dim)
        self.l2 = CoupledLSTMCell(hidden_dim, hidden_dim)
    

In [20]:
lstm = CoupledLSTM(input_dim, hidden_dim)
lstm.reset_parameters()
h1 = torch.randn(hidden_dim)
h2 = torch.randn(hidden_dim)
c1 = torch.randn(hidden_dim)
c2 = torch.randn(hidden_dim)
x = torch.randn(seq_length, input_dim)

result_hidden = []
result_cell = []

for i in range(seq_length):
    h1, c1, h2, c2 = lstm(x[i], h1, c1, h2, c2)
    result_hidden.append((h1, h2))
    result_cell.append((c1, c2))
    
print(len(result_hidden))
print(result_hidden[0][0].shape)
print(len(result_cell))
print(result_cell[0][0].shape)

5
torch.Size([20])
5
torch.Size([20])


**Bonus:** Implement *peephole connections* as described at the start of the Section *Variants on Long Short Term Memory* in [this blog post explaining LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).

The gate update definitions get an additional term that looks at the cell state:
$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f \boldsymbol{+ V_f c_{t-1}}) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i \boldsymbol{+ V_i c_{t-1}}) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_t})
\end{align}
$$

To make the task a bit easier, we will implement the last equation with the cell state of the previous time step $t-1$ as $$o_t = \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_{t-1}})$$ instead.

In [21]:
class PeepholeLSTMCell(LSTMCell):
    def __init__(self, input_dim, hidden_dim):
        super().__init__(input_dim, hidden_dim)
        # cell_state weights
        self.Vf = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        self.Vi = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        self.Vo = nn.Parameter(torch.zeros(hidden_dim, hidden_dim))
        
    def forward(self, x, hidden_state, cell_state):
        ft = torch.sigmoid(self.Wf @ hidden_state + self.Uf @ x + self.bf + self.Vf @ cell_state)
        it = torch.sigmoid(self.Wi @ hidden_state + self.Ui @ x + self.bi + self.Vi @ cell_state)
        c_t = torch.tanh(self.Wc @ hidden_state + self.Uc @ x + self.bc)
        ct = ft * cell_state + it * c_t
        ot = torch.sigmoid(self.Wo @ hidden_state + self.Uo @ x + self.bo + self.Vo @ cell_state)
        ht = ot * torch.tanh(ct)
        return ht, ct

In [22]:
class PeepholeLSTM(LSTM):
    def __init__(self, input_dim, hidden_dim):
        nn.Module.__init__(self)
        self.l1 = PeepholeLSTMCell(input_dim, hidden_dim)
        self.l2 = PeepholeLSTMCell(hidden_dim, hidden_dim)

In [23]:
lstm = PeepholeLSTM(input_dim, hidden_dim)
lstm.reset_parameters()
h1 = torch.randn(hidden_dim)
h2 = torch.randn(hidden_dim)
c1 = torch.randn(hidden_dim)
c2 = torch.randn(hidden_dim)
x = torch.randn(seq_length, input_dim)

result_hidden = []
result_cell = []

for i in range(seq_length):
    h1, c1, h2, c2 = lstm(x[i], h1, c1, h2, c2)
    result_hidden.append((h1, h2))
    result_cell.append((c1, c2))
    
print(len(result_hidden))
print(result_hidden[0][0].shape)
print(len(result_cell))
print(result_cell[0][0].shape)

5
torch.Size([20])
5
torch.Size([20])
