In [1]:
import numpy as np
import keras
from keras import layers
from keras import ops
from keras.utils import pad_sequences
from keras.callbacks import EarlyStopping
import os
import tensorflow as tf
from keras.layers import TextVectorization
os.environ["KERAS_BACKEND"] = "jax"



# Learning the Keras 3 Abstraction

This module will implement the jax_imbdb_mlp model I made but while exploring the Keras 3 framework. I will stick with the JAX backend and try to break down each step to maintain full confidence without becoming reliant on the Keras 3 high level tools.

There is still stuff to pickup though:
1. Dropout, which is a mask over the previous layer's activations
2. early stop callback. Validation Set
3. Incorporating a pre-trained set of word embeddings (absolute hell)

## With the functional API

In [20]:

EMBED_DIM = 200
VOCABULARY_SIZE = 10000
MAX_SEQUENCE_LEN = 256

path_to_glove_file = './pre_trained_models/glove.6B.200d.txt'
embeddings_index = {}
with open(path_to_glove_file, encoding="utf8") as f:
    for line in f:
        word, coefs = line.split(maxsplit=1)
        coefs = np.fromstring(coefs, "f", sep=" ")
        embeddings_index[word] = coefs
print(f"Found {len(embeddings_index)} word vectors.")

# already tokenized
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(
    num_words = VOCABULARY_SIZE,
    skip_top = 0,
    max_len=256,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,   
)
x_train = pad_sequences(x_train, maxlen=MAX_SEQUENCE_LEN)
x_test = pad_sequences(x_test, maxlen=MAX_SEQUENCE_LEN)
word_index = keras.datasets.imdb.get_word_index()

embedding_matrix = np.zeros((VOCABULARY_SIZE, EMBED_DIM))
for word, i in word_index.items():
    index = i + 3 
    if index < VOCABULARY_SIZE:
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            # Place the vector at the correct, offset index
            embedding_matrix[index] = embedding_vector

inputs = keras.Input(shape=(MAX_SEQUENCE_LEN,), dtype='int32')
x = keras.layers.Embedding(
    VOCABULARY_SIZE,
    EMBED_DIM,
    embeddings_initializer=keras.initializers.Constant(embedding_matrix),
    trainable=True,
)(inputs)
x = keras.layers.GlobalAveragePooling1D()(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.3)(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.3)(x)
outputs = keras.layers.Dense(1)(x)

model = keras.Model(inputs=inputs, outputs = outputs ,name='imdb_model')
model.summary()

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',          # Monitor validation loss
    patience=3,                  # Stop after 3 epochs of no improvement
    restore_best_weights=True    # Restore weights from the best epoch
)

model.compile(
    loss= keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer = keras.optimizers.RMSprop(),
    metrics=['accuracy'],
)

history = model.fit(
    x_train, 
    y_train, 
    epochs=20, 
    batch_size=64, 
    validation_split=0.2, 
    callbacks=[early_stopping_callback])

test_scores=model.evaluate(x_test, y_test, verbose=1)

print(f'Test Loss: {test_scores[0]}')
print(f'Test Accuracy: {test_scores[1]}')

Found 400000 word vectors.


Epoch 1/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 13ms/step - accuracy: 0.5128 - loss: 0.6880 - val_accuracy: 0.6972 - val_loss: 0.6164
Epoch 2/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 12ms/step - accuracy: 0.6472 - loss: 0.6052 - val_accuracy: 0.6536 - val_loss: 0.5281
Epoch 3/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 12ms/step - accuracy: 0.7186 - loss: 0.5317 - val_accuracy: 0.8140 - val_loss: 0.4347
Epoch 4/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - accuracy: 0.7664 - loss: 0.4726 - val_accuracy: 0.7974 - val_loss: 0.3998
Epoch 5/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 12ms/step - accuracy: 0.7778 - loss: 0.4560 - val_accuracy: 0.8132 - val_loss: 0.4267
Epoch 6/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.7942 - loss: 0.4291 - val_accuracy: 0.8408 - val_loss: 0.3747
Epoch 7/20
[1m313/313

## Cracking it open if we can?

So that was really easy. Let's make it hard! This Object Oriented Approach utilized Model Subclassing, which allows for custom training/evaluation logic or for a dynamic forward pass(call) as opposed to static logic. For IMDB, we are just copying the structure, but hopefully for later, I want to use this method because:
- In the forward pass, you can use conditional logic and for loops like an "Imitation Learning" model that might choose to use one of its internal "expert" sub-networks based on an if condition on the input data.
- In the train step, we can tailor for more complex evaluation logic, like for a Generative Adversarial Network, which requires two optimizers and two models updates side by side.

The drawbacks include
- model.summary() cannot read inside the call() function so it can't show how layers are connected
- less portable. Parameters are easy to save but not the architecture. Not serializable; recreating the architecture requires code
- Later error checking. The functional API constructs the entire graph before the call, leading to early catches, but that is not possible with this imperative method.
- Slightly more boilerplate. object oriented

In [21]:
EMBED_DIM = 256
VOCABULARY_SIZE = 10000
MAX_SEQUENCE_LEN = 256

class ImdbModel(keras.Model):
    def __init__(self):
        super().__init__()
        input_shape = keras.Input(shape=(MAX_SEQUENCE_LEN,), dtype='int32')
        self.embedding = keras.layers.Embedding(VOCABULARY_SIZE, EMBED_DIM)
        self.globalAveragePooling = keras.layers.GlobalAveragePooling1D()
        self.dropout1 = keras.layers.Dropout(0.5)
        self.dense1 = keras.layers.Dense(128, activation='relu')
        self.dropout2 = keras.layers.Dropout(0.3)
        self.dense2 = keras.layers.Dense(128, activation='relu')
        self.dropout3 = keras.layers.Dropout(0.3)
        self.dense3 = keras.layers.Dense(1)
        self.build(input_shape)
        
    def call(self, inputs):
        x = self.embedding(inputs)
        x = self.globalAveragePooling(x)
        x = self.dropout1(x)
        x = self.dense1(x)
        x = self.dropout2(x)
        x = self.dense2(x)
        x = self.dropout3(x)
        x = self.dense3(x)
        return x
        
model = ImdbModel()

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',          # Monitor validation loss
    patience=3,                  # Stop after 3 epochs of no improvement
    restore_best_weights=True    # Restore weights from the best epoch
)
        
model.compile(
    optimizer=keras.optimizers.RMSprop(),
    loss = keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy']
)

history = model.fit(
    x_train,
    y_train,
    epochs = 20,
    batch_size=64,
    validation_split=0.2,
    callbacks=[early_stopping_callback]
)

test_scores=model.evaluate(x_test, y_test, verbose=1)

print(f'Test Loss: {test_scores[0]}')
print(f'Test Accuracy: {test_scores[1]}')

Epoch 1/20




[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.5010 - loss: 0.6922 - val_accuracy: 0.5118 - val_loss: 0.7036
Epoch 2/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - accuracy: 0.5619 - loss: 0.6516 - val_accuracy: 0.6898 - val_loss: 0.5934
Epoch 3/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - accuracy: 0.6779 - loss: 0.5573 - val_accuracy: 0.6038 - val_loss: 0.8062
Epoch 4/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 12ms/step - accuracy: 0.7374 - loss: 0.4988 - val_accuracy: 0.8442 - val_loss: 0.3966
Epoch 5/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.7652 - loss: 0.4590 - val_accuracy: 0.6848 - val_loss: 0.6144
Epoch 6/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - accuracy: 0.8004 - loss: 0.4153 - val_accuracy: 0.8526 - val_loss: 0.3866
Epoch 7/20
[1m313/313[0m [32m━

In [22]:
model.summary()