In [83]:
import torch
import torch.nn as nn
import numpy as np
torch.manual_seed(1)

# We have 5 items with length of 3, batchsize=1
x = torch.randn(1, 5, 3)
rnn = nn.LSTM(3, 3, 1, batch_first=True)

w_ih, w_hh, b_ih, b_hh = list(rnn.parameters())
#print(w_ih)  # W_ii, W_if, W_ig, W_io
#print(w_hh)  # W_hi, W_hf, W_hg, W_ho
#print(b_ih)  # b_ii, b_if, b_ig, b_io
#print(b_hh)  # b_hi, b_hf, b_hg, b_ho
out, (hn, cn) = rnn(x)
print(out)
print(hn)
print(cn)

tensor([[[ 0.0586,  0.0328,  0.1944],
         [ 0.1246, -0.0483,  0.3200],
         [ 0.0925, -0.0567,  0.2099],
         [ 0.2293, -0.0622,  0.3411],
         [ 0.2669, -0.0550,  0.4358]]], grad_fn=<TransposeBackward0>)
tensor([[[ 0.2669, -0.0550,  0.4358]]], grad_fn=<StackBackward0>)
tensor([[[ 0.7132, -0.2131,  0.9159]]], grad_fn=<StackBackward0>)


In [85]:
import leaf

# forget gate
# f = sigmoid(Uf*x + Vf*h + bf)

# input gate
# i = sigmoid(Ui*x + Vi*h + bi)

# output gate
# o = sigmoid(Uo*x + Vo*h + bo)

# candidate C
# g = tanh(Ug*x + Vg*h + bg)

# c_t = f * c_t-1 + i * g

# or optimized
# A = U*x + V*h + b
# 
# https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(MyLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.U = nn.Parameter(w_ih.detach())
        self.V = nn.Parameter(w_hh.detach())
        self.bih = nn.Parameter(b_ih.detach())
        self.bhh = nn.Parameter(b_hh.detach())
    
    def forward(self, x):
        bs, seq_size, _ = x.size()
        hidden_seq = []
        
        h_t, c_t = (torch.zeros(bs, self.hidden_size),
                        torch.zeros(bs, self.hidden_size))
        
        HS = self.hidden_size
        for t in range(seq_size):
            x_t = x[:, t, :]
            gates = x_t @ self.U.T + h_t @ self.V.T + self.bih + self.bhh
            
            i_t, f_t, g_t, o_t = torch.split(gates, 3, dim=1)
            i_t = torch.sigmoid(i_t)
            f_t = torch.sigmoid(f_t)
            g_t = torch.tanh(g_t)
            o_t = torch.sigmoid(o_t)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        return hidden_seq, (h_t, c_t)

model = MyLSTM(3, 3)
model(x)
            
                

(tensor([[[[ 0.0586,  0.0328,  0.1944]]],
 
 
         [[[ 0.1246, -0.0483,  0.3200]]],
 
 
         [[[ 0.0925, -0.0567,  0.2099]]],
 
 
         [[[ 0.2293, -0.0622,  0.3411]]],
 
 
         [[[ 0.2669, -0.0550,  0.4358]]]], grad_fn=<StackBackward0>),
 (tensor([[ 0.2669, -0.0550,  0.4358]], grad_fn=<MulBackward0>),
  tensor([[ 0.7132, -0.2131,  0.9159]], grad_fn=<AddBackward0>)))