In [1]:
import jax
import jax.numpy as np
import jax.tree_util as jtu

import numpy as onp

from typing import Any, Callable, Union

import equinox as eqx

import matplotlib.pyplot as plt

In [2]:
key = jax.random.PRNGKey(0)

# Cell State

In [3]:
import jax_md

In [23]:
class BaseCellState(eqx.Module):
    '''
    Module containing the basic features of a system state.

    '''

    # METHODS
    displacement:   jax_md.space.DisplacementFn = eqx.field(static=True)
    shift:          jax_md.space.ShiftFn = eqx.field(static=True)

    # STATE
    position:   jax.Array
    celltype:   jax.Array
    radius:     jax.Array
    division:   jax.Array
    #key:       jax.Array


    @classmethod
    def empty(cls, n_dim=2):

        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.

        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'

        disp, shift = jax_md.space.free()
        

        args = {
            'displacement'  :   disp,
            'shift'         :   shift,
            'position'  :   np.empty(shape=(0, n_dim), dtype=np.float32),
            'celltype'  :   np.empty(shape=(0,), dtype=np.float32),
            'radius'    :   np.empty(shape=(0,), dtype=np.float32),
            'division'   :   np.empty(shape=(0,), dtype=np.float32),
            #'key'       :   None,
            }
        
        return cls(**args)


In [24]:
class CellState(BaseCellState):
    chemical:           jax.Array
    secretion_rate:     jax.Array
    chemical_grad:      jax.Array
    hidden_state:       jax.Array
    mechanical_stress:  jax.Array



    @classmethod
    def empty(cls, n_dim=2, n_chem=1, hidden_size=10):

        a = BaseCellState.empty(n_dim).__dict__

        new_args = {
            'chemical' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'secretion_rate' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'chemical_grad' :   np.empty(shape=(0, int(n_dim*n_chem)), dtype=np.float32),
            'hidden_state' :   np.empty(shape=(0, hidden_size), dtype=np.float32),
            'mechanical_stress' :  np.empty(shape=(0,), dtype=np.float32),
            }
        
        a.update(new_args)

        return cls(**a)

In [25]:
disp, shift = jax_md.space.free()

N_DIM = 2
N = 5

N_CHEM = 3
N_HIDDEN = 5

test_state = BaseCellState(
    displacement=disp,
    shift=shift,
    position=np.zeros(shape=(N,N_DIM)),
    celltype=np.zeros(shape=(N,)).at[0].set(1),
    radius=np.zeros(shape=(N,)).at[0].set(1.),
    division=np.zeros(shape=(N,)).at[0].set(1.),
)    


test_state2 = CellState(
    displacement=disp,
    shift=shift,
    position=np.zeros(shape=(N,N_DIM)),
    celltype=np.zeros(shape=(N,)).at[0].set(1),
    radius=np.zeros(shape=(N,)).at[0].set(1.),
    division=np.zeros(shape=(N,)).at[0].set(1.),
    chemical=np.zeros(shape=(N,N_CHEM)),
    secretion_rate=np.zeros(shape=(N,N_CHEM)).at[0].set(1.),
    chemical_grad=np.zeros(shape=(N,int(N_DIM*N_CHEM))),
    hidden_state=np.zeros(shape=(N,N_HIDDEN)),
    mechanical_stress=np.zeros(shape=(N,)),
)

# Environment

In [38]:
import abc

class SimulationStep(eqx.Module):

    @abc.abstractmethod
    def __call__(self, state, *, key=None):
        pass

## Cell division

**NOTE**: The whole Gumbel-Softmax is useless in this context. By formulating copy operations as multiplications autodiff is already able to deal with this situation. The gradient of the operation is just the copy matrix!

In [39]:
class CellDivisionReparam(SimulationStep):
    birth_radius_multiplier:    float
    #inference:                  bool = eqx.field(static=True)

    def __init__(self, #state,
                 *,
                 birth_radius_multiplier=float(1/np.sqrt(2)),
                 #inference=True,
                 **kwargs
                 ):

        # if not hasattr(state, 'divrate'):
        #     raise AttributeError('CellState must have "divrate" attribute')
        # if not hasattr(state, 'radius'):
        #     raise AttributeError('CellState must have "radius" attribute')
        # if not hasattr(state, 'celltype'):
        #     raise AttributeError('CellState must have "celltype" attribute')
        # if not hasattr(state, 'position'):
        #     raise AttributeError('CellState must have "position" attribute')
        # if not hasattr(state, 'key'):
        #     raise AttributeError('CellState must have valid "key" attribute')
        
        self.birth_radius_multiplier = birth_radius_multiplier
        #self.inference = inference
        

    def __call__(self, state, *, key=None):#, inference=None, softmax_T=1.):

        # if inference is None:
        #     inference = self.inference


        #split key
        subkey_div, subkey_place = jax.random.split(key,2)
        
        p = state.division/state.division.sum()

        # safe_p = np.where(p > 0, p, 1)
        # logp = np.where(p > 0, np.log(safe_p), -np.inf)

        # logit = (logp + jax.random.gumbel(subkey_div, shape=logp.shape))/softmax_T
        # idx_dividing_cell = np.argmax(logit)

        idx_dividing_cell = jax.random.choice(subkey_div, a=len(p), p=p)
        new_cell_contribs = np.zeros_like(state.celltype).at[idx_dividing_cell].set(1.)

        # if inference:
        #     new_cell_contribs = np.zeros_like(state.celltype).at[idx_dividing_cell].set(1.)
        #     print(new_cell_contribs)
        # else:
        #     new_cell_contribs = jax.nn.softmax(logit)
        #     print(new_cell_contribs)

        
        idx_new_cell = np.count_nonzero(state.celltype)

        division_matrix = np.eye(state.celltype.shape[0]).at[idx_new_cell].set(new_cell_contribs)

        new_state = jax.tree_map(lambda x: np.dot(division_matrix, x), state)


        #resize cell radii
        get_radius = lambda s: s.radius
        resize_rad = lambda r: r.at[idx_new_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier).at[idx_dividing_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier)
        new_state = eqx.tree_at(get_radius, new_state, replace_fn=resize_rad)
        
        ### POSITION OF NEW CELLS
        angle = jax.random.uniform(subkey_place, minval=0., maxval=2*np.pi)
        cell_displacement = self.birth_radius_multiplier*np.array([np.cos(angle),np.sin(angle)])

        get_position = lambda s: s.position
        new_position = lambda p: p.at[idx_new_cell].set(p[idx_dividing_cell]-cell_displacement).at[idx_dividing_cell].set(p[idx_dividing_cell]+cell_displacement)
        new_state = eqx.tree_at(get_position, new_state, replace_fn=new_position)
        

        return new_state

## Cell growth

In [40]:
class CellGrowth(SimulationStep):
    max_radius:     float
    growth_rate:    float
    growth_type:    str = eqx.field(static=True)
    
    def __init__(self, #state, 
                 *, 
                 growth_rate=1., 
                 max_radius=.5, 
                 growth_type='linear',
                 **kwargs
                 ):

        # if not hasattr(state, 'radius'):
        #     raise AttributeError('CellState must have "radius" attribute')
        
        if growth_type not in ['linear', 'exponential']:
            raise ValueError('growth_type must be either "linear" or "exponential"')
        
        self.growth_rate = growth_rate
        self.max_radius = max_radius
        self.growth_type = growth_type


    def __call__(self, state, *, key=None):

        if self.growth_type == 'linear':
            new_radius = state.radius + self.growth_rate
        elif self.growth_type == 'exponential':
            new_radius = state.radius*np.exp(self.growth_rate)

        new_radius = np.where(new_radius > self.max_radius, self.max_radius, new_radius)*np.where(state.celltype>0, 1, 0)

        get_radius = lambda s: s.radius
        state = eqx.tree_at(get_radius, state, new_radius)

        return state

## Mechanical

In [41]:
class SGDMechanicalRelaxation(SimulationStep):
    mechanical_potential:   eqx.Module
    relaxation_steps:       int = eqx.field(default=15, static=True)
    dt:                     float = eqx.field(default=1e-3, static=True)


    def _sgd(self, state, pair_potential):

        init, apply = jax_md.minimize.gradient_descent(pair_potential, state.shift, self.dt) 
 
        def scan_fn(opt_state, i):
            return apply(opt_state), 0.

        #relax system
        opt_state = init(state.position)
        opt_state, _ = jax.lax.scan(scan_fn, opt_state, np.arange(self.relaxation_steps))

        return opt_state
    

    def __call__(self, state, *, key=None):

        #generate morse pair potential
        energy_fn = self.mechanical_potential.energy_fn(state)
        
        #minimize
        new_positions = self._sgd(state, energy_fn)

        state = eqx.tree_at(lambda s: s.position, state, new_positions)

        return state

In [45]:
#Define Potential ABC
class MechanicalInteractionPotential(eqx.Module):

    @abc.abstractmethod
    def energy_fn(self, state, *, per_particle):
        pass


class MorsePotential(MechanicalInteractionPotential):
    epsilon:   Union[float, jax.Array] = 3.
    alpha:     float = 2.8
    r_cutoff:  float = eqx.field(default=2., static=True)
    r_onset:   float = eqx.field(default=1.7, static=True)


    def _calculate_epsilon_matrix(self, state):

        if np.atleast_1d(self.epsilon).size == 1:
            alive = np.where(state.celltype > 0, 1, 0)
            epsilon_matrix = (np.outer(alive, alive)-np.eye(alive.shape[0]))*self.epsilon


        elif isinstance(self.epsilon, jax.interpreters.xla.DeviceArray):
            
            ### implement general logic for multiple cell types
            raise NotImplementedError('Multiple cell types not implemented yet')


        return epsilon_matrix
    

    def _calculate_sigma_matrix(self, state):

        sigma_matrix = state.radius[:,None] + state.radius[None,:]

        return sigma_matrix
    

    def energy_fn(self, state, *, per_particle=False):

        epsilon_matrix = self._calculate_epsilon_matrix(state)
        sigma_matrix = self._calculate_sigma_matrix(state)

        #generate morse pair potential
        morse_energy = jax_md.energy.morse_pair(state.displacement,
                                                alpha=self.alpha,
                                                epsilon=epsilon_matrix,
                                                sigma=sigma_matrix, 
                                                r_onset=self.r_onset, 
                                                r_cutoff=self.r_cutoff,
                                                per_particle=per_particle
                                                )
        
        return morse_energy


## Diffusion

In [13]:
class SteadyStateDiffusion(SimulationStep):
    diffusion_coeff:    Union[float, jax.Array]
    degradation_rate:   Union[float, jax.Array]
    _vmap_diff_inaxes:  tuple = eqx.field(static=True)

    def __init__(self, diffusion_coeff=2., degradation_rate=1., **kwargs):

        self.diffusion_coeff = diffusion_coeff
        self.degradation_rate = degradation_rate

        _inaxes_diffcoef = 0 if np.atleast_1d(self.diffusion_coeff).size > 1 else None
        _inaxes_degrate = 0 if np.atleast_1d(self.degradation_rate).size > 1 else None
        self._vmap_diff_inaxes = (1, _inaxes_diffcoef, _inaxes_degrate, None)


    def __call__(self, state, *, key=None):

        #calculate all pairwise distances
        dist = jax_md.space.map_product(jax_md.space.metric(state.displacement))(state.position, state.position)

        #prevent division by zero
        dist *= np.where(np.outer(state.celltype, state.celltype)>0, 1, -1)
        dist -= np.eye(dist.shape[0])

        #adjacency matrix
        # zero out connections to inexistent cells
        A = 1/dist
        A = (np.where(A>0, A, 0))**2


        #calculate graph laplacian
        L = np.diag(np.sum(A, axis=0)) - A
    

        def _ss_chemfield(P, D, K, L):

            #update laplacian with degradation
            L = D*L + K*np.eye(L.shape[0])

            #solve for steady state
            c = np.linalg.solve(L, P)

            return c
        
        #calculate steady state chemical field
        _ss_chemfield = jax.vmap(_ss_chemfield, in_axes=self._vmap_diff_inaxes, out_axes=1)

        new_chem = _ss_chemfield(state.secretion_rate, self.diffusion_coeff, self.degradation_rate, L)

        #update chemical field
        state = eqx.tree_at(lambda s: s.chemical, state, new_chem)

        return state

# Cell internals

## Chemical gradients

In [14]:
class LocalChemicalGradients(SimulationStep):
    neighbor_radius:    Union[float,None] = eqx.field(default=None, static=True)

    def __call__(self, state, *, key=None):

        # mask only cells that exist
        c_alive = state.celltype>0.

        # displacements between cell pairs (ncells x ncells x ndim_space)
        disp = jax.vmap(jax.vmap(state.displacement, in_axes=[0,None]), in_axes=[None,0])(state.position, state.position)

        # distances btw cell pairs
        # dist w/ non-existing cells are zeroed out
        dist = np.sqrt((disp**2).sum(2))*np.outer(c_alive, c_alive)

        # consider as neigbors:
        # - cells less than one radius away (+ small tolerance)
        # - cells differents from themselves
        # - cells that exist
        if None == self.neighbor_radius:
            # "touching" distance betw. cells
            R = (state.radius+state.radius[:,None])*np.outer(c_alive, c_alive)
        else:
            R = (self.neighbor_radius)*np.outer(c_alive, c_alive)
        
        neig = (dist<R)*(dist>0.)

        # normalize all displacements
        norm_disp = (disp*neig[:,:,None])/(dist[:,:,None]+1e-8)

        # calculates x and y components of grad for single chemical
        def _grad_chem(chem):
            return (norm_disp*chem.ravel()[:,None]).sum(1)
            

        #vectorize over chemicals
        #OUTPUT SHAPE: ncells x ndim x nchem
        _grad_chem = jax.vmap(_grad_chem, in_axes=1, out_axes=2)

        #calc grads (no non-existing cells or lone cells w/ no neighbors)
        chemgrads = _grad_chem(state.chemical)
            
        # transform into ncells x (grad_x + grad_y)
        #reshape like ncells x ndim x nchem to revert
        chemgrads = chemgrads.reshape(len(state.celltype), -1)

        #update state
        state = eqx.tree_at(lambda s: s.chemical_grad, state, chemgrads)

        return state

## Mechanical stress

In [15]:
class LocalMechanicalStress(SimulationStep):
    mechanical_potential:       eqx.Module
    
    def __call__(self, state, *, key=None):

        #generate pair potential
        pair_potential = self.mechanical_potential.energy_fn(state)
        
        forces = jax.jacrev(pair_potential)(state.position)
        
        # F_ij = force on i by j, r_ij = displacement from i to j
        drs = jax_md.space.map_product(state.displacement)(state.position, state.position)
        
        stresses = np.sum(np.multiply(forces, np.sign(drs)), axis=(0, 2))
        stresses = np.where(state.celltype > 0, stresses, 0.0)

        state = eqx.tree_at(lambda s: s.mechanical_stress, state, stresses)

        return state

## Division Rates

In [None]:
class DivisionMLP(SimulationStep):
    mlp:    eqx.nn.MLP


    

# Tests

In [49]:
key, subkey = jax.random.split(key)

In [94]:
# class RelaxAndStress(eqx.Module):
#     shared:     eqx.Module

#     def __init__(self, relax, stress, **kwargs):
        
#         where = lambda x: x[1].mechanical_potential
#         get = lambda x: x[0].mechanical_potential

#         self.shared = eqx.nn.Shared((relax, stress), where, get)


#     def __call__(self, state, *, key=None):

#         relax, stress = self.shared()

#         state = relax(state, key=key)
#         state = stress(state, key=key)

#         return state

In [95]:
mech_potential = MorsePotential(epsilon=np.asarray(3.), alpha=2.8)



model = eqx.nn.Sequential(
    layers=[
        CellDivisionReparam(birth_radius_multiplier=np.asarray(.5)), 
        CellGrowth(growth_rate=np.asarray(.03), max_radius=1.1, growth_type='linear'),
        SGDMechanicalRelaxation(mech_potential),
        SteadyStateDiffusion(degradation_rate=np.asarray(2.), diffusion_coeff=.5),
        LocalChemicalGradients(),
        # RelaxAndStress(
            relax=SGDMechanicalRelaxation(mech_potential),
            stress=LocalMechanicalStress(mech_potential)
        )
        # LocalMechanicalStress(mech_potential)
    ])

def loss(model, state, key):
    
    for i in range(3):
        key, subkey = jax.random.split(key)
        state = model(state, key=subkey)

    return (state.position[:,1]**2).sum(), state

In [96]:
_, s = loss(model, test_state2, subkey)

In [52]:
s.mechanical_stress

Array([21.791218 , 14.1632595, 32.398632 , 14.496561 ,  0.       ],      dtype=float32)

In [53]:
s

CellState(
  displacement=<function displacement_fn>,
  shift=<function shift_fn>,
  position=f32[5,2],
  celltype=f32[5],
  radius=f32[5],
  division=f32[5],
  chemical=f32[5,3],
  secretion_rate=f32[5,3],
  chemical_grad=f32[5,6],
  hidden_state=f32[5,5],
  mechanical_stress=f32[5]
)

In [97]:
g, s = eqx.filter_grad(loss, has_aux=True)(model, test_state2, subkey)

In [99]:
g

Sequential(
  layers=(
    CellDivisionReparam(birth_radius_multiplier=f32[]),
    CellGrowth(max_radius=None, growth_rate=f32[], growth_type='linear'),
    SteadyStateDiffusion(
      diffusion_coeff=None,
      degradation_rate=f32[],
      _vmap_diff_inaxes=(1, None, None, None)
    ),
    LocalChemicalGradients(neighbor_radius=None),
    RelaxAndStress(
      shared=Shared(
        pytree=(
          SGDMechanicalRelaxation(
            mechanical_potential=MorsePotential(
              epsilon=f32[],
              alpha=None,
              r_cutoff=2.0,
              r_onset=1.7
            ),
            relaxation_steps=15,
            dt=0.001
          ),
          LocalMechanicalStress(mechanical_potential=None)
        ),
        where=None,
        get=None
      )
    )
  )
)

In [102]:
g.layers[-1].shared.pytree[0].mechanical_potential.epsilon

Array(-0.1339076, dtype=float32, weak_type=True)

In [103]:
g.layers[-1].shared.pytree[1].mechanical_potential.epsilon

AttributeError: 'NoneType' object has no attribute 'epsilon'

In [None]:
g.layers[-1].mechanical_potential.epsilon

Array(0., dtype=float32, weak_type=True)

In [90]:
model_shared = eqx.nn.Shared(model, lambda m: m.layers[-1].mechanical_potential.epsilon, lambda m: m.layers[2].mechanical_potential.epsilon)

In [104]:
g, s = eqx.filter_grad(loss, has_aux=True)(model_shared(), test_state2, subkey)

In [106]:
g.layers[2].mechanical_potential.epsilon

Array(-0.0153906, dtype=float32, weak_type=True)

In [93]:
assert model_shared().layers[-1].mechanical_potential.epsilon is model_shared().layers[2].mechanical_potential.epsilon