<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [147]:
from typing import Sequence

import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import optax

In [148]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

In [149]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict: dict, subsample_size: int=0):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data = flatten_image_for_mlp(data)
        if subsample_size > 0:
            data = data[:subsample_size]
        data_jax[data_type] = data

    return data_jax

In [150]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = nn.Dense(
                layer_size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1)
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations

In [151]:
def initialise_network_params(model, input_layer_size, key):
    """Initialize all layers for a fully-connected neural network"""
    input_shape_dummy = jnp.ones((1, input_layer_size))
    params = model.init(key, input_shape_dummy)["params"]
    return params

In [152]:
def calculate_mean_loss_batch(params, apply_fn, images, labels):
    logits = apply_fn({"params": params}, images) # FORWARD PASS
    cross_entropy_by_sample = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    cross_entropy_mean = cross_entropy_by_sample.mean()
    return cross_entropy_mean

In [153]:
@jax.jit
def take_training_step(training_state, images, labels):
    """
    Single training step 
    The model and optimiser are passed in the training state
    returns a training state
    """
    grads_by_params_fn = jax.grad(calculate_mean_loss_batch)
    grads_by_params = grads_by_params_fn(
        training_state.params,     # params is first â†’ grad w.r.t. params
        training_state.apply_fn,
        images,
        labels,
    )
    return training_state.apply_gradients(grads=grads_by_params)

In [154]:
def get_batches(images, labels, n_batches):
    """Drops the last set of samples if they're not the right length"""
    n_samples = len(images)
    assert len(images) == len(labels)
    assert n_samples >= n_batches
    assert n_batches > 0
    n_samples_per_batch = n_samples // n_batches
    start = 0
    end = n_samples_per_batch
    while end <= n_samples: 
        yield (images[start:end], labels[start:end])
        start += n_samples_per_batch
        end += n_samples_per_batch

In [None]:
def run_training(images, labels, n_steps, layer_sizes, optimizer, key):
    """
    The training state ('state') is an instance of TrainState that holds:
    - apply_fn: the model's apply function, used for forward passes
    - params: the parameters of the neural network
    - tx: the optimizers (Optax transformation) for parameter updates
    - opt_state: the state of the optimizer
    """

    input_layer_size = layer_sizes[0]
    network_layer_sizes = layer_sizes[1:]
    model = MLP(layer_sizes=network_layer_sizes)
    params = initialise_network_params(model, input_layer_size, key)
    apply_fn = model.apply
    
    training_state = train_state.TrainState.create(
        apply_fn=apply_fn, 
        params=params, 
        tx=optimizer,
        )

    step = 1
    for images_batch, labels_batch in get_batches(images=images, labels=labels, n_batches=n_steps):
        training_state = take_training_step(training_state, images_batch, labels_batch)
        loss = calculate_mean_loss_batch(training_state.params, training_state.apply_fn, images_batch, labels_batch)
        print(f"step {step}: loss={loss}")
        step += 1

    return training_state.params

In [171]:
train_data = prepare_data(mnist_data["train"], subsample_size=10**3) 
test_data = prepare_data(mnist_data["test"], subsample_size=10**2) 

In [172]:
def train_mlp(train_data, optimizer):
    n_steps = 20
    layer_sizes = [784, 128, 10]
    key = jax.random.key(0)
    final_params = run_training(
        train_data["image"], 
        train_data["label"], 
        n_steps, 
        layer_sizes, 
        optimizer,
        key,
        )
    return final_params

In [None]:
def extract_layer_sizes(params):
    layer_sizes = []
    for layer, layer_params in enumerate(params.values()):
        if layer == 0:
            layer_sizes.append(layer_params["kernel"].shape[0])
            layer_sizes.append(layer_params["kernel"].shape[1])
        else:
            layer_sizes.append(layer_params["bias"].shape[0])
    return layer_sizes

In [None]:
def evaluate_mlp(test_data, params, n_examples=10):
    layer_sizes = extract_layer_sizes(params)
    model = MLP(layer_sizes=layer_sizes[1:])
    apply_fn = model.apply

    images = test_data["image"]
    labels = test_data["label"]

    mean_loss = calculate_mean_loss_batch(params, apply_fn, images, labels)
    example_images = images[:n_examples]
    example_labels = labels[:n_examples]
    logits = apply_fn({"params": params}, example_images)
    example_predictions = jnp.argmax(logits, axis=1)

    print("Mean loss       ", mean_loss)
    print("True labels:    ", example_labels)
    print("Predictions:    ", example_predictions)

1. Learning rate decay
2. Weight decay

In [173]:
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.329394817352295
step 2: loss=2.110353708267212
step 3: loss=2.0951199531555176
step 4: loss=2.107095956802368
step 5: loss=2.1107473373413086
step 6: loss=1.921804428100586
step 7: loss=1.9533743858337402
step 8: loss=1.84724760055542
step 9: loss=1.7940917015075684
step 10: loss=1.764140248298645
step 11: loss=1.751694917678833
step 12: loss=1.6859773397445679
step 13: loss=1.7524176836013794
step 14: loss=1.642892599105835
step 15: loss=1.6448626518249512
step 16: loss=1.460929274559021
step 17: loss=1.3634432554244995
step 18: loss=1.413408875465393
step 19: loss=1.2928463220596313
step 20: loss=1.3445155620574951
Mean loss        1.3096466
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 8 7 6 0 3 3 1]


In [174]:
learning_rate = 1e-2
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=1.3223949670791626
step 2: loss=1.4270418882369995
step 3: loss=1.1914063692092896
step 4: loss=0.8947293162345886
step 5: loss=0.8820065259933472
step 6: loss=0.6795811057090759
step 7: loss=0.6839590072631836
step 8: loss=0.7954467535018921
step 9: loss=0.7858410477638245
step 10: loss=0.7066607475280762
step 11: loss=0.508007824420929
step 12: loss=0.6086064577102661
step 13: loss=0.6874796152114868
step 14: loss=0.6014252305030823
step 15: loss=0.6547794938087463
step 16: loss=0.4428560435771942
step 17: loss=0.35594701766967773
step 18: loss=0.3600447475910187
step 19: loss=0.4020954668521881
step 20: loss=0.3246818482875824
Mean loss        0.5156536
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 8 7 6 0 5 5 1]


In [175]:
learning_rate = 1e-1
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=7.921195030212402
step 2: loss=22.573795318603516
step 3: loss=15.713532447814941
step 4: loss=14.277660369873047
step 5: loss=4.392229080200195
step 6: loss=2.0070979595184326
step 7: loss=1.3166155815124512
step 8: loss=1.6341989040374756
step 9: loss=1.6222692728042603
step 10: loss=1.4166510105133057
step 11: loss=1.3110309839248657
step 12: loss=1.5402616262435913
step 13: loss=1.504196286201477
step 14: loss=1.6209816932678223
step 15: loss=1.6276339292526245
step 16: loss=1.2568767070770264
step 17: loss=0.8523489236831665
step 18: loss=1.3203682899475098
step 19: loss=1.4638503789901733
step 20: loss=1.3379669189453125
Mean loss        1.8293804
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 3 7 8 3 8 8 1]


In [189]:
learning_rate = 1e-2 # for all subsequent models

In [190]:
optimizer = optax.adamw(learning_rate, weight_decay=1e-4)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=1.3223958015441895
step 2: loss=1.4270424842834473
step 3: loss=1.191407322883606
step 4: loss=0.8947314023971558
step 5: loss=0.8820069432258606
step 6: loss=0.6795834898948669
step 7: loss=0.6839614510536194
step 8: loss=0.7954472303390503
step 9: loss=0.7858410477638245
step 10: loss=0.7066599726676941
step 11: loss=0.5080094337463379
step 12: loss=0.6086087822914124
step 13: loss=0.687477707862854
step 14: loss=0.6014255285263062
step 15: loss=0.6547796726226807
step 16: loss=0.44285672903060913
step 17: loss=0.35594889521598816
step 18: loss=0.36004677414894104
step 19: loss=0.4020947813987732
step 20: loss=0.3246852457523346
Mean loss        0.51565444
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 8 7 6 0 5 5 1]


In [191]:
mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim != 1, p) # mask biases
optimizer = optax.adamw(learning_rate, weight_decay=1e-4, mask=mask_fn)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=1.3223958015441895
step 2: loss=1.4270424842834473
step 3: loss=1.1914074420928955
step 4: loss=0.8947315812110901
step 5: loss=0.8820069432258606
step 6: loss=0.6795834302902222
step 7: loss=0.6839614510536194
step 8: loss=0.795447051525116
step 9: loss=0.7858409881591797
step 10: loss=0.7066601514816284
step 11: loss=0.5080092549324036
step 12: loss=0.6086088418960571
step 13: loss=0.687477707862854
step 14: loss=0.6014255881309509
step 15: loss=0.6547797918319702
step 16: loss=0.44285669922828674
step 17: loss=0.3559488356113434
step 18: loss=0.3600468337535858
step 19: loss=0.4020947813987732
step 20: loss=0.3246852159500122
Mean loss        0.5156544
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 8 7 6 0 5 5 1]


In [192]:
lr = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=20)
optimizer = optax.adamw(lr, weight_decay=1e-4, mask=mask_fn)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.329394578933716
step 2: loss=2.111168384552002
step 3: loss=2.098027229309082
step 4: loss=2.115138530731201
step 5: loss=2.1225409507751465
step 6: loss=1.9438470602035522
step 7: loss=1.989481806755066
step 8: loss=1.8919346332550049
step 9: loss=1.8656792640686035
step 10: loss=1.854970932006836
step 11: loss=1.867348551750183
step 12: loss=1.8192609548568726
step 13: loss=1.8986468315124512
step 14: loss=1.8100045919418335
step 15: loss=1.870030164718628
step 16: loss=1.7674648761749268
step 17: loss=1.6689090728759766
step 18: loss=1.7513816356658936
step 19: loss=1.673409342765808
step 20: loss=1.7518361806869507
Mean loss        1.7164931
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 8 7 6 0 8 8 1]


In [193]:
optimizer = optax.contrib.muon(learning_rate=learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.4460508823394775
step 2: loss=2.2159481048583984
step 3: loss=2.226898193359375
step 4: loss=2.282470703125
step 5: loss=2.2132441997528076
step 6: loss=2.1017580032348633
step 7: loss=2.186561346054077
step 8: loss=2.037569284439087
step 9: loss=2.0478765964508057
step 10: loss=2.033282995223999
step 11: loss=2.013596534729004
step 12: loss=1.9634274244308472
step 13: loss=2.019657850265503
step 14: loss=1.8717783689498901
step 15: loss=1.937285304069519
step 16: loss=1.8528590202331543
step 17: loss=1.723109245300293
step 18: loss=1.7739412784576416
step 19: loss=1.659662127494812
step 20: loss=1.698677659034729
Mean loss        1.7045716
True labels:     [2 0 4 8 7 6 0 6 3 1]
Predictions:     [2 0 4 8 7 6 0 8 8 1]
