<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [22]:
from typing import Sequence

import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
from pathlib import Path


In [23]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

In [24]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict: dict, subsample_size: int=0):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data = flatten_image_for_mlp(data)
        if subsample_size > 0:
            data = data[:subsample_size]
        data_jax[data_type] = data

    return data_jax

In [25]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = nn.Dense(
                layer_size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1)
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations

In [26]:
class LowRankDense(nn.Module):
    """Low-rank dense layer implemented with two factors and einsum.

    Parameters are U in R^{in_features x rank} and V in R^{rank x features}.
    The forward pass computes y = (x @ U) @ V + b using einsum.
    """
    features: int
    rank: int
    use_bias: bool = True

    @nn.compact
    def __call__(self, inputs):
        # inputs: [batch, in_features]
        in_features = inputs.shape[-1]

        U = self.param(
            "U",
            nn.initializers.normal(0.1),
            (in_features, self.rank),
        )
        V = self.param(
            "V",
            nn.initializers.normal(0.1),
            (self.rank, self.features),
        )

        hidden = jnp.einsum("bi,ir->br", inputs, U)
        y = jnp.einsum("br,rf->bf", hidden, V)

        if self.use_bias:
            bias = self.param(
                "bias",
                nn.initializers.normal(0.1),
                (self.features,),
            )
            y = y + bias

        return y


class LowRankMLP(nn.Module):
    """
    Every layer uses the same low-rank dimension rank (="rank")
    """
    layer_sizes: Sequence[int]
    rank: int

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = LowRankDense(
                features=layer_size,
                rank=self.rank,
                use_bias=True,
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations


In [27]:
def initialise_network_params(model, input_layer_size, key):
    """Initialize all layers for a fully-connected neural network"""
    input_shape_dummy = jnp.ones((1, input_layer_size))
    params = model.init(key, input_shape_dummy)["params"]
    return params

In [28]:
def calculate_mean_loss_batch(params, apply_fn, images, labels):
    logits = apply_fn({"params": params}, images) # FORWARD PASS
    cross_entropy_by_sample = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    cross_entropy_mean = cross_entropy_by_sample.mean()
    return cross_entropy_mean

In [29]:
@jax.jit
def take_training_step(training_state, images, labels):
    """
    Single training step 
    The model and optimiser are passed in the training state
    returns a training state
    """
    grads_by_params_fn = jax.grad(calculate_mean_loss_batch)
    grads_by_params = grads_by_params_fn(
        training_state.params,     # params is first â†’ grad w.r.t. params
        training_state.apply_fn,
        images,
        labels,
    )
    return training_state.apply_gradients(grads=grads_by_params)

In [30]:
def get_batches(images, labels, n_batches):
    """Drops the last set of samples if they're not the right length"""
    n_samples = len(images)
    assert len(images) == len(labels)
    assert n_samples >= n_batches
    assert n_batches > 0
    n_samples_per_batch = n_samples // n_batches
    start = 0
    end = n_samples_per_batch
    while end <= n_samples: 
        yield (images[start:end], labels[start:end])
        start += n_samples_per_batch
        end += n_samples_per_batch

In [31]:
def make_experiment_name(layer_sizes, optimizer):
    layer_part = "mlp_" + "-".join(str(s) for s in layer_sizes)
    opt_name = optimizer.__class__.__name__
    return f"{layer_part}_{opt_name}"

def initialise_checkpoint_manager(experiment_name: str = "mlp", max_to_keep=20):
    project_root = Path().resolve()
    base_dir = project_root / "checkpoints"
    checkpoint_dir = base_dir / experiment_name
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_manager = ocp.CheckpointManager(
        directory=str(checkpoint_dir),
        options=ocp.CheckpointManagerOptions(max_to_keep=max_to_keep),
    )
    return checkpoint_manager

In [32]:
def create_training_state(layer_sizes, optimizer, key, use_lowrank: bool = False, rank: int | None = None):
    input_layer_size = layer_sizes[0]
    network_layer_sizes = layer_sizes[1:]

    if use_lowrank:
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        model = LowRankMLP(layer_sizes=network_layer_sizes, rank=rank)
    else:
        model = MLP(layer_sizes=network_layer_sizes)

    apply_fn = model.apply
    params = initialise_network_params(model, input_layer_size, key)
    training_state = train_state.TrainState.create(
        apply_fn=apply_fn,
        params=params,
        tx=optimizer,
    )
    return training_state

In [33]:
def run_training(
    images,
    labels,
    n_steps,
    layer_sizes,
    optimizer,
    checkpoint_manager,
    key,
    steps_per_save,
    training_state,
    use_lowrank: bool = False,
    rank: int | None = None,
    ): 
    """
    The training state ('state') is an instance of TrainState that holds:
    - apply_fn: the model's apply function, used for forward passes
    - params: the parameters of the neural network
    - tx: the optimizers (Optax transformation) for parameter updates
    - opt_state: the state of the optimizer
    """
    if training_state is None:
        training_state = create_training_state(
            layer_sizes,
            optimizer,
            key,
            use_lowrank=use_lowrank,
            rank=rank,
        )

    for images_batch, labels_batch in get_batches(images=images, labels=labels, n_batches=n_steps):
        training_state = take_training_step(training_state, images_batch, labels_batch)
        step = training_state.step
        loss = calculate_mean_loss_batch(training_state.params, training_state.apply_fn, images_batch, labels_batch)
        print(f"step {step}: loss={loss}")
        if step == 1 or step % steps_per_save == 0:
            step_dir = step
            checkpoint_manager.save(
                step_dir,
                args=ocp.args.StandardSave(training_state)
                )

    return training_state.params

In [34]:
def train_mlp(
    train_data,
    optimizer,
    n_steps=10**3,
    steps_per_save=100,
    training_state=None,
    key=jax.random.key(0),
    use_lowrank: bool = False,
    rank: int | None = None,
    layer_sizes=(784, 128, 10),
):
    layer_sizes = list(layer_sizes)
    experiment_name = make_experiment_name(layer_sizes, optimizer)
    if use_lowrank:
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        experiment_name = experiment_name + f"_lowrank-r{rank}"

    checkpoint_manager = initialise_checkpoint_manager(experiment_name)
    final_params = run_training(
        images=train_data["image"],
        labels=train_data["label"],
        n_steps=n_steps,
        layer_sizes=layer_sizes,
        optimizer=optimizer,
        checkpoint_manager=checkpoint_manager,
        key=key,
        steps_per_save=steps_per_save,
        training_state=training_state,
        use_lowrank=use_lowrank,
        rank=rank,
    )
    return final_params

In [35]:
def extract_layer_sizes(params):
    layer_sizes = []
    for layer, layer_params in enumerate(params.values()):
        if layer == 0:
            layer_sizes.append(layer_params["kernel"].shape[0])
            layer_sizes.append(layer_params["kernel"].shape[1])
        else:
            layer_sizes.append(layer_params["bias"].shape[0])
    return layer_sizes

In [36]:
def evaluate_mlp(
    test_data,
    params,
    n_examples=10,
    use_lowrank: bool = False,
    rank: int | None = None,
    layer_sizes=None,
):
    images = test_data["image"]
    labels = test_data["label"]

    if use_lowrank:
        if layer_sizes is None:
            raise ValueError("layer_sizes must be provided when use_lowrank=True")
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        model = LowRankMLP(layer_sizes=layer_sizes[1:], rank=rank)
    else:
        layer_sizes = extract_layer_sizes(params)
        model = MLP(layer_sizes=layer_sizes[1:])

    apply_fn = model.apply

    mean_loss = calculate_mean_loss_batch(params, apply_fn, images, labels)
    example_images = images[:n_examples]
    example_labels = labels[:n_examples]
    logits = apply_fn({"params": params}, example_images)
    example_predictions = jnp.argmax(logits, axis=1)

    prefix = "[low-rank] " if use_lowrank else ""
    print(prefix + "Mean loss       ", mean_loss)
    print(prefix + "True labels:    ", example_labels)
    print(prefix + "Predictions:    ", example_predictions)

1. Learning rate decay
2. Weight decay

In [37]:
train_data = prepare_data(mnist_data["train"], subsample_size=10**3) 
test_data = prepare_data(mnist_data["test"], subsample_size=10**3) 

In [18]:
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

step 1: loss=2.2421865463256836
step 2: loss=1.4549814462661743
step 3: loss=3.0767364501953125
step 4: loss=1.6018766164779663
step 5: loss=1.922335147857666
step 6: loss=1.5518207550048828
step 7: loss=2.2047572135925293
step 8: loss=2.28387188911438
step 9: loss=1.3409204483032227
step 10: loss=2.344090700149536
step 11: loss=2.6371214389801025
step 12: loss=3.4621236324310303
step 13: loss=1.713283658027649
step 14: loss=1.9781702756881714
step 15: loss=3.3416006565093994
step 16: loss=2.9397127628326416
step 17: loss=1.0316027402877808
step 18: loss=2.92775821685791
step 19: loss=2.7445106506347656
step 20: loss=1.4158217906951904
step 21: loss=1.9610295295715332
step 22: loss=1.9986664056777954
step 23: loss=1.672990083694458
step 24: loss=2.457414388656616
step 25: loss=1.5557224750518799
step 26: loss=2.3051018714904785
step 27: loss=2.603790044784546
step 28: loss=1.988590955734253
step 29: loss=2.3711462020874023
step 30: loss=1.6227275133132935
step 31: loss=1.86228585243225

In [19]:
resume_from_step = 1000  # e.g. resume from checkpoint at step 1000
layer_sizes = [784, 128, 10]

experiment_name = make_experiment_name(layer_sizes, optimizer)
checkpoint_manager = initialise_checkpoint_manager(experiment_name)

template_state = create_training_state(layer_sizes, optimizer, jax.random.key(0))
restored_state = checkpoint_manager.restore(
    resume_from_step,
    args=ocp.args.StandardRestore(template_state),
)

In [20]:
learning_rate = 1e-3
rank = 32
layer_sizes = (784, 128, 10)

optimizer = optax.adam(learning_rate)

params_lowrank = train_mlp(
    train_data,
    optimizer,
    use_lowrank=True,
    rank=rank,
    layer_sizes=layer_sizes,
)
evaluate_mlp(
    test_data,
    params_lowrank,
    use_lowrank=True,
    rank=rank,
    layer_sizes=layer_sizes,
)


step 1: loss=1.940758228302002
step 2: loss=2.1760780811309814
step 3: loss=2.123516082763672
step 4: loss=1.7909069061279297
step 5: loss=2.7731266021728516
step 6: loss=1.962531328201294
step 7: loss=2.494056224822998
step 8: loss=1.8388534784317017
step 9: loss=1.8824307918548584
step 10: loss=2.6979455947875977
step 11: loss=2.197720766067505
step 12: loss=2.1643967628479004
step 13: loss=1.4617383480072021
step 14: loss=1.5403313636779785
step 15: loss=2.446437358856201
step 16: loss=2.2722434997558594
step 17: loss=1.3893849849700928
step 18: loss=3.549886703491211
step 19: loss=3.3583061695098877
step 20: loss=1.5383902788162231
step 21: loss=2.0145435333251953
step 22: loss=1.8640811443328857
step 23: loss=1.9871046543121338
step 24: loss=2.46518611907959
step 25: loss=2.6712143421173096
step 26: loss=2.0123965740203857
step 27: loss=3.0125784873962402
step 28: loss=2.691230297088623
step 29: loss=1.9900810718536377
step 30: loss=1.6837618350982666
step 31: loss=1.7050302028656

GANs

In [21]:
def gan_discriminator_loss(d_real_logits, d_fake_logits):
    real_targets = jnp.ones_like(d_real_logits)
    fake_targets = jnp.zeros_like(d_fake_logits)
    real_loss = optax.sigmoid_binary_cross_entropy(d_real_logits, real_targets)
    fake_loss = optax.sigmoid_binary_cross_entropy(d_fake_logits, fake_targets)
    return (real_loss.mean() + fake_loss.mean())


def gan_generator_loss(d_fake_logits):
    targets = jnp.ones_like(d_fake_logits)
    loss = optax.sigmoid_binary_cross_entropy(d_fake_logits, targets)
    return loss.mean()

In [38]:
def discriminator_loss(real_logits, fake_logits):
    """
    Discriminator loss.
    Real -> 1, Fake -> 0
    """
    real_targets = jnp.ones_like(real_logits)
    fake_targets = jnp.zeros_like(fake_logits)

    loss_real = optax.sigmoid_binary_cross_entropy(
        logits=real_logits, labels=real_targets
    ).mean()

    loss_fake = optax.sigmoid_binary_cross_entropy(
        logits=fake_logits, labels=fake_targets
    ).mean()

    return loss_real + loss_fake

In [40]:
def generator_loss_saturating(fake_logits):
    """
    Generator minimax loss.
    Literal mirror of discriminator fake term.
    (Saturates when D(fake) â‰ˆ 0)
    """
    fake_targets = jnp.zeros_like(fake_logits)

    return optax.sigmoid_binary_cross_entropy(
        logits=fake_logits, labels=fake_targets
    ).mean()

In [41]:
def generator_loss_non_saturating(fake_logits):
    """
    Generator surrogate loss.
    Train G so that D(fake) -> 1 (non-saturating).
    """
    real_targets = jnp.ones_like(fake_logits)

    return optax.sigmoid_binary_cross_entropy(
        logits=fake_logits, labels=real_targets
    ).mean()

In [None]:
batch_size = 32

# D outputs 0.5 for both real and fake
real_logits = jnp.zeros((batch_size, 1), dtype=jnp.float32)
fake_logits = jnp.zeros((batch_size, 1), dtype=jnp.float32)

print("D loss:",
      discriminator_loss(real_logits, fake_logits))
# log(2) + log(2) = ~1.386

print("G loss (saturating):",
      generator_loss_saturating(fake_logits))
# -log(1 - 0.5) = ~0.693

print("G loss (non-saturating):",
      generator_loss_non_saturating(fake_logits))
# -log(0.5) = ~0.693




D loss: 1.386295
G loss (saturating): 0.6931475
G loss (non-saturating): 0.6931475
