#### Gain a better understanding of LSTM cells by doing a basic implementation in numpy without considering performance. 

In [1]:
import numpy as np
import tensorflow as tf

A single LSTM Cell

In [2]:
def sigmoid(x): 
    return 1/ (1+np.exp(-x))

class LSTMCell(): 
    def __init__(self, x, c, h, W=None, b=None):
        self.x = x 
        self.c = c
        self.h = h 
        self.W = W
        self.b = b
        self.concat = np.concatenate([self.x, self.h], axis=1)
        
    def call(self): 
        # arr has dimensions [x.shape[0], c.shape[1]*4]
        arr = np.matmul(self.concat, self.W) + self.b
        
        # input, new_c, forget, output
        i, j, f, o = np.split(arr, 4, axis=1)
        
        # Calculate new c and h 
        # Operations are pointwise 
        new_c = (self.c * sigmoid(f)) + (sigmoid(i) * np.tanh(j))
        self.h = np.tanh(new_c) * sigmoid(o)
        self.c = new_c
        return self.c, self.h

Set initial states

In [3]:
tf.reset_default_graph()
x = tf.constant([[1,1]], dtype=tf.float32) 
c = tf.constant([[.1,.1]], dtype=tf.float32) 
h = tf.constant([[.3,.5]], dtype=tf.float32)

Run in Tensorflow to check work

In [4]:
cell = tf.contrib.rnn.BasicLSTMCell(num_units=2,  
                                    forget_bias=1.0, 
                                    state_is_tuple=True)
state = (c,h)
outputs, states = cell(x, state)
init = tf.global_variables_initializer()
with tf.Session() as sess: 
    sess.run(init)
    weights = sess.run(cell.weights)
    expected_c = sess.run(states.c)
    expected_output = sess.run(states.h)
    print('Expected c:', expected_c)
    print('Expected h:', expected_output)

Expected c: [[-0.08290453 -0.21169293]]
Expected h: [[-0.02955644 -0.0449947 ]]


Run the manual LSTM Cell

In [6]:
# Use the same weights initialized in the Tensorflow LSTMCell 
W,b = weights 

x = np.array([[1.,1.]])
c = np.array([[.1,.1]])
h = np.array([[.3,.5]])

manual_cell = LSTMCell(x, c, h, W, b)
c, h = manual_cell.call()

# End up with close-ish results 
print('Manual c:', c)
print('Manual h:', h) 

Manual c: [[-0.10680572 -0.23565682]]
Manual h: [[-0.03802023 -0.04991358]]
