In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import numpy as np
import tensorflow as tf
import keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P

In [2]:
print(jax.devices())

[cuda(id=0), cuda(id=1)]


In [4]:
def get_model():
    # Make a simple convnet with batch normalization and dropout.
    inputs = keras.Input(shape=(28, 28, 1))
    x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
    x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
        x
    )
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=24,
        kernel_size=6,
        use_bias=False,
        strides=2,
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=32,
        kernel_size=6,
        padding="same",
        strides=2,
        name="large_k",
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dropout(0.5)(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    return model


def get_datasets():
    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")
    print(x_test.shape[0], "test samples")

    # Create TF Datasets
    train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    return train_data, eval_data

In [5]:
# Config
num_epochs = 2
batch_size = 64

train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)

model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [7]:
# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
    y_pred, updated_non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss_value = loss(y, y_pred)
    return loss_value, updated_non_trainable_variables


# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)

In [8]:
# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    trainable_variables, non_trainable_variables, optimizer_variables = train_state
    (loss_value, non_trainable_variables), grads = compute_gradients(
        trainable_variables, non_trainable_variables, x, y
    )

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )

    return loss_value, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

In [9]:
# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
    # All variables will be replicated on all devices
    var_mesh = Mesh(devices, axis_names=("_"))
    # In NamedSharding, axes not mentioned are replicated (all axes here)
    var_replication = NamedSharding(var_mesh, P())

    # Apply the distribution settings to the model variables
    trainable_variables = jax.device_put(model.trainable_variables, var_replication)
    non_trainable_variables = jax.device_put(
        model.non_trainable_variables, var_replication
    )
    optimizer_variables = jax.device_put(optimizer.variables, var_replication)

    # Combine all state in a tuple
    return (trainable_variables, non_trainable_variables, optimizer_variables)


In [10]:
num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))

# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # naming axes of the sharded partition

Running on 2 devices: [cuda(id=0), cuda(id=1)]


In [11]:
# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))

train_state = get_replicated_train_state(devices)

Data sharding


In [14]:
# Custom training loop
for epoch in range(100):
    data_iter = iter(train_data)
    for data in data_iter:
        x, y = data
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
    print("Epoch", epoch, "loss:", loss_value)

Epoch 0 loss: 0.015642429
Epoch 1 loss: 8.512389e-05
Epoch 2 loss: 0.00097128894
Epoch 3 loss: 0.00017838652
Epoch 4 loss: 0.0016941413
Epoch 5 loss: 0.0033031378
Epoch 6 loss: 0.0008292221
Epoch 7 loss: 0.0012198741
Epoch 8 loss: 0.00859246
Epoch 9 loss: 0.00043380394
Epoch 10 loss: 0.0349427
Epoch 11 loss: 8.1968516e-05
Epoch 12 loss: 0.001320726
Epoch 13 loss: 6.179434e-05
Epoch 14 loss: 0.00024736577
Epoch 15 loss: 0.0011571057
Epoch 16 loss: 0.003857566
Epoch 17 loss: 0.07024604
Epoch 18 loss: 2.021938e-05
Epoch 19 loss: 0.00043419388
Epoch 20 loss: 2.0674895e-06
Epoch 21 loss: 0.0019072035
Epoch 22 loss: 4.773264e-05
Epoch 23 loss: 9.284496e-05
Epoch 24 loss: 0.010036009
Epoch 25 loss: 0.009936109
Epoch 26 loss: 3.162212e-05
Epoch 27 loss: 0.0011522755
Epoch 28 loss: 0.00034596364
Epoch 29 loss: 0.006314157
Epoch 30 loss: 0.00015106531
Epoch 31 loss: 2.9397821e-05
Epoch 32 loss: 0.0006282042
Epoch 33 loss: 0.0003507756
Epoch 34 loss: 0.00049265556
Epoch 35 loss: 0.008983544
Epoch