In [None]:
import math
import torch
import torch.nn as nn


class Cust_LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()

        self.input_size  = input_size
        self.hidden_size = hidden_size

        #here were implementing the final equation that are present in the lstm

        #f_t
        self.U_f = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))
        self.V_f = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
        self.b_f = nn.Parameter(torch.Tensor(self.hidden_size))

        #i_t
        self.U_i = nn.Parameter(torch.Tensor(self.input_size, hidden_size))
        self.V_i = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
        self.b_i = nn.Parameter(torch.Tensor(self.hidden_size))


        #o_t
        self.U_o = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))
        self.V_o = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
        self.b_o = nn.Parameter(torch.Tensor(self.hidden_size)) 


        #g_t
        self.U_g = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))
        self.V_g = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
        self.b_g = nn.Parameter(torch.Tensor(self.hidden_size))
    

    # Following function will helps us in initializing the weight. we've used the same as the one in Pytorch default

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)

        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)


    
    def forward(self, x, init_states=None):
        """
        assumes x.shape represents (batch_size, sequence_size, input_size)
        """

        bs, seq_size, _ = x.size()
        hidden_seq = []

        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_size).to(x.device),
                torch.zeros(bs, self.hidden_size).to(x.device),
            )
        
        else:
            h_t, c_t = init_states

        
        for t in range(seq_size):
            x_t = x[:, t, :]

            f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
            i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
            o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
            g_t = torch.tanh(x_t @ self.U_g + h_t @ self.V_g + self.b_g)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)

            hidden_seq.append(h_t.unsqueeze(0))
        

        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        
        return hidden_seq, (h_t, c_t)
    





