#### LSTM Cell implementation

In [55]:
import torch
from torch import nn

def make_gate(features, num_units=1, bias=True):
    unit = nn.Sequential(
        nn.Linear(in_features=features, out_features=num_units, bias=bias),
        nn.Sigmoid())
    return unit

class LSTMCell(nn.Module):
    def __init__(self, input_features, layer_size):
        super().__init__()
        self.total_inputs = input_features + layer_size

        self.forget_gate = make_gate(self.total_inputs)
        self.remember_gate = make_gate(self.total_inputs)
        self.output_gate = make_gate(self.total_inputs)
        self.output_activation = nn.Tanh()

        self.input_unit = nn.Sequential(
            nn.Linear(in_features=self.total_inputs, out_features=1),
            nn.ReLU()
        )

        self.state = torch.zeros(1)

    def set_batch_size(self, size):
        self.state = torch.zeros((size, 1))

    def forward(self, x):
        input_act = self.input_unit(x)
        remember_act = self.remember_gate(x)
        forget_act = self.forget_gate(x)
        output_gate_act = self.output_gate(x)

        self.state = (self.state * forget_act) + (input_act * remember_act)
        return self.output_activation(self.state) * output_gate_act


class LSTMLayer(nn.Module):
    def __init__(self, input_features, layer_size):
        super().__init__()
        self.layer_size = layer_size
        self.input_features = input_features
        self.cells = nn.ModuleList([LSTMCell(input_features, layer_size) for _ in range(layer_size)])

    def forward(self, x_batch):
        batch_size = x_batch.shape[1]
        for cell in self.cells:
            cell.set_batch_size(batch_size)
        last_h = torch.zeros((batch_size, self.layer_size))
        for time_step in x_batch:
            cell_input = torch.cat([time_step, last_h], dim=1)
            last_h = [cell(cell_input) for cell in self.cells]
            last_h = torch.cat(last_h, dim=1)
        return last_h


lstm = LSTMLayer(input_features=3, layer_size=4)



In [59]:
import torch
from torch import nn

def make_gate(features, num_units=1, bias=True):
    unit = nn.Sequential(
        nn.Linear(in_features=features, out_features=num_units, bias=bias),
        nn.Sigmoid())
    return unit

class LSTMLayerMatrix(nn.Module):
    def __init__(self, in_features, num_hidden, batch_size):
        super().__init__()
        self.total_inputs = in_features + num_hidden
        self.num_hidden = num_hidden
        self.forget_gate = make_gate(self.total_inputs, num_units=num_hidden)
        self.remember_gate = make_gate(self.total_inputs, num_units=num_hidden)
        self.output_gate = make_gate(self.total_inputs, num_units=num_hidden)
        self.output_activation = nn.Tanh()
        self.input_unit = nn.Linear(in_features=self.total_inputs, out_features=num_hidden)
        self.state = torch.zeros((batch_size, num_hidden))

    def forward(self, x_batch):
        batch_size = x_batch.shape[1]
        last_h = torch.zeros((batch_size, self.num_hidden), device=x_batch.device)
        for time_step in x_batch:
            cell_input = torch.cat([time_step, last_h], dim=1)

            forget = self.forget_gate(cell_input)
            remember = self.remember_gate(cell_input)
            should_output = self.output_gate(cell_input)

            input_activation = torch.tanh(self.input_unit(cell_input))
            self.state = (forget * self.state) + (remember * input_activation)
            last_h = should_output * torch.tanh(self.state)
        return last_h


In [60]:
seq = 23
batch0 = 16
num_feat = 3

mini_batch = torch.rand(size=(seq, batch0, num_feat))

lstm = LSTMLayerMatrix(in_features=3, num_hidden=5, batch_size=batch0)

lstm(mini_batch).shape

torch.Size([16, 5])

In [57]:
lstm2 = LSTMLayer(input_features=3, layer_size=5)
lstm2(mini_batch).shape

torch.Size([16, 5])