In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
from functools import partial
from jax import random, jit, vmap, value_and_grad
import time

  from .autonotebook import tqdm as notebook_tqdm


# IMDB Sentiment Analysis in jax

## The Dense Neural Network: A naive approach

Currently just realized that building the tokenizer from scratch is a very complicated problem that I will try doing later. For now, we will take the dataset as is. The IMDB reviews are broken up and encoded as an array of numbers [0,2,3] for instance. Each value in the array corresponds to a different word token that can easily be rebuilt by the encoder in the tfds given ds_info. This encoding is very special; a different problem (that has already been solved) is on how to make this encoding efficient and meaningful (so for instance, this is not a given simple word to number assignment).

In [2]:
MAX_LEN = 256
BATCH_SIZE = 64
LEARNING_RATE = 1e-3

def preprocess_data():
    (ds_train, ds_test), ds_info = tfds.load(
        'imdb_reviews/subwords8k',
        split = ['train', 'test'],
        with_info=True,
        as_supervised= True
    )
    
    # for debugging, allows us to read encoded reviews
    tokenizer = ds_info.features['text'].encoder
    # shape of each data input. first one is the review input of length (max_len) and the second is for the label (scalar)
    def truncate_text(text, label):
        return text[:MAX_LEN], label
    # Apply this function to every example in the datasets
    ds_train = ds_train.map(truncate_text)
    ds_test = ds_test.map(truncate_text)
    padded_shapes = ([MAX_LEN], [])
    # 1000 is buffer size. It pulls first 64 from shuffled buffer
    ds_train = ds_train.shuffle(1000).padded_batch(BATCH_SIZE, padded_shapes=padded_shapes)
    # test data is batched because that is more scalable than my previous approach for MNIST
    ds_test = ds_test.padded_batch(BATCH_SIZE, padded_shapes=padded_shapes)
    return ds_train, ds_test, tokenizer

layer_sizes = [MAX_LEN, 128, 128, 1]

def init_mlp_params(layer_sizes, key):
    params = []
    keys = random.split(key, len(layer_sizes) - 1)
    for nin, nout, layer_key in zip(layer_sizes[:-1], layer_sizes[1:], keys):
        w_key, b_key = random.split(layer_key)
        layer_params = {
            'w': random.normal(w_key, (nout, nin)),
            'b': jnp.zeros((nout,))
        }
        params.append(layer_params)
    return params

def mlp_apply(params, inputs):
    x = inputs
    for layer_params in params[:-1]:
        z = layer_params['w'] @ x + layer_params['b']
        x = jax.nn.relu(z)
    final_layer_params = params[-1]
    output = final_layer_params['w'] @ x + final_layer_params['b']
    output = jax.nn.sigmoid(output)
    return output
    
def loss_fn(params, inputs, targets):
    predictions = vmap(mlp_apply, in_axes=(None,0))(params, inputs)
    return jnp.mean((jnp.squeeze(predictions)-targets) ** 2)

weight_decay = 0.0001
solver = optax.adamw(learning_rate = 0.001, weight_decay=weight_decay)
@jit
def train_step(params, inputs, targets, opt_state):
    loss, grads = value_and_grad(loss_fn)(params, inputs, targets)
    updates, opt_state = solver.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return loss, new_params, opt_state

def test_accuracy(params, ds_test):
    batch_accuracies = []
    for batch in tfds.as_numpy(ds_test):
        inputs, labels = batch
        batch_predictions = vmap(mlp_apply, in_axes=(None,0))(params,inputs)
        batch_accuracies.append(jnp.mean(batch_predictions.round() == labels))
    return sum(batch_accuracies) / len(batch_accuracies)

In [3]:
ds_train, ds_test, tokenizer = preprocess_data()



In [4]:
key = random.PRNGKey(42)
params = init_mlp_params(layer_sizes, key)
opt_state = solver.init(params)

epochs = 50

# total_training_start_time = time.time()
# for epoch in range(epochs):
#     for test_batch in tfds.as_numpy(ds_train):
#         reviews, labels = test_batch
#         loss, params, opt_state = train_step(params, reviews, labels, opt_state)
#     if epoch % 10 == 0:
#         print(f'accuracy at epoch {epoch}: {test_accuracy(params, ds_test)}')
# print(f'accuracy at epoch 50: {test_accuracy(params, ds_test)}')
# total_training_duration = time.time() - total_training_start_time
# print(f'Training finished in {total_training_duration:.2f}s')

## The Dense Neural Network with upgrades
1. Add an embedding layer. This translates meaningless token values (50 for 'movie' for instance) to a mapped vector of size EMBED_DIM. 
   1. Instead of inputting a (MAX_LEN * EMBED_DIM) size vector after flattening this input, we summarize with global average pooling: we take the mean of the (MAX_LEN, EMBED_DIM) vector into a (EMBED_DIM) vector
   2. This input does not maintain the order of the words, simple meanings associated with each word
   3. Include Embed layer in params, and also before MLP in forward pass
2. Incorporate Binary Cross Entropy for Loss
   1. Sigmoid converts values into probalities between 0 and 1
   2. Previous attempt uses p or 1-p squared to calculate loss
   3. BCE uses -log(1-p) to calculuate loss. This non 0-1 bounded output makes penalties bigger near the tails (0 and 1)
   4. We still keep just the sigmoid for accuracy
   5. notably, we apply sigmoid and bce at the same time (sigmoid_binary_cross_entropy) because library designers made this more robust for floating point arithmetic
      1. Means we predict the logits, and apply sigmoid in accuracy, sigmoid_binary_cross_entropy in loss
3. Tune Hyperparameters
   1. Will try [200, 256, 128, 1]
4. Implement Dropout for overfitting
5. Integrate Pre-trained GloVe embeddings
   1. Will be important to translate this to the tokenization system that subwords 8k already did and handle out of vocabulary words

we keep the same preprocess data function

Due to the complexity, I think now is a good time to switch to Keras

In [None]:
MAX_LEN = 256
BATCH_SIZE = 64
EMBED_DIM = 200
VOCAB_SIZE = 10000
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5

mlp_layer_sizes = [EMBED_DIM, 256, 128, 1]

def init_mlp_params(layer_sizes, key):
    mlp_params = []
    embed_key, mlp_key = random.split(key)
    embedding_matrix = random.normal(embed_key, (VOCAB_SIZE, EMBED_DIM))
    mlp_keys = random.split(mlp_key, len(layer_sizes) - 1)
    for nin, nout, layer_key in zip(layer_sizes[:-1], layer_sizes[1:], mlp_keys):
        w_key, b_key = random.split(layer_key)
        layer_params = {
            'w': random.normal(w_key, (nout, nin)),
            'b': jnp.zeros((nout,))
        }
        mlp_params.append(layer_params)
    return {
        'embedding': embedding_matrix,
        'mlp': mlp_params
    }

def mlp_apply(params, inputs):
    x = params['embedding'][inputs]
    x = jnp.mean(x, axis=0)
    for layer_params in params['mlp'][:-1]:
        z = layer_params['w'] @ x + layer_params['b']
        x = jax.nn.relu(z)
    final_layer_params = params['mlp'][-1]
    output = final_layer_params['w'] @ x + final_layer_params['b']
    return output

def loss_fn(params, inputs, targets):
    prediction_logits = vmap(mlp_apply, in_axes=(None,0))(params, inputs)
    prediction_logits = jnp.squeeze(prediction_logits)
    return jnp.mean(optax.sigmoid_binary_cross_entropy(prediction_logits, targets))

weight_decay = 0.0001
solver = optax.adamw(learning_rate = 0.001, weight_decay=weight_decay)
@jit
def train_step(params, inputs, targets, opt_state):
    loss, grads = value_and_grad(loss_fn)(params, inputs, targets)
    updates, opt_state = solver.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return loss, new_params, opt_state

def test_accuracy(params, ds_test):
    batch_accuracies = []
    for batch in tfds.as_numpy(ds_test):
        inputs, labels = batch
        logits = vmap(mlp_apply, in_axes=(None,0))(params,inputs)
        probabilities = jax.nn.sigmoid(logits)
        predictions = jnp.squeeze(probabilities).round()
        batch_accuracies.append(jnp.mean(predictions == labels))
    return jnp.mean(jnp.array(batch_accuracies))

In [6]:
key = random.PRNGKey(42)
params = init_mlp_params(mlp_layer_sizes, key)
opt_state = solver.init(params)

epochs = 50
total_training_start_time = time.time()
for epoch in range(epochs):
    epoch_losses = []
    for test_batch in tfds.as_numpy(ds_train):
        reviews, labels = test_batch
        loss, params, opt_state = train_step(params, reviews, labels, opt_state)
        epoch_losses.append(loss)
    avg_loss = np.mean(epoch_losses)
    if epoch % 10 == 0:
        accuracy = test_accuracy(params, ds_test)
        print(f'After epoch {epoch}: Avg Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}')
print(f'accuracy at epoch 50: {test_accuracy(params, ds_test)}')
total_training_duration = time.time() - total_training_start_time
print(f'Training finished in {total_training_duration:.2f}s')

After epoch 0: Avg Loss = 36.8940, Accuracy = 0.5988
After epoch 10: Avg Loss = 0.5627, Accuracy = 0.7163
After epoch 20: Avg Loss = 0.1415, Accuracy = 0.7389
After epoch 30: Avg Loss = 0.0996, Accuracy = 0.7555
After epoch 40: Avg Loss = 0.0592, Accuracy = 0.7646
accuracy at epoch 50: 0.7655929923057556
Training finished in 116.62s
