In [2]:
import jax
import jax.numpy as np
import jax.tree_util as jtu

import numpy as onp

from typing import Any, Callable

import equinox as eqx

import matplotlib.pyplot as plt

In [3]:
key = jax.random.PRNGKey(0)

# Cell State

In [4]:
import jax_md

In [79]:
class BaseCellState(eqx.Module):
    '''
    Module containing the basic features of a system state.

    '''

    # METHODS
    displacement:   jax_md.space.DisplacementFn = eqx.field(static=True)
    shift:          jax_md.space.ShiftFn = eqx.field(static=True)

    # STATE
    position:   jax.Array
    celltype:   jax.Array
    radius:     jax.Array
    divrate:    jax.Array
    #key:        jax.Array


    @classmethod
    def empty(cls, n_dim=2):

        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.

        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'

        disp, shift = jax_md.space.free()
        

        args = {
            'displacement'  :   disp,
            'shift'         :   shift,
            'position'  :   np.empty(shape=(0, n_dim), dtype=np.float32),
            'celltype'  :   np.empty(shape=(0,), dtype=np.float32),
            'radius'    :   np.empty(shape=(0,), dtype=np.float32),
            'divrate'   :   np.empty(shape=(0,), dtype=np.float32),
            #'key'       :   None,
            }
        
        return cls(**args)


In [80]:
class CellState(BaseCellState):
    chemical:       jax.Array
    chemgrad:       jax.Array
    hidden_state:   jax.Array


    @classmethod
    def empty(cls, n_dim=2, n_chem=1, hidden_size=10):

        a = BaseCellState.empty(n_dim).__dict__

        new_args = {
            'chemical' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'chemgrad' :   np.empty(shape=(0, int(n_dim*n_chem)), dtype=np.float32),
            'hidden_state' :   np.empty(shape=(0, hidden_size), dtype=np.float32),
            }
        
        a.update(new_args)

        return cls(**a)

# Cell division

In [233]:
class CellDivision(eqx.Module):
    birth_radius_multiplier:    float
    inference:                  bool = eqx.field(static=True)

    def __init__(self, 
                 state, 
                 birth_radius_multiplier=float(1/np.sqrt(2)),
                 inference=True
                 ):

        if not hasattr(state, 'divrate'):
            raise AttributeError('CellState must have "divrate" attribute')
        if not hasattr(state, 'radius'):
            raise AttributeError('CellState must have "radius" attribute')
        if not hasattr(state, 'celltype'):
            raise AttributeError('CellState must have "celltype" attribute')
        if not hasattr(state, 'position'):
            raise AttributeError('CellState must have "position" attribute')
        # if not hasattr(state, 'key'):
        #     raise AttributeError('CellState must have valid "key" attribute')
        
        self.birth_radius_multiplier = birth_radius_multiplier
        self.inference = inference
        

    def __call__(self, state, *, key=None, inference=None, softmax_T=1.):

        if inference is None:
            inference = self.inference


        #split key
        new_key, subkey_div, subkey_place = jax.random.split(key,3)
        
        p = state.divrate/state.divrate.sum()

        safe_p = np.where(p > 0, p, 1)
        logp = np.where(p > 0, np.log(safe_p), -np.inf)

        logit = logp + jax.random.gumbel(subkey_div, shape=logp.shape)
        idx_dividing_cell = np.argmax(logit)

        print(idx_dividing_cell)


        if inference:
            new_cell_contribs = np.one_hot(idx_dividing_cell, state.celltype.shape[0])
        else:
            new_cell_contribs = jax.nn.softmax(logit)

        
        idx_new_cell = np.count_nonzero(state.celltype)

        division_matrix = np.eye(state.celltype.shape[0]).at[idx_new_cell].set(new_cell_contribs)

        new_state = jax.tree_map(lambda x: np.dot(division_matrix, x), state)


        #resize cell radii
        get_radius = lambda s: s.radius
        resize_rad = lambda r: r.at[idx_new_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier).at[idx_dividing_cell].set(r[idx_dividing_cell]*self.birth_radius_multiplier)
        new_state = eqx.tree_at(get_radius, new_state, replace_fn=resize_rad)
        
        ### POSITION OF NEW CELLS
        angle = jax.random.uniform(subkey_place, minval=0., maxval=2*np.pi, dtype=np.float32)

        cell_displacement = self.birth_radius_multiplier*np.array([np.cos(angle),np.sin(angle)])
        get_position = lambda s: s.position
        new_position = lambda p: p.at[idx_new_cell].set(p[idx_dividing_cell]-cell_displacement).at[idx_dividing_cell].set(p[idx_dividing_cell]+cell_displacement)
        new_state = eqx.tree_at(get_position, new_state, replace_fn=new_position)
        
        if inference:
            return new_state, logp[idx_dividing_cell]
        else:
            return new_state

In [234]:
div = CellDivision(BaseCellState.empty(), inference=False)

In [235]:
disp, shift = jax_md.space.free()

N = 5

state = BaseCellState(
    displacement=disp,
    shift=shift,
    position=np.zeros(shape=(N,2)),
    celltype=np.zeros(shape=(N,), dtype=np.int8).at[0].set(1),
    radius=np.zeros(shape=(N,), dtype=np.float32).at[0].set(1.),
    divrate=np.zeros(shape=(N,), dtype=np.float32).at[0].set(1.),
)
    

    

In [261]:
key, subkey = jax.random.split(key)
ns = div(state, key=subkey)

0


In [262]:
nns = div(ns, key=subkey)

1


In [263]:
nns.celltype

Array([1., 1., 1., 0., 0.], dtype=float32)

In [281]:
def loss(a, state, key):
    div = CellDivision(BaseCellState.empty(), inference=False, birth_radius_multiplier=a)

    for i in range(1):
        key, subkey = jax.random.split(key)
        state = div(state, key=subkey)

    return state.radius[1], state

In [282]:
g, s = jax.grad(loss, has_aux=True)(1., state, key)

0


In [283]:
g

Array(1., dtype=float32, weak_type=True)

In [277]:
s.radius

Array([1., 1., 0., 0., 0.], dtype=float32)