In [1]:
import torch
import torch.nn as nn

In [2]:
class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        super(LSTMCell, self).__init__()
        self.ih = nn.Linear(ni, 4*nh)
        self.hh = nn.Linear(nh, 4*nh)

    def forward(self, input, state):
        h, c = state
        # One big multiplication for all the gates is better than 4 smaller ones
        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
        ingate, forgetgate, outgate = map(torch.sigmoid, gates[:3])

        cellgate = gates[3].tanh()

        c = (forgetgate * c) + (ingate * cellgate)
        h = outgate * c.tanh()

        return h, (h,c)

In [16]:
class LSTMLayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    def forward(self, input, state):
        inputs = input.unbind(1)
        #print(len(inputs)) #-->70
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

In [17]:
lstm = LSTMLayer(LSTMCell, 300, 300)

x = torch.randn(64, 70, 300)
h = (torch.zeros(64, 300), torch.zeros(64, 300))

In [21]:
y, h1 = lstm(x, h)
print(y.shape)
print(y[0].shape)

70
torch.Size([64, 70, 300])
torch.Size([70, 300])
