In [1]:
import jax
import jax.numpy as np
import jax.tree_util as jtu

import numpy as onp

from typing import Any, Callable, Union

import equinox as eqx

import matplotlib.pyplot as plt

In [2]:
key = jax.random.PRNGKey(0)

# Cell State

In [3]:
import jax_md

In [4]:
class BaseCellState(eqx.Module):
    '''
    Module containing the basic features of a system state.

    '''

    # METHODS
    displacement:   jax_md.space.DisplacementFn = eqx.field(static=True)
    shift:          jax_md.space.ShiftFn = eqx.field(static=True)

    # STATE
    position:   jax.Array
    celltype:   jax.Array
    radius:     jax.Array
    divrate:    jax.Array
    #key:        jax.Array


    @classmethod
    def empty(cls, n_dim=2):

        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.

        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'

        disp, shift = jax_md.space.free()
        

        args = {
            'displacement'  :   disp,
            'shift'         :   shift,
            'position'  :   np.empty(shape=(0, n_dim), dtype=np.float32),
            'celltype'  :   np.empty(shape=(0,), dtype=np.float32),
            'radius'    :   np.empty(shape=(0,), dtype=np.float32),
            'divrate'   :   np.empty(shape=(0,), dtype=np.float32),
            #'key'       :   None,
            }
        
        return cls(**args)


In [44]:
class CellState(BaseCellState):
    chemical:       jax.Array
    secretion_rate: jax.Array
    chemical_grad:  jax.Array
    hidden_state:   jax.Array



    @classmethod
    def empty(cls, n_dim=2, n_chem=1, hidden_size=10):

        a = BaseCellState.empty(n_dim).__dict__

        new_args = {
            'chemical' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'secretion_rate' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'chemical_grad' :   np.empty(shape=(0, int(n_dim*n_chem)), dtype=np.float32),
            'hidden_state' :   np.empty(shape=(0, hidden_size), dtype=np.float32),
            }
        
        a.update(new_args)

        return cls(**a)

In [60]:
disp, shift = jax_md.space.free()

N_DIM = 2
N = 5

N_CHEM = 3
N_HIDDEN = 5

test_state = BaseCellState(
    displacement=disp,
    shift=shift,
    position=np.zeros(shape=(N,N_DIM)),
    celltype=np.zeros(shape=(N,)).at[0].set(1),
    radius=np.zeros(shape=(N,)).at[0].set(1.),
    divrate=np.zeros(shape=(N,)).at[0].set(1.),
)    


test_state2 = CellState(
    displacement=disp,
    shift=shift,
    position=np.zeros(shape=(N,N_DIM)),
    celltype=np.zeros(shape=(N,)).at[0].set(1),
    radius=np.zeros(shape=(N,)).at[0].set(1.),
    divrate=np.zeros(shape=(N,)).at[0].set(1.),
    chemical=np.zeros(shape=(N,N_CHEM)),
    secretion_rate=np.zeros(shape=(N,N_CHEM)).at[0].set(1.),
    chemical_grad=np.zeros(shape=(N,int(N_DIM*N_CHEM))),
    hidden_state=np.zeros(shape=(N,N_HIDDEN)),
)

# Cell division

**NOTE**: The whole Gumbel-Softmax is useless in this context. By formulating copy operations as multiplications autodiff is already able to deal with this situation. The gradient of the operation is just the copy matrix!

In [46]:
class CellDivisionReparam(eqx.Module):
    birth_radius_multiplier:    float
    #inference:                  bool = eqx.field(static=True)

    def __init__(self, #state,
                 *,
                 birth_radius_multiplier=float(1/np.sqrt(2)),
                 #inference=True,
                 **kwargs
                 ):

        # if not hasattr(state, 'divrate'):
        #     raise AttributeError('CellState must have "divrate" attribute')
        # if not hasattr(state, 'radius'):
        #     raise AttributeError('CellState must have "radius" attribute')
        # if not hasattr(state, 'celltype'):
        #     raise AttributeError('CellState must have "celltype" attribute')
        # if not hasattr(state, 'position'):
        #     raise AttributeError('CellState must have "position" attribute')
        # if not hasattr(state, 'key'):
        #     raise AttributeError('CellState must have valid "key" attribute')
        
        self.birth_radius_multiplier = birth_radius_multiplier
        #self.inference = inference
        

    def __call__(self, state, *, key=None):#, inference=None, softmax_T=1.):

        # if inference is None:
        #     inference = self.inference


        #split key
        subkey_div, subkey_place = jax.random.split(key,2)
        
        p = state.divrate/state.divrate.sum()

        # safe_p = np.where(p > 0, p, 1)
        # logp = np.where(p > 0, np.log(safe_p), -np.inf)

        # logit = (logp + jax.random.gumbel(subkey_div, shape=logp.shape))/softmax_T
        # idx_dividing_cell = np.argmax(logit)

        idx_dividing_cell = jax.random.choice(subkey_div, a=len(p), p=p)
        new_cell_contribs = np.zeros_like(state.celltype).at[idx_dividing_cell].set(1.)

        # if inference:
        #     new_cell_contribs = np.zeros_like(state.celltype).at[idx_dividing_cell].set(1.)
        #     print(new_cell_contribs)
        # else:
        #     new_cell_contribs = jax.nn.softmax(logit)
        #     print(new_cell_contribs)

        
        idx_new_cell = np.count_nonzero(state.celltype)

        division_matrix = np.eye(state.celltype.shape[0]).at[idx_new_cell].set(new_cell_contribs)

        new_state = jax.tree_map(lambda x: np.dot(division_matrix, x), state)


        #resize cell radii
        get_radius = lambda s: s.radius
        resize_rad = lambda r: r.at[idx_new_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier).at[idx_dividing_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier)
        new_state = eqx.tree_at(get_radius, new_state, replace_fn=resize_rad)
        
        ### POSITION OF NEW CELLS
        angle = jax.random.uniform(subkey_place, minval=0., maxval=2*np.pi)
        cell_displacement = self.birth_radius_multiplier*np.array([np.cos(angle),np.sin(angle)])

        get_position = lambda s: s.position
        new_position = lambda p: p.at[idx_new_cell].set(p[idx_dividing_cell]-cell_displacement).at[idx_dividing_cell].set(p[idx_dividing_cell]+cell_displacement)
        new_state = eqx.tree_at(get_position, new_state, replace_fn=new_position)
        

        return new_state

In [47]:
def loss(a, state, key):
    div = CellDivisionReparam(inference=False, birth_radius_multiplier=a)

    for i in range(1):
        key, subkey = jax.random.split(key)
        state = div(state, key=subkey)

    return state.radius.sum(), state

In [49]:
key, subkey = jax.random.split(key)

g, s = jax.grad(loss, has_aux=True)(1., test_state2, subkey)


In [50]:
g

Array(2., dtype=float32, weak_type=True)

# Cell growth

In [14]:
class CellGrowth(eqx.Module):
    max_radius:     float
    growth_rate:    float
    growth_type:    str = eqx.field(static=True)
    
    def __init__(self, #state, 
                 *, 
                 growth_rate=1., 
                 max_radius=.5, 
                 growth_type='linear',
                 **kwargs
                 ):

        # if not hasattr(state, 'radius'):
        #     raise AttributeError('CellState must have "radius" attribute')
        
        if growth_type not in ['linear', 'exponential']:
            raise ValueError('growth_type must be either "linear" or "exponential"')
        
        self.growth_rate = growth_rate
        self.max_radius = max_radius
        self.growth_type = growth_type


    def __call__(self, state, *, key=None):

        if self.growth_type == 'linear':
            new_radius = state.radius + self.growth_rate
        elif self.growth_type == 'exponential':
            new_radius = state.radius*np.exp(self.growth_rate)

        new_radius = np.where(new_radius > self.max_radius, self.max_radius, new_radius)*np.where(state.celltype>0, 1, 0)

        get_radius = lambda s: s.radius
        state = eqx.tree_at(get_radius, state, new_radius)

        return state

In [65]:
cg = CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear')

In [66]:
def loss(a, state, key):
    div = CellDivisionReparam(inference=False, birth_radius_multiplier=a)
    cg = CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear')

    for i in range(1):
        key, subkey = jax.random.split(key)
        state = div(state, key=subkey)
        state = cg(state)

    return state.radius.sum(), state

In [67]:
g, s = jax.grad(loss, has_aux=True)(1., test_state, subkey)

In [68]:
s.radius

Array([1.01, 1.01, 0.  , 0.  , 0.  ], dtype=float32)

In [74]:
boh = eqx.nn.Sequential(layers=[
    CellDivisionReparam(inference=False, birth_radius_multiplier=.5), 
    CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear')
    ])

In [75]:
boh(test_state, key=subkey).radius

Array([0.51, 0.51, 0.  , 0.  , 0.  ], dtype=float32)

# Mechanical

In [138]:
class MinimizePotential(eqx.Module):
    relaxation_steps:   int = eqx.field(default=15, static=True)
    dt:                 float = eqx.field(default=1e-3, static=True)


    def _sgd(self, state, pair_potential):

        init, apply = jax_md.minimize.gradient_descent(pair_potential, state.shift, self.dt) 
 
        def scan_fn(opt_state, i):
            return apply(opt_state), 0.

        #relax system
        opt_state = init(state.position)
        opt_state, _ = jax.lax.scan(scan_fn, opt_state, np.arange(self.relaxation_steps))

        return opt_state
    

    def _brownian(self, state, pair_potential):

        raise NotImplementedError('Brownian dynamics not implemented yet')

In [106]:
class MinimizeMorse(MinimizePotential):
    epsilon:   Union[float, jax.Array] = 3.
    alpha:     float = 2.8
    r_cutoff:  float = eqx.field(default=2., static=True)
    r_onset:   float = eqx.field(default=1.7, static=True)
    


    def _calculate_epsilon(self, state):

        if np.atleast_1d(self.epsilon).size == 1:
            alive = np.where(state.celltype > 0, 1, 0)
            epsilon_matrix = (np.outer(alive, alive)-np.eye(alive.shape[0]))*self.epsilon


        elif isinstance(self.epsilon, jax.interpreters.xla.DeviceArray):
            
            ### implement general logic for multiple cell types
            raise NotImplementedError('Multiple cell types not implemented yet')


        return epsilon_matrix



    def __call__(self, state, *, key=None):

        #calculate sigma matrix
        sigma_matrix = state.radius[:,None] + state.radius[None,:]

        #calculate epsilon matrix
        epsilon_matrix = self._calculate_epsilon(state)

        #generate morse pair potential
        morse_energy = jax_md.energy.morse_pair(state.displacement,
                                                alpha=self.alpha,
                                                epsilon=epsilon_matrix,
                                                sigma=sigma_matrix, 
                                                r_onset=self.r_onset, 
                                                r_cutoff=self.r_cutoff
                                                )
        
        #minimize
        new_positions = self._sgd(state, morse_energy)

        state = eqx.tree_at(lambda s: s.position, state, new_positions)

        return state

In [107]:
MinimizeMorse()

MinimizeMorse(
  relaxation_steps=15,
  dt=0.001,
  epsilon=3.0,
  alpha=2.8,
  r_cutoff=2.0,
  r_onset=1.7
)

In [108]:
boh1 = eqx.nn.Sequential(layers=[
    CellDivisionReparam(inference=False, birth_radius_multiplier=.5), 
    CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear'),
    MinimizeMorse(epsilon=.5)
    ])

In [25]:
def loss(a, state, key):
    boh1 = eqx.nn.Sequential(layers=[
        CellDivisionReparam(inference=True, birth_radius_multiplier=a), 
        CellGrowth(growth_rate=.03, max_radius=1.1, growth_type='linear'),
        MinimizeMorse(epsilon=3.)
        ])
    
    for i in range(4):
        key, subkey = jax.random.split(key)
        state = boh1(state, key=subkey)

    return (state.position[1]**2).sum(), state

In [26]:
g, s = jax.grad(loss, has_aux=True)(.8, test_state, subkey)

In [27]:
g

Array(19.080671, dtype=float32, weak_type=True)

In [28]:
s.position

Array([[ 1.4130108 ,  0.41121322],
       [ 0.8397293 , -0.9832696 ],
       [-1.2640176 ,  0.59299135],
       [-2.3082333 ,  1.2221705 ],
       [-0.02124166,  0.74504286]], dtype=float32)

In [33]:
s.radius

Array([0.60919994, 0.9199999 , 0.6151999 , 0.6151999 , 0.60919994],      dtype=float32)

# Diffusion

In [144]:
class SteadyStateDiffusion(eqx.Module):
    diffusion_coeff:    Union[float, jax.Array]
    degradation_rate:   Union[float, jax.Array]
    _vmap_diff_inaxes:  tuple = eqx.field(static=True)

    def __init__(self, diffusion_coeff=2., degradation_rate=1., **kwargs):

        self.diffusion_coeff = diffusion_coeff
        self.degradation_rate = degradation_rate

        _inaxes_diffcoef = 0 if np.atleast_1d(self.diffusion_coeff).size > 1 else None
        _inaxes_degrate = 0 if np.atleast_1d(self.degradation_rate).size > 1 else None
        self._vmap_diff_inaxes = (1, _inaxes_diffcoef, _inaxes_degrate, None)


    def __call__(self, state, *, key=None):

        #calculate all pairwise distances
        dist = jax_md.space.map_product(jax_md.space.metric(state.displacement))(state.position, state.position)

        #prevent division by zero
        dist *= np.where(np.outer(state.celltype, state.celltype)>0, 1, -1)
        dist -= np.eye(dist.shape[0])

        #adjacency matrix
        # zero out connections to inexistent cells
        A = 1/dist
        A = (np.where(A>0, A, 0))**2


        #calculate graph laplacian
        L = np.diag(np.sum(A, axis=0)) - A
    

        def _ss_chemfield(P, D, K, L):

            #update laplacian with degradation
            L = D*L + K*np.eye(L.shape[0])

            #solve for steady state
            c = np.linalg.solve(L, P)

            return c
        
        #calculate steady state chemical field
        _ss_chemfield = jax.vmap(_ss_chemfield, in_axes=self._vmap_diff_inaxes, out_axes=1)

        new_chem = _ss_chemfield(state.secretion_rate, self.diffusion_coeff, self.degradation_rate, L)

        #update chemical field
        state = eqx.tree_at(lambda s: s.chemical, state, new_chem)

        return state

In [96]:
s = SteadyStateDiffusion()(test_state2)

In [139]:
model = eqx.nn.Sequential(layers=[
    CellDivisionReparam(inference=True, birth_radius_multiplier=np.asarray(.5)), 
    CellGrowth(growth_rate=.03, max_radius=1.1, growth_type='linear'),
    MinimizeMorse(epsilon=np.asarray(3.)),
    SteadyStateDiffusion(degradation_rate=np.asarray(2.))
    ])

def loss(model, state, key):
    
    for i in range(3):
        key, subkey = jax.random.split(key)
        state = model(state, key=subkey)

    return (state.chemical[:,1]**2).sum(), state

In [140]:
g, s = eqx.filter_grad(loss, has_aux=True)(model, test_state2, subkey)

In [141]:
g.layers[2].epsilon

Array(0., dtype=float32, weak_type=True)

In [142]:
g.layers[0].birth_radius_multiplier

Array(0., dtype=float32, weak_type=True)

In [143]:
g.layers[-1].degradation_rate

Array(-0.9999995, dtype=float32, weak_type=True)

In [91]:
np.atleast_1d(1).size

1