# Training Chromatix models

Chromatix is a fully differentiable library, meaning we can calculate gradients w.r.t to (almost) every quantity in our models. In this notebook we'll show how to optimise and train Chromatix models using two of the most well-known optimisation libraries: [Optax](https://github.com/deepmind/optax) for deep-learning optimisers such as Adam, and [Jaxopt](https://github.com/google/jaxopt) for classical optimisers such as L-BFGS.

In [1]:
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
from flax.core import freeze, unfreeze
import numpy as np

import optax
import jaxopt
from flax.training.train_state import TrainState

from chromatix.elements import ObjectivePointSource, FFLens, ZernikeAberrations
from chromatix.utils import trainable

from typing import Tuple

key = random.PRNGKey(42)

## Making model and data

As our model we'll take the example from the Zernike fitting, where we simulate a PSF with some Zernike coefficients, and try and infer them from this simulated data.

In [2]:
class ZernikePSF(nn.Module):
    ansi_indices: np.ndarray = np.arange(1, 11)
    camera_shape: Tuple[int, int] = (256, 256)
    camera_pixel_pitch: float = 0.125
    f: float = 100
    NA: float = 0.8
    n: float = 1.33
    wavelength: float = 0.532
    wavelength_ratio: float = 1.0
    
    @nn.compact
    def __call__(self):
        spacing = self.f * self.wavelength/ (self.n * self.camera_shape[0] * self.camera_pixel_pitch)
        field = ObjectivePointSource(self.camera_shape, spacing, self.wavelength, self.wavelength_ratio, self.f, self.n, self.NA, power=1e7)(z=0)
        field = ZernikeAberrations(trainable(jnp.zeros_like(self.ansi_indices, dtype=jnp.float32)), self.f, self.n, self.NA, self.ansi_indices)(field)
        field = FFLens(self.f, self.n)(field)
        return field 

When initialising the model, we get a dictionary consisting of both trainable parameters and a so-called state. The state contains all things we want calculated once and want to cache. Here it's just some of the other parameters, but it can also be a more complicated phasemask or a propagator.

In [3]:
model = ZernikePSF()
variables = model.init(key)
print(variables)

# Split into two
params, state = variables["params"], variables["state"]
del variables # delete for memory

FrozenDict({
    state: {
        ObjectivePointSource_0: {
            _f: 100,
            _n: 1.33,
            _NA: 0.8,
            _power: 10000000.0,
            _amplitude: 1.0,
        },
        FFLens_0: {
            _f: 100,
            _n: 1.33,
            _NA: None,
        },
    },
    params: {
        ZernikeAberrations_0: {
            _coefficients: Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
        },
    },
})


We make some synthetic data data using some coefficients. Note that the loss function has two inputs

 We then define a loss function, which should return a (loss, metrics) pair:

In [4]:
# Specify "ground truth" parameters for Zernike coefficients
coefficients_truth = jnp.array([2.0, 5.0, 3.0, 0, 1, 0, 1, 0, 1, 0])
params_true = unfreeze(params)
params_true["ZernikeAberrations_0"]["_coefficients"] = coefficients_truth
params_true = freeze(params_true)

# Generating data
data = model.apply({"params": params_true, "state": state}).intensity.squeeze()

# Our loss function
def loss_fn(params, state, data):
    psf_estimate = model.apply({"params": params, "state": state}).intensity.squeeze()
    loss = jnp.mean((psf_estimate - data)**2) / jnp.mean(data**2)
    return loss, {"loss": loss}

## Training with Optax

Now that we have our model and data, we infer the parameters by training the model using optax. We'll use the Adam optimiser (note the very high learning rate!) and use Flax's `TrainState` to deal with the optimiser state:

In [5]:
# Setting the state which has the model, params and optimiser
trainstate = TrainState.create(apply_fn=model.apply, 
                          params=params, 
                          tx=optax.adam(learning_rate=0.5))

# Defining the function which returns the gradients
grad_fn = jax.jit(jax.grad(loss_fn, has_aux=True))

In [6]:
%%time
# Simple training loop
max_iterations = 500
for iteration in range(max_iterations):
    grads, metrics = grad_fn(trainstate.params, state, data) 
    trainstate = trainstate.apply_gradients(grads=grads)

    if iteration % 100 == 0:
        print(iteration, metrics)

0 {'loss': Array(3.695996, dtype=float32)}
100 {'loss': Array(0.06857503, dtype=float32)}
200 {'loss': Array(0.02074304, dtype=float32)}
300 {'loss': Array(6.8586576e-07, dtype=float32)}
400 {'loss': Array(2.7676396e-11, dtype=float32)}
CPU times: user 4.09 s, sys: 239 ms, total: 4.33 s
Wall time: 2.84 s


In [7]:
print(f"Learned coefficients: {jnp.abs(jnp.around(trainstate.params['ZernikeAberrations_0']['_coefficients'], 2))}")
print(f"True Coefficients: {coefficients_truth}")

Learned coefficients: [2. 5. 3. 0. 1. 0. 1. 0. 1. 0.]
True Coefficients: [2. 5. 3. 0. 1. 0. 1. 0. 1. 0.]


## Training with Jaxopt

Because of Jax's use of pytrees, classical optimisation using Jaxopt doesn't require any code change! We can optimise this model using the following simple two-liner:

In [29]:
# Defining solver
solver = jaxopt.LBFGS(loss_fn, has_aux=True)

# Running solver
res = solver.run(model.init(key)["params"], state, data)

In [31]:
print(f"Learned coefficients: {jnp.abs(jnp.around(res.params['ZernikeAberrations_0']['_coefficients'], 2))}")
print(f"True Coefficients: {coefficients_truth}")

Learned coefficients: [2.   5.   3.   0.   1.   0.   1.   0.   1.01 0.01]
True Coefficients: [2. 5. 3. 0. 1. 0. 1. 0. 1. 0.]
