In [40]:
import jax
import jax.numpy as np
import jax.tree_util as jtu

import numpy as onp

from typing import Any, Callable, Union

import equinox as eqx

import matplotlib.pyplot as plt

In [2]:
key = jax.random.PRNGKey(0)

# Cell State

In [3]:
import jax_md

In [233]:
class BaseCellState(eqx.Module):
    '''
    Module containing the basic features of a system state.

    '''

    # METHODS
    displacement:   jax_md.space.DisplacementFn = eqx.field(static=True)
    shift:          jax_md.space.ShiftFn = eqx.field(static=True)

    # STATE
    position:   jax.Array
    celltype:   jax.Array
    radius:     jax.Array
    divrate:    jax.Array
    #key:        jax.Array


    @classmethod
    def empty(cls, n_dim=2):

        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.

        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'

        disp, shift = jax_md.space.free()
        

        args = {
            'displacement'  :   disp,
            'shift'         :   shift,
            'position'  :   np.empty(shape=(0, n_dim), dtype=np.float32),
            'celltype'  :   np.empty(shape=(0,), dtype=np.float32),
            'radius'    :   np.empty(shape=(0,), dtype=np.float32),
            'divrate'   :   np.empty(shape=(0,), dtype=np.float32),
            #'key'       :   None,
            }
        
        return cls(**args)


In [234]:
class CellState(BaseCellState):
    chemical:       jax.Array
    chemgrad:       jax.Array
    hidden_state:   jax.Array


    @classmethod
    def empty(cls, n_dim=2, n_chem=1, hidden_size=10):

        a = BaseCellState.empty(n_dim).__dict__

        new_args = {
            'chemical' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'chemgrad' :   np.empty(shape=(0, int(n_dim*n_chem)), dtype=np.float32),
            'hidden_state' :   np.empty(shape=(0, hidden_size), dtype=np.float32),
            }
        
        a.update(new_args)

        return cls(**a)

In [235]:
disp, shift = jax_md.space.free()

N = 5

test_state = BaseCellState(
    displacement=disp,
    shift=shift,
    position=np.zeros(shape=(N,2)),
    celltype=np.zeros(shape=(N,)).at[0].set(1),
    radius=np.zeros(shape=(N,)).at[0].set(1.),
    divrate=np.zeros(shape=(N,)).at[0].set(1.),
)    

# Cell division

**NOTE**: The whole Gumbel-Softmax is useless in this context. By formulating copy operations as multiplications autodiff is already able to deal with this situation. The gradient of the operation is just the copy matrix!

In [258]:
class CellDivisionReparam(eqx.Module):
    birth_radius_multiplier:    float
    #inference:                  bool = eqx.field(static=True)

    def __init__(self, #state,
                 *,
                 birth_radius_multiplier=float(1/np.sqrt(2)),
                 #inference=True,
                 **kwargs
                 ):

        # if not hasattr(state, 'divrate'):
        #     raise AttributeError('CellState must have "divrate" attribute')
        # if not hasattr(state, 'radius'):
        #     raise AttributeError('CellState must have "radius" attribute')
        # if not hasattr(state, 'celltype'):
        #     raise AttributeError('CellState must have "celltype" attribute')
        # if not hasattr(state, 'position'):
        #     raise AttributeError('CellState must have "position" attribute')
        # if not hasattr(state, 'key'):
        #     raise AttributeError('CellState must have valid "key" attribute')
        
        self.birth_radius_multiplier = birth_radius_multiplier
        #self.inference = inference
        

    def __call__(self, state, *, key=None):#, inference=None, softmax_T=1.):

        # if inference is None:
        #     inference = self.inference


        #split key
        subkey_div, subkey_place = jax.random.split(key,2)
        
        p = state.divrate/state.divrate.sum()

        # safe_p = np.where(p > 0, p, 1)
        # logp = np.where(p > 0, np.log(safe_p), -np.inf)

        # logit = (logp + jax.random.gumbel(subkey_div, shape=logp.shape))/softmax_T
        # idx_dividing_cell = np.argmax(logit)

        idx_dividing_cell = jax.random.choice(subkey_div, a=len(p), p=p)
        new_cell_contribs = np.zeros_like(state.celltype).at[idx_dividing_cell].set(1.)

        # if inference:
        #     new_cell_contribs = np.zeros_like(state.celltype).at[idx_dividing_cell].set(1.)
        #     print(new_cell_contribs)
        # else:
        #     new_cell_contribs = jax.nn.softmax(logit)
        #     print(new_cell_contribs)

        
        idx_new_cell = np.count_nonzero(state.celltype)

        division_matrix = np.eye(state.celltype.shape[0]).at[idx_new_cell].set(new_cell_contribs)

        new_state = jax.tree_map(lambda x: np.dot(division_matrix, x), state)


        #resize cell radii
        get_radius = lambda s: s.radius
        resize_rad = lambda r: r.at[idx_new_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier).at[idx_dividing_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier)
        new_state = eqx.tree_at(get_radius, new_state, replace_fn=resize_rad)
        
        ### POSITION OF NEW CELLS
        angle = jax.random.uniform(subkey_place, minval=0., maxval=2*np.pi)
        cell_displacement = self.birth_radius_multiplier*np.array([np.cos(angle),np.sin(angle)])

        get_position = lambda s: s.position
        new_position = lambda p: p.at[idx_new_cell].set(p[idx_dividing_cell]-cell_displacement).at[idx_dividing_cell].set(p[idx_dividing_cell]+cell_displacement)
        new_state = eqx.tree_at(get_position, new_state, replace_fn=new_position)
        

        return new_state

In [106]:
def loss(a, state, key):
    div = CellDivisionReparam(inference=False, birth_radius_multiplier=a)

    for i in range(1):
        key, subkey = jax.random.split(key)
        state = div(state, key=subkey, softmax_T=1.)

    return state.radius.sum(), state

In [107]:
key, subkey = jax.random.split(key)

g, s = jax.grad(loss, has_aux=True)(1., test_state, subkey)


In [108]:
g

Array(2., dtype=float32, weak_type=True)

# Cell growth

In [64]:
class CellGrowth(eqx.Module):
    max_radius:     float
    growth_rate:    float
    growth_type:    str = eqx.field(static=True)
    
    def __init__(self, #state, 
                 *, 
                 growth_rate=1., 
                 max_radius=.5, 
                 growth_type='linear',
                 **kwargs
                 ):

        # if not hasattr(state, 'radius'):
        #     raise AttributeError('CellState must have "radius" attribute')
        
        if growth_type not in ['linear', 'exponential']:
            raise ValueError('growth_type must be either "linear" or "exponential"')
        
        self.growth_rate = growth_rate
        self.max_radius = max_radius
        self.growth_type = growth_type


    def __call__(self, state, *, key=None):

        if self.growth_type == 'linear':
            new_radius = state.radius + self.growth_rate
        elif self.growth_type == 'exponential':
            new_radius = state.radius*np.exp(self.growth_rate)

        new_radius = np.where(new_radius > self.max_radius, self.max_radius, new_radius)*np.where(state.celltype>0, 1, 0)

        get_radius = lambda s: s.radius
        state = eqx.tree_at(get_radius, state, new_radius)

        return state

In [65]:
cg = CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear')

In [66]:
def loss(a, state, key):
    div = CellDivisionReparam(inference=False, birth_radius_multiplier=a)
    cg = CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear')

    for i in range(1):
        key, subkey = jax.random.split(key)
        state = div(state, key=subkey, softmax_T=1.)
        state = cg(state)

    return state.radius.sum(), state

In [67]:
g, s = jax.grad(loss, has_aux=True)(1., test_state, subkey)

In [68]:
s.radius

Array([1.01, 1.01, 0.  , 0.  , 0.  ], dtype=float32)

In [74]:
boh = eqx.nn.Sequential(layers=[
    CellDivisionReparam(inference=False, birth_radius_multiplier=.5), 
    CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear')
    ])

In [75]:
boh(test_state, key=subkey).radius

Array([0.51, 0.51, 0.  , 0.  , 0.  ], dtype=float32)

# Mechanical

In [58]:
class MinimizePotential(eqx.Module):
    relaxation_steps:   int = eqx.field(default=15, static=True)
    dt:                 float = eqx.field(default=1e-3, static=True)


    def _sgd(self, state, pair_potential):

        init, apply = jax_md.minimize.gradient_descent(pair_potential, state.shift, self.dt) 
 
        def scan_fn(opt_state, i):
            return apply(opt_state), 0.

        #relax system
        opt_state = init(state.position)
        opt_state, _ = jax.lax.scan(scan_fn, opt_state, np.arange(self.relaxation_steps))

        return opt_state

In [None]:
class MinimizeMorse(MinimizePotential):
    epsilon:   Union[float, jax.Array] = 3.
    alpha:     float = 2.8
    r_cutoff:  float = eqx.field(default=2., static=True)
    r_onset:   float = eqx.field(default=1.7, static=True)
    


    def _calculate_epsilon(self, state):

        if isinstance(self.epsilon, float):
            alive = np.where(state.celltype > 0, 1, 0)
            epsilon_matrix = (np.outer(alive, alive)-np.eye(alive.shape[0]))*self.epsilon


        elif isinstance(self.epsilon, jax.interpreters.xla.DeviceArray):
            
            ### implement general logic for multiple cell types
            pass


        return epsilon_matrix



    def __call__(self, state, *, key=None):

        #calculate sigma matrix
        sigma_matrix = state.radius[:,None] + state.radius[None,:]

        #calculate epsilon matrix
        epsilon_matrix = self._calculate_epsilon(state)

        #generate morse pair potential
        morse_energy = jax_md.energy.morse_pair(state.displacement,
                                                alpha=self.alpha,
                                                epsilon=epsilon_matrix,
                                                sigma=sigma_matrix, 
                                                r_onset=self.r_onset, 
                                                r_cutoff=self.r_cutoff
                                                )
        
        #minimize
        new_positions = self._sgd(state, morse_energy)

        state = eqx.tree_at(lambda s: s.position, state, new_positions)

        return state

In [59]:
MinimizeMorse()

MinimizeMorse(
  relaxation_steps=15,
  dt=0.001,
  epsilon=3.0,
  alpha=2.8,
  r_cutoff=2.0,
  r_onset=1.7
)

In [80]:
boh1 = eqx.nn.Sequential(layers=[
    CellDivisionReparam(inference=False, birth_radius_multiplier=.5), 
    CellGrowth(growth_rate=.01, max_radius=1.1, growth_type='linear'),
    MinimizeMorse(epsilon=.5)
    ])

In [78]:
ns = boh(boh(test_state, key=subkey), key=subkey)

In [84]:
MinimizeMorse(epsilon=.5)(ns, key=subkey).position

Array([[-0.46525076,  0.3462038 ],
       [ 0.08624908, -0.06417993],
       [ 0.78013045, -0.580513  ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ]], dtype=float32)

In [270]:
def loss(a, state, key):
    boh1 = eqx.nn.Sequential(layers=[
        CellDivisionReparam(inference=True, birth_radius_multiplier=a, softmax_T=10.), 
        CellGrowth(growth_rate=.03, max_radius=1.1, growth_type='linear'),
        MinimizeMorse(epsilon=3.)
        ])
    
    for i in range(4):
        key, subkey = jax.random.split(key)
        state = boh1(state, key=subkey)

    return (state.position[1]**2).sum(), state

In [271]:
g, s = jax.grad(loss, has_aux=True)(.8, test_state, subkey)

In [272]:
ss = loss(.8, ns, subkey)

In [273]:
g

Array(8.629837, dtype=float32, weak_type=True)

In [232]:
s.position.sum()

Array(1.2136753, dtype=float32)

In [207]:
ns.radius

Array([0.52 , 0.265, 0.265, 0.   , 0.   ], dtype=float32)

In [206]:
ss[1].radius

Array([0.41599998, 0.265     , 0.265     , 0.41599998, 0.        ],      dtype=float32)