In [None]:
import torch
from torch import nn


class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.Wf = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wi = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wc = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wo = nn.Linear(input_size + hidden_size, hidden_size)
        
    def forward(self, input, init_states):
        h_o, c_o = init_states
        
        combined = torch.cat((input, h_o), 1)
        f_gate = torch.sigmoid(self.Wf(combined))
        i_gate = torch.sigmoid(self.Wi(combined))
        c_tilda = torch.tanh(self.Wc(combined))    
        c_n = f_gate*c_o + i_gate*c_tilda
        o_n = torch.sigmoid(self.Wo(combined))
        h_n = o_n * torch.tanh(c_n)
        return h_n, c_n
    
    
class LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = LSTMCell(input_size, output_size)
        self.linear = nn.Linear(hidden_size, output_size)
        
        
    def forward(self, x):
        hidden = torch.zeros(x.size(0), self.hidden_size)
        cell_state = torch.zeros(x.size(0), self.hidden_size)
        
        for i in range(x.size(1)):
            input = x[:, i]
            init_state = (hidden, cell_state)
            hidden, cell_state = self.cell(input, init_state)
            
        output = self.linear(hidden)        
        return output
        
        