<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [29]:
from typing import Sequence

import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
from pathlib import Path


In [30]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

In [31]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict: dict, subsample_size: int=0):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data = flatten_image_for_mlp(data)
        if subsample_size > 0:
            data = data[:subsample_size]
        data_jax[data_type] = data

    return data_jax

In [32]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = nn.Dense(
                layer_size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1)
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations

In [None]:
class LowRankDense(nn.Module):
    """Low-rank dense layer implemented with two factors and einsum.

    Parameters are U in R^{in_features x rank} and V in R^{rank x features}.
    The forward pass computes y = (x @ U) @ V + b using einsum.
    """
    features: int
    rank: int
    use_bias: bool = True

    @nn.compact
    def __call__(self, inputs):
        # inputs: [batch, in_features]
        in_features = inputs.shape[-1]

        U = self.param(
            "U",
            nn.initializers.normal(0.1),
            (in_features, self.rank),
        )
        V = self.param(
            "V",
            nn.initializers.normal(0.1),
            (self.rank, self.features),
        )

        hidden = jnp.einsum("bi,ir->br", inputs, U)
        y = jnp.einsum("br,rf->bf", hidden, V)

        if self.use_bias:
            bias = self.param(
                "bias",
                nn.initializers.normal(0.1),
                (self.features,),
            )
            y = y + bias

        return y


class LowRankMLP(nn.Module):
    """
    Every layer uses the same low-rank dimension rank (="rank")
    """
    layer_sizes: Sequence[int]
    rank: int

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = LowRankDense(
                features=layer_size,
                rank=self.rank,
                use_bias=True,
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations


In [34]:
def initialise_network_params(model, input_layer_size, key):
    """Initialize all layers for a fully-connected neural network"""
    input_shape_dummy = jnp.ones((1, input_layer_size))
    params = model.init(key, input_shape_dummy)["params"]
    return params

In [35]:
def calculate_mean_loss_batch(params, apply_fn, images, labels):
    logits = apply_fn({"params": params}, images) # FORWARD PASS
    cross_entropy_by_sample = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    cross_entropy_mean = cross_entropy_by_sample.mean()
    return cross_entropy_mean

In [36]:
@jax.jit
def take_training_step(training_state, images, labels):
    """
    Single training step 
    The model and optimiser are passed in the training state
    returns a training state
    """
    grads_by_params_fn = jax.grad(calculate_mean_loss_batch)
    grads_by_params = grads_by_params_fn(
        training_state.params,     # params is first â†’ grad w.r.t. params
        training_state.apply_fn,
        images,
        labels,
    )
    return training_state.apply_gradients(grads=grads_by_params)

In [37]:
def get_batches(images, labels, n_batches):
    """Drops the last set of samples if they're not the right length"""
    n_samples = len(images)
    assert len(images) == len(labels)
    assert n_samples >= n_batches
    assert n_batches > 0
    n_samples_per_batch = n_samples // n_batches
    start = 0
    end = n_samples_per_batch
    while end <= n_samples: 
        yield (images[start:end], labels[start:end])
        start += n_samples_per_batch
        end += n_samples_per_batch

In [38]:
def make_experiment_name(layer_sizes, optimizer):
    layer_part = "mlp_" + "-".join(str(s) for s in layer_sizes)
    opt_name = optimizer.__class__.__name__
    return f"{layer_part}_{opt_name}"

def initialise_checkpoint_manager(experiment_name: str = "mlp", max_to_keep=20):
    project_root = Path().resolve()
    base_dir = project_root / "checkpoints"
    checkpoint_dir = base_dir / experiment_name
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_manager = ocp.CheckpointManager(
        directory=str(checkpoint_dir),
        options=ocp.CheckpointManagerOptions(max_to_keep=max_to_keep),
    )
    return checkpoint_manager

In [39]:
def create_training_state(layer_sizes, optimizer, key, use_lowrank: bool = False, rank: int | None = None):
    input_layer_size = layer_sizes[0]
    network_layer_sizes = layer_sizes[1:]

    if use_lowrank:
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        model = LowRankMLP(layer_sizes=network_layer_sizes, rank=rank)
    else:
        model = MLP(layer_sizes=network_layer_sizes)

    apply_fn = model.apply
    params = initialise_network_params(model, input_layer_size, key)
    training_state = train_state.TrainState.create(
        apply_fn=apply_fn,
        params=params,
        tx=optimizer,
    )
    return training_state

In [40]:
def run_training(
    images,
    labels,
    n_steps,
    layer_sizes,
    optimizer,
    checkpoint_manager,
    key,
    steps_per_save,
    training_state,
    use_lowrank: bool = False,
    rank: int | None = None,
    ): 
    """
    The training state ('state') is an instance of TrainState that holds:
    - apply_fn: the model's apply function, used for forward passes
    - params: the parameters of the neural network
    - tx: the optimizers (Optax transformation) for parameter updates
    - opt_state: the state of the optimizer
    """
    if training_state is None:
        training_state = create_training_state(
            layer_sizes,
            optimizer,
            key,
            use_lowrank=use_lowrank,
            rank=rank,
        )

    for images_batch, labels_batch in get_batches(images=images, labels=labels, n_batches=n_steps):
        training_state = take_training_step(training_state, images_batch, labels_batch)
        step = training_state.step
        loss = calculate_mean_loss_batch(training_state.params, training_state.apply_fn, images_batch, labels_batch)
        print(f"step {step}: loss={loss}")
        if step == 1 or step % steps_per_save == 0:
            step_dir = step
            checkpoint_manager.save(
                step_dir,
                args=ocp.args.StandardSave(training_state)
                )

    return training_state.params

In [41]:
def train_mlp(
    train_data,
    optimizer,
    n_steps=10**3,
    steps_per_save=100,
    training_state=None,
    key=jax.random.key(0),
    use_lowrank: bool = False,
    rank: int | None = None,
    layer_sizes=(784, 128, 10),
):
    layer_sizes = list(layer_sizes)
    experiment_name = make_experiment_name(layer_sizes, optimizer)
    if use_lowrank:
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        experiment_name = experiment_name + f"_lowrank-r{rank}"

    checkpoint_manager = initialise_checkpoint_manager(experiment_name)
    final_params = run_training(
        images=train_data["image"],
        labels=train_data["label"],
        n_steps=n_steps,
        layer_sizes=layer_sizes,
        optimizer=optimizer,
        checkpoint_manager=checkpoint_manager,
        key=key,
        steps_per_save=steps_per_save,
        training_state=training_state,
        use_lowrank=use_lowrank,
        rank=rank,
    )
    return final_params

In [42]:
def extract_layer_sizes(params):
    layer_sizes = []
    for layer, layer_params in enumerate(params.values()):
        if layer == 0:
            layer_sizes.append(layer_params["kernel"].shape[0])
            layer_sizes.append(layer_params["kernel"].shape[1])
        else:
            layer_sizes.append(layer_params["bias"].shape[0])
    return layer_sizes

In [43]:
def evaluate_mlp(
    test_data,
    params,
    n_examples=10,
    use_lowrank: bool = False,
    rank: int | None = None,
    layer_sizes=None,
):
    images = test_data["image"]
    labels = test_data["label"]

    if use_lowrank:
        if layer_sizes is None:
            raise ValueError("layer_sizes must be provided when use_lowrank=True")
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        model = LowRankMLP(layer_sizes=layer_sizes[1:], rank=rank)
    else:
        layer_sizes = extract_layer_sizes(params)
        model = MLP(layer_sizes=layer_sizes[1:])

    apply_fn = model.apply

    mean_loss = calculate_mean_loss_batch(params, apply_fn, images, labels)
    example_images = images[:n_examples]
    example_labels = labels[:n_examples]
    logits = apply_fn({"params": params}, example_images)
    example_predictions = jnp.argmax(logits, axis=1)

    prefix = "[low-rank] " if use_lowrank else ""
    print(prefix + "Mean loss       ", mean_loss)
    print(prefix + "True labels:    ", example_labels)
    print(prefix + "Predictions:    ", example_predictions)

1. Learning rate decay
2. Weight decay

In [44]:
train_data = prepare_data(mnist_data["train"], subsample_size=10**3) 
test_data = prepare_data(mnist_data["test"], subsample_size=10**3) 

In [45]:
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.2421865463256836
step 2: loss=1.4549814462661743
step 3: loss=3.0767364501953125
step 4: loss=1.6018766164779663
step 5: loss=1.922335147857666
step 6: loss=1.5518207550048828
step 7: loss=2.2047572135925293
step 8: loss=2.28387188911438
step 9: loss=1.3409204483032227
step 10: loss=2.344090700149536
step 11: loss=2.6371214389801025
step 12: loss=3.4621236324310303
step 13: loss=1.713283658027649
step 14: loss=1.9781702756881714
step 15: loss=3.3416006565093994
step 16: loss=2.9397127628326416
step 17: loss=1.0316027402877808
step 18: loss=2.92775821685791
step 19: loss=2.7445106506347656
step 20: loss=1.4158217906951904
step 21: loss=1.9610295295715332
step 22: loss=1.9986664056777954
step 23: loss=1.672990083694458
step 24: loss=2.457414388656616
step 25: loss=1.5557224750518799
step 26: loss=2.3051018714904785
step 27: loss=2.603790044784546
step 28: loss=1.988590955734253
step 29: loss=2.3711462020874023
step 30: loss=1.6227275133132935
step 31: loss=1.86228585243225

In [46]:
resume_from_step = 1000  # e.g. resume from checkpoint at step 1000
layer_sizes = [784, 128, 10]

experiment_name = make_experiment_name(layer_sizes, optimizer)
checkpoint_manager = initialise_checkpoint_manager(experiment_name)

template_state = create_training_state(layer_sizes, optimizer, jax.random.key(0))
restored_state = checkpoint_manager.restore(
    resume_from_step,
    args=ocp.args.StandardRestore(template_state),
)

In [20]:
extra_steps = 1000
key = jax.random.key(1)
params = train_mlp(
    train_data=train_data, 
    optimizer=optimizer, 
    training_state=restored_state,
    )
evaluate_mlp(test_data, params)

step 1001: loss=0.1626293659210205
step 1002: loss=0.017407894134521484
step 1003: loss=0.031241416931152344
step 1004: loss=0.004973888397216797
step 1005: loss=0.10784530639648438
step 1006: loss=0.002956390380859375
step 1007: loss=0.0009908676147460938
step 1008: loss=0.5485451221466064
step 1009: loss=0.004882335662841797
step 1010: loss=0.015553951263427734
step 1011: loss=0.22939682006835938
step 1012: loss=0.03225088119506836
step 1013: loss=0.014636039733886719
step 1014: loss=0.19443297386169434
step 1015: loss=0.0006361007690429688
step 1016: loss=9.5367431640625e-05
step 1017: loss=0.0055103302001953125
step 1018: loss=0.4422445297241211
step 1019: loss=0.8210334777832031
step 1020: loss=0.043196678161621094
step 1021: loss=0.042943477630615234
step 1022: loss=0.1428360939025879
step 1023: loss=0.09797310829162598
step 1024: loss=0.16956233978271484
step 1025: loss=0.3048539161682129
step 1026: loss=1.570849061012268
step 1027: loss=0.314772367477417
step 1028: loss=0.14598

In [21]:
learning_rate = 1e-2
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=0.32265305519104004
step 2: loss=0.7677340507507324
step 3: loss=1.3151179552078247
step 4: loss=0.6902498006820679
step 5: loss=1.5896849632263184
step 6: loss=0.20079827308654785
step 7: loss=1.7467846870422363
step 8: loss=0.7341939210891724
step 9: loss=0.4832557439804077
step 10: loss=2.7541565895080566
step 11: loss=1.8300225734710693
step 12: loss=4.892587661743164
step 13: loss=0.03459310531616211
step 14: loss=0.21892571449279785
step 15: loss=6.246583938598633
step 16: loss=3.886425018310547
step 17: loss=0.005482196807861328
step 18: loss=6.574459075927734
step 19: loss=5.217857837677002
step 20: loss=0.5348055362701416
step 21: loss=2.4435982704162598
step 22: loss=0.31761860847473145
step 23: loss=0.20634746551513672
step 24: loss=0.8876932859420776
step 25: loss=0.6408635377883911
step 26: loss=2.7429752349853516
step 27: loss=4.988874912261963
step 28: loss=0.7707905769348145
step 29: loss=1.6248918771743774
step 30: loss=1.6446428298950195
step 31: loss=1.1

In [22]:
learning_rate = 1e-1
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=0.0
step 2: loss=0.021732807159423828
step 3: loss=0.007058143615722656
step 4: loss=1.9073486328125e-06
step 5: loss=3.5937793254852295
step 6: loss=0.033194541931152344
step 7: loss=7.7817277908325195
step 8: loss=6.1789631843566895
step 9: loss=0.11160469055175781
step 10: loss=15.611993789672852
step 11: loss=1.6699985265731812
step 12: loss=0.4520533084869385
step 13: loss=0.01483774185180664
step 14: loss=0.605947732925415
step 15: loss=8.520060539245605
step 16: loss=4.101454734802246
step 17: loss=0.00015354156494140625
step 18: loss=9.17702865600586
step 19: loss=2.7513885498046875
step 20: loss=0.17457973957061768
step 21: loss=1.8984501361846924
step 22: loss=0.6763033866882324
step 23: loss=0.31062984466552734
step 24: loss=2.33121395111084
step 25: loss=1.7108426094055176
step 26: loss=1.2369067668914795
step 27: loss=2.4545674324035645
step 28: loss=1.2753641605377197
step 29: loss=0.6617397665977478
step 30: loss=0.47773122787475586
step 31: loss=2.540803432

In [23]:
learning_rate = 1e-2 # for all subsequent models

In [24]:
optimizer = optax.adamw(learning_rate, weight_decay=1e-4)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=0.32265353202819824
step 2: loss=0.7677361965179443
step 3: loss=1.3151150941848755
step 4: loss=0.6902555227279663
step 5: loss=1.5896856784820557
step 6: loss=0.20080161094665527
step 7: loss=1.746786117553711
step 8: loss=0.7341992855072021
step 9: loss=0.48326241970062256
step 10: loss=2.7541470527648926
step 11: loss=1.8300195932388306
step 12: loss=4.892520427703857
step 13: loss=0.034595489501953125
step 14: loss=0.2189323902130127
step 15: loss=6.246511936187744
step 16: loss=3.8863754272460938
step 17: loss=0.005483150482177734
step 18: loss=6.574321269989014
step 19: loss=5.217706203460693
step 20: loss=0.5348190069198608
step 21: loss=2.4435720443725586
step 22: loss=0.3176398277282715
step 23: loss=0.2063615322113037
step 24: loss=0.8877159953117371
step 25: loss=0.6408861875534058
step 26: loss=2.742943525314331
step 27: loss=4.988760471343994
step 28: loss=0.7708103656768799
step 29: loss=1.6248962879180908
step 30: loss=1.644683837890625
step 31: loss=1.1385

In [25]:
mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim != 1, p) # mask biases
optimizer = optax.adamw(learning_rate, weight_decay=1e-4, mask=mask_fn)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=0.32265353202819824
step 2: loss=0.7677361965179443
step 3: loss=1.3151150941848755
step 4: loss=0.6902559995651245
step 5: loss=1.5896861553192139
step 6: loss=0.20080137252807617
step 7: loss=1.7467864751815796
step 8: loss=0.7341995239257812
step 9: loss=0.48326265811920166
step 10: loss=2.754148483276367
step 11: loss=1.8300219774246216
step 12: loss=4.892520904541016
step 13: loss=0.034595489501953125
step 14: loss=0.2189321517944336
step 15: loss=6.246511459350586
step 16: loss=3.886373996734619
step 17: loss=0.005483150482177734
step 18: loss=6.574323654174805
step 19: loss=5.217709541320801
step 20: loss=0.5348189473152161
step 21: loss=2.443570613861084
step 22: loss=0.3176398277282715
step 23: loss=0.2063612937927246
step 24: loss=0.8877174854278564
step 25: loss=0.6408871412277222
step 26: loss=2.742940664291382
step 27: loss=4.988762855529785
step 28: loss=0.7708112001419067
step 29: loss=1.6248937845230103
step 30: loss=1.6446845531463623
step 31: loss=1.13855

In [26]:
lr = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=20)
optimizer = optax.adamw(lr, weight_decay=1e-4, mask=mask_fn)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.2421863079071045
step 2: loss=1.4559381008148193
step 3: loss=3.085052490234375
step 4: loss=1.6123230457305908
step 5: loss=1.9464938640594482
step 6: loss=1.5765169858932495
step 7: loss=2.2348220348358154
step 8: loss=2.330615520477295
step 9: loss=1.399373173713684
step 10: loss=2.4562790393829346
step 11: loss=2.742587089538574
step 12: loss=3.5112454891204834
step 13: loss=1.9610660076141357
step 14: loss=2.111886978149414
step 15: loss=3.3601794242858887
step 16: loss=3.1095786094665527
step 17: loss=1.3849537372589111
step 18: loss=2.7661705017089844
step 19: loss=2.647594928741455
step 20: loss=1.561249017715454
step 21: loss=2.3637077808380127
step 22: loss=2.654891014099121
step 23: loss=2.297776699066162
step 24: loss=2.9272146224975586
step 25: loss=1.6235723495483398
step 26: loss=2.1043612957000732
step 27: loss=2.7969167232513428
step 28: loss=2.523972749710083
step 29: loss=2.4751412868499756
step 30: loss=2.1431922912597656
step 31: loss=3.2830438613891

In [27]:
optimizer = optax.contrib.muon(learning_rate=learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.430816650390625
step 2: loss=1.5101913213729858
step 3: loss=3.2343151569366455
step 4: loss=1.6777729988098145
step 5: loss=2.034066677093506
step 6: loss=1.636694073677063
step 7: loss=2.13456130027771
step 8: loss=2.250941276550293
step 9: loss=1.3820940256118774
step 10: loss=2.4763283729553223
step 11: loss=2.7123100757598877
step 12: loss=3.4853312969207764
step 13: loss=1.7098143100738525
step 14: loss=1.984086275100708
step 15: loss=3.205305576324463
step 16: loss=3.001132011413574
step 17: loss=1.0761189460754395
step 18: loss=2.6243045330047607
step 19: loss=2.5216028690338135
step 20: loss=1.3572767972946167
step 21: loss=1.971611738204956
step 22: loss=2.129270076751709
step 23: loss=1.6986463069915771
step 24: loss=2.4057188034057617
step 25: loss=1.6564723253250122
step 26: loss=2.4194247722625732
step 27: loss=2.6503312587738037
step 28: loss=2.3030922412872314
step 29: loss=2.2875442504882812
step 30: loss=1.3227665424346924
step 31: loss=1.96161544322967

In [48]:
learning_rate = 1e-3
rank = 32
layer_sizes = (784, 128, 10)

optimizer = optax.adam(learning_rate)

params_lowrank = train_mlp(
    train_data,
    optimizer,
    use_lowrank=True,
    rank=rank,
    layer_sizes=layer_sizes,
)
evaluate_mlp(
    test_data,
    params_lowrank,
    use_lowrank=True,
    rank=rank,
    layer_sizes=layer_sizes,
)


step 1: loss=1.940758228302002
step 2: loss=2.1760780811309814
step 3: loss=2.123516082763672
step 4: loss=1.7909069061279297
step 5: loss=2.7731266021728516
step 6: loss=1.962531328201294
step 7: loss=2.494056224822998
step 8: loss=1.8388534784317017
step 9: loss=1.8824307918548584
step 10: loss=2.6979455947875977
step 11: loss=2.197720766067505
step 12: loss=2.1643967628479004
step 13: loss=1.4617383480072021
step 14: loss=1.5403313636779785
step 15: loss=2.446437358856201
step 16: loss=2.2722434997558594
step 17: loss=1.3893849849700928
step 18: loss=3.549886703491211
step 19: loss=3.3583061695098877
step 20: loss=1.5383902788162231
step 21: loss=2.0145435333251953
step 22: loss=1.8640811443328857
step 23: loss=1.9871046543121338
step 24: loss=2.46518611907959
step 25: loss=2.6712143421173096
step 26: loss=2.0123965740203857
step 27: loss=3.0125784873962402
step 28: loss=2.691230297088623
step 29: loss=1.9900810718536377
step 30: loss=1.6837618350982666
step 31: loss=1.7050302028656