In [0]:
import torch
from torch import nn
import torch.nn.functional as F

In [0]:
class LSTMCell(nn.Module):
  def __init__(self, input_depth, hidden_depth):
    super().__init__()
    
    self.input_depth = input_depth
    self.hidden_depth = hidden_depth
    
    concat_depth = input_depth + hidden_depth

    self.forget = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Sigmoid())
    self.input = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Sigmoid())

    self.new_cell = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Tanh())
    self.output = nn.Sequential(nn.Linear(concat_depth, hidden_depth), nn.Sigmoid())

    self.cell = None
    self.hidden = None
  
  def reset(self):
    self.cell = None
    self.hidden = None
  
  def forward(self, token):
    if self.hidden is None:
      self.hidden = torch.zeros(token.shape[0], self.hidden_depth)
    
    if self.cell is None:
      self.cell = torch.zeros(token.shape[0], self.hidden_depth)

    concat = torch.cat([token, self.hidden], dim=1)

    forget = self.forget(concat)
    inpt = self.input(concat)
    new_cell = self.new_cell(concat)

    self.cell = self.cell * forget + new_cell * inpt

    output = self.output(concat)

    self.hidden = torch.tanh(self.cell) * output 

    return self.hidden, self.cell



In [0]:
class LSTM(nn.Module):
  def __init__(self, input_depth, hidden_depth, num_layers): 
    super().__init__()

    self.input_depth = input_depth
    self.hidden_depth = hidden_depth
    self.num_layers = num_layers

    self.layers = [LSTMCell(input_depth, hidden_depth) for _ in range(num_layers)]

  def reset(self):
    for l in self.layers: l.reset()

  # Inp takes shape (sequence_length, batch_size, input_depth)
  def forward(self, inp):
    
    self.reset()

    hiddens = []


    for token in inp:
      carry = token
      for l in self.layers:
        hidden, cell = l(carry)
        carry = hidden
      
      hiddens.append(carry)

    return hiddens


In [0]:
l = LSTM(10, 10, 2)

In [0]:
l(torch.randn(3, 2, 10))

[tensor([[-0.0193, -0.0183, -0.0331, -0.0252,  0.0037,  0.0604,  0.0005,  0.0384,
           0.0047,  0.0080],
         [-0.0284, -0.0091, -0.0374, -0.0258,  0.0147,  0.0492, -0.0038,  0.0198,
           0.0003,  0.0077]], grad_fn=<MulBackward0>),
 tensor([[-5.0068e-02, -4.2120e-02, -3.7415e-02, -2.2011e-02,  3.0578e-02,
           6.6465e-02, -3.6155e-03,  3.3229e-02,  9.7587e-03, -6.9730e-05],
         [-3.8661e-02, -1.3510e-02, -2.3087e-02, -2.8316e-02,  1.8454e-02,
           4.9066e-02,  8.1171e-03,  2.3087e-02, -1.3406e-03,  2.1450e-02]],
        grad_fn=<MulBackward0>),
 tensor([[-0.0304, -0.0007, -0.0285, -0.0291, -0.0003,  0.0617, -0.0069,  0.0173,
          -0.0024, -0.0054],
         [-0.0303, -0.0411, -0.0392, -0.0303,  0.0118,  0.0637,  0.0016,  0.0330,
           0.0088,  0.0100]], grad_fn=<MulBackward0>)]