# Naive Long Short Term Memory Unit Implementation
This Notebook contains the code written to implement the NaiveLSTM unit for NLP/DL text applications.

In [17]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.optim as optim

from typing import *
from pathlib import Path

from enum import IntEnum

class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2
    
class NaiveLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        # input gate
        self.W_ii = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hi = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_i = Parameter(torch.Tensor(hidden_sz))
        # forget gate
        self.W_if = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hf = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_f = Parameter(torch.Tensor(hidden_sz))
        # ???
        self.W_ig = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hg = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_g = Parameter(torch.Tensor(hidden_sz))
        # output gate
        self.W_io = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_ho = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_o = Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()
    
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)
        
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor, torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = torch.zeros(self.hidden_size).to(x.device), torch.zeros(self.hidden_size).to(x.device)
        else:
            h_t, c_t = init_states
        for t in range(seq_sz): # iterate over the time steps
            x_t = x[:, t, :]
            i_t = torch.sigmoid(x_t @ self.W_ii + h_t @ self.W_hi + self.b_i)
            f_t = torch.sigmoid(x_t @ self.W_if + h_t @ self.W_hf + self.b_f)
            g_t = torch.tanh(x_t @ self.W_ig + h_t @ self.W_hg + self.b_g)
            o_t = torch.sigmoid(x_t @ self.W_io + h_t @ self.W_ho + self.b_o)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (h_t, c_t)
    
# [1] http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf

In [18]:
bs, seq_len, feat_sz, hidden_sz = 5, 10, 32, 16
arr = torch.randn(bs, seq_len, feat_sz)
lstm = NaiveLSTM(feat_sz, hidden_sz)

In [19]:
hs, (hn, cn) = lstm(arr)

In [20]:
hs.shape

torch.Size([5, 10, 16])

In [21]:
!curl http://www.sls.hawaii.edu/bley-vroman/brown.txt -o {"brown.txt"}

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  1 6040k    1  103k    0     0  66259      0  0:01:33  0:00:01  0:01:32 66259
 14 6040k   14  892k    0     0   351k      0  0:00:17  0:00:02  0:00:15  351k
 66 6040k   66 4021k    0     0  1119k      0  0:00:05  0:00:03  0:00:02 1119k
100 6040k  100 6040k    0     0  1674k      0  0:00:03  0:00:03 --:--:-- 1674k


In [22]:
!pip install allennlp

Collecting allennlp
  Downloading https://files.pythonhosted.org/packages/3f/bc/e30325523363215c503171822f09436adcfbc74f426ad62496276f1ac4c0/allennlp-0.8.5-py3-none-any.whl (7.4MB)
Collecting tensorboardX>=1.2 (from allennlp)
  Downloading https://files.pythonhosted.org/packages/c3/12/dcaf67e1312475b26db9e45e7bb6f32b540671a9ee120b3a72d9e09bc517/tensorboardX-1.8-py2.py3-none-any.whl (216kB)
Collecting numpydoc>=0.8.0 (from allennlp)
  Downloading https://files.pythonhosted.org/packages/6a/f3/7cfe4c616e4b9fe05540256cc9c6661c052c8a4cec2915732793b36e1843/numpydoc-0.9.1.tar.gz
Collecting flaky (from allennlp)
  Downloading https://files.pythonhosted.org/packages/fe/12/0f169abf1aa07c7edef4855cca53703d2e6b7ecbded7829588ac7e7e3424/flaky-3.6.1-py2.py3-none-any.whl
Collecting sqlparse>=0.2.4 (from allennlp)
  Downloading https://files.pythonhosted.org/packages/ef/53/900f7d2a54557c6a37886585a91336520e5539e3ae2423ff1102daf4f3a7/sqlparse-0.3.0-py2.py3-none-any.whl
Collecting flask-cors>=3.0.7 (from

Collecting regex (from pytorch-pretrained-bert>=0.6.0->allennlp)
  Downloading https://files.pythonhosted.org/packages/f1/2f/f586e982712ffee5681ca149d54480dbb04ff533e9e4638c5e28ae76bdb5/regex-2019.08.19-cp37-none-win_amd64.whl (325kB)
Collecting greenlet>=0.4.14; platform_python_implementation == "CPython" (from gevent>=1.3.6->allennlp)
  Downloading https://files.pythonhosted.org/packages/90/a3/da8593df08ee2efeb86ccf3201508a1fd2a3749e2735b7cadb7dd00416c6/greenlet-0.4.15-cp37-cp37m-win_amd64.whl
Collecting sentencepiece (from pytorch-transformers==1.1.0->allennlp)
  Downloading https://files.pythonhosted.org/packages/ce/16/17838ebf03ee21daa3b4da0ca5c344bd060bc2963a7567a071cd7008e996/sentencepiece-0.1.83-cp37-cp37m-win_amd64.whl (1.2MB)
Collecting sphinxcontrib-devhelp (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp)
  Downloading https://files.pythonhosted.org/packages/b0/a3/fea98741f0b2f2902fbf6c35c8e91b22cd0dd13387291e81d457f9a93066/sphinxcontrib_devhelp-1.0.1-py2.py3-none-any.whl (84

Installing collected packages: tensorboardX, sphinxcontrib-devhelp, sphinxcontrib-jsmath, sphinxcontrib-htmlhelp, sphinxcontrib-qthelp, sphinxcontrib-serializinghtml, babel, snowballstemmer, alabaster, sphinxcontrib-applehelp, imagesize, sphinx, numpydoc, flaky, sqlparse, flask-cors, overrides, parsimonious, unidecode, word2number, regex, pytorch-pretrained-bert, responses, ftfy, greenlet, gevent, conllu, jsonpickle, sentencepiece, pytorch-transformers, editdistance, allennlp
Successfully installed alabaster-0.7.12 allennlp-0.8.5 babel-2.7.0 conllu-1.3.1 editdistance-0.5.3 flaky-3.6.1 flask-cors-3.0.8 ftfy-5.6 gevent-1.4.0 greenlet-0.4.15 imagesize-1.1.0 jsonpickle-1.2 numpydoc-0.9.1 overrides-1.9 parsimonious-0.8.1 pytorch-pretrained-bert-0.6.2 pytorch-transformers-1.1.0 regex-2019.8.19 responses-0.10.6 sentencepiece-0.1.83 snowballstemmer-1.9.1 sphinx-2.2.0 sphinxcontrib-applehelp-1.0.1 sphinxcontrib-devhelp-1.0.1 sphinxcontrib-htmlhelp-1.0.2 sphinxcontrib-jsmath-1.0.1 sphinxcontrib-