# JAX-native distribution with a Keras model

In [None]:
!pip install --upgrade -q tensorflow
!pip install --upgrade keras

In [None]:
!pip install tqdm

In [4]:
# Force a JAX backend
import os
os.environ['KERAS_BACKEND'] = 'jax'

In [5]:
import jax
import jax.numpy as jnp
import tensorflow as tf # just for tf.data
import keras # Keras multi-backend

import numpy as np
import collections
from tqdm import tqdm

print(keras.__version__)

2024-01-25 06:49:00.626221: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-25 06:49:00.626273: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-25 06:49:00.628363: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


3.0.4


# Dataset
Classic MNIST, loaded using tf.data

In [6]:
BATCH_SIZE=192

(x_train, train_labels), (x_eval, eval_labels) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1).astype(np.float32) # from 28x28 to 28x28 x 1 color channel (B&W)
x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32)

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)
train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)
train_data = train_data.repeat()

eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels))
eval_data = eval_data.batch(10000) # everything as one batch

STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE
print(train_data.element_spec)

(TensorSpec(shape=(192, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(192,), dtype=tf.uint8, name=None))


# Keras model
Simple but non-trivial model with:
* Batch Normalization (non-trainable state updated during trainig, different training-time and inference behavior)
* Dropout (randomness, different training time and inference behavior)

In [7]:
# Keras "sequential" model building style
def make_backbone():
  return keras.Sequential([
    keras.layers.Rescaling(1./255.), # input images are in the range [0, 255]

    keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', use_bias=False),
    keras.layers.BatchNormalization(scale=False, center=True),
    keras.layers.Activation('relu'),

    keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', use_bias=False, strides=2),
    keras.layers.BatchNormalization(scale=False, center=True),
    keras.layers.Activation('relu'),

    keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', use_bias=False, strides=2, name='large_k'),
    keras.layers.BatchNormalization(scale=False, center=True),
    keras.layers.Activation('relu'),
  ], name="backbone")

# Keras "functional" model building style: adding a classification head
input = keras.Input(shape=[28, 28, 1])
y = make_backbone()(input)
y = keras.layers.Flatten()(y)
y = keras.layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = keras.layers.Dropout(0.4)(y)
y = keras.layers.Dense(units=10, activation='softmax')(y)
model = keras.Model(inputs=input, outputs=y)

model.summary(expand_nested=True)

In [8]:
# Define the optimizer, loss, metrics and learning rate schedule
lr = keras.optimizers.schedules.ExponentialDecay(0.01, STEPS_PER_EPOCH, 0.6)
optimizer = keras.optimizers.Adam(lr)
model.compile(optimizer=optimizer,
              loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])

# Default Keras trainer (optional)
Keras offers a default trainier. You can also use a custom training loop (see further down)

In [9]:
EPOCHS=5
history = model.fit(train_data,
                    steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS,
                    validation_data=eval_data, validation_steps=1)

Epoch 1/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - loss: 0.9698 - sparse_categorical_accuracy: 0.8040 - val_loss: 0.2725 - val_sparse_categorical_accuracy: 0.9209
Epoch 2/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0945 - sparse_categorical_accuracy: 0.9730 - val_loss: 0.0813 - val_sparse_categorical_accuracy: 0.9760
Epoch 3/5
[1m 52/312[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 3ms/step - loss: 0.0605 - sparse_categorical_accuracy: 0.9826

  for step, data in epoch_iterator.enumerate_epoch():


[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0558 - sparse_categorical_accuracy: 0.9839 - val_loss: 0.0325 - val_sparse_categorical_accuracy: 0.9888
Epoch 4/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0403 - sparse_categorical_accuracy: 0.9882 - val_loss: 0.0380 - val_sparse_categorical_accuracy: 0.9887
Epoch 5/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0286 - sparse_categorical_accuracy: 0.9917 - val_loss: 0.0217 - val_sparse_categorical_accuracy: 0.9927


# Custom Keras layer
Multi-backend compatible.<BR/>
Keras provides several multi-backend abstractions:
* `Layer.add_weight` for creating trainable variables
* `keras.operations` for multi-backend math ops

All standard Keras elements are also multi-backend: `keras.initializers`, `keras.layers`, `keras.optimizers` etc

In [10]:
class MyDense(keras.Layer):
    def __init__(self, units, name=None):
        super().__init__(name=name)
        self.units = units

    # Weights can be instantiates either in __init__ or in build.
    # In build, the input shape of the layer is available automatically.
    def build(self, input_shape):
        input_dim = input_shape[-1]
        w_shape = (input_dim, self.units)
        self.w = self.add_weight(name="kernel", shape=w_shape, initializer="GlorotUniform")

        b_shape = (self.units,)
        self.b = self.add_weight(name="bias", shape=b_shape, initializer="Zeros")

    def call(self, inputs):
        return keras.ops.numpy.matmul(inputs, self.w) + self.b

In [11]:
# model with new classification head using MyDense
input = keras.Input(shape=[28, 28, 1])
y = make_backbone()(input)
y = keras.layers.Flatten()(y)
y = MyDense(200)(y)
y = keras.layers.Activation('relu')(y)
y = keras.layers.Dropout(0.4)(y)
y = MyDense(10)(y)
y = keras.layers.Activation('softmax')(y)
my_model = keras.Model(inputs=input, outputs=y)

In [12]:
# train it
optimizer = keras.optimizers.Adam(lr)
my_model.compile(optimizer=optimizer,
                 loss=keras.losses.SparseCategoricalCrossentropy(),
                 metrics=[keras.metrics.SparseCategoricalAccuracy()])
EPOCHS=5
history = my_model.fit(train_data,
                       steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS,
                       validation_data=eval_data, validation_steps=1)

Epoch 1/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 16ms/step - loss: 1.0293 - sparse_categorical_accuracy: 0.8116 - val_loss: 0.0946 - val_sparse_categorical_accuracy: 0.9714
Epoch 2/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0953 - sparse_categorical_accuracy: 0.9721 - val_loss: 0.0866 - val_sparse_categorical_accuracy: 0.9746
Epoch 3/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0601 - sparse_categorical_accuracy: 0.9831 - val_loss: 0.0346 - val_sparse_categorical_accuracy: 0.9889
Epoch 4/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0447 - sparse_categorical_accuracy: 0.9887 - val_loss: 0.0415 - val_sparse_categorical_accuracy: 0.9888
Epoch 5/5
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.0322 - sparse_categorical_accuracy: 0.9903 - val_loss: 0.0282 - val_sparse_categorical_accuracy: 0.9908


# Custom training loop
Written in Keras + JAX


In [13]:
# This cell is all Keras

def make_model():
    input = keras.Input(shape=[28, 28, 1])
    y = make_backbone()(input)
    y = keras.layers.Flatten()(y)
    y = keras.layers.Dense(200, activation="relu")(y)
    y = keras.layers.Dropout(0.4)(y)
    y = keras.layers.Dense(10, activation='softmax')(y)
    model = keras.Model(inputs=input, outputs=y)
    return model

# instantiate the model again
model = make_model()

# optimizer
optimizer = keras.optimizers.Adam(lr)

# initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)

# collect state in a handy named tuple
TrainingState = collections.namedtuple('TrainingState',
 ['trainable_variables', 'non_trainable_variables', 'optimizer_variables'])

train_state = TrainingState(trainable_variables = model.trainable_variables,
                            non_trainable_variables = model.non_trainable_variables,
                            optimizer_variables = optimizer.variables)

This cell uses JAX. Keras provides two pure functions for JAX:
* `model.stateless_call`
* `optimizer.stateless_apply`

These functions also work on other backends.

In [14]:
# This cell is all JAX,
# using a Keras loss as well as Keras stateless functions
# model.stateless_call and optimizer.stateless_apply

# define loss
loss = keras.losses.SparseCategoricalCrossentropy()

# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
    y_pred, updated_non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x)
    loss_value = loss(y, y_pred)
    return loss_value, updated_non_trainable_variables

# function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)

# Trainig step
# Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    (loss_value, non_trainable_variables), grads = compute_gradients(
        train_state.trainable_variables, train_state.non_trainable_variables,
        x, y)

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        train_state.optimizer_variables, grads, train_state.trainable_variables)

    return loss_value, TrainingState(trainable_variables,
                                     non_trainable_variables,
                                     optimizer_variables)

In [15]:
# Custom training loop
data_iter = iter(train_data)
for epoch in range(EPOCHS):
    for i in tqdm(range(STEPS_PER_EPOCH)):
        x, y = next(data_iter)
        loss_value, train_state = train_step(train_state, x.numpy(), y.numpy())
    print("Epoch", epoch, "loss:", loss_value)

100%|█████████████████████████████████████████| 312/312 [00:03<00:00, 80.00it/s]


Epoch 0 loss: 0.007071685


100%|████████████████████████████████████████| 312/312 [00:00<00:00, 789.35it/s]


Epoch 1 loss: 0.025994292


100%|████████████████████████████████████████| 312/312 [00:00<00:00, 805.98it/s]


Epoch 2 loss: 0.0018344952


100%|████████████████████████████████████████| 312/312 [00:00<00:00, 811.68it/s]


Epoch 3 loss: 0.0017096938


100%|████████████████████████████████████████| 312/312 [00:00<00:00, 815.06it/s]

Epoch 4 loss: 0.0014083621





In [None]:
# Post-processing model state update
update = lambda variable, value: variable.assign(value)

jax.tree_map(update, model.trainable_variables, train_state.trainable_variables)
jax.tree_map(update, model.non_trainable_variables, train_state.non_trainable_variables)
jax.tree_map(update, optimizer.variables, train_state.optimizer_variables)

# check that the model has the new state by running an eval
# known issue: the optimizer should not be required here
model.compile(optimizer=optimizer,
              loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])
loss, accuracy = model.evaluate(eval_data)
print("The model achieved an evaluation accuracy of:", accuracy)

# JAX-native distribution with a Keras model
For now, you have to write a custom training loop for this

__Note: The features required by jax.sharding are not supported by the Colab TPU runtime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs. Try Kaggle first.__

In [17]:
if len(jax.local_devices()) < 8:
  raise Exception("This part requires 8 devices to run")

In [18]:
# This cell is all Keras

# instantiate the model again
model = make_model()

# optimizer
optimizer = keras.optimizers.Adam(lr)

# initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)

## Distribution settings
* Sharding the data on the batch axis
* Replicating all model variables

__Note__: this implements standard "data parallel" distributed training

* Just for show, sharding the largest convolutional kernel along the "channels" axis 4-ways and replicating 2-ways

__Note__: this does not reflect a best practice but is intended to show that you can split a very large kernel across multiple devices if you have to

In [19]:
# this cell is all JAX

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P

devices = mesh_utils.create_device_mesh((8,))

# data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=('batch',)) # naming axes of the mesh
data_sharding = NamedSharding(data_mesh, P('batch',)) # naming axes of the sharded partition

# all variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=('_'))
var_replication = NamedSharding(var_mesh, P()) # in NamedSharding, axes that are not mentioned are replicated (all axes here)

# for the demo, we wil split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices)
large_kernel_mesh = Mesh(devices.reshape((-1,4)), axis_names=(None, 'out_chan')) # naming axes of the mesh
large_kernel_sharding = NamedSharding(large_kernel_mesh, P(None, None, None, 'out_chan')) # naming axes of the sharded partition

In [20]:
# Use Keras APIs to find the variable of a specific layer (we will be sharding this one in a special way)
# In a Conv2D or Dense layer, the variables are 'kernel' and 'bias'
special_layer_var = model.get_layer("backbone").get_layer("large_k").kernel

In [21]:
# this cell is all JAX,
# accessing variables in Keras lists model.trainable_variables, model.non_trainable_variables and optimizer.variables

# Apply the distribution settings to the model variables
non_trainable_variables = jax.device_put(model.non_trainable_variables, var_replication)
optimizer_variables = jax.device_put(optimizer.variables, var_replication)
# trainable_variables = jax.device_put(model.trainable_variables, var_replication)

# for the demo, we split the largest kernel 4-ways instead of replicating it
print_once=True
trainable_variables = model.trainable_variables
for i,v in enumerate(trainable_variables):
    if v is special_layer_var:

        # Apply distribution settings: sharding
        sharded_v = jax.device_put(v, large_kernel_sharding)
        trainable_variables[i] = sharded_v

        print("Sharding of convolutional", v.name, v.shape)
        jax.debug.visualize_array_sharding(jnp.reshape(sharded_v, [-1, v.shape[-1]]))
    else:
        # Apply distribution settings: replication
        replicated_v = jax.device_put(v, var_replication)
        trainable_variables[i] = replicated_v

        if (print_once):
            print_once=False
            print("Sharding of all other model variables")
            jax.debug.visualize_array_sharding(jnp.reshape(replicated_v, [-1, v.shape[-1]]))

# collect state in a handy named tuple
device_train_state = TrainingState(trainable_variables=trainable_variables,
                                   non_trainable_variables=non_trainable_variables,
                                   optimizer_variables=optimizer_variables)

Sharding of all other model variables


Sharding of convolutional kernel (6, 6, 24, 32)


In [22]:
# Custom training loop

# display data sharding
x,y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28*28]))

# training loop
data_iter = iter(train_data)
for epoch in range(EPOCHS):
    for i in tqdm(range(STEPS_PER_EPOCH)):
        x, y = next(data_iter)
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, device_train_state = train_step(device_train_state, sharded_x, y.numpy())
    print("Epoch", epoch, "loss:", loss_value)

Data sharding


100%|█████████████████████████████████████████| 312/312 [00:09<00:00, 33.16it/s]


Epoch 0 loss: 0.011462945


100%|████████████████████████████████████████| 312/312 [00:02<00:00, 129.64it/s]


Epoch 1 loss: 0.019896783


100%|████████████████████████████████████████| 312/312 [00:02<00:00, 128.28it/s]


Epoch 2 loss: 0.0040018708


100%|████████████████████████████████████████| 312/312 [00:02<00:00, 129.54it/s]


Epoch 3 loss: 0.00522309


100%|████████████████████████████████████████| 312/312 [00:02<00:00, 130.06it/s]

Epoch 4 loss: 0.00052411645





The output of the model is still sharded. Sharding follows the data.

In [23]:
data, labels = next(iter(eval_data))
sharded_data = jax.device_put(data.numpy(), data_sharding)

@jax.jit
def predict(data):
    predictions, updated_non_trainable_variables = model.stateless_call(
        device_train_state.trainable_variables,
        device_train_state.non_trainable_variables, data)
    return predictions

predictions = predict(sharded_data)
print("Model output sharding")
jax.debug.visualize_array_sharding(predictions)

Model output sharding


In [37]:
predictions.argmax(axis=-1), labels

(Array([7, 2, 1, ..., 4, 5, 6], dtype=int32),
 <tf.Tensor: shape=(10000,), dtype=uint8, numpy=array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)>)

In [38]:
# Post-processing model state update
update = lambda variable, value: variable.assign(value)

jax.tree_map(update, model.trainable_variables, device_train_state.trainable_variables)
jax.tree_map(update, model.non_trainable_variables, device_train_state.non_trainable_variables)
jax.tree_map(update, optimizer.variables, device_train_state.optimizer_variables)

# check that the model has the new state by running an eval
# known issue: https://github.com/keras-team/keras/issues/18681
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])
loss, accuracy = model.evaluate(eval_data)
print("The model achieved an evaluation accuracy of:", accuracy)

The model achieved an evaluation accuracy of: 0.9906999468803406


# And coming soon: KerasCV and KerasNLP pre-trained models
We are currently adjusting them to support Keras multi-backend. Stay tuned.
