<small>

**Key differences between NNX and Linen:**

* **State model:** NNX modules are *stateful Python objects*; Linen modules are *stateless templates* whose state lives in external PyTrees.
* **Initialization:** NNX creates parameters when you **instantiate** the module; Linen requires a separate **`model.init(rng, x)`** step.
* **Calling convention:** NNX uses **direct calls** (`y = model(x)`); Linen uses **`apply`** (`y = model.apply(params, x)`).
* **Mental model:** NNX behaves more like **normal Python classes with live attributes**; Linen enforces a **pure-functional style** with explicit data flow.

</small>


In [1]:
from typing import Sequence

import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from flax import nnx
import optax

In [2]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

E0000 00:00:1764797987.289718    8223 cuda_executor.cc:1309] INTERNAL: CUDA Runtime error: Failed call to cudaGetRuntimeVersion: Error loading CUDA libraries. GPU will not be used.: Error loading CUDA libraries. GPU will not be used.
W0000 00:00:1764797987.293828    8223 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [3]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data_jax[data_type] = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data_jax[data_type] = flatten_image_for_mlp(data_jax[data_type])
    return data_jax

In [4]:
dataset_tf = "train"
all_data_tf = mnist_data[dataset_tf]
all_data_jax = prepare_data(all_data_tf)



In [5]:
images = all_data_jax["image"]
labels = all_data_jax["label"]

In [6]:
print("Images shape:", images.shape)
print("Labels shape:", labels.shape)

Images shape: (60000, 784)
Labels shape: (60000,)


In [7]:
class Linear(nnx.Module):
    def __init__(self, input_size, output_size, *, rngs, init_sd=0.05):
        self.weights = nnx.Param(rngs.params.normal((input_size, output_size)) *init_sd)
        self.biases = nnx.Param(jnp.zeros((output_size,)))

    def __call__(self, input_activations):
        return input_activations @ self.weights + self.biases

class MLP(nnx.Module):
    def __init__(self, layer_sizes, *, rngs):
        layers = []
        input_sizes = layer_sizes[:-1]
        output_sizes = layer_sizes[1:]
        for input_size, output_size in zip(input_sizes, output_sizes):
            layers.append(Linear(input_size, output_size, rngs=rngs))
        self.layers = nnx.List(layers)

    def __call__(self, activations):
        for layer_number, layer in enumerate(self.layers):
            activations = layer(activations)
            if layer_number != (len(self.layers) - 1):
                activations = jax.nn.relu(activations)
        return activations

In [8]:
def calculate_mean_loss_batch(model, images, labels):
    logits = model(images) # FORWARD PASS
    cross_entropy_by_sample = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    cross_entropy_mean = cross_entropy_by_sample.mean()
    return cross_entropy_mean

In [None]:
@nnx.jit
def take_training_step(model, optimizer, images, labels):
    loss, grads = nnx.value_and_grad(calculate_mean_loss_batch)(model, images, labels)
    optimizer.update(model, grads)
    return loss

In [None]:
def run_training(layer_sizes, images, labels, n_steps, initial_learning_rate=1e-3):
    model = MLP(layer_sizes, rngs=nnx.Rngs(0))
    optimizer = nnx.Optimizer(model, optax.adam(initial_learning_rate), wrt=nnx.Param)

    for step in range(n_steps):
        loss = take_training_step(model, optimizer, images, labels)
        print(f"step {step}: loss={loss}")

    return model

In [11]:
trial_set_size = 20
test_images = images[:trial_set_size]
test_labels = labels[:trial_set_size]

layer_sizes = [784, 128, 10]
model = run_training(layer_sizes, test_images, test_labels, n_steps=5, initial_learning_rate=1e-3)

step 0: loss=2.2502477169036865
step 1: loss=2.117492198944092
step 2: loss=1.9914255142211914
step 3: loss=1.8695420026779175
step 4: loss=1.7506182193756104


In [12]:
logits = model(test_images)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, test_labels)

predictions = jnp.argmax(logits, axis=1)

print("True labels:    ", test_labels)
print("Predictions:  s  ", predictions)
print("Match:          ", predictions == test_labels)
print("Loss            ", loss)

True labels:     [4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1]
Predictions:  s   [7 1 7 7 8 1 2 7 1 2 6 4 7 7 3 3 7 7 7 1]
Match:           [False  True False  True  True  True  True  True  True False  True  True
  True  True  True  True  True False False  True]
Loss             [2.3186948 1.6746998 2.1503305 1.1054006 1.7520564 1.7254624 1.4025867
 1.3423988 1.4489025 2.0254958 1.939018  1.9059012 1.325226  1.439785
 1.5715833 1.4464514 1.1044837 1.7482777 1.7558821 1.5099177]
