In [None]:
from __future__ import print_function, division, absolute_import

In [None]:
from keras import backend as K

In [None]:
import keras.layers.LSTM as LSTM

In [None]:
# modeled after ryankiros and gruln code
def sample_normalize(x, _eps=1e-5):
    """centers a set of samples x to have zero mean and unit standard deviation"""
    # keepdims=True the axes which are reduced are left in the result as dimensions with size one
    # axis=-1 means do things across the last axis
    m = K.mean(x, axis=-1, keepdims=True) # could subtract this off earlier
    # std = K.std(x)
    std = K.sqrt(K.var(x, axis=-1, keepdims=True) + _eps) # not using K.std for _eps stability
    return (x-m)/ (std+_eps)

Here is what the code for a single step of LSTM looks like in Keras 
(recurrent.py, the LSTM class function)

```python 
    def step(self, x, states):
        h_tm1 = states[0]
        c_tm1 = states[1]
        B_U = states[2]
        B_W = states[3]

        if self.consume_less == 'gpu':
            z = K.dot(x * B_W[0], self.W) + K.dot(h_tm1 * B_U[0], self.U) + self.b

            z0 = z[:, :self.output_dim]
            z1 = z[:, self.output_dim: 2 * self.output_dim]
            z2 = z[:, 2 * self.output_dim: 3 * self.output_dim]
            z3 = z[:, 3 * self.output_dim:]

            i = self.inner_activation(z0)
            f = self.inner_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.inner_activation(z3)
        else:
            if self.consume_less == 'cpu':
                x_i = x[:, :self.output_dim]
                x_f = x[:, self.output_dim: 2 * self.output_dim]
                x_c = x[:, 2 * self.output_dim: 3 * self.output_dim]
                x_o = x[:, 3 * self.output_dim:]
            elif self.consume_less == 'mem':
                x_i = K.dot(x * B_W[0], self.W_i) + self.b_i
                x_f = K.dot(x * B_W[1], self.W_f) + self.b_f
                x_c = K.dot(x * B_W[2], self.W_c) + self.b_c
                x_o = K.dot(x * B_W[3], self.W_o) + self.b_o
            else:
                raise Exception('Unknown `consume_less` mode.')

            i = self.inner_activation(x_i + K.dot(h_tm1 * B_U[0], self.U_i))
            f = self.inner_activation(x_f + K.dot(h_tm1 * B_U[1], self.U_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1 * B_U[2], self.U_c))
            o = self.inner_activation(x_o + K.dot(h_tm1 * B_U[3], self.U_o))

        h = o * self.activation(c)
        return h, [h, c]
```

Here is the _step function from ryankiros

```python
def ln(x, b, s):
    _eps = 1e-5
    output = (x - x.mean(1)[:,None]) / tensor.sqrt((x.var(1)[:,None] + _eps))
    output = s[None, :] * output + b[None,:]
    return output

# class function for lstm layser normalization



    def _step(mask, sbelow, sbefore, cell_before, *args):
        sbelow_ = ln(sbelow, param('b1'), param('s1'))
        sbefore_ = ln(dot(sbefore, param('U')), param('b2'), param('s2'))

        preact = sbefore_ + sbelow_ + param('b')

        i = Sigmoid(_slice(preact, 0, dim))
        f = Sigmoid(_slice(preact, 1, dim))
        o = Sigmoid(_slice(preact, 2, dim))
        c = Tanh(_slice(preact, 3, dim))

        c = f * cell_before + i * c
        c = mask * c + (1. - mask) * cell_before

        c_ = ln(c, param('b3'), param('s3'))
        h = o * tensor.tanh(c_)
        h = mask * h + (1. - mask) * sbefore


```

In [None]:
class LSTM-LN(LSTM):
    def step(self, x, states):
        h_tm1 = states[0]
        c_tm1 = states[1]
        B_U = states[2]
        B_W = states[3]

        if self.consume_less == 'gpu':
            # original linear activity
            # z = K.dot(x * B_W[0], self.W) + K.dot(h_tm1 * B_U[0], self.U) + self.b
            # linear activity without bias term self.b # will need to add this back !!! (see what ryankiros does in ln())

            z = K.dot(x * B_W[0], self.W) + K.dot(h_tm1 * B_U[0], self.U)
            # seems that ryankiros divides things into inputs from below and recurrent input from before (t-1)
            # and normalizes them 
            
            z0 = z[:, :self.output_dim]                         # z0(x_i)         
            z1 = z[:, self.output_dim: 2 * self.output_dim]     # z1(x_f)
            z2 = z[:, 2 * self.output_dim: 3 * self.output_dim] # z2(x_c)
            z3 = z[:, 3 * self.output_dim:]                     # z3(x_o)
            # normalization
            
            i = self.inner_activation(z0)
            f = self.inner_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.inner_activation(z3)
        else:
            if self.consume_less == 'cpu':
                x_i = x[:, :self.output_dim]
                x_f = x[:, self.output_dim: 2 * self.output_dim]
                x_c = x[:, 2 * self.output_dim: 3 * self.output_dim]
                x_o = x[:, 3 * self.output_dim:]
            elif self.consume_less == 'mem':
                x_i = K.dot(x * B_W[0], self.W_i) + self.b_i
                x_f = K.dot(x * B_W[1], self.W_f) + self.b_f
                x_c = K.dot(x * B_W[2], self.W_c) + self.b_c
                x_o = K.dot(x * B_W[3], self.W_o) + self.b_o
            else:
                raise Exception('Unknown `consume_less` mode.')

            i = self.inner_activation(x_i + K.dot(h_tm1 * B_U[0], self.U_i))
            f = self.inner_activation(x_f + K.dot(h_tm1 * B_U[1], self.U_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1 * B_U[2], self.U_c))
            o = self.inner_activation(x_o + K.dot(h_tm1 * B_U[3], self.U_o))

        h = o * self.activation(c)
        return h, [h, c]