<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [34]:
from typing import Sequence

import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import optax

In [35]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

In [36]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data_jax[data_type] = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data_jax[data_type] = flatten_image_for_mlp(data_jax[data_type])
    return data_jax

In [37]:
prepared_data = {key: prepare_data(value) for key, value in mnist_data.items()}
train_data = prepared_data["train"]

images_train = train_data["image"]
labels_train = train_data["label"]

print("Images shape:", images_train.shape)
print("Labels shape:", labels_train.shape)

Images shape: (60000, 784)
Labels shape: (60000,)


In [38]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = nn.Dense(
                layer_size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1)
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations

In [39]:
def initialise_network_params(model, input_layer_size, key):
    """Initialize all layers for a fully-connected neural network"""
    input_shape_dummy = jnp.ones((1, input_layer_size))
    params = model.init(key, input_shape_dummy)["params"]
    return params

In [40]:
def calculate_mean_loss_batch(params, apply_fn, images, labels):
    logits = apply_fn({"params": params}, images) # FORWARD PASS
    cross_entropy_by_sample = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    cross_entropy_mean = cross_entropy_by_sample.mean()
    return cross_entropy_mean

In [41]:
@jax.jit
def take_training_step(training_state, images, labels):
    """
    Single training step 
    The model and optimiser are passed in the training state
    returns a training state
    """
    grads_by_params_fn = jax.grad(calculate_mean_loss_batch)
    grads_by_params = grads_by_params_fn(
        training_state.params,     # params is first â†’ grad w.r.t. params
        training_state.apply_fn,
        images,
        labels,
    )
    return training_state.apply_gradients(grads=grads_by_params)

In [42]:
def get_batches(images, labels, n_batches):
    """Drops the last set of samples if they're not the right length"""
    n_samples = len(images)
    assert len(images) == len(labels)
    assert n_samples >= n_batches
    assert n_batches >= 0
    n_samples_per_batch = n_samples // n_batches
    start = 0
    end = n_samples_per_batch
    while end <= n_samples: 
        yield (images[start:end], labels[start:end])
        start += n_samples_per_batch
        end += n_samples_per_batch

In [47]:
def run_training(images, labels, n_steps, layer_sizes, key, initial_learning_rate=1e-3):
    """
    The training state ('state') is an instance of TrainState that holds:
    - apply_fn: the model's apply function, used for forward passes
    - params: the parameters of the neural network
    - tx: the optimizers (Optax transformation) for parameter updates
    - opt_state: the state of the optimizer
    """

    input_layer_size = layer_sizes[0]
    network_layer_sizes = layer_sizes[1:]
    model = MLP(layer_sizes=network_layer_sizes)
    params = initialise_network_params(model, input_layer_size, key)
    apply_fn = model.apply
    
    optimizer = optax.adam(initial_learning_rate)
    
    training_state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=optimizer)

    step = 0
    for images_batch, labels_batch in get_batches(images=images, labels=labels, n_batches=n_steps):
        training_state = take_training_step(training_state, images_batch, labels_batch)
        loss = calculate_mean_loss_batch(training_state.params, training_state.apply_fn, images_batch, labels_batch)
        print(f"step {step}: loss={loss}")
        step += 1

    return training_state.params

In [51]:
trial_set_size = 100
images = images_train[:trial_set_size]
labels = labels_train[:trial_set_size]

In [52]:
n_steps = 5
layer_sizes = [784, 128, 10]
key = jax.random.key(0)
final_params = run_training(images, labels, n_steps=n_steps, layer_sizes=layer_sizes, key=key)

step 0: loss=2.1574771404266357
step 1: loss=2.396851062774658
step 2: loss=2.1652634143829346
step 3: loss=2.0755467414855957
step 4: loss=1.9482983350753784


In [53]:
model = MLP(layer_sizes=layer_sizes[1:])
logits = model.apply({"params": final_params}, images)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
predictions = jnp.argmax(logits, axis=1)

print("True labels:    ", labels)
print("Predictions:    ", predictions)
print("Match:          ", predictions == labels)
print("Loss            ", loss)

True labels:     [4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1 0 6 6 9 9 4 8 9 4 7 3 3 0 9 4 9 0
 6 8 4 7 2 6 0 3 1 1 7 2 4 4 6 5 1 9 3 2 4 3 4 4 7 5 8 1 1 4 1 5 3 5 8 4 1
 1 4 5 3 2 4 1 4 8 1 2 1 9 0 7 6 7 4 4 9 7 5 6 8 4 6]
Predictions:     [1 1 4 7 8 1 7 4 1 6 5 1 4 1 3 8 7 4 9 1 4 5 6 2 9 7 1 1 2 9 6 4 4 9 4 4 7
 6 1 2 7 1 9 7 2 1 1 9 6 1 4 2 9 1 9 7 9 9 9 4 9 7 7 6 7 1 1 1 7 3 1 1 4 1
 1 4 5 1 9 4 1 0 1 1 1 1 4 9 7 6 1 7 9 9 9 4 5 2 1 1]
Match:           [False  True False  True  True  True False False  True  True False False
 False False  True False  True False  True  True False False  True False
  True False False False False False False False False  True  True False
 False  True False False  True False False False False  True  True False
 False False  True False False  True  True False False False False  True
 False  True False False False  True False  True False  True False False
  True  True  True  True  True False False  True  True False False  True
 False  True False False  True