# LSTM Long Short Term Memory
### From scratch

## Imports

In [1]:
from utilities.std_imports import *
import random
import math

 ![](lstm.png)

## 1. Activation functions

In [11]:
def Sigmoid(x): 
    return 1. / (1 + np.exp(-x))

def SigmoidDeriv(values): 
    return values * (1-values)

def TanhDeriv(values): 
    return 1. - values ** 2

## 2. State

In [12]:
class State:

    def __init__(self, nCells, x_dim):
        self.g = np.zeros(nCells)
        self.i = np.zeros(nCells)
        self.f = np.zeros(nCells)
        self.o = np.zeros(nCells)
        self.s = np.zeros(nCells)
        self.h = np.zeros(nCells)
        self.bottomDiffH = np.zeros_like(self.h)
        self.bottomDiffS = np.zeros_like(self.s)

## 3. Parameters

In [13]:
class Params:

    def __init__(self, nCells, xDim):

        self.nCells = nCells
        self.xDim = xDim

        # Weights
        self.wg = RandArr(-0.1, 0.1, nCells, xDim + nCells)
        self.wi = RandArr(-0.1, 0.1, nCells, xDim + nCells) 
        self.wf = RandArr(-0.1, 0.1, nCells, xDim + nCells)
        self.wo = RandArr(-0.1, 0.1, nCells, xDim + nCells)
        
        # bias
        self.bg = RandArr(-0.1, 0.1, nCells) 
        self.bi = RandArr(-0.1, 0.1, nCells) 
        self.bf = RandArr(-0.1, 0.1, nCells) 
        self.bo = RandArr(-0.1, 0.1, nCells) 

        # diffs (derivative of loss function in all parameters)
        self.wgDiff = np.zeros((nCells, xDim + nCells)) 
        self.wiDiff = np.zeros((nCells, xDim + nCells)) 
        self.wfDiff = np.zeros((nCells, xDim + nCells)) 
        self.woDiff = np.zeros((nCells, xDim + nCells)) 
        
        self.bgDiff = np.zeros(nCells) 
        self.biDiff = np.zeros(nCells) 
        self.bfDiff = np.zeros(nCells) 
        self.boDiff = np.zeros(nCells) 

    def ApplyDiff(self, lr = 1):
        self.wg -= lr * self.wgDiff
        self.wi -= lr * self.wiDiff
        self.wf -= lr * self.wfDiff
        self.wo -= lr * self.woDiff
        self.bg -= lr * self.bgDiff
        self.bi -= lr * self.biDiff
        self.bf -= lr * self.bfDiff
        self.bo -= lr * self.boDiff

        # reset diffs to zero
        self.wgDiff = np.zeros_like(self.wg)
        self.wiDiff = np.zeros_like(self.wi) 
        self.wfDiff = np.zeros_like(self.wf) 
        self.woDiff = np.zeros_like(self.wo) 
        self.bgDiff = np.zeros_like(self.bg)
        self.biDiff = np.zeros_like(self.bi) 
        self.bfDiff = np.zeros_like(self.bf) 
        self.boDiff = np.zeros_like(self.bo) 

        
# Create uniform random array w/ values in [a,b) and shape args
def RandArr(a, b, *args): 
    np.random.seed(0)
    return np.random.rand(*args) * (b - a) + a

## 4. Node

In [14]:
class Node:

    def __init__(self, pars, state):
        self.state = state
        self.pars = pars
        self.xc = None

    def BottomDataIs(self, x, sPrev = None, hPrev = None):

        if sPrev is None: sPrev = np.zeros_like(self.state.s)
        if hPrev is None: hPrev = np.zeros_like(self.state.h)

        self.sPrev = sPrev
        self.hPrev = hPrev

        # Concatenate x(t) and h(t-1)
        xc = np.hstack((x,  hPrev))
        self.state.g = np.tanh(np.dot(self.pars.wg, xc) + self.pars.bg)
        self.state.i = Sigmoid(np.dot(self.pars.wi, xc) + self.pars.bi)
        self.state.f = Sigmoid(np.dot(self.pars.wf, xc) + self.pars.bf)
        self.state.o = Sigmoid(np.dot(self.pars.wo, xc) + self.pars.bo)
        self.state.s = self.state.g * self.state.i + sPrev * self.state.f
        self.state.h = self.state.s * self.state.o

        self.xc = xc
    
    def TopDiffIs(self, topDiffH, topDiffS):

        # topDiffS is carried along the constant error carousel
        ds = self.state.o * topDiffH + topDiffS
        do = self.state.s * topDiffH
        di = self.state.g * ds
        dg = self.state.i * ds
        df = self.sPrev * ds

        # diffs in vector inside sigma/tanh function
        diInput = SigmoidDeriv(self.state.i) * di 
        dfInput = SigmoidDeriv(self.state.f) * df 
        doInput = SigmoidDeriv(self.state.o) * do 
        dgInput = TanhDeriv(self.state.g) * dg

        # diffs in inputs
        self.pars.wiDiff += np.outer(diInput, self.xc)
        self.pars.wfDiff += np.outer(dfInput, self.xc)
        self.pars.woDiff += np.outer(doInput, self.xc)
        self.pars.wgDiff += np.outer(dgInput, self.xc)
        self.pars.biDiff += diInput
        self.pars.bfDiff += dfInput  
        self.pars.boDiff += doInput
        self.pars.bgDiff += dgInput     

        # compute bottom diff
        dxc = np.zeros_like(self.xc)
        dxc += np.dot(self.pars.wi.T, diInput)
        dxc += np.dot(self.pars.wf.T, dfInput)
        dxc += np.dot(self.pars.wo.T, doInput)
        dxc += np.dot(self.pars.wg.T, dgInput)

        # save bottom diffs
        self.state.bottomDiffS = ds * self.state.f
        self.state.bottomDiffH = dxc[self.pars.xDim:]

## 5. Toy toss layer

In [15]:
# Computes square loss with first element of hidden layer array
class ToyLossLayer:
    @classmethod
    def Loss(self, pred, label):
        return (pred[0] - label) ** 2

    @classmethod
    def BottomDiff(self, pred, label):
        diff = np.zeros_like(pred)
        diff[0] = 2 * (pred[0] - label)
        return diff

## 6. LSTM

In [16]:
class LSTM():

    def __init__(self, pars):
        self.pars = pars
        self.nodes = []
        self.x = []

    # Updates diffs by setting target sequence with corresponding loss layer. 
    def YIs(self, y, lossLayer):
        assert len(y) == len(self.x)
        i = len(self.x) - 1

        # first node only gets diffs from label 
        loss = lossLayer.Loss(self.nodes[i].state.h, y[i])
        diffH = lossLayer.BottomDiff(self.nodes[i].state.h, y[i])

        # here s is not affecting loss due to h(t+1), hence we set equal to zero
        diffS = np.zeros(self.pars.nCells)
        self.nodes[i].TopDiffIs(diffH, diffS)
        i -= 1

        # following nodes also get diffs from next nodes, hence we add diffs to diffH,  also propagate error along constant error carousel using diffS
        while i>= 0:
            loss += lossLayer.Loss(self.nodes[i].state.h, y[i])
            diffH = lossLayer.BottomDiff(self.nodes[i].state.h, y[i])
            diffH += self.nodes[i+1].state.bottomDiffH
            diffS = self.nodes[i+1].state.bottomDiffS
            self.nodes[i].TopDiffIs(diffH, diffS)
            i -= 1 

        return loss

    def ClearX(self):
        self.x = []

    def XAdd(self, x):
        self.x.append(x)
        if len(self.x) > len(self.nodes):
            # need to add new lstm node, create new state mem
            state = State(self.pars.nCells, self.pars.xDim)
            self.nodes.append(Node(self.pars, state))

        # get index of most recent x input
        i = len(self.x) - 1
        if i == 0:
            # no recurrent inputs yet
            self.nodes[i].BottomDataIs(x)
        else:
            sPrev = self.nodes[i-1].state.s
            hPrev = self.nodes[i-1].state.h
            self.nodes[i].BottomDataIs(x, sPrev, hPrev)

## Testing

In [17]:
# Generate random inputs and create network
epochs = 100
pars = Params(nCells=100, xDim=50)
np.random.seed(0)
y = [-0.5, 0.2, 0.1, -0.5]
inputs = [np.random.random(pars.xDim) for _ in y]
lstm = LSTM(pars)

y = [-0.5, 0.2, 0.1, -0.5]
inputs = [np.random.random(pars.xDim) for _ in y]

# Learns to repeat simple sequence from inputs
for e in range(epochs):
    print("Epoch", e)
    for i in range(len(y)): lstm.XAdd(inputs[i])
    print('\t y = [' + ', '.join(['% 2.5f' % lstm.nodes[i].state.h[0] for i in range(len(y))]) + ']', end=", "); print(']')

    loss = lstm.YIs(y, ToyLossLayer)
    print("\t loss:", "%.3e" % loss)
    pars.ApplyDiff(lr=0.1)
    lstm.ClearX()

Epoch 0
	 y = [ 0.11995,  0.19698,  0.16151,  0.10199], ]
	 loss: 7.505e-01
Epoch 1
	 y = [-0.23957, -0.37839, -0.38047, -0.35234], ]
	 loss: 6.550e-01
Epoch 2
	 y = [-0.14506, -0.18230, -0.19233, -0.19017], ]
	 loss: 4.536e-01
Epoch 3
	 y = [-0.13779, -0.15888, -0.17133, -0.18357], ]
	 loss: 4.337e-01
Epoch 4
	 y = [-0.13747, -0.14735, -0.16134, -0.18676], ]
	 loss: 4.185e-01
Epoch 5
	 y = [-0.14056, -0.14100, -0.15620, -0.19491], ]
	 loss: 4.042e-01
Epoch 6
	 y = [-0.14560, -0.13718, -0.15345, -0.20613], ]
	 loss: 3.899e-01
Epoch 7
	 y = [-0.15177, -0.13450, -0.15181, -0.21937], ]
	 loss: 3.753e-01
Epoch 8
	 y = [-0.15857, -0.13209, -0.15046, -0.23386], ]
	 loss: 3.604e-01
Epoch 9
	 y = [-0.16564, -0.12934, -0.14884, -0.24897], ]
	 loss: 3.452e-01
Epoch 10
	 y = [-0.17274, -0.12577, -0.14652, -0.26417], ]
	 loss: 3.296e-01
Epoch 11
	 y = [-0.17975, -0.12100, -0.14318, -0.27901], ]
	 loss: 3.136e-01
Epoch 12
	 y = [-0.18661, -0.11468, -0.13854, -0.29319], ]
	 loss: 2.969e-01
Epoch 13
