<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [175]:
from typing import Sequence

import jax.numpy as jnp
from jax import grad, jit, random, tree_util, vmap
from jax.scipy.special import logsumexp
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import optax

In [176]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

In [177]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data_jax[data_type] = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data_jax[data_type] = flatten_image_for_mlp(data_jax[data_type])
    return data_jax

In [178]:
dataset_tf = "train"
all_data_tf = mnist_data[dataset_tf]
all_data_jax = prepare_data(all_data_tf)

In [179]:
images = all_data_jax["image"]
labels = all_data_jax["label"]

In [180]:
print("Images shape:", images.shape)
print("Labels shape:", labels.shape)

Images shape: (60000, 784)
Labels shape: (60000,)


In [181]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]  # Proper annotation for Flax modules

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = nn.Dense(
                layer_size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1)
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations

In [182]:
def initialise_network_params(layer_sizes, key, model):
    """Initialize all layers for a fully-connected neural network"""
    input_shape_dummy = jnp.ones((1, layer_sizes[0]))
    params = model.init(key, input_shape_dummy)["params"]
    return params

In [183]:
def calculate_loss(predictions_logits, observed_label):
    log_probs = predictions_logits - logsumexp(predictions_logits)
    return -log_probs[observed_label] 

calculate_loss_batch = vmap(calculate_loss, in_axes=(0, 0))

def calculate_mean_loss_batch(params, images, labels, model):
    logits = model.apply({"params": params}, images) # foward pass
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

In [184]:
@jit
def take_training_step(params, images, labels, optimizer_state):
    """
    Single training step 
    `model` and `optimizer` are defined outside the function because they can't be jitted
    """
    calculate_gradients_by_param = grad(calculate_mean_loss_batch)
    gradients_by_param = calculate_gradients_by_param(params, images, labels)
    updates_by_param, optimizer_state = optimizer.update(gradients_by_param, optimizer_state)
    params = optax.apply_updates(params, updates_by_param)
    return params, optimizer_state

In [185]:
@jit
def take_training_step(state, images, labels):
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, images)
        return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

    grads = grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)


In [186]:
def run_training(images, labels, n_steps, layer_sizes, key, lr=1e-3):
    model = MLP(layer_sizes=layer_sizes[1:])
    params = initialise_network_params(layer_sizes, key, model)

    optimizer = optax.adam(lr)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    for step in range(n_steps):
        state = take_training_step(state, images, labels)
        loss = calculate_mean_loss_batch(state.params, images, labels, model)
        print(f"step {step}: loss={loss}")

    return state.params, model

In [187]:
trial_set_size = 20
test_images = images[:trial_set_size]
test_labels = labels[:trial_set_size]

layer_sizes = [784, 128, 10]
key = random.key(0)

params, model = run_training(test_images, test_labels, n_steps=5, layer_sizes=layer_sizes, key=key)

logits = model.apply({"params": params}, test_images)
loss = calculate_loss_batch(logits, test_labels)
predictions = jnp.argmax(logits, axis=1)

step 0: loss=2.1574771404266357
step 1: loss=1.9256445169448853
step 2: loss=1.7224485874176025
step 3: loss=1.5373908281326294
step 4: loss=1.3704960346221924


In [188]:
print("True labels:    ", test_labels)
print("Predictions:    ", predictions)
print("Match:          ", predictions == test_labels)
print("Loss            ", loss)

True labels:     [4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1]
Predictions:     [1 1 6 7 8 1 2 9 1 6 6 1 7 7 3 3 7 9 9 1]
Match:           [False  True False  True  True  True  True False  True  True  True False
  True  True  True  True  True  True  True  True]
Loss             [1.7879559  1.1756442  2.1219096  0.80904436 1.1157752  1.352685
 1.3984716  1.574126   0.9887434  1.3136357  1.7677958  2.639235
 1.1510949  1.5284108  0.9676604  1.4889581  0.8168993  1.0736182
 1.2909293  1.0473256 ]
