In [1]:
import torch
import torch.nn as nn

In [2]:
class LSTMCell_v1(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # forget gate
        self.W_if = nn.Linear(input_size, hidden_size)
        self.W_hf = nn.Linear(hidden_size, hidden_size)

        # input gate
        self.W_ii = nn.Linear(input_size, hidden_size)
        self.W_hi = nn.Linear(hidden_size, hidden_size)
        self.W_ig = nn.Linear(input_size, hidden_size)
        self.W_hg = nn.Linear(hidden_size, hidden_size)

        # output gate
        self.W_io = nn.Linear(input_size, hidden_size)
        self.W_ho = nn.Linear(hidden_size, hidden_size)

    
    def forget_gate(self, x, h):
        sum = self.W_if(x) + self.W_hf(h)
        return torch.sigmoid(sum)
    
    def input_gate(self, x, h):
        sum_i = self.W_ii(x) + self.W_hi(h)
        
        return torch.sigmoid(sum_i)
    
    def gate_gate(self, x, h):
        sum_g = self.W_ig(x) + self.W_hg(h)
        return torch.tanh(sum_g)

    def output_gate(self, x, h):
        sum = self.W_io(x) + self.W_ho(h)
        return torch.sigmoid(sum)
    
    def forward(self, x, hx):
        """
        hx[0] = h0 (previous hidden state)
        hx[1] = c0 (previous cell state)
        """
        f = self.forget_gate(x, hx[0])

        i = self.input_gate(x, hx[0])

        g = self.gate_gate(x, hx[0])

        c = f * hx[1] + i * g

        o = self.output_gate(x, hx[0])

        h = o * torch.tanh(c)

        return h, c

In [3]:
cell = LSTMCell_v1(3, 2)

input = torch.tensor([1.0, 2.0, 3.0])

h0 = torch.zeros(2)
c0 = torch.zeros(2)

output = cell(input, (h0, c0))

output

(tensor([-0.0136,  0.0510], grad_fn=<MulBackward0>),
 tensor([-0.0877,  0.0929], grad_fn=<AddBackward0>))

# LSTM Cell Efficient Linear

In [4]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W_x = nn.Linear(input_size, 4 * hidden_size) # 4 is for forget, input, gate, and output gate
        self.W_h = nn.Linear(hidden_size, 4 * hidden_size)
    
    def forward(self, x, hx):
        assert x.shape[-1] == self.W_x.in_features, "Input size mismatch"
        assert hx[0].shape[-1] == self.W_x.out_features // 4, "Output size mismatch"

        gates = self.W_x(x) + self.W_h(hx[0])
        f, i, g, o = torch.chunk(gates, 4, dim=-1)
        f = torch.sigmoid(f)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        o = torch.sigmoid(o)

        c = f * hx[1] + i * g

        h = o * torch.tanh(c)

        return h, c

In [5]:
cell = LSTMCell(5, 3)

input = torch.tensor([1., 2., 3., 4., 5.,])

h0 = torch.zeros(3)
c0 = torch.zeros(3)

output = cell(input, (h0, c0))

output

(tensor([0.1761, 0.0148, 0.4325], grad_fn=<MulBackward0>),
 tensor([0.7827, 0.1725, 0.5500], grad_fn=<AddBackward0>))

In [6]:
input = torch.randn(3, 4, 5)
input

tensor([[[ 0.1195, -0.2427, -0.9611,  0.5320, -2.3102],
         [ 0.4877,  1.0431,  0.8586, -0.3078,  2.6141],
         [ 0.2589, -0.9733, -1.6025,  1.4277,  0.6385],
         [-0.7495,  0.5077, -0.1007, -0.6688,  0.1554]],

        [[ 1.0665,  0.2208,  1.8834, -1.1158, -0.2876],
         [ 1.5418, -1.6930, -1.5466, -0.6099, -0.0300],
         [ 0.0327, -0.7160,  1.0028,  0.9985,  0.7949],
         [-2.0103,  0.2070, -0.9969, -1.3104, -0.7697]],

        [[-0.2051,  0.0206, -0.3492,  0.2328,  1.5664],
         [ 0.5154,  0.4913, -0.6668, -0.8935, -0.5981],
         [ 0.7033, -0.9606,  1.4943,  0.4578, -0.7708],
         [ 0.6048, -0.2460,  1.2659,  0.9342, -0.8268]]])

In [7]:
cell = LSTMCell(5, 3)

# 3 batch, 4 input_size/sequence_length, 5 emb_size
# input = torch.tensor([[1., 2., 3., 4., 5.,], [1., 2., 3., 4., 5.,], [1., 2., 3., 4., 5.,]])
input = torch.randn(3, 4, 5)

h0 = torch.zeros(3)
c0 = torch.zeros(3)

ht, ct = cell(input[:,0], (h0, c0))

ht

tensor([[-0.1999, -0.0150, -0.1023],
        [-0.1042, -0.0923, -0.0475],
        [-0.0583,  0.0196,  0.0155]], grad_fn=<MulBackward0>)

In [8]:
ht[:, -1]

tensor([-0.1023, -0.0475,  0.0155], grad_fn=<SelectBackward0>)

# LSTM

In [94]:
class LSTM_v1(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, bidirectional=False, device=None):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.device = device

        # create layers
        self.layers = torch.nn.ModuleList()
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        if self.bidirectional:
            self.backward_layers = torch.nn.ModuleList()
        
        for i in range(num_layers):
            if i == 0:
                self.layers.append(LSTMCell(self.input_size, self.hidden_size))
                if self.bidirectional:
                    self.backward_layers.append(LSTMCell(self.input_size, self.hidden_size))
            else:
                self.layers.append(LSTMCell(self.hidden_size, self.hidden_size))
                if self.bidirectional:
                    self.backward_layers.append(LSTMCell(self.hidden_size, self.hidden_size))
            
        self.num_layers = num_layers * self.num_directions

    def forward(self, x: torch.tensor, h0: tuple[torch.tensor, torch.tensor] = None):
        assert x.shape[-1] == self.input_size, "Input size mismatch"
        if h0 is not None:
            assert h0[0].shape[-1] == self.hidden_size, "Output size mismatch"

        seq_length = x.shape[1]
        batch_size = x.shape[0]
        if h0 is None:
            hx, cx = torch.chunk(torch.zeros(self.num_layers*2, batch_size, self.hidden_size), 2)
        else:
            hx, cx = h0

        # output = torch.tensor([]) # contains all hidden state for the last layer (batch_size, seq_length, hidden_size * D)
        h_n = torch.tensor([]) # contains all last hidden state for every layer (num_layers * D, batch_size, hidden_size)
        c_n = torch.tensor([])  # contains all last cell state for every layer (num_layers * D, batch_size, hidden_size)
        # output_backward = torch.tensor([])
        h_n_backward = torch.tensor([])
        c_n_backward = torch.tensor([])

        input_backward = None
        ht_backward = None

        
        # loop layers
        for i in range(len(self.layers)):
            next_layer_input = torch.tensor([])

            if i == 0:
                input = x
                if self.bidirectional:
                    input_backward = x.flip(dims=(1,))
                    next_layer_input_backward = torch.tensor([])
            else:
                if self.bidirectional:
                    next_layer_input_backward = torch.tensor([])
                
            ht, ct = hx[i], cx[i]
            ht_backward, ct_backward = hx[self.num_layers//2+i], cx[self.num_layers//2+i]
            for j in range(seq_length):
                ht, ct = self.layers[i](input[:,j], (ht,ct)) # take input on [all batch, current seq length]
                next_layer_input = torch.cat([next_layer_input, ht.unsqueeze(1)], dim=1)
                # hx = ht.clone().detach()

                if self.bidirectional:
                    # ht_backward, ct_backward = self.backward_layers[i](input_backward[:,j], (ht[i+1],ct[i+1]))
                    ht_backward, ct_backward = self.backward_layers[i](input_backward[:,j], (ht_backward,ct_backward))
                    next_layer_input_backward = torch.cat([next_layer_input_backward, ht_backward.unsqueeze(1)], dim=1)
                    # input_backward = ht_backward
            
            # append the h_n output of hidden state where n = seq_length on every layer
            h_n = torch.cat([h_n, ht.unsqueeze(0)])
            c_n = torch.cat([c_n, ct.unsqueeze(0)])
            if self.bidirectional:
                h_n_backward = torch.cat([h_n_backward, ht_backward.unsqueeze(0)])
                c_n_backward = torch.cat([c_n_backward, ct_backward.unsqueeze(0)])

            input = next_layer_input.clone().detach()
            if self.bidirectional:
                input_backward = next_layer_input_backward.clone().detach()

        # output = ht
        output = next_layer_input.clone().detach()
        if ht_backward is not None:
            output_backward = next_layer_input_backward.clone().detach()
            output = torch.cat([output, output_backward], dim=2)
            h_n = torch.cat([h_n, h_n_backward], dim=0)
        
        assert output.shape == (batch_size, seq_length, self.hidden_size * self.num_directions), "Output shape mismatch"
        assert h_n.shape == (len(self.layers) * self.num_directions, batch_size, self.hidden_size), "Hidden state mismatch"
        return output, h_n, c_n

In [95]:
batch_size = 2
input_size = 6
hidden_size = 4
num_layers = 2
seq_length = 3

lstm = LSTM_v1(input_size, hidden_size, num_layers=num_layers, bidirectional=True)
inputs = torch.randn(batch_size, seq_length, input_size)
# hx = torch.zeros(4)
# cx = torch.zeros(4)

# output, h_n, c_n = lstm(inputs, (hx, cx))
output, h_n, c_n = lstm(inputs)
output.shape, h_n.shape

(torch.Size([2, 3, 8]), torch.Size([4, 2, 4]))

- output shape should be (batch_size, seq_length, hidden_size * D) -> (2, 3, 4*2) which is true
- hidden state shape should be (num_layers * D, batch_size, hidden_size) -> (2*2, 2, 4) which is true