<small>

**Key differences from JAX implementation:**  
- <b>Network definition:</b> Use a Flax <code>nn.Module</code> (e.g., an <code>MLP</code> class) instead of lists of parameter dicts.  
- <b>Initialization:</b> Flax handles parameter initialization with <code>model.init(...)</code>, using specified initializers within the class.  
- <b>Forward pass:</b> Compute outputs with <code>model.apply(params, x)</code> instead of manual matrix multiplications.  

</small>

In [1]:
from typing import Sequence

import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
from pathlib import Path


In [2]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

E0000 00:00:1765323872.906212    1569 cuda_executor.cc:1309] INTERNAL: CUDA Runtime error: Failed call to cudaGetRuntimeVersion: Error loading CUDA libraries. GPU will not be used.: Error loading CUDA libraries. GPU will not be used.
W0000 00:00:1765323872.921617    1569 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [3]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict: dict, subsample_size: int=0):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data = flatten_image_for_mlp(data)
        if subsample_size > 0:
            data = data[:subsample_size]
        data_jax[data_type] = data

    return data_jax

In [4]:
class MLP(nn.Module):
    layer_sizes: Sequence[int]

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = nn.Dense(
                layer_size,
                kernel_init=nn.initializers.normal(0.1),
                bias_init=nn.initializers.normal(0.1)
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations

In [5]:
class LowRankDense(nn.Module):
    """Low-rank dense layer implemented with two factors and einsum.

    Parameters are U in R^{in_features x rank} and V in R^{rank x features}.
    The forward pass computes y = (x @ U) @ V + b using einsum.
    """
    features: int
    rank: int
    use_bias: bool = True

    @nn.compact
    def __call__(self, inputs):
        # inputs: [batch, in_features]
        in_features = inputs.shape[-1]

        U = self.param(
            "U",
            nn.initializers.normal(0.1),
            (in_features, self.rank),
        )
        V = self.param(
            "V",
            nn.initializers.normal(0.1),
            (self.rank, self.features),
        )

        hidden = jnp.einsum("bi,ir->br", inputs, U)
        y = jnp.einsum("br,rf->bf", hidden, V)

        if self.use_bias:
            bias = self.param(
                "bias",
                nn.initializers.normal(0.1),
                (self.features,),
            )
            y = y + bias

        return y


class LowRankMLP(nn.Module):
    """
    Every layer uses the same low-rank dimension rank (="rank")
    """
    layer_sizes: Sequence[int]
    rank: int

    @nn.compact
    def __call__(self, activations):
        for layer_number, layer_size in enumerate(self.layer_sizes):
            activations = LowRankDense(
                features=layer_size,
                rank=self.rank,
                use_bias=True,
            )(activations)

            if layer_number != (len(self.layer_sizes) - 1):
                activations = nn.relu(activations)

        return activations


In [6]:
def initialise_network_params(model, input_layer_size, key):
    """Initialize all layers for a fully-connected neural network"""
    input_shape_dummy = jnp.ones((1, input_layer_size))
    params = model.init(key, input_shape_dummy)["params"]
    return params

In [7]:
def calculate_mean_loss_batch(params, apply_fn, images, labels):
    logits = apply_fn({"params": params}, images) # FORWARD PASS
    cross_entropy_by_sample = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    cross_entropy_mean = cross_entropy_by_sample.mean()
    return cross_entropy_mean

In [8]:
@jax.jit
def take_training_step(training_state, images, labels):
    """
    Single training step 
    The model and optimiser are passed in the training state
    returns a training state
    """
    grads_by_params_fn = jax.grad(calculate_mean_loss_batch)
    grads_by_params = grads_by_params_fn(
        training_state.params,     # params is first â†’ grad w.r.t. params
        training_state.apply_fn,
        images,
        labels,
    )
    return training_state.apply_gradients(grads=grads_by_params)

In [9]:
def get_batches(images, labels, n_batches):
    """Drops the last set of samples if they're not the right length"""
    n_samples = len(images)
    assert len(images) == len(labels)
    assert n_samples >= n_batches
    assert n_batches > 0
    n_samples_per_batch = n_samples // n_batches
    start = 0
    end = n_samples_per_batch
    while end <= n_samples: 
        yield (images[start:end], labels[start:end])
        start += n_samples_per_batch
        end += n_samples_per_batch

In [10]:
def make_experiment_name(layer_sizes, optimizer):
    layer_part = "mlp_" + "-".join(str(s) for s in layer_sizes)
    opt_name = optimizer.__class__.__name__
    return f"{layer_part}_{opt_name}"

def initialise_checkpoint_manager(experiment_name: str = "mlp", max_to_keep=20):
    project_root = Path().resolve()
    base_dir = project_root / "checkpoints"
    checkpoint_dir = base_dir / experiment_name
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_manager = ocp.CheckpointManager(
        directory=str(checkpoint_dir),
        options=ocp.CheckpointManagerOptions(max_to_keep=max_to_keep),
    )
    return checkpoint_manager

In [11]:
def create_training_state(layer_sizes, optimizer, key, use_lowrank: bool = False, rank: int | None = None):
    input_layer_size = layer_sizes[0]
    network_layer_sizes = layer_sizes[1:]

    if use_lowrank:
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        model = LowRankMLP(layer_sizes=network_layer_sizes, rank=rank)
    else:
        model = MLP(layer_sizes=network_layer_sizes)

    apply_fn = model.apply
    params = initialise_network_params(model, input_layer_size, key)
    training_state = train_state.TrainState.create(
        apply_fn=apply_fn,
        params=params,
        tx=optimizer,
    )
    return training_state

In [12]:
def run_training(
    images,
    labels,
    n_steps,
    layer_sizes,
    optimizer,
    checkpoint_manager,
    key,
    steps_per_save,
    training_state,
    use_lowrank: bool = False,
    rank: int | None = None,
    ): 
    """
    The training state ('state') is an instance of TrainState that holds:
    - apply_fn: the model's apply function, used for forward passes
    - params: the parameters of the neural network
    - tx: the optimizers (Optax transformation) for parameter updates
    - opt_state: the state of the optimizer
    """
    if training_state is None:
        training_state = create_training_state(
            layer_sizes,
            optimizer,
            key,
            use_lowrank=use_lowrank,
            rank=rank,
        )

    for images_batch, labels_batch in get_batches(images=images, labels=labels, n_batches=n_steps):
        training_state = take_training_step(training_state, images_batch, labels_batch)
        step = training_state.step
        loss = calculate_mean_loss_batch(training_state.params, training_state.apply_fn, images_batch, labels_batch)
        print(f"step {step}: loss={loss}")
        if step == 1 or step % steps_per_save == 0:
            step_dir = step
            checkpoint_manager.save(
                step_dir,
                args=ocp.args.StandardSave(training_state)
                )

    return training_state.params

In [13]:
def train_mlp(
    train_data,
    optimizer,
    n_steps=10**3,
    steps_per_save=100,
    training_state=None,
    key=jax.random.key(0),
    use_lowrank: bool = False,
    rank: int | None = None,
    layer_sizes=(784, 128, 10),
):
    layer_sizes = list(layer_sizes)
    experiment_name = make_experiment_name(layer_sizes, optimizer)
    if use_lowrank:
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        experiment_name = experiment_name + f"_lowrank-r{rank}"

    checkpoint_manager = initialise_checkpoint_manager(experiment_name)
    final_params = run_training(
        images=train_data["image"],
        labels=train_data["label"],
        n_steps=n_steps,
        layer_sizes=layer_sizes,
        optimizer=optimizer,
        checkpoint_manager=checkpoint_manager,
        key=key,
        steps_per_save=steps_per_save,
        training_state=training_state,
        use_lowrank=use_lowrank,
        rank=rank,
    )
    return final_params



In [14]:
def extract_layer_sizes(params):
    layer_sizes = []
    for layer, layer_params in enumerate(params.values()):
        if layer == 0:
            layer_sizes.append(layer_params["kernel"].shape[0])
            layer_sizes.append(layer_params["kernel"].shape[1])
        else:
            layer_sizes.append(layer_params["bias"].shape[0])
    return layer_sizes

In [15]:
def evaluate_mlp(
    test_data,
    params,
    n_examples=10,
    use_lowrank: bool = False,
    rank: int | None = None,
    layer_sizes=None,
):
    images = test_data["image"]
    labels = test_data["label"]

    if use_lowrank:
        if layer_sizes is None:
            raise ValueError("layer_sizes must be provided when use_lowrank=True")
        if rank is None:
            raise ValueError("rank must be provided when use_lowrank=True")
        model = LowRankMLP(layer_sizes=layer_sizes[1:], rank=rank)
    else:
        layer_sizes = extract_layer_sizes(params)
        model = MLP(layer_sizes=layer_sizes[1:])

    apply_fn = model.apply

    mean_loss = calculate_mean_loss_batch(params, apply_fn, images, labels)
    example_images = images[:n_examples]
    example_labels = labels[:n_examples]
    logits = apply_fn({"params": params}, example_images)
    example_predictions = jnp.argmax(logits, axis=1)

    prefix = "[low-rank] " if use_lowrank else ""
    print(prefix + "Mean loss       ", mean_loss)
    print(prefix + "True labels:    ", example_labels)
    print(prefix + "Predictions:    ", example_predictions)

1. Learning rate decay
2. Weight decay

In [16]:
train_data = prepare_data(mnist_data["train"], subsample_size=10**3) 
test_data = prepare_data(mnist_data["test"], subsample_size=10**3) 

In [None]:
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
params = train_mlp(train_data, optimizer)
evaluate_mlp(test_data, params)

In [46]:
resume_from_step = 1000  # e.g. resume from checkpoint at step 1000
layer_sizes = [784, 128, 10]

experiment_name = make_experiment_name(layer_sizes, optimizer)
checkpoint_manager = initialise_checkpoint_manager(experiment_name)

template_state = create_training_state(layer_sizes, optimizer, jax.random.key(0))
restored_state = checkpoint_manager.restore(
    resume_from_step,
    args=ocp.args.StandardRestore(template_state),
)

In [None]:
learning_rate = 1e-3
rank = 32
layer_sizes = (784, 128, 10)

optimizer = optax.adam(learning_rate)

params_lowrank = train_mlp(
    train_data,
    optimizer,
    use_lowrank=True,
    rank=rank,
    layer_sizes=layer_sizes,
)
evaluate_mlp(
    test_data,
    params_lowrank,
    use_lowrank=True,
    rank=rank,
    layer_sizes=layer_sizes,
)


GANs

In [None]:
def gan_discriminator_loss(d_real_logits, d_fake_logits):
    real_targets = jnp.ones_like(d_real_logits)
    fake_targets = jnp.zeros_like(d_fake_logits)
    real_loss = optax.sigmoid_binary_cross_entropy(d_real_logits, real_targets)
    fake_loss = optax.sigmoid_binary_cross_entropy(d_fake_logits, fake_targets)
    return (real_loss.mean() + fake_loss.mean())


def gan_generator_loss(d_fake_logits):
    targets = jnp.ones_like(d_fake_logits)
    loss = optax.sigmoid_binary_cross_entropy(d_fake_logits, targets)
    return loss.mean()

In [21]:
def calculate_discriminator_loss_batch(real_logits, fake_logits):
    """
    Discriminator loss.

    D sees real images and fake images.
    It should output 1 for real and 0 for fake, so we use
    two binary-cross-entropy terms: one for real, one for fake.
    """
    real_labels = jnp.ones_like(real_logits)
    fake_labels = jnp.zeros_like(fake_logits)

    real_loss = optax.sigmoid_binary_cross_entropy(logits=real_logits, labels=real_labels).mean()
    fake_loss = optax.sigmoid_binary_cross_entropy(logits=fake_logits, labels=fake_labels).mean()

    return real_loss + fake_loss


def calculate_generator_loss_batch(fake_logits):
    """
    Generator loss.

    G only affects fake images, so its loss only depends on D(fake).
    We DON'T use the exact mirror of the discriminator's loss,
    because that version gives almost no gradient once D gets good.
    Instead we use the 'non-saturating' version: make D(fake) look real (label=1).
    """
    fool_labels = jnp.ones_like(fake_logits)  # generator wants D(fake)=1
    return optax.sigmoid_binary_cross_entropy(logits=fake_logits, labels=fool_labels).mean()

In [22]:
batch_size = 32
real_logits = jnp.zeros((batch_size, 1), dtype=jnp.float32)  # predicts 0.5 for real
fake_logits = jnp.zeros((batch_size, 1), dtype=jnp.float32)  # predicts 0.5 for fake

print("D loss:", calculate_discriminator_loss_batch(real_logits, fake_logits))  # ~ 1.386
print("G loss:", calculate_generator_loss_batch(fake_logits))                  # ~ 0.693

D loss: 1.386295
G loss: 0.6931475
