# A simple implementation of LSTM from scratch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import lightning as L
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# A simple LSTM cell
from IPython import display
display.Image("lstm.png")

: 

In [None]:
#Define the LSTM NN class

class tiny_lstm(L.LightningModule):

    def __init__(self):

        super().__init__()
        #create and initialize the weights and biases using a normal distribution
        mean = torch.tensor(0.0)
        std = torch.tensor(1.0)

        #percent of long term mem to remember
        self.h_f = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.iw_f = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.b_f= nn.Parameter(torch.tensor(0.), requires_grad=True)

        #percent of potential memory to remember
        self.h_pr = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.iw_pr= nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.b_pr = nn.Parameter(torch.tensor(0.), requires_grad=True)

        #potential memory
        self.h_p = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.iw_p = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.b_p = nn.Parameter(torch.tensor(0.), requires_grad=True)
        
        #output
        self.h_o = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.iw_o = nn.Parameter(torch.normal(mean = mean, std =std), requires_grad=True)
        self.b_o = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def lstm_units(self, input_value, longmem_value, shortmem_value):

        long_rem_percent = torch.sigmoid((input_value*self.iw_f)+ (shortmem_value*self.h_f) + self.b_f)
        potential_rem_percent = torch.sigmoid((input_value*self.iw_pr)+(shortmem_value*self.h_pr)+self.b_pr)
        potential_mem = torch.tanh((input_value*self.iw_p)+ (shortmem_value*self.h_p) + self.b_p)

        update_long_mem = (longmem_value * long_rem_percent) + (potential_rem_percent * potential_mem)
        output_rem_percent =  torch.sigmoid((input_value*self.iw_o)+(shortmem_value*self.h_o)+self.b_o)

        update_short_mem = torch.tanh(update_long_mem) * output_rem_percent

        return ([update_long_mem, update_short_mem])

    def forward(self, input):
        long_mem = 0
        short_mem = 0
        for i in range(len(input)):
            long_mem, short_mem = self.lstm_units(input[i], long_mem, short_mem)
        return short_mem #as the output
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr = 0.001)
    
    def training_step(self, batch, batch_idx):
        #batch contains the different data from the two companies
        input_i, label_i = batch
        output_i = self.forward(input_i[0])
        loss = (output_i - label_i)**2
        self.log("train loss", loss)

        if (label_i == 0):
            self.log("out_0", output_i) 
        else: 
            self.log("out_1", output_i)
        return loss
                     
    # To plot the logs in Tensorboard, type tensorboard --logdir=lightning_logs/ in the directory of the log file created by lightning

In [None]:
#Train the LSTM
model = tiny_lstm()

inputs = torch.tensor([[0., 5, .25, 1.], [1., .5, .25, 1.]])
labels = torch.tensor([0., 1.])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

trainer = L.Trainer(max_epochs=2000)
trainer.fit(model, train_dataloaders = dataloader)


In [None]:
""" # why we don't need to implicity call the forward method
class Module:
    def __init__(self):
        pass
    def __call__(self, data):
        self.forward(data)

    def forward(self, data):
        print("forward function, data =", data)

net = Module()
net([1,2,3])
# forward function, data = [1, 2, 3]
Now that we have out Module class, let's create another Net class that inherits from it

# Net inherits from Module
class Net(Module):
    def __init__(self):
        super(Net, self).__init__()

    def forward(self, data):
        print("Net.forward, data =", data)

net = Net()
net([1,2,3,4])
# Net.forward, data = [1, 2, 3, 4] """

: 

In [None]:
#Using Pytorch's LSTM module
class tiny_lstm_pytorch(L.LightningModule):
    
    def __init__(self):
        super(tiny_lstm_pytorch, self).__init()
        self.lstm = nn.LSTM(input_size =1, hidden_size=1)

    def forward(self, input):
        input_trans = input.view(len(input), 1)
        lstm_out, temp = self.lstm(input_trans)
        prediction = lstm_out[-1]
        return prediction