## Importing the Required Libraries

In [1]:
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import io

# Only enable this for tensor-core GPUs
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

# Hyperparameters

In [2]:
RESIZE_TO = 32
PATCH_SIZE = 32

NUM_MIXER_LAYERS = 2
HIDDEN_SIZE = 64
MLP_SEQ_DIM = 64
MLP_CHANNEL_DIM = 64

EPOCHS = 5
BATCH_SIZE = 128

# Dataset

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
x_train[0].shape

(32, 32, 3)

## MLP-Mixer Utilities

In [5]:
def mlp_block(x, mlp_dim):
    x = layers.Dense(mlp_dim)(x)
    x = tf.nn.gelu(x)
    return layers.Dense(x.shape[-1])(x)

def mixer_block(x, tokens_mlp_dim, channels_mlp_dim):
    y = layers.LayerNormalization()(x)
    y = layers.Permute((2, 1))(y)
    
    token_mixing = mlp_block(y, tokens_mlp_dim)
    token_mixing = layers.Permute((2, 1))(token_mixing)
    x = layers.Add()([x, token_mixing])
    
    y = layers.LayerNormalization()(x)
    channel_mixing = mlp_block(y, channels_mlp_dim)
    output = layers.Add()([x, channel_mixing])
    return output

def mlp_mixer(x, num_blocks, patch_size, hidden_dim, 
              tokens_mlp_dim, channels_mlp_dim,
              num_classes=10):
    x = layers.Conv2D(hidden_dim, kernel_size=patch_size,
                      strides=patch_size, padding="valid")(x)
    x = layers.Reshape((x.shape[1]*x.shape[2], x.shape[3]))(x)

    for _ in range(num_blocks):
        x = mixer_block(x, tokens_mlp_dim, channels_mlp_dim)
    
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(0.25)(x)
    x = layers.GlobalAveragePooling1D()(x)
    return layers.Dense(num_classes, activation="softmax", dtype="float32")(x)

In [6]:
def create_mlp_mixer():
    inputs = layers.Input(shape=(32, 32, 3))
    outputs = mlp_mixer(inputs, NUM_MIXER_LAYERS,
                        PATCH_SIZE, HIDDEN_SIZE, 
                        MLP_SEQ_DIM, MLP_CHANNEL_DIM)
    return tf.keras.Model(inputs, outputs, name="mlp_mixer")

In [7]:
mlp_mixer_classifier = create_mlp_mixer()

In [8]:
def run_experiment(model):
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )
    
    
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, top_1_accuracy = model.evaluate(x_test, y_test)
    print()
    print(f"Test accuracy: {round(top_1_accuracy * 100, 2)}%")
    
    return history, model

## Model Training and Evaluation

In [9]:
history, model = run_experiment(mlp_mixer_classifier)
model.save(f"mlp_mixer_{NUM_MIXER_LAYERS}")

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5

Test accuracy: 40.08%
