In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
from flax.training import train_state
import optax

In [2]:
IMAGE_SIZE = 28 * 28
NUM_CLASSES = 10
TRAIN_BATCH_SIZE = 100
TRAIN_STEPS = 1000

train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()

In [3]:
def create_keras_model():
    model = keras.Sequential([
        keras.layers.Conv2D(32, 3, 1, padding='SAME', input_shape=[28, 28, 1], activation='relu'),
        keras.layers.MaxPool2D(2, padding='SAME'),
        keras.layers.Conv2D(64, 3, 1, padding='SAME', activation='relu'),
        keras.layers.MaxPool2D(2, padding='SAME'),
        keras.layers.Conv2D(128, 3, 1, padding='SAME', activation='relu'),
        keras.layers.MaxPool2D(2, padding='SAME'),
        keras.layers.Flatten(),
        keras.layers.Dense(NUM_CLASSES)
    ])
    model.summary()
    return model


keras.backend.clear_session()
model = create_keras_model()
optimizer = keras.optimizers.Adam()


# Casting from raw data to the required datatypes.
def cast(images, labels):
    images = tf.cast(
        tf.reshape(images, [-1, 28, 28, 1]), tf.float32)
    labels = tf.cast(labels, tf.int64)
    return (images, labels)


@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
        predicted_labels = model(images)
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=predicted_labels, labels=labels
        ))
    layer_variables = model.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))


def train():
    steps = 0
    for images, labels in train_ds:
        if steps > TRAIN_STEPS:
            break
        train_mnist(images, labels)
        steps += 1


%timeit train()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 28, 28, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 4, 4, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 2048)              0

In [4]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x


@jax.jit
def update_state(state, grads):
    return state.apply_gradients(grads=grads)


@jax.jit
def train_step(state, images, labels):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        y = jax.nn.one_hot(labels, NUM_CLASSES)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
        return loss
    
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params)
    return update_state(state, grads)


def train_jax():
    key = jax.random.PRNGKey(666)
    model = Model()
    params = model.init(key, jnp.ones([1, 28, 28, 1]))['params']
    state = train_state.TrainState.create(apply_fn=model.apply,
                                          params=params,
                                          tx=optax.adam(1e-3))
    steps = 0
    for batch in train_ds:
        if steps > TRAIN_STEPS:
            break
        images = np.array(batch[0]).reshape([-1, 28, 28, 1])
        labels = np.array(batch[1])
        state = train_step(state, images, labels)
        steps += 1


%timeit train_jax()

1.78 s ± 43.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
