<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [43]:
from typing import Sequence

import jax.numpy as jnp
from jax import grad, jit, random, tree_util, vmap
from jax.scipy.special import logsumexp
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import linen as nn

In [44]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

In [45]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data_jax[data_type] = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data_jax[data_type] = flatten_image_for_mlp(data_jax[data_type])
    return data_jax

In [46]:
dataset_tf = "train"
all_data_tf = mnist_data[dataset_tf]
all_data_jax = prepare_data(all_data_tf)

In [47]:
images = all_data_jax["image"]
labels = all_data_jax["label"]

In [48]:
print("Images shape:", images.shape)
print("Labels shape:", labels.shape)

Images shape: (60000, 784)
Labels shape: (60000,)


In [49]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for layer, size in enumerate(self.layer_sizes):
            x = nn.Dense(
                size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1),
            )(x)
            if layer != len(self.layer_sizes) - 1:
                x = nn.relu(x)
        return x

In [50]:
def forward_pass(x, params):
    """Forward pass through all layers: inputs and outputs are vectors"""
    model = MLP(layer_sizes=layer_sizes[1:])
    return model.apply(params, x)

In [51]:
def calculate_loss(predictions_logits, observed_label):
    log_probs = predictions_logits - logsumexp(predictions_logits)
    return -log_probs[observed_label] 

In [52]:
forward_pass_batch = vmap(forward_pass, in_axes=(0, None))
calculate_loss_batch = vmap(calculate_loss, in_axes=(0, 0))

In [None]:
def initialise_network(sizes, key):
    """Initialize all layers for a fully-connected neural network"""
    model = MLP(layer_sizes=sizes[1:])
    dummy_ones_for_shape = jnp.ones((1, sizes[0]))
    params = model.init(key, dummy_ones_for_shape)
    return params

In [54]:
def calculate_mean_loss_batch(params, images, labels):
    logits = forward_pass_batch(images, params)
    loss = calculate_loss_batch(logits, labels)
    mean_loss = jnp.mean(loss)
    return mean_loss
    
calculate_gradients_by_param = grad(calculate_mean_loss_batch)

def update_parameters(params, gradients_by_param, learning_rate):
    return tree_util.tree_map(
        lambda p, g: p - learning_rate * g, 
        params, 
        gradients_by_param
    )

@jit
def take_training_step(params, images, labels, learning_rate=0.1):
    gradients_by_param = calculate_gradients_by_param(params, images, labels)
    params_new = update_parameters(params, gradients_by_param, learning_rate)
    return params_new

In [None]:
def run_training(images, labels, n_steps, layer_sizes, key):
    params = initialise_network(layer_sizes, key)
    for step in range(n_steps):
        params = take_training_step(params, images, labels)
        loss = calculate_mean_loss_batch(params, images, labels)
        print(f"step {step} complete: loss = {loss}")
    print(f"training finished after {n_steps}: final loss = {loss}")
    return params

In [56]:
layer_sizes = [784, 128, 10]
key = random.key(0)

trial_set_size = 20
test_images = images[:trial_set_size]
test_labels = labels[:trial_set_size]

# Train the model
params = run_training(test_images, test_labels, n_steps=5, layer_sizes=layer_sizes, key=key)

# Then use it
logits = forward_pass_batch(test_images, params)
loss = calculate_loss_batch(logits, test_labels)
predictions = jnp.argmax(logits, axis=1)

step 0 complete: loss = 1.9175628423690796
step 1 complete: loss = 1.587611436843872
step 2 complete: loss = 1.333425521850586
step 3 complete: loss = 1.1325856447219849
step 4 complete: loss = 0.9700950980186462
training finished after 5: final loss = 0.9700950980186462


In [57]:
print("True labels:    ", test_labels)
print("Predictions:    ", predictions)
print("Match:          ", predictions == test_labels)
print("Loss            ", loss)

True labels:     [4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1]
Predictions:     [1 1 6 7 8 1 7 7 1 6 6 9 7 7 3 3 7 9 9 1]
Match:           [False  True False  True  True  True False  True  True  True  True False
  True  True  True  True  True  True  True  True]
Loss             [1.6613489  0.70690393 1.6412845  0.3408208  1.1715062  0.8336673
 1.494753   0.89425874 0.5227027  0.8694961  1.3074994  2.3385448
 0.539485   0.9489362  0.5305395  0.8975339  0.35296202 0.75522757
 1.0567391  0.5376885 ]
