# JAX-MORPH: Gradients and Optimization

The advantage of writing simulation in JAX is that we can easily compute gradients of the simulation with respect to the parameters. This can be used both to carry out sensitivity analysis and to optimize the parameters of the simulation with respect to some objective function.

In this notebook we will show how to use JAX-MORPH to compute gradients of simulations and carry out gradient-based optimization in this setting.

# Imports & Utils

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as np
jax.config.update("jax_enable_x64", True)

key = jax.random.PRNGKey(0)

import jax_morph as jxm
import equinox as eqx
import jax_md


import matplotlib.pyplot as plt

## Animation functions

In [3]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML


def animate_traj_ctype(trajectory, xlim=(-25, 25), ylim=(-25, 25)):

    sim_steps = trajectory.position.shape[0]

    def animate(i):

        ax.clear()  # Clear the axis to remove the previous frame
        ax.set_aspect('equal', 'box')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xticks([])
        ax.set_yticks([])

        
        state = jxm.BaseCellState(None,
                                  None,
                                  trajectory.position[i],
                                  trajectory.celltype[i],
                                  trajectory.radius[i],
                                  )

        jxm.visualization.draw_circles_ctype(state, ax=ax)

    fig, ax = plt.subplots()
    ax.set_aspect('equal', 'box')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])

    anim = FuncAnimation(fig, animate, frames=sim_steps, interval=100)

    html_anim = HTML(anim.to_jshtml())

    plt.close()

    return html_anim




def animate_traj_chem(trajectory, chem=0, xlim=(-25, 25), ylim=(-25, 25), colorbar=False):

    sim_steps = trajectory.position.shape[0]

    cm = plt.cm.YlGn


    #we only need the chemical field for the plot
    class PlottingCellState(jxm.BaseCellState):
        chemical: np.ndarray

    def animate(i):

        ax.clear()  # Clear the axis to remove the previous frame
        ax.set_aspect('equal', 'box')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xticks([])
        ax.set_yticks([])

        state = PlottingCellState(None,
                            None,
                            trajectory.position[i],
                            trajectory.celltype[i],
                            trajectory.radius[i],
                            trajectory.chemical[i]
                            )
        

        jxm.visualization.draw_circles_chem(state, chem, colorbar=False, cm=cm, ax=ax)

    fig, ax = plt.subplots()
    ax.set_aspect('equal', 'box')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])

    if colorbar:
        vmax = np.max(trajectory.chemical)

        sm = plt.cm.ScalarMappable(cmap=cm, norm=plt.Normalize(vmin=0, vmax=vmax))
        sm._A = []

        cbar = fig.colorbar(sm, ax=ax, fraction=.05, alpha=.5) # rule of thumb
        cbar.set_label('Conc. Chem. '+str(chem), labelpad=20)


    anim = FuncAnimation(fig, animate, frames=sim_steps, interval=100)

    html_anim = HTML(anim.to_jshtml())

    plt.close()

    return html_anim

# Define simulation

We define a simple simulation, following the same structure as in notebook 1. Cells move with brownian motion while growing and interacting mechanically via a Morse potential.

Suppose we want to calculate the gradient of a function of the final output of a simulation (e.g. the average cell distance to the center of mass of the whole cluster) with respect to the parameters of the simulation (e.g. the cell-cell adhesion strength).

JAX-MORPH is thought to take advantage of the features of the `Equinox` library for easy computation of gradients.

In [4]:
### CREATE INITIAL STATE


N_CELLS = 10


### INITIALIZE FIELD VALUES

displacement_fn, shift_fn = jax_md.space.free()

key, pos_subkey, rad_subkey = jax.random.split(key, 3)
pos = jax.random.uniform(pos_subkey, (N_CELLS, 2), minval=-5.0, maxval=5.0)
ctype = np.ones((N_CELLS,1))
rad = jax.nn.relu(jax.random.normal(shape=(N_CELLS, 1), key=rad_subkey)*.1 + .3)


istate = jxm.BaseCellState(displacement=displacement_fn, shift=shift_fn, position=pos, celltype=ctype, radius=rad)

print(istate)

BaseCellState(
  displacement=<function displacement_fn>,
  shift=<function shift_fn>,
  position=f64[10,2],
  celltype=f64[10,1],
  radius=f64[10,1]
)


In [5]:
def avg_dist_loss(model, istate, run_key, n_steps):

    fstate = jxm.simulate(model, istate, key=run_key, n_steps=n_steps)

    #avg cell distance from center of cluster
    center = np.mean(fstate.position, axis=0)
    avg_dist = np.linalg.norm(fstate.position - center, axis=1).mean()

    return avg_dist

**NOTE:** The default setting is that all of the parameters declared as `jax.numpy` arrays are considered as trainable parameters and their gradients are computed. All other parameters are considered as fixed and their gradients are not computed.

By changing the data type of the parameters, we can control which parameters are considered as trainable and which are not. This behavior is guaranteed to work as long as we use Equinox's filtered transformations to compute the gradients (e.g. `eqx.filter_grad`).

In [6]:
### DEFINE SIMULATION STEP


mech_potential = jxm.env.mechanics.MorsePotential(epsilon=np.asarray(3.), alpha=np.asarray(3.), r_cutoff=20., r_onset=15.)


model = jxm.Sequential([

    #cell radius grows linearly in time
    jxm.env.CellGrowth(growth_rate=.005, max_radius=2., growth_type='linear'),

    #cells move with brownian dynamics for 100 steps after each growth step
    jxm.env.mechanics.BrownianMechanicalRelaxation(mech_potential, relaxation_steps=10, kT=1., dt=1e-4)
])

In [7]:
key, subkey = jax.random.split(key)

g = eqx.filter_grad(avg_dist_loss)(model, istate, subkey, 100)

print(g)

Sequential(
  substeps=(
    CellGrowth(
      max_radius=None,
      growth_rate=None,
      growth_type='linear',
      _smoothing_exp=10.0
    ),
    BrownianMechanicalRelaxation(
      mechanical_potential=MorsePotential(
        epsilon=f64[],
        alpha=f64[],
        r_cutoff=20.0,
        r_onset=15.0
      ),
      relaxation_steps=10,
      dt=0.0001,
      kT=1.0,
      gamma=0.8,
      discount=1.0
    )
  ),
  _return_logp=False
)


In [8]:
jax.tree.leaves(g)

[Array(-0.01351579, dtype=float64, weak_type=True),
 Array(0.02199259, dtype=float64, weak_type=True)]

If we want instead to calculate gradients with respect to the cell max radius while keeping the cell-cell adhesion strength fixed we just need to change the data type of the appropriate parameters:

In [9]:
### DEFINE SIMULATION STEP


mech_potential = jxm.env.mechanics.MorsePotential(epsilon=3., alpha=3., r_cutoff=20., r_onset=15.)


model = jxm.Sequential([

    #cell radius grows linearly in time
    jxm.env.CellGrowth(growth_rate=.005, max_radius=np.asarray(2.), growth_type='linear'),

    #cells move with brownian dynamics for 100 steps after each growth step
    jxm.env.mechanics.BrownianMechanicalRelaxation(mech_potential, relaxation_steps=10, kT=1., dt=1e-4)
])

In [10]:
key, subkey = jax.random.split(key)

g = eqx.filter_grad(avg_dist_loss)(model, istate, subkey, 100)

print(g)

Sequential(
  substeps=(
    CellGrowth(
      max_radius=f64[],
      growth_rate=None,
      growth_type='linear',
      _smoothing_exp=10.0
    ),
    BrownianMechanicalRelaxation(
      mechanical_potential=MorsePotential(
        epsilon=None,
        alpha=None,
        r_cutoff=20.0,
        r_onset=15.0
      ),
      relaxation_steps=10,
      dt=0.0001,
      kT=1.0,
      gamma=0.8,
      discount=1.0
    )
  ),
  _return_logp=False
)


In [11]:
jax.tree.leaves(g)

[Array(-0.00052793, dtype=float64, weak_type=True)]

In [10]:
import jax
import jax.numpy as jnp

# Define the forward function
def f_fwd(x):
    return x  # Identity operation

# Define the backward function
def f_bwd(res, g):
    return (g * constant,)  # Multiply gradient by a constant

# Apply the custom_vjp decorator
f = jax.custom_vjp(f_fwd)

# Define the custom VJP rules
f.defvjp(lambda x: (x, None),  # Forward VJP rule
         lambda res, g: f_bwd(res, g))  # Backward VJP rule

# Example usage
x = jnp.array([1.0, 2.0, 3.0])
constant = 4.0  # Define the constant to multiply with gradient

y = f(x)
grad = jax.grad(lambda x: jnp.sum(f(x)))(x)
print("Output:", y)
print("Gradient:", grad)


Output: [1. 2. 3.]
Gradient: [4. 4. 4.]


In [15]:
import jax
import jax.numpy as jnp

def f(x, c):
    return x

# Attach custom VJP rules to the function
f = jax.custom_vjp(f)


# Define the forward pass
def f_fwd(x, c):
    # Return the output and a tuple of residuals to be used in the backward pass
    return x, (c,)

# Define the backward pass
def f_bwd(residuals, g):
    # Extract the constant from the residuals
    c, = residuals
    # Multiply the gradient g by the constant c
    return (g * c, None)  # None for the gradient w.r.t. constant

# Set the forward and backward functions
f.defvjp(f_fwd, f_bwd)

# Test the function
x = jnp.array([1.0, 2.0, 3.0])
c = 5.0
y = f(x, c)
grad = jax.grad(lambda x: jnp.sum(f(x,c)))(x)
print("Output:", y)
print("Gradient w.r.t x:", grad)


Output: [1. 2. 3.]
Gradient w.r.t x: [5. 5. 5.]
