# Training Chromatix models

<a target="_blank" href="https://colab.research.google.com/github/chromatix-team/chromatix/blob/main/docs/training.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
# If in Colab, install Chromatix. Don't forget to select a GPU!
!pip install --upgrade pip
!pip install git+https://github.com/chromatix-team/chromatix.git

Chromatix is a fully differentiable library, meaning we can calculate gradients with respect to (almost) every quantity in our models. In this notebook we'll show how to optimize and train Chromatix models using the most well-known JAX optimization library: [Optax](https://github.com/deepmind/optax) for deep learning optimizers such as Adam.

In [1]:
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import skimage
from einops import rearrange
from jaxtyping import Array

from chromatix import Field
from chromatix.elements import (
    FFLens,
    PhaseMask,
    PlaneWave,
    Propagate,
)
from chromatix.functional import (
    amplitude_change,
    ff_lens,
    phase_change,
    plane_wave,
    transfer_propagate,
)
from chromatix.systems import OpticalSystem

## Training with Optax

Most of the time, you'll want to optimize simulations in Chromatix using a gradient descent optimizer or more generally a modern deep-learning gradient descent optimizer like Adam (especially when you have millions of parameters, e.g. if your parameters are pixels in a sample). These first-order gradient descent methods are available via a library called [Optax](https://optax.readthedocs.io/en/latest/). There are three ways you might go about organizing optimizations like that:

### 1. Optimizing parameters of a simulation using the functional API

The simplest and most flexible way to define a simulation and optimize it is to use the functional API. For most quick experiments or computational optics inverse problems, this is the method you will want to use. This involves writing a function that takes the optical parameter(s) you wish to optimize as an input and then using `jax.grad` to differentiate a loss with respect to that parameter. Here we show an example of this style for performing [Fourier ptychography via gradient descent](https://chromatix.readthedocs.io/en/latest/examples/fourier_ptychography/):

In [2]:
# First we define the simulation we want to optimize as a function.
# In this case, we simulate imaging a sample illuminated at an angle
# using a low NA microscope.
def tilted_illumination_system(amplitude: Array, phase: Array, kykx: Array) -> Array:
    field = plane_wave(amplitude.shape, 0.3, 0.532, kykx=kykx)
    field = amplitude_change(field, amplitude)
    field = phase_change(field, phase)
    field = ff_lens(field, 1.8e3, 1.33)
    field = ff_lens(field, 1.8e3, 1.33, NA=0.3)
    return field.intensity


# Here's some data we are using to simulate some "measurements"
# and then perform a reconstruction.
amplitude = skimage.data.camera().astype("float")
amplitude = amplitude / amplitude.max()
phase = skimage.data.moon().astype("float")
phase = np.pi * phase / phase.max()


# Now we simulate some "measurements" from our low NA imaging system.
kykx = (
    jnp.array(
        jnp.meshgrid(jnp.linspace(-0.5, 0.5, num=11), jnp.linspace(-0.5, 0.5, num=11))
    )
    * 2
    * jnp.pi
)
kykx = rearrange(kykx, "d h w -> (h w) d")
images = jax.vmap(
    lambda k: tilted_illumination_system(jnp.array(amplitude), jnp.array(phase), k)
)(kykx)


# We're initializing the parameters that we want to optimize, i.e. the amplitude and phase of the sample.
parameters = (images[60][::-1, ::-1], jnp.zeros_like(images[60]))
# We also initialize an optimizer (in this case we're just performing gradient descent, so there's no optimizer state).
optimizer = optax.sgd(1e13)
opt_state = optimizer.init(parameters)


# This defines our loss function, and it takes in the parameters
# we want to take gradients with respect to as the first argument.
def fp_loss_fn(parameters: Array, measured_image: Array, kykx: Array) -> Array:
    # 1. Extract the amplitude and phase from the parameters tuple.
    # Remember that the first element of the tuple is the amplitude and
    # the second element is the phase.
    amplitude = parameters[0]
    phase = parameters[1]
    # 2. Simulate imaging the amplitude and phase you just got using
    # the forward model we defined previously.
    simulated = tilted_illumination_system(amplitude, phase, kykx)
    # 3. Return the mean squared error between the simulated image of
    # our reconstruction and the "measured" image that we passed to
    # this loss function. WARNING: Make sure to squeeze the output of your
    # simulation! Otherwise you'll get weird broadcasting that creates a
    # huge array, giving you incorrect results and making things really slow.
    return jnp.mean((simulated - measured_image) ** 2)


# This defines the update step which computes the loss but also the gradient of the loss
# with respect to our parameters (i.e. our guess of the reconstructed sample). We then use
# that gradient and our optimizer to update our reconstruction.
@jax.jit
def update(
    parameters: tuple[Array, Array], opt_state: Any, image: Array, kykx: Array
) -> tuple[tuple[Array, Array], Any]:
    loss, grads = jax.value_and_grad(fp_loss_fn)(parameters, image, kykx)
    updates, opt_state = optimizer.update(grads, opt_state)
    parameters = optax.apply_updates(parameters, updates)
    return loss, parameters, opt_state


# We then run this update step multiple times over all 121 measured images in order to arrive at a reconstruction.
# We've chosen the learning rate for you so that this should appropriately converge if everything's gone right.
losses = []
for i in range(10):
    for j in range(kykx.shape[0]):
        loss, parameters, opt_state = update(parameters, opt_state, images[j], kykx[j])
        losses.append(np.array(loss))
    print(
        f"iteration {i + 1} loss = {np.mean(np.array(losses[-1 : -kykx.shape[0] : -1]))} over {kykx.shape[0]} images"
    )

iteration 1 loss = 1.8397162959704616e-10 over 121 images
iteration 2 loss = 6.692484674081234e-12 over 121 images
iteration 3 loss = 8.807251260754823e-13 over 121 images
iteration 4 loss = 7.475818466877449e-13 over 121 images
iteration 5 loss = 6.836539255712648e-13 over 121 images
iteration 6 loss = 6.424953884892615e-13 over 121 images
iteration 7 loss = 6.116519508762852e-13 over 121 images
iteration 8 loss = 5.859798865755217e-13 over 121 images
iteration 9 loss = 5.633953590114538e-13 over 121 images
iteration 10 loss = 5.431229467903198e-13 over 121 images


### 2. Combining an optical simulation with its parameters and data using custom Equinox Modules

Sometimes, you might want to wrap up a simulation, its optimizable parameters, and also any other parameters that define the simulation together. This is useful for keeping track of the proper parameters in a convenient way without having to pass them around explicitly all the time, serializing/saving different simulation configurations, and when you want to optimize multiple kinds of parameters at the same time. We use Equinox [`Module`s](https://docs.kidger.site/equinox/api/module/module/) to do this. **Note**: the only two real changes to the style of the optimization here are how we define the simulation and how we pass the parameters to the loss function (we pass the whole `Module` rather than just the parameter itself). Here we'll show a simple example of this style using our [holography example](https://chromatix.readthedocs.io/en/latest/examples/cgh/):

In [3]:
%%time


class CGH(eqx.Module):
    phase: (
        Array  # This is the phase mask we want to optimize, and is not marked static!
    )
    shape: tuple[int, int] = eqx.field(static=True)
    spacing: float = eqx.field(static=True)
    z: Array = eqx.field(static=True)
    f: float = eqx.field(static=True)
    n: float = eqx.field(static=True)
    NA: float | None = eqx.field(static=True)
    pad_width: int = eqx.field(static=True)
    spectrum: float | Array = eqx.field(static=True)

    def __init__(
        self,
        shape: tuple[int, int],
        spacing: float,  # microns
        z: Array,  # microns
        f: float = 200.0e3,  # microns
        n: float = 1.0,
        NA: float | None = None,
        pad_width: int = 0,
        spectrum: float | Array = 1.035,  # microns
    ):
        self.shape = shape
        self.spacing = spacing
        self.z = z
        self.f = f
        self.n = n
        self.NA = NA
        self.pad_width = pad_width
        self.spectrum = spectrum
        self.phase = jnp.zeros(
            self.shape
        )  # Initialization of our phase mask parameter to zeros

    def __call__(self) -> Field:
        field = plane_wave(self.shape, self.spacing, self.spectrum)
        field = phase_change(field, self.phase)
        field = ff_lens(field, self.f, self.n, self.NA)
        field = transfer_propagate(
            field, self.z, self.n, pad_width=self.pad_width, mode="same"
        )
        return field


# Let's initialize the holography model
shape = (256, 256)
spacing = 9.2  # microns
z = jnp.linspace(0.0, 100.0e4, num=51)  # Planes we want to simulate the hologram at
model = CGH(shape=shape, spacing=spacing, z=z)


# Now we create the optimizer
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(model)


# Create a target pattern for which we want to optimize a hologram
sample = np.zeros((51, 256, 256))
sample[30, 128, 128] = 1.0
sample[10, 51, 92] = 1.0
sample[50, 10, 25] = 1.0
diameter = 25
kernel = np.zeros((diameter, diameter, diameter))
grid = np.meshgrid(
    np.linspace(-diameter / 2, diameter / 2, num=diameter),
    np.linspace(-diameter / 2, diameter / 2, num=diameter),
    np.linspace(-diameter / 2, diameter / 2, num=diameter),
)
grid = np.sqrt(grid[0] ** 2 + grid[1] ** 2 + grid[2] ** 2)
kernel[grid < diameter / 5] = 1.0
sample = jnp.fft.ifftn(
    jnp.fft.fftn(jnp.array(sample)) * jnp.fft.fftn(jnp.array(kernel), s=sample.shape)
).real
sample = sample[..., jnp.newaxis, jnp.newaxis]
sample *= 1000.0


# Here we define a loss function, but note that this time
# we're taking the whole model directly! The model itself
# contains the parameters, and we'll update the model itself
# in our optimization loop.
def cgh_loss_fn(model, target):
    approx = model().intensity
    correlation = jnp.corrcoef(approx.flatten(), target.flatten())[0, 1]
    loss = 1.0 - correlation
    return loss, {"loss": loss, "correlation": correlation}


# This is our update function just like before, but this time
# taking the model directly.
@jax.jit
def update(model, opt_state, target):
    grads, metrics = jax.grad(cgh_loss_fn, has_aux=True)(model, target)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = optax.apply_updates(model, updates)
    return model, opt_state, metrics


# Now we just run the optimization loop!
max_iterations = 1000
history = {
    "loss": np.zeros((max_iterations)),
    "correlation": np.zeros((max_iterations)),
}
for iteration in range(max_iterations):
    model, opt_state, metrics = update(model, opt_state, sample)
    for m in metrics:
        history[m][iteration] = metrics[m]
    if iteration % 200 == 0:
        print(iteration, metrics)



0 {'correlation': Array(0.00036571, dtype=float32), 'loss': Array(0.99963427, dtype=float32)}
200 {'correlation': Array(0.70244604, dtype=float32), 'loss': Array(0.29755396, dtype=float32)}
400 {'correlation': Array(0.70941615, dtype=float32), 'loss': Array(0.29058385, dtype=float32)}
600 {'correlation': Array(0.7137658, dtype=float32), 'loss': Array(0.2862342, dtype=float32)}
800 {'correlation': Array(0.7153402, dtype=float32), 'loss': Array(0.2846598, dtype=float32)}
CPU times: user 31 s, sys: 489 ms, total: 31.5 s
Wall time: 56.2 s


### 3. Using OpticalSystem and Equinox partitioning to choose what parameters are optimizable

The previous approaches have defined the parameter to be optimized either by setting that parameter as the first argument of the function we differentiate or by creating an Equinox `Module` that has all its attributes set to static except the parameter we wish to optimize. You'll notice that Equinox prints out a warning due to this use of static because it can lead to confusing bugs when combined with other JAX transformations. We set everything that's not optimized to static in order to easily avoid having to deal with gradients to wrong parameters in our model. However, this abuse of static could be limiting (e.g. if we need to combine the optical simulation with other JAX transformations like `vmap` or optimize only some parameters and not others with the same simulation). Another way to choose what parameters of a simulation are optimizable is with `partition`. Here, we'll revisit the CGH example and create an `OpticalSystem` in a very succinct way which does not let us choose what is optimizable. Then, we'll use `partition` to select which parameter should be optimized:

In [4]:
%%time
# Let's initialize the holography model
shape = (256, 256)
spacing = 9.2  # microns
image_plane_spacing = 200.0e3 * 1.035 / (spacing * shape[0])
z = jnp.linspace(0.0, 100.0e4, num=51)  # Planes we want to simulate the hologram at
# This time, the model doesn't let us set static fields
# because we are not creating our own Module.
model = OpticalSystem(
    [
        PlaneWave(shape, spacing, spectrum=1.035),
        PhaseMask(phase=jnp.zeros(shape)),
        FFLens(f=200.0e3, n=1.0),
        Propagate(
            Field.empty(shape, image_plane_spacing, spectrum=1.035),
            z,
            n=1.0,
            method="transfer",
            mode="same",
        ),
    ]
)


# We want to create the optimizer, but first we have to
# select which parameter to optimize. We start by setting
# that we don't want to optimize any parameter:
filter_spec = jax.tree.map(lambda _: False, model)
# Then, we select the parameter we actually want to optimize.
# In this line, the lambda function selects the phase pixels
# from the second element in our optical system, and sets that to true:
filter_spec = eqx.tree_at(lambda m: m.elements[1].phase, filter_spec, True)
# Then we split the model using this filter_spec into the part that
# we want to optimize (the parameters) and the part we don't (the state):
parameters, state = eqx.partition(model, filter_spec)
# Now we can create the optimizer
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(parameters)


# Create a target pattern for which we want to optimize a hologram
sample = np.zeros((51, 256, 256))
sample[30, 128, 128] = 1.0
sample[10, 51, 92] = 1.0
sample[50, 10, 25] = 1.0
diameter = 25
kernel = np.zeros((diameter, diameter, diameter))
grid = np.meshgrid(
    np.linspace(-diameter / 2, diameter / 2, num=diameter),
    np.linspace(-diameter / 2, diameter / 2, num=diameter),
    np.linspace(-diameter / 2, diameter / 2, num=diameter),
)
grid = np.sqrt(grid[0] ** 2 + grid[1] ** 2 + grid[2] ** 2)
kernel[grid < diameter / 5] = 1.0
sample = jnp.fft.ifftn(
    jnp.fft.fftn(jnp.array(sample)) * jnp.fft.fftn(jnp.array(kernel), s=sample.shape)
).real
sample = sample[..., jnp.newaxis, jnp.newaxis]
sample *= 1000.0


# Here we define a loss function, but note that this time
# we're taking the parameters and state separately! The model
# is reconstructed in the loss function, and we'll update just
# the parameters in our optimization loop. The overhead of this
# recombination is compiled away by JAX.
def cgh_loss_fn_partitioned(parameters, state, target):
    model = eqx.combine(parameters, state)
    approx = model().intensity
    correlation = jnp.corrcoef(approx.flatten(), target.flatten())[0, 1]
    loss = 1.0 - correlation
    return loss, {"loss": loss, "correlation": correlation}


# This is our update function just like before, but this time
# we partition and combine the model before computing the loss.
@jax.jit
def update(model, opt_state, target):
    parameters, state = eqx.partition(model, filter_spec)
    grads, metrics = jax.grad(cgh_loss_fn_partitioned, has_aux=True)(
        parameters, state, target
    )
    updates, opt_state = optimizer.update(grads, opt_state, parameters)
    parameters = optax.apply_updates(parameters, updates)
    model = eqx.combine(parameters, state)
    return model, opt_state, metrics


# Now we just run the optimization loop!
max_iterations = 1000
history = {
    "loss": np.zeros((max_iterations)),
    "correlation": np.zeros((max_iterations)),
}
for iteration in range(max_iterations):
    model, opt_state, metrics = update(model, opt_state, sample)
    for m in metrics:
        history[m][iteration] = metrics[m]
    if iteration % 200 == 0:
        print(iteration, metrics)

0 {'correlation': Array(0.00036571, dtype=float32), 'loss': Array(0.99963427, dtype=float32)}
200 {'correlation': Array(0.70244604, dtype=float32), 'loss': Array(0.29755396, dtype=float32)}
400 {'correlation': Array(0.70941615, dtype=float32), 'loss': Array(0.29058385, dtype=float32)}
600 {'correlation': Array(0.7137658, dtype=float32), 'loss': Array(0.2862342, dtype=float32)}
800 {'correlation': Array(0.7153402, dtype=float32), 'loss': Array(0.2846598, dtype=float32)}
CPU times: user 29.2 s, sys: 577 ms, total: 29.8 s
Wall time: 54.8 s
