In [42]:
import jax
import jax.numpy as np
import jax.tree_util as jtu

import numpy as onp

from typing import Any, Callable
from jaxtyping import Array

import equinox as eqx

import matplotlib.pyplot as plt

In [2]:
key = jax.random.PRNGKey(0)

# Cell State

In [None]:
import jax_md

In [89]:
class BaseCellState(eqx.Module):
    '''
    Module containing the basic features of a system state.

    '''

    # METHODS
    displacement:   jax_md.space.DisplacementFn = eqx.field(static=True)
    shift:          jax_md.space.ShiftFn = eqx.field(static=True)

    # STATE
    position:   jax.Array
    celltype:   jax.Array
    radius:     jax.Array
    divrate:    jax.Array
    #key:        jax.Array


    @classmethod
    def empty(cls, n_dim=2):

        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.

        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'

        disp, shift = jax_md.space.free()
        

        args = {
            'displacement'  :   disp,
            'shift'         :   shift,
            'position'  :   np.empty(shape=(0, n_dim), dtype=np.float32),
            'celltype'  :   np.empty(shape=(0,), dtype=np.int8),
            'radius'    :   np.empty(shape=(0,), dtype=np.float32),
            'divrate'   :   np.empty(shape=(0,), dtype=np.float32),
            #'key'       :   None,
            }
        
        return cls(**args)


In [90]:
class CellState(BaseCellState):
    chemical:       jax.Array
    chemgrad:       jax.Array
    hidden_state:   jax.Array


    @classmethod
    def empty(cls, n_dim=2, n_chem=1, hidden_size=10):

        a = BaseCellState.empty(n_dim).__dict__

        new_args = {
            'chemical' :   np.empty(shape=(0, n_chem), dtype=np.float32),
            'chemgrad' :   np.empty(shape=(0, int(n_dim*n_chem)), dtype=np.float32),
            'hidden_state' :   np.empty(shape=(0, hidden_size), dtype=np.float32),
            }
        
        a.update(new_args)

        return cls(**a)

# Cell division

In [86]:
class CellDivision(eqx.Module):
    birth_radius_multiplier:   float
    inference:                  bool = eqx.field(static=True)

    def __init__(self, 
                 state, 
                 birth_radius_multiplier=None,
                 inference=True
                 ):

        if not hasattr(state, 'divrate'):
            raise AttributeError('CellState must have "divrate" attribute')
        if not hasattr(state, 'radius'):
            raise AttributeError('CellState must have "radius" attribute')
        if not hasattr(state, 'celltype'):
            raise AttributeError('CellState must have "celltype" attribute')
        if not hasattr(state, 'position'):
            raise AttributeError('CellState must have "position" attribute')
        if not hasattr(state, 'key'):
            raise AttributeError('CellState must have valid "key" attribute')
        
        self.birth_radius_multiplier = birth_radius_multiplier or float(np.sqrt(2))
        self.inference = inference
        

    def __call__(self, state, *, key=None, inference=None):

        if inference is None:
            inference = self.inference


        #split key
        new_key, subkey_div, subkey_place = jax.random.split(state.key,3)
        
        p = state.divrate/state.divrate.sum()

        safe_p = np.where(p > 0, p, 1)
        logp = np.where(p > 0, np.log(safe_p), 0)

        logit = logp + jax.random.gumbel(subkey_div, shape=logp.shape)
        idx_dividing_cell = np.argmax(state.divrate)


        if inference:
            new_cell_contribs = np.one_hot(idx_dividing_cell, state.celltype.shape[0])
        else:
            new_cell_contribs = jax.nn.softmax(logit)

        
        idx_new_cell = np.count_nonzero(state.celltype)

        division_matrix = np.eye(state.celltype.shape[0]).at[idx_new_cell].set(new_cell_contribs)


        #to figure out cell radius
        





        
        ### POSITION OF NEW CELLS
        #note that cell positions will be symmetric so max is pi
        angle = jax.random.uniform(subkey_place, minval=0., maxval=np.pi, dtype=np.float32)

        first_cell = np.array([np.cos(angle),np.sin(angle)])
        second_cell = np.array([-np.cos(angle),-np.sin(angle)])
        
        pos1 = state.position[idx_dividing_cell] + self.birth_radius_multiplier*first_cell
        pos2 = state.position[idx_dividing_cell] + self.birth_radius_multiplier*second_cell
        
        
        new_fields = {}
        for field in jdc.fields(state):

            value = getattr(state, field.name)

            if 'position' == field.name:
                new_fields[field.name] = value.at[idx_dividing_cell].set(pos1).at[idx_new_cell].set(pos2)
            elif 'radius' == field.name:
                new_fields[field.name] = value.at[idx_dividing_cell].set(cellRadBirth).at[idx_new_cell].set(cellRadBirth)
            elif 'key' == field.name:
                new_fields[field.name] = new_key
            else:
                new_fields[field.name] = value.at[idx_new_cell].set(value[idx_dividing_cell])

        new_state = type(state)(**new_fields)
        
        return new_state, log_p
    
    
    def _no_division():
        return state, 0.
    
    return jax.lax.cond(state.divrate.sum()>0, _divide, _no_division)

        


        
        


SyntaxError: invalid syntax (1556326297.py, line 79)

In [27]:
CellDivision(CellState())

CellDivision(birth_radius_multiplier=1.4142135381698608)

In [28]:
jax.tree_util.tree_map(lambda x: 100, CellState(), is_leaf=lambda x: isinstance(x, Array))

CellState(
  displacement=100,
  shift=100,
  position=100,
  celltype=100,
  radius=100,
  divrate=100,
  key=None,
  chemical=100,
  chemgrad=100,
  hidden_state=100
)

In [61]:
cs = CellState()

In [54]:
isinstance(cs.displacement, Array)

False

In [55]:
len(list(map(eqx.is_array, jtu.tree_flatten(cs)[0])))

7

In [56]:
len(jax.tree_flatten(cs)[0])

  len(jax.tree_flatten(cs)[0])


7

In [65]:
jtu.tree_map(lambda x: 2*x, cs)

CellState(
  displacement=<function displacement_fn>,
  shift=<function shift_fn>,
  position=f32[0,2],
  celltype=i8[0],
  radius=f32[0],
  divrate=f32[0],
  key=None,
  chemical=f32[0,1],
  chemgrad=f32[0,2],
  hidden_state=f32[0,10]
)