In [70]:
import torch
import torch.nn as nn

In [71]:
class LSTMCell_v1(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # forget gate
        self.W_if = nn.Linear(input_size, hidden_size)
        self.W_hf = nn.Linear(hidden_size, hidden_size)

        # input gate
        self.W_ii = nn.Linear(input_size, hidden_size)
        self.W_hi = nn.Linear(hidden_size, hidden_size)
        self.W_ig = nn.Linear(input_size, hidden_size)
        self.W_hg = nn.Linear(hidden_size, hidden_size)

        # output gate
        self.W_io = nn.Linear(input_size, hidden_size)
        self.W_ho = nn.Linear(hidden_size, hidden_size)

    
    def forget_gate(self, x, h):
        sum = self.W_if(x) + self.W_hf(h)
        return torch.sigmoid(sum)
    
    def input_gate(self, x, h):
        sum_i = self.W_ii(x) + self.W_hi(h)
        
        return torch.sigmoid(sum_i)
    
    def gate_gate(self, x, h):
        sum_g = self.W_ig(x) + self.W_hg(h)
        return torch.tanh(sum_g)

    def output_gate(self, x, h):
        sum = self.W_io(x) + self.W_ho(h)
        return torch.sigmoid(sum)
    
    def forward(self, x, hx):
        """
        hx[0] = h0 (previous hidden state)
        hx[1] = c0 (previous cell state)
        """
        f = self.forget_gate(x, hx[0])

        i = self.input_gate(x, hx[0])

        g = self.gate_gate(x, hx[0])

        c = f * hx[1] + i * g

        o = self.output_gate(x, hx[0])

        h = o * torch.tanh(c)

        return h, c

In [72]:
cell = LSTMCell_v1(3, 2)

input = torch.tensor([1.0, 2.0, 3.0])

h0 = torch.zeros(2)
c0 = torch.zeros(2)

output = cell(input, (h0, c0))

output

(tensor([0.0026, 0.4531], grad_fn=<MulBackward0>),
 tensor([0.0057, 0.7365], grad_fn=<AddBackward0>))

# LSTM Cell Efficient Linear

In [73]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W_x = nn.Linear(input_size, 4 * hidden_size) # 4 is for forget, input, gate, and output gate
        self.W_h = nn.Linear(hidden_size, 4 * hidden_size)
    
    def forward(self, x, hx):
        assert x.shape[-1] == self.W_x.in_features, "Input size mismatch"
        assert hx[0].shape[-1] == self.W_x.out_features // 4, "Output size mismatch"

        gates = self.W_x(x) + self.W_h(hx[0])
        f, i, g, o = torch.chunk(gates, 4, dim=-1)
        f = torch.sigmoid(f)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        o = torch.sigmoid(o)

        c = f * hx[1] + i * g

        h = o * torch.tanh(c)

        return h, c

In [76]:
cell = LSTMCell(5, 3)

input = torch.tensor([1., 2., 3., 4., 5.,])

h0 = torch.zeros(3)
c0 = torch.zeros(3)

output = cell(input, (h0, c0))

output

(tensor([-0.1013,  0.2584,  0.3161], grad_fn=<MulBackward0>),
 tensor([-0.8177,  0.2755,  0.3938], grad_fn=<AddBackward0>))