# JAX-MORPH: Gradients and Optimization

The advantage of writing simulation in JAX is that we can easily compute gradients of the simulation with respect to the parameters. This can be used both to carry out sensitivity analysis and to optimize the parameters of the simulation with respect to some objective function.

In this notebook we will show how to use JAX-MORPH to compute gradients of simulations and carry out gradient-based optimization in this setting.

# Imports & Utils

In [2]:
# JAX Imports
import jax
import jax.numpy as np
import matplotlib.pyplot as plt

# JAX-Morph Imports
import jax_morph as jxm  # type: ignore

import equinox as eqx

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_enable_x64", True)

## Plotting

In [None]:
def plot_reward(train_log):

    plt.plot(train_log.rewards, "r")
    plt.xlabel("Epoch")
    plt.ylabel("Reward")
    plt.grid(alpha=0.2)
    plt.tight_layout()
    plt.close()


def plot_spheres_div(model, istate, sim_key):

    N_ADD = int(istate.celltype.shape[0] - istate.celltype.sum(-1).sum(-1))

    fstate_opt, _ = jxm.simulate(model, istate, key=sim_key, n_steps=N_ADD)

    fig, ax = jxm.visualization.draw_spheres_division(
        fstate_opt,
        grid=True,
        elev=0,
        azim=0,
    )

    # axis labels
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    plt.tight_layout()

    plt.close()

## Animation functions

In [3]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML


def animate_traj_ctype(trajectory, xlim=(-25, 25), ylim=(-25, 25)):

    sim_steps = trajectory.position.shape[0]

    def animate(i):

        ax.clear()  # Clear the axis to remove the previous frame
        ax.set_aspect("equal", "box")
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xticks([])
        ax.set_yticks([])

        state = jxm.BaseCellState(
            None,
            None,
            trajectory.position[i],
            trajectory.celltype[i],
            trajectory.radius[i],
        )

        jxm.visualization.draw_circles_ctype(state, ax=ax)

    fig, ax = plt.subplots()
    ax.set_aspect("equal", "box")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])

    anim = FuncAnimation(fig, animate, frames=sim_steps, interval=100)

    html_anim = HTML(anim.to_jshtml())

    plt.close()

    return html_anim


def animate_traj_chem(
    trajectory, chem=0, xlim=(-25, 25), ylim=(-25, 25), colorbar=False
):

    sim_steps = trajectory.position.shape[0]

    cm = plt.cm.YlGn

    # we only need the chemical field for the plot
    class PlottingCellState(jxm.BaseCellState):
        chemical: np.ndarray

    def animate(i):

        ax.clear()  # Clear the axis to remove the previous frame
        ax.set_aspect("equal", "box")
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xticks([])
        ax.set_yticks([])

        state = PlottingCellState(
            None,
            None,
            trajectory.position[i],
            trajectory.celltype[i],
            trajectory.radius[i],
            trajectory.chemical[i],
        )

        jxm.visualization.draw_circles_chem(state, chem, colorbar=False, cm=cm, ax=ax)

    fig, ax = plt.subplots()
    ax.set_aspect("equal", "box")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])

    if colorbar:
        vmax = np.max(trajectory.chemical)

        sm = plt.cm.ScalarMappable(cmap=cm, norm=plt.Normalize(vmin=0, vmax=vmax))
        sm._A = []

        cbar = fig.colorbar(sm, ax=ax, fraction=0.05, alpha=0.5)  # rule of thumb
        cbar.set_label("Conc. Chem. " + str(chem), labelpad=20)

    anim = FuncAnimation(fig, animate, frames=sim_steps, interval=100)

    html_anim = HTML(anim.to_jshtml())

    plt.close()

    return html_anim

# Build initial state

In [None]:
# Define CellState class
class CellState(jxm.BaseCellState):
    division: jax.Array
    chemical: jax.Array
    secretion_rate: jax.Array
    hidden_state: jax.Array

In [None]:
# Build initial state
def build_istate(
    N_HIDDEN_GENES=1,
    N_INIT=15,
    N_TOT=150,
    N_DIM=3,
):
    """
    Initialize CellState data structure.

    Parameters:
    - N_HIDDEN_GENES: number of hidden genes
    - N_INIT: number of initial cells
    - N_TOT: total number of cells
    - N_DIM: number of dimensions

    Returns:
    - istate: initial CellState
    """

    # Define number of cell types and chemical species
    # Create CellState data structure

    N_CTYPES = 2
    N_CHEM = 1

    disp, shift = jxm.space.free()

    istate = CellState(
        displacement=disp,
        shift=shift,
        position=np.zeros(shape=(N_TOT, N_DIM)),
        celltype=np.zeros(shape=(N_TOT, N_CTYPES)).at[0].set(np.array([1.0, 0.0])),
        radius=np.zeros(shape=(N_TOT, 1)).at[0].set(0.5),
        division=np.zeros(shape=(N_TOT, 1)).at[0].set(1.0),
        chemical=np.zeros(shape=(N_TOT, N_CHEM)),
        secretion_rate=np.zeros(shape=(N_TOT, N_CHEM)).at[0].set(1.0),
        hidden_state=np.zeros(shape=(N_TOT, N_HIDDEN_GENES)),
    )

    # Grow initial cluster of cells (uniform division probability)

    ikey = jax.random.PRNGKey(0)

    mech_potential = jxm.env.mechanics.MorsePotential(epsilon=2.0, alpha=2.8)

    imodel = jxm.Sequential(
        substeps=[
            jxm.env.CellGrowth(growth_rate=0.03, max_radius=0.5, growth_type="linear"),
            jxm.env.mechanics.SGDMechanicalRelaxation(mech_potential),
            jxm.env.CellDivision(),
        ]
    )

    istate = jxm.simulate(imodel, istate, key=ikey, n_steps=N_INIT - 1)[0]

    # Prepare the initial state

    # identify cell idxs with position in the lower half of the y-axis and are alive
    lower_half_idx = np.where(istate.position[:, 1] < 0.5)[0]
    alive_idx = np.where(istate.celltype.sum(1) > 0)[0]
    lower_half_alive_idx = np.intersect1d(lower_half_idx, alive_idx)

    istate = eqx.tree_at(
        lambda s: s.division, istate, istate.division.at[lower_half_alive_idx].set(0.0)
    )
    istate = eqx.tree_at(
        lambda s: s.secretion_rate,
        istate,
        np.zeros_like(istate.secretion_rate).at[lower_half_alive_idx].set(1.0),
    )
    istate = eqx.tree_at(
        lambda s: s.celltype,
        istate,
        istate.celltype.at[lower_half_alive_idx].set(np.array([0.0, 1.0])),
    )

    istate = jxm.env.ExponentialSteadyStateDiffusion(
        diffusion_coeff=0.1,
        degradation_rate=0.1,
    )(istate)

    return istate

# Build the model to simulate

In [None]:
class ZeroDivisionMaskByCellType(jxm.SimulationStep):
    ctype_zero_div: jax.Array = eqx.field(static=True)

    def return_logprob(self) -> bool:
        return False

    def __init__(self, state, ctype_zero_div=None):
        if ctype_zero_div is None:
            raise ValueError("ctype_zero_div must be provided")

        if not isinstance(ctype_zero_div, (list, np.ndarray)):
            raise ValueError("ctype_zero_div must be a list or array")

        if len(ctype_zero_div) != state.celltype.shape[1]:
            raise ValueError(
                f"ctype_zero_div must have length {state.celltype.shape[1]} (number of cell types)"
            )

        self.ctype_zero_div = np.asarray(ctype_zero_div).tolist()

    @jax.named_scope("jax_morph.ZeroDivisionMaskByCellType")
    def __call__(self, state, *, key=None, **kwargs):
        # Create mask for cells that should have division zeroed out
        # Multiply celltype matrix by mask vector to get per-cell mask
        div_mask = state.celltype @ np.asarray(self.ctype_zero_div)

        # sum over cell types to get a single mask per cell
        # zero out division for cell types provided in ctype_zero_div
        div_mask = (div_mask[..., None] if div_mask.ndim == 1 else div_mask).sum(
            axis=1
        ) > 0.0

        # Zero out division for masked cells by multiplying by (1 - mask)
        division = state.division * (1.0 - div_mask)[:, None]
        state = eqx.tree_at(lambda s: s.division, state, division)

        return state

In [None]:
# Build the simulation step model
def build_model(init_key, istate):
    """
    Build the simulation step model.

    Parameters:
    - init_key: random key for initialization
    - istate: initial CellState

    """

    # Define mechanical interaction potential
    mech_potential = jxm.env.mechanics.MorsePotential(epsilon=1.7, alpha=3.5)

    # Define model
    model = jxm.Sequential(
        substeps=[
            jxm.env.CellGrowth(growth_rate=0.03, max_radius=0.5, growth_type="linear"),
            jxm.env.mechanics.SGDMechanicalRelaxation(mech_potential),
            jxm.env.ExponentialSteadyStateDiffusion(
                diffusion_coeff=1.0,
                degradation_rate=0.1,
            ),
            jxm.cell.GeneNetwork(
                istate,
                input_fields=["chemical"],
                output_fields=["division"],
                key=init_key,
                transform_output={
                    "division": lambda s, x: x * jax.nn.sigmoid(50 * (s.radius - 0.45))
                },
                degradation_init=jax.nn.initializers.constant(0.1),
                interaction_init=jax.nn.initializers.orthogonal(1e-2),
                dt=0.1,
                T=5.0,
                expression_offset=-0.5,
            ),
            ZeroDivisionMaskByCellType(istate, ctype_zero_div=[0.0, 1.0]),
            jxm.env.CellDivision(),
        ]
    )

    return model

# Training

In [None]:
# BUILD INITIAL STATE
key = jxm.utils.generate_random_key()

istate = build_istate(
    N_HIDDEN_GENES=1,
    N_INIT=15,
    N_TOT=150,
    N_DIM=3,
)

# BUILD INITIAL MODEL
key, init_key = jax.random.split(key)
model = build_model(init_key, istate)

In [None]:
# Define rewards function (maximize sum of squares along y-axis)
rewards_fn = jxm.opt.reward_ssq_diff(coordinate_idx=1, center_of_mass=True)

# create Reinforce loss function
rloss_fn = jxm.opt.reinforce_loss()

# add L1 regularization to interaction matrix
loss_fn = lambda model, trajectory, rewards, sim_keys: rloss_fn(
    model, trajectory, rewards, sim_keys
) + 1e0 * np.mean(np.abs(model.GeneNetwork.interaction_matrix))

In [None]:
# TRAIN MODEL
key, train_key = jax.random.split(key)

trained_model, training_log = jxm.opt.train_reinforce(
    train_key,
    model,
    istate,
    rewards_fn,
    loss_fn,
    epochs=1000,
    learning_rate=5e-3,
    batch_size=4,
    return_discount=0.6,
)

In [None]:
plot_reward(training_log)