# Learning Flax
tutorial: https://flax-linen.readthedocs.io/en/latest/quick_start.html

this tutorial uses Jax + Flax + Optax to build and train a simple CNN on the MNIST dataset

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"  # Use only the 3rd GPU 
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' # makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed

In [2]:
from jax import random

# hyperparameters
num_epochs = 10
batch_size = 32
train_split = 0.8
learning_rate = 0.01
momentum = 0.9
eval_iters = 200 # during eval sample 200 batches from training and test each

In [3]:
import numpy as np
import jax.numpy as jnp
from jax.tree_util import tree_map
from torch.utils.data import DataLoader, default_collate, random_split
from torchvision.datasets import MNIST
from einops import rearrange

def numpy_collate(batch):
    """
    Collate function specifies how to combine a list of data samples into a batch. default_collate creates pytorch tensors, then tree_map converts them into numpy arrays
    """
    return tree_map(np.asarray, default_collate(batch))

def reshape_and_cast(pic):
    """ convert PIL image to numpy array with shape (28, 28, 1)"""
    return rearrange((np.array(pic, dtype=jnp.float32) / 255.0),  'width height -> width height 1') # use einops to add channel dimension

In [4]:
# define our dataset using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=reshape_and_cast)

train_dataset, val_dataset = random_split(mnist_dataset, [train_split, 1-train_split])

# # create pytorch data loaders with custom collate function, one per split
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)

In [5]:
from flax import linen as nn 

In [6]:
class CNN(nn.Module):
    ''' a simple CNN module'''

    @nn.compact # allows u to bypass having to write a setup method (analogous to __init__)
    def __call__(self,  x): # __call__() is jax's version of pytorch's forward()
        x = nn.Conv(features=32, kernel_size=(3,3))(x) # features = number of output channels
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.Conv(features=64, kernel_size=(3,3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        #x = x.reshape((x.shape[0], -1)) # flatten before fc layer; reshape takes a tuple specifying the new dims 
        # using einops (better readability):
        x = rearrange(x, "batch width height channels -> batch (width height channels)") # () indicates flatten these dims
        x = nn.Dense(features=256)(x) # features = output dims 
        x = nn.relu(x)
        x = nn.Dense(features=10)(x) 
        return x

In [7]:
cnn = CNN()
print(cnn.tabulate(random.key(0), jnp.ones((1, 28, 28, 1)), compute_flops=True, compute_vjp_flops=True))

2026-01-08 05:28:31.713305: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.8 which is older than the PTX compiler version 12.9.86. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.



[3m                                  CNN Summary                                   [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs    [0m[1m [0m┃[1m [0m[1moutputs    [0m[1m [0m┃[1m [0m[1mflops[0m[1m [0m┃[1m [0m[1mvjp_flops[0m[1m [0m┃[1m [0m[1mparams    [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩
│         │ CNN    │ [2mfloat32[0m[1… │ [2mfloat32[0m[1,… │ 0     │ 0         │            │
├─────────┼────────┼────────────┼─────────────┼───────┼───────────┼────────────┤
│ Conv_0  │ Conv   │ [2mfloat32[0m[1… │ [2mfloat32[0m[1,… │ 0     │ 0         │ bias:      │
│         │        │            │             │       │           │ [2mfloat32[0m[3… │
│         │        │            │             │       │           │ kernel:    │
│         │        │            │             │       

tabulate() runs a forward pass to discover the model's structure. Some models may use randomness (such as with dropout) so flax needs a random key to execute those layers.

FLOPs = the number of floating-point arithmetic operations (additions, multiplications, etc.) required for a forward pass through the model.

VJP FLOPs = the number of floating-point operations required to compute gradients during backpropagation.

In [8]:
from flax.training import train_state # useful dataclass to keep train state
from flax import struct # flax dataclass
import optax # common loss functions and optimizers

In [9]:
import wandb

wandb.init(project='learning-flax-by-testing-CNN-on-MNIST')

[34m[1mwandb[0m: Currently logged in as: [33mnityakas[0m ([33mnityakas-usc[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
def create_train_state(module, rng, learning_rate, momentum):
    """Creates an initial `TrainState`."""
    params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=module.apply, params=params, tx=tx)

Pytorch vs Flax: 
in pytorch the model parameters live inside the model class. in flax this is not the case, which enables its optimizations. Without this design, jax.jit(), jax.grad() and jax.vmap() won't work. 

They make u call apply() so u remember to explicitly pass in the parameters to the model, but under the hood that's just executing model.__call__()

In [11]:
from jax import jit, value_and_grad

def loss_fn(params, state, batch):
  logits = state.apply_fn({'params': params}, batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']).mean()
  return loss, logits

def calculate_accuracy(logits, batch):
  preds = jnp.argmax(logits, axis= -1)
  accuracy = jnp.mean(preds == batch['label'])
  return accuracy

@jit
def eval(state, batch):
  loss, logits = loss_fn(state.params, state, batch)
  accuracy = calculate_accuracy(logits, batch)
  return loss, accuracy
  

@jit # apply to whole function
def train_step(state, batch):
  """Train for a single step."""
  #grad_fn = jax.grad(loss_fn, argnums=0) # if u only want grad
  func = value_and_grad(loss_fn, argnums=0, has_aux=True)
  (loss, logits), grads = func(state.params, state, batch)
  accuracy = calculate_accuracy(logits, batch)
  state = state.apply_gradients(grads=grads)
  return state

why does loss_fn need to take in params as an argument?

Because when jax.grad() needs to know which params to take the grad of loss_fn wrt to. It expects those params to be the first argument of loss(fn). use `argnums` (default = 0) to tell which positional args to differentiate wrt to (can be more than one)

Which functions need the jax.jit decorator?
The typical pattern is:
- Apply @jax.jit to the top-level functions that orchestrate computations (train_step, eval_step, etc.)  
- Keep helper functions like loss_fn or compute_loss as regular Python functions  
- JAX will automatically optimize everything when it compiles the outer JITted function  

what if i wanna return logits and loss in loss_fn, not just loss?
jax.value_and_grad and jax.grad allow u to specify `has_aux` for this.  
- has_aux=True tells JAX that loss_fn returns (value, auxiliary_data)
JAX only differentiates w.r.t. the first return value (loss).  
- The second return value (logits) is passed through unchanged as "auxiliary" data.  
- You get back ((loss, logits), grads) - note the nested tuple structure.  

In [12]:
init_rng = random.key(0)

state = create_train_state(cnn, init_rng, learning_rate=learning_rate, momentum=momentum)
del init_rng  # Must not be used anymore.

In [13]:
for epoch in range(num_epochs):
    for img_batch, label_batch in train_loader:
 
        batch = {'image': img_batch, 'label': label_batch}
        state = train_step(state, batch) # forward pass, backward pass, logs metrics to wandb

    # evaluation on both training and val data after each epoch
    t_losses = []
    t_accs = []
    v_losses = []
    v_accs = []
    for _ in range(eval_iters):
        img_batch, label_batch = next(iter(train_loader)) 
        train_batch = {'image': img_batch, 'label': label_batch}
        loss, accuracy = eval(state, train_batch)
        t_losses.append(loss)
        t_accs.append(accuracy)
        
        img_batch, label_batch = next(iter(val_loader)) 
        val_batch = {'image': img_batch, 'label': label_batch}
        loss, accuracy = eval(state, val_batch)
        v_losses.append(loss)
        v_accs.append(accuracy)

    t_losses = np.array(t_losses)
    v_losses = np.array(v_losses)
    t_accs = np.array(t_accs)
    v_accs = np.array(v_accs)

    wandb.log({'train loss': np.mean(t_losses), 'train acc': np.mean(t_accs), 'val loss': np.mean(v_losses), 'val acc': np.mean(v_accs)})

since we're using `optax.softmax_cross_entropy_with_integer_labels()` in `loss_fn()`, we dont need to convert the labels to one-hot

In [14]:
def classify(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['image'])
    preds = jnp.argmax(logits, axis=-1)
    return preds


img_batch, label_batch = next(iter(val_loader))
val_batch = {'image': img_batch, 'label': label_batch}

preds = classify(state, val_batch)
print("Predictions shape:", preds.shape)
print("Predictions:", preds)
print("\nActual labels:", val_batch['label'])

Predictions shape: (32,)
Predictions: [1 8 8 3 2 0 3 0 2 7 2 1 6 3 1 0 0 6 1 2 3 4 1 9 2 1 8 7 2 9 0 1]

Actual labels: [1 5 8 3 2 0 3 0 2 7 2 1 6 3 1 0 0 6 1 2 3 6 1 9 2 1 8 7 2 9 0 1]
