# Building an LSTM from Scratch

Adapted from https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091
And https://gist.github.com/piEsposito/a05bc12cd107fdec68e35ad61302da4c

In [1]:
import math
import numpy as np
import time
import torch
import torch.nn as nn

from itertools import chain

np.set_printoptions(formatter={'float_kind': "{:.3f}".format})

In [2]:
class CustomLSTM(nn.Module):
    """
    Written by Pi Esposito
    https://gist.github.com/piEsposito/a05bc12cd107fdec68e35ad61302da4c
    """
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
                
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, 
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]),     # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),  # g(X, h)
                torch.sigmoid(gates[:, HS*3:]),   # output
            )
            c_t = f_t * c_t + i_t * g_t  # forget old, add new
            h_t = o_t * torch.tanh(c_t)  # produce output
            hidden_seq.append(h_t.unsqueeze(0))

        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)
    
    def verbose_forward(self, x):
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                    torch.zeros(bs, self.hidden_size).to(x.device))
         
        (i_t_vals, f_t_vals, g_t_vals,
         c_t_vals, h_t_vals, o_t_vals) = ([], [], [], [], [], [])
        
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]),     # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),  # g(X, h)
                torch.sigmoid(gates[:, HS*3:]),   # output
            )
            c_t = f_t * c_t + i_t * g_t  # forget old, add new
            h_t = o_t * torch.tanh(c_t)  # produce output
            
            i_t_vals.append(i_t.detach().numpy())
            f_t_vals.append(f_t.detach().numpy())
            g_t_vals.append(g_t.detach().numpy())
            c_t_vals.append(c_t.detach().numpy())
            h_t_vals.append(h_t.detach().numpy())
            o_t_vals.append(o_t.detach().numpy())

        return {"i_t": i_t_vals, "f_t": f_t_vals,
                "g_t": g_t_vals, "c_t": c_t_vals,
                "h_t": h_t_vals, "o_t": o_t_vals,
                }

In [3]:
def make_onehot(arr, vocab_size):
    # Convert to one-hot representation
    # https://en.wikipedia.org/wiki/One-hot

    n, max_length = arr.shape
    onehot_data = np.zeros([n, max_length, vocab_size])
    for v in range(vocab_size):
        onehot_row = np.zeros([vocab_size])
        onehot_row[v] = 1
        onehot_data[arr == v] = onehot_row

    return onehot_data

class CountingDataset(torch.utils.data.Dataset):
    def __init__(self, n, max_length=8, vocab_size=8):
        
        assert vocab_size > 2
        self.n = n
        self.vocab_size = vocab_size
        seq_lengths = np.random.randint(max_length // 2, max_length, n)
        data = np.random.randint(0, vocab_size, [n, max_length])
        
        # Replace elements past the sequence length with -1
        for i in range(n):
            data[i, slice(seq_lengths[i] + 1, None)] = -1
            
        onehot_data = make_onehot(data, vocab_size)
        
        # Label is whether ones outnumber twos in the sequence
        num_ones = (data == 1).sum(axis=1, keepdims=True)
        num_twos = (data == 2).sum(axis=1, keepdims=True)
        label = (num_ones > num_twos).astype(int)

        self.data = torch.tensor(onehot_data).float()
        self.label = torch.tensor(label).long()

    def __len__(self):
        return self.n

    def __getitem__(self, item_index):
        """
        Allow us to select items with `dataset[0]`
        Returns (x, y)
            x: the data tensor
            y: the label tensor
        """
        return self.data[item_index], self.label[item_index]
    
d = CountingDataset(3, max_length=8, vocab_size=3)
d[0]

(tensor([[0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]),
 tensor([0]))

In [4]:
vocab_size = 5
hidden_size = 2
num_epochs = 20
max_length = 10

dataset = CountingDataset(
    10000, max_length=max_length, vocab_size=vocab_size)
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=100, shuffle=False)

lstm = CustomLSTM(vocab_size, hidden_size)
# lstm = nn.LSTM(vocab_size, hidden_size)
classifier = nn.Linear(hidden_size, 2)

loss_func = torch.nn.CrossEntropyLoss()
lstm_opt = torch.optim.SGD(lstm.parameters(), lr=1e-3)
clf_opt = torch.optim.SGD(classifier.parameters(), lr=1e-3)

start = time.time()
for epoch in range(num_epochs):
    total_loss = []
    total_acc = []
    for (X, y) in data_loader:

        hidden, _ = lstm(X)
        final_hidden = hidden[:, -1, :]
        logits = classifier(final_hidden)

        loss = loss_func(logits, y.squeeze())
        acc = torch.argmax(logits, dim=1, keepdims=True) == y
        acc = torch.mean(acc.float())

        loss.backward()
        lstm_opt.step()
        clf_opt.step()

        total_loss.append(loss.detach().numpy())
        total_acc.append(acc.detach().numpy())

    if (epoch + 1) % max(1, num_epochs // 10) == 0:
        mins = (time.time() - start) / 60

        print(", ".join([
            f"Epoch: {epoch + 1:4d}",
            f"Loss: {np.mean(total_loss):.5f}",
            f"Acc: {100*np.mean(total_acc):.1f}",
            f"in {mins:.1f}min",
        ]))

Epoch:    2, Loss: 0.68721, Acc: 61.4, in 0.0min
Epoch:    4, Loss: 0.67619, Acc: 61.6, in 0.1min
Epoch:    6, Loss: 0.67887, Acc: 59.3, in 0.1min
Epoch:    8, Loss: 0.64599, Acc: 63.6, in 0.1min
Epoch:   10, Loss: 0.39516, Acc: 85.5, in 0.1min
Epoch:   12, Loss: 0.20761, Acc: 92.3, in 0.1min
Epoch:   14, Loss: 0.01162, Acc: 99.6, in 0.2min
Epoch:   16, Loss: 0.00279, Acc: 100.0, in 0.2min
Epoch:   18, Loss: 0.00028, Acc: 100.0, in 0.2min
Epoch:   20, Loss: 0.00014, Acc: 100.0, in 0.2min


In [5]:
i = 4
X, y = dataset[i:i+1]

X = np.array([[1., 0, 0, 0, 0, 0, 0, 2]])
X = torch.tensor(make_onehot(X, vocab_size)).float()

out = lstm.verbose_forward(X)
mask = (torch.sum(X, axis=2) > 0).squeeze().numpy()

val = np.argmax(X[:, mask, :], axis=2).reshape(-1, 1)
i_t = np.array(out["i_t"])[mask].squeeze()
g_t = np.array(out["g_t"])[mask].squeeze()
f_t = np.array(out["f_t"])[mask].squeeze()
c_t = np.array(out["c_t"])[mask].squeeze()
h_t = np.array(out["h_t"])[mask].squeeze()
o_t = np.array(out["o_t"])[mask].squeeze()

print("forget gate mean: {} std: {}".format(
    np.mean(f_t, axis=0), np.std(f_t, axis=0)))

columns = ["i_t", "g_t", "c_t", "h_t"]
col_vals = [np.array(out[col])[mask].squeeze() for col in columns]
table = np.concatenate([val] + col_vals, axis=1).tolist()

fmt = "{:^3s}  " + " ".join(
    ["{:^13s}" for _ in range(len(columns))])
print(fmt.format("val", *columns))
for i in range(len(table)):
    row = table[i]
    print("{:^3d}".format(int(row[0])), end=" ")
    print(" ".join(map("{:6.2f}".format, row[1:])))

forget gate mean: [1.000 0.291] std: [0.000 0.048]
val       i_t           g_t           c_t           h_t     
 1    0.95   0.14   1.00  -1.00   0.95  -0.14   0.74  -0.05
 0    0.00   0.08   0.58  -0.67   0.95  -0.09   0.74  -0.01
 0    0.00   0.08   0.55  -0.67   0.95  -0.08   0.74  -0.00
 0    0.00   0.08   0.54  -0.67   0.95  -0.08   0.74  -0.00
 0    0.00   0.08   0.54  -0.67   0.95  -0.08   0.74  -0.00
 0    0.00   0.08   0.54  -0.67   0.95  -0.08   0.74  -0.00
 0    0.00   0.08   0.54  -0.67   0.95  -0.08   0.74  -0.00
 2    1.00   0.19  -1.00   0.92  -0.05   0.15  -0.05   0.03
