# Training Chromatix models

In this notebook we show how to easily train chromatix models with either optax for deep-learning style optimisers such as Adam or Jaxopt for more classical optimisation such as conjugate gradient.

As our model we'll take the Zernike polynomials tutorial. 

## Making model and data

In [1]:
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
import numpy as np

import optax
import jaxopt
from flax.training.train_state import TrainState

from chromatix.elements import ObjectivePointSource, FFLens, ZernikeAberrations
from chromatix.utils import trainable

from typing import Tuple

key = random.PRNGKey(42)

In [2]:
# Simple model
camera_shape: Tuple[int, int] = (256, 256)
camera_pixel_pitch: float = 0.125
f: float = 100
NA: float = 0.8
n: float = 1.33
wavelength: float = 0.532
wavelength_ratio: float = 1.0
spacing = f * wavelength/ (n * camera_shape[0] * camera_pixel_pitch)

class ZernikePSF(nn.Module):
    ansi_indices: np.ndarray

    @nn.compact
    def __call__(self):
        field = ObjectivePointSource(camera_shape, spacing, wavelength, wavelength_ratio, f, n, NA, power=1e7)(z=0)
        field = ZernikeAberrations(trainable(jnp.zeros_like(self.ansi_indices, dtype=jnp.float32)), f, n, NA, self.ansi_indices)(field)
        field = FFLens(f, n)(field)
        return field 

We initialise the model, and simulate the data using some coefficients. We then define a loss function, which should return a (loss, metrics) pair:

In [3]:
# Instantiating model
model = ZernikePSF(np.arange(1, 11))
params = model.init(key)

# Specify "ground truth" parameters for Zernike coefficients
coefficients_truth = jnp.array([2.0, 5.0, 3.0, 0, 1, 0, 1, 0, 1, 0])
params_true = jax.tree_map(lambda x: coefficients_truth, params) # easiest to just do a tree_map

# Generating data
data = model.apply(params_true).intensity.squeeze()

# Our loss function
def loss_fn(params, data):
    psf_estimate = model.apply(params).intensity.squeeze()
    loss = jnp.mean((psf_estimate - data)**2) / jnp.mean(data**2)
    return loss, {"loss": loss}

## Training with Optax

We first train using optax. We use flax's TrainState to handle the optimizer state.

In [4]:
# Setting up training: optimiser, state and grad_fn
optimizer = optax.adam(learning_rate=0.5)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
grad_fn = jax.jit(jax.grad(loss_fn, has_aux=True))

In [5]:
%%time
# Simple training loop
max_iterations = 500
for iteration in range(max_iterations):
    grads, metrics = grad_fn(state.params, data) 
    state = state.apply_gradients(grads=grads)

    if iteration % 100 == 0:
        print(iteration, metrics)

0 {'loss': Array(3.695996, dtype=float32)}
100 {'loss': Array(0.06857503, dtype=float32)}
200 {'loss': Array(0.02074304, dtype=float32)}
300 {'loss': Array(6.8586576e-07, dtype=float32)}
400 {'loss': Array(2.7676396e-11, dtype=float32)}
CPU times: user 4.11 s, sys: 268 ms, total: 4.38 s
Wall time: 2.88 s


In [6]:
print(state.params)

FrozenDict({
    params: {
        ZernikeAberrations_0: {
            zernike_coefficients: Array([ 2.0000017e+00,  5.0000029e+00, -3.0000019e+00, -2.2701904e-06,
                   -1.0000006e+00,  7.8149048e-07,  1.0000002e+00,  1.0779938e-06,
                    9.9999923e-01,  6.9686160e-07], dtype=float32),
        },
    },
})


## Training with Jaxopt

The jax ecosystem also has a rich set of classical optimisers such as LBFGS or Conjugate gradient solvers - all implemented in Jaxopt. Here we show how to train the model we defined above using Jaxopt. Because of Jax' use of pytrees, everything composes nicely and you don't need to make any change to your models to use Jaxopt instead of Optax.

In [7]:
params = model.init(key) # reset the parameters

# Defining solver
solver = jaxopt.LBFGS(loss_fn, has_aux=True)

# Running solver
res = solver.run(params, data)

In [8]:
print(res.params)

FrozenDict({
    params: {
        ZernikeAberrations_0: {
            zernike_coefficients: Array([ 1.9957651e+00,  4.9954557e+00,  2.9984114e+00, -1.5044229e-03,
                    1.0011595e+00, -1.2449678e-03,  9.9792808e-01, -2.7636252e-03,
                    1.0069935e+00,  5.0141285e-03], dtype=float32),
        },
    },
})
