# Implementation of lstm from scratch in PyTorch

Reference: https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091

In [1]:
import math
import torch
import torch.nn as nn

### LSTM w/o vectorization

In [8]:
class NaiveCustomLSTM(nn.Module):
  def __init__(self, input_sz: int, hidden_sz: int):
    super().__init__()
    self.input_size = input_sz
    self.hidden_size = hidden_sz

    #i_t 
                # nn.Parameter : special subclass of Tensor that's, when
                # assigned as module attribute, registered as module parameters
                # so it's included in module.parameters() iterator
    self.U_i = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))
    self.V_i = nn.Paramater(torch.Tensor(self.hidden_size, self.hidden_size))
    self.b_i = nn.Parameter(torch.Tensor(hidden_sz))

    # f_t
    self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
    self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
    self.b_f = nn.Parameter(torch.Tensor(hidden_sz))

    #c_t
    self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
    self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
    self.b_c = nn.Parameter(torch.Tensor(hidden_sz))

    #o_t
    self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
    self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
    self.b_o = nn.Parameter(torch.Tensor(hidden_sz))

    self.init_weights() # defined below

  def init_weights(self):
    # xavier initialization (suitable for sigmoid or tanh)
    # https://machinelearningmastery.com/weight-initialization-for-deep-learning-neural-networks/
    # pytorch implementation does the following as well
    # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
      weight.data.uniform_(-stdv, stdv)

  def forward(self, x, init_states=None):
    bs, seq_sz, _ = x.size() # use .size() to get shape
    hidden_seq = []

    if init_states is None:
      h_t, c_t = (
          torch.zeros(bs, self.hidden_size).to(x.device),
          torch.zeros(bs, self.hidden_size).to(x.device)
      )
    else:
      h_t, c_t = init_states

    for t in range(seq_sz): # for loop over timesteps (in the same batch, seq_sz must be the same....)
      x_t = x[:, t, :] # of shape (bs, #feature)
                      # if you index along one axis, the dimension decreases to 2d

      i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i) # input gate
      f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f) # forget gate
      g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c) # gate gate
      o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o) # output gate
      c_t = f_t * c_t + i_t * g_t # update cell state (f_t * c_t-1 + i_t * g_t)
      h_t = o_t * torch.tanh(c_t)  # shape: (bs, hidden)

      hidden_seq.append(h_t.unsqueeze(0)) # shape (1, bs, hidden) 
                        # -> if not add dim, after torch.cat, the shape will be (t*bs, hidden) 
                        # which we don't want.
                        # https://pytorch.org/docs/stable/generated/torch.cat.html#torch-cat

    hidden_seq = torch.cat(hidden_seq, dim=0) # (t, bs, hidden)
    hidden_seq = hidden_seq.transpose(0, 1).contiguous() # (bs, t, hidden) # why use contiguous? unclear.
    return hidden_seq, (h_t, c_t) # hidden for each timestamp, (final hidden, final cell state)

### LSTM w/ vectorization

In [14]:
class CustomLSTM(nn.Module):
  def __init__(self, input_sz, hidden_sz):
    super().__init__()
    self.input_sz = input_sz
    self.hidden_size = hidden_sz
    self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
    self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
    self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
    self.init_weights()

  def init_weights(self):
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
      weight.data.uniform_(-stdv, stdv)
  
  def forward(self, x, init_states = None):
    bs, seq_sz, _ = x.size() 
    hidden_seq = []
    if init_states is None:
      h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
                  torch.zeros(bs, self.hidden_size).to(x.device))
    else:
      h_t, c_t = init_states
    
    HS = self.hidden_size
    for t in range(seq_sz):
      x_t = x[:, t, :]

      # batch computation
      gates = x_t @ self.W + h_t @ self.U + self.bias
      i_t, f_t, g_t, o_t = (
          torch.sigmoid(gates[:, :HS]),
          torch.sigmoid(gates[:, HS:HS*2]),
          torch.tanh(gates[:, HS*2: HS*3]),
          torch.sigmoid(gates[:, HS*3:]),
      )
      c_t = f_t * c_t + i_t * g_t
      h_t = o_t * torch.tanh(c_t)
      hidden_seq.append(h_t.unsqueeze(0))
    hidden_seq = torch.cat(hidden_seq, dim=0) # (t, batch_size, feature)
    hidden_seq = hidden_seq.transpose(0, 1).contiguous() # (batch_size, t, feature)
    return hidden_seq, (h_t, c_t)

### bidirectional LSTM w/ vectorization

In [37]:
# bidirectional version
class BiCustomLSTM(nn.Module):
  def __init__(self, input_sz, hidden_sz):
    super(BiCustomLSTM, self).__init__()
    ## hidden_size, emb_size
    self.input_size = input_sz
    self.hidden_size = hidden_sz
    # should there be 2 sets of U, V, b? -> two sets? b/c forward and backward relation must be diff.
    self.Uf = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
    self.Ub = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
    self.Vf = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
    self.Vb = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
    self.bf = nn.Parameter(torch.Tensor(hidden_sz*4))
    self.bb = nn.Parameter(torch.Tensor(hidden_sz*4))

    self.init_weights()

  def init_weights(self):
    stdv = 1.0 / math.sqrt(self.hidden_size) 
    for weight in self.parameters():
      weight.data.uniform_(-stdv, stdv) # syntax to initialize parameters
    
  def forward(self, x, init_states = None):
    bs, seq_sz, _ = x.size()

    hidden_seq_f = []
    hidden_seq_b = []

    if init_states is None:
      hf_t, cf_t, hb_t, cb_t = (
          torch.zeros(bs, self.hidden_size).to(x.device),
          torch.zeros(bs, self.hidden_size).to(x.device),
          torch.zeros(bs, self.hidden_size).to(x.device),
          torch.zeros(bs, self.hidden_size).to(x.device)
      )
    else:
      (hf_t, hb_t), (cf_t, cb_t) = init_states

    for t in range(seq_sz):
      x_t = x[:, t, :] 
      gates = x_t @ self.Uf + hf_t @ self.Vf + self.bf
      i_t = torch.sigmoid(gates[:, :self.hidden_size])
      f_t = torch.sigmoid(gates[:, self.hidden_size: 2*self.hidden_size])
      g_t = torch.tanh(gates[:, 2*self.hidden_size:3*self.hidden_size])
      o_t = torch.sigmoid(gates[:, 3*self.hidden_size])

      cf_t = f_t * cf_t + i_t * g_t
      hf_t = o_t * torch.tanh(cf_t) # (bs, hidden)

      hidden_seq_f.append(hf_t.unsqueeze(0))

    hidden_seq_f = torch.cat(hidden_seq_f, axis=0) # (t, bs, hidden)

    for t in range(seq_sz-1, -1, -1):
      x_t = x[:, t, :]
      gates = x_t @ self.Ub + hb_t @ self.Vb + self.bb
      i_t = torch.sigmoid(gates[:, :self.hidden_size])
      f_t = torch.sigmoid(gates[:, self.hidden_size: 2*self.hidden_size])
      g_t = torch.tanh(gates[:, 2*self.hidden_size:3*self.hidden_size])
      o_t = torch.sigmoid(gates[:, 3*self.hidden_size])

      cb_t = f_t * cb_t + i_t * g_t
      hb_t = o_t * torch.tanh(cb_t) # (bs, hidden)

      hidden_seq_b.append(hb_t.unsqueeze(0)) 

    hidden_seq_b = torch.cat(hidden_seq_b[::-1], dim=0) # (t, bs, hidden) 
                                                        # [::-1] to align timesteps with forward direction
                                                        # (0,1, ..., t)

    hidden_seq = torch.cat([hidden_seq_f, hidden_seq_b], dim=-1) # (t, bs, hidden*2)
    hidden_seq = hidden_seq.transpose(1, 0).contiguous() # permute vs transpose
                      # transpose: indicate dim1, dim2 and swap dim1 and dim2
                      # permute: indicate a tuple of desired ordering of dimensions
                      #          must include all dimensions (1, 0, 2) if swapping 0th and 1th dim.

    h_t = torch.cat([hf_t, hb_t], dim=-1) # (bs, hidden*2)
    c_t = torch.cat([cf_t, cb_t], dim=-1) # (bs, hidden*2)

    return hidden_seq, (h_t, c_t)

In [24]:
# let's try it out!
x = torch.Tensor([10, 20, 30, 12]).long().unsqueeze(0)
emb = nn.Embedding(50, 10)
x = emb(x)
lstm = CustomLSTM(10, 5)
hiddens, (h_fin, c_fin) = lstm(x)

In [25]:
x.size()

torch.Size([1, 4, 10])

In [26]:
print(lstm)

CustomLSTM()


In [27]:
hiddens.size()

torch.Size([1, 4, 5])

In [28]:
h_fin

tensor([[-0.1585, -0.1943,  0.1581, -0.0865, -0.0030]], grad_fn=<MulBackward0>)

In [29]:
c_fin

tensor([[-0.3514, -0.3162,  0.3134, -0.1480, -0.0099]], grad_fn=<AddBackward0>)

In [30]:
hiddens

tensor([[[-0.0299, -0.0535, -0.1645,  0.2020, -0.0797],
         [-0.0626, -0.1986, -0.0399,  0.0957, -0.1124],
         [ 0.0499, -0.3435,  0.0904,  0.3283, -0.0740],
         [-0.1585, -0.1943,  0.1581, -0.0865, -0.0030]]],
       grad_fn=<TransposeBackward0>)

In [38]:
# let's try it out!
x = torch.Tensor([10, 20, 30, 12]).long().unsqueeze(0)
emb = nn.Embedding(50, 10)
x = emb(x)
bilstm = BiCustomLSTM(10, 5)
hiddens, (h_fin, c_fin) = bilstm(x)

In [45]:
hiddens.size() 

torch.Size([1, 4, 10])

In [44]:
h_fin.size()

torch.Size([1, 10])

In [43]:
c_fin.size()

torch.Size([1, 10])

In [49]:
hiddens[:,-1,:5], hiddens[:, 0, 5:]

(tensor([[-0.3111, -0.4424, -0.0711, -0.5527, -0.4322]],
        grad_fn=<SliceBackward0>),
 tensor([[ 0.0270, -0.0345, -0.2361, -0.0316, -0.0786]],
        grad_fn=<SliceBackward0>))

In [48]:
h_fin

tensor([[-0.3111, -0.4424, -0.0711, -0.5527, -0.4322,  0.0270, -0.0345, -0.2361,
         -0.0316, -0.0786]], grad_fn=<CatBackward0>)