In [1]:
import numpy as np
import keras
from keras import layers
from keras import ops
from keras.utils import pad_sequences
from keras.callbacks import EarlyStopping
import os
import tensorflow as tf
os.environ["KERAS_BACKEND"] = "jax"



# Recurrent Neural Network and LSTM for IMDB Sentiment Analysis

Apparently this is a pre-requisite for understanding Transformers, so let's do it

In this one, I will implement the Recurrent Neural Network. First with the high level functional api call, then writing the layer as a custom subclass

The mlp had 2.5 million parameters and .88 accuracy

In [2]:
# reduced these hyper parameters since this is so much slower
EMBED_DIM = 100
VOCABULARY_SIZE = 8000
MAX_SEQUENCE_LEN = 64

(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(
    num_words = VOCABULARY_SIZE,
    skip_top = 0,
    max_len=MAX_SEQUENCE_LEN,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,   
)
x_train = pad_sequences(x_train, maxlen=MAX_SEQUENCE_LEN)
x_test = pad_sequences(x_test, maxlen=MAX_SEQUENCE_LEN)
word_index = keras.datasets.imdb.get_word_index(path="imdb_word_index.json")


## Implementation with keras.layers.LSTM

In [3]:
inputs = keras.Input(shape=(MAX_SEQUENCE_LEN,), dtype='int32')
x = keras.layers.Embedding(VOCABULARY_SIZE, EMBED_DIM) (inputs)
x = keras.layers.LSTM(64)(x) # That's it?
outputs = keras.layers.Dense(1)(x)

model = keras.Model(inputs=inputs, outputs = outputs ,name='imdb_model')
model.summary()

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',     
    patience=3,      
    restore_best_weights=True 
)

model.compile(
    loss= keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer = keras.optimizers.RMSprop(),
    metrics=['accuracy'],
)

history = model.fit(
    x_train, 
    y_train, 
    epochs=5, 
    batch_size=64, 
    validation_split=0.2, 
    callbacks=[early_stopping_callback])

test_scores=model.evaluate(x_test, y_test, verbose=1)

print(f'Test Loss: {test_scores[0]}')
print(f'Test Accuracy: {test_scores[1]}')

Epoch 1/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 38ms/step - accuracy: 0.6172 - loss: 0.5980 - val_accuracy: 0.8096 - val_loss: 0.4002
Epoch 2/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 40ms/step - accuracy: 0.8302 - loss: 0.3726 - val_accuracy: 0.8250 - val_loss: 0.3849
Epoch 3/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 40ms/step - accuracy: 0.8619 - loss: 0.3122 - val_accuracy: 0.7970 - val_loss: 0.4078
Epoch 4/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 35ms/step - accuracy: 0.8837 - loss: 0.2722 - val_accuracy: 0.8212 - val_loss: 0.3960
Epoch 5/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 38ms/step - accuracy: 0.8985 - loss: 0.2469 - val_accuracy: 0.8162 - val_loss: 0.4184
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - accuracy: 0.8264 - loss: 0.3835
Test Loss: 0.3794962465763092
Test Accuracy: 0.8289200067520142


## Implementation with custom layers

Just learned now that unlike in the jax docs or my familiar linear algebra textbook, the vectors in Keras are conventionally row vectors, which works out more nicely when batch is the first dimension of input. As a result, matrix multiplications look like Vector @ Matrix (of shape input, output)

Also, a quick thing I noticed. Applying dimensionality analysis shows that b has to be broadcasted in order to work with these functions. b is written as a vector, but needs to be illegally added to a batched matrix.

In [7]:
class MyRNN(keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units=units
        self.activation = keras.activations.tanh
    
    def build(self, input_shape):
        # input_shape is (batch_size, timesteps, input_features)
        input_features = input_shape[-1]
        self.wx = self.add_weight(
            shape=(input_features, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.b = self.add_weight(shape=(self.units,), initializer="zeros", trainable=True)
        self.wh = self.add_weight(
            shape=(self.units,self.units),
            initializer='random_normal',
            trainable=True
        )
        super().build(input_shape)
        
    def call(self, inputs):
        batch_size = ops.shape(inputs)[0]
        h = ops.zeros(shape=(batch_size, self.units))
        for x_t in ops.unstack(inputs, axis=1): 
            z = ops.matmul(x_t, self.wx) + ops.matmul(h, self.wh) + self.b
            h = self.activation(z)
        return h
            

In [10]:
class MyLSTM(keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
    
    def build(self, input_shape):
        input_features = input_shape[-1]
        self.sw1 = self.add_weight(
            shape=(self.units, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.sw2 = self.add_weight(
            shape=(self.units, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.sw3 = self.add_weight(
            shape=(self.units, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.sw4 = self.add_weight(
            shape=(self.units, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.xw1 = self.add_weight(
            shape=(input_features, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.xw2 = self.add_weight(
            shape=(input_features, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.xw3 = self.add_weight(
            shape=(input_features, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.xw4 = self.add_weight(
            shape=(input_features, self.units),
            initializer='random_normal',
            trainable=True
        )
        self.b1 = self.add_weight(shape=(self.units,), initializer="zeros", trainable=True)
        self.b2 = self.add_weight(shape=(self.units,), initializer="zeros", trainable=True)
        self.b3 = self.add_weight(shape=(self.units,), initializer="zeros", trainable=True)
        self.b4 = self.add_weight(shape=(self.units,), initializer="zeros", trainable=True)
        super().build(input_shape)
        
    
    def call(self, inputs):
        batch_size = ops.shape(inputs)[0]
        #short term and long term memory vectors
        s = ops.zeros(shape=(batch_size, self.units))
        l = ops.zeros(shape=(batch_size, self.units))
        for x_t in ops.unstack(inputs, axis=1):
            # forget gate. sigmoid
            l *= keras.activations.sigmoid(ops.matmul(x_t,self.xw1) + ops.matmul(s,self.sw1) + self.b1)
            # input gate. tanh and sigmoid
            z1 = keras.activations.sigmoid(ops.matmul(x_t,self.xw2) + ops.matmul(s,self.sw2) + self.b2)
            z2 = keras.activations.tanh(ops.matmul(x_t,self.xw3) + ops.matmul(s,self.sw3) + self.b3)
            l += z1 * z2
            # output gate. sigmoid and tanh
            z1 = keras.activations.sigmoid(ops.matmul(x_t,self.xw4) + ops.matmul(s,self.sw4) + self.b4)
            z2 = keras.activations.tanh(l)
            s = z1 * z2
        return s

In [11]:
inputs = keras.Input(shape=(MAX_SEQUENCE_LEN,), dtype='int32')
x = keras.layers.Embedding(VOCABULARY_SIZE, EMBED_DIM) (inputs)
# x = MyRNN(64)(x) # That's it?
x = MyLSTM(64)(x)

outputs = keras.layers.Dense(1)(x)

model = keras.Model(inputs=inputs, outputs = outputs ,name='imdb_model')
model.summary()

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',     
    patience=3,      
    restore_best_weights=True 
)

model.compile(
    loss= keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer = keras.optimizers.RMSprop(),
    metrics=['accuracy'],
)

history = model.fit(
    x_train, 
    y_train, 
    epochs=5, 
    batch_size=64, 
    validation_split=0.2, 
    callbacks=[early_stopping_callback])

test_scores=model.evaluate(x_test, y_test, verbose=1)

print(f'Test Loss: {test_scores[0]}')
print(f'Test Accuracy: {test_scores[1]}')

Epoch 1/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 30ms/step - accuracy: 0.5408 - loss: 0.6596 - val_accuracy: 0.7948 - val_loss: 0.4762
Epoch 2/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 32ms/step - accuracy: 0.8159 - loss: 0.4045 - val_accuracy: 0.8138 - val_loss: 0.4239
Epoch 3/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 31ms/step - accuracy: 0.8574 - loss: 0.3375 - val_accuracy: 0.8200 - val_loss: 0.3939
Epoch 4/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 31ms/step - accuracy: 0.8650 - loss: 0.3103 - val_accuracy: 0.8240 - val_loss: 0.3965
Epoch 5/5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 33ms/step - accuracy: 0.8796 - loss: 0.2900 - val_accuracy: 0.8116 - val_loss: 0.3931
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 6ms/step - accuracy: 0.8095 - loss: 0.3935
Test Loss: 0.387563019990921
Test Accuracy: 0.8129199743270874


## Takeaways

I'm pretty proud of that LSTM implementation. Got it in one try by carefully tracing the matrix dimensions. Notably, I could improve it by running more computations at once. For instance, many of the gates can be computed at the same time using one larger matrix. Additionally, the keras.layers.RNN class allows me to instantiate with a custom Cell containing this new call function. Appararently that would migrate the implementation of the for loop from the python interpretor to the low level backend and make it a lot faster. 

Empirically though, I do not notice a difference in speed (I don't have the compatible hardware for gpu acceleration), so I don't expect that refactoring exercise to be that satisfying. Additionally, I think I got the pedadogical takeaways for this module.