# <center>Growing clusters</center>

## NOTE:

- You can tell jax to trace non-array values but apparently cannot force static on DeviceArray values

- Grad works fine for whatever. Jacfwd is needed in case you have very long iterations, e.g. for differentiating through epsilons when you have a lot of relaxation steps (up to around 40 grad is fine). In this last case you decrease sensibly memory requirements but you pay with increased computation times.


## To do:
- Use SafeKey implementation from AlphaFold (https://github.com/deepmind/alphafold/blob/main/alphafold/model/prng.py)


## Imports

In [1]:
### JAX

import jax
import jax.numpy as np

from jax import random, value_and_grad, grad, jacfwd, jacrev
from jax import jit, lax, vmap

from jax.random import split
from jax.example_libraries import optimizers
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_structure
from jax.flatten_util import ravel_pytree

from jax.config import config
config.update('jax_enable_x64', True)
config.update('jax_debug_nans', True)
#config.update('jax_platform_name', 'cpu')

In [2]:
### JAX MD

import jax_md.dataclasses as jax_dataclasses
from jax_md import space, energy, minimize
#from jax_md.colab_tools import renderer

from jax_md import  util

In [3]:
# LIMIT GPU MEMORY TAKEN UP BY THE NOTEBOOK

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.9'
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [4]:
# OTHER STUFF

import equinox as eqx

from tqdm import tqdm

from functools import partial

In [5]:
### MATPLOTLIB

import matplotlib.pyplot as plt

#set global properties of plots
plt.rcParams.update({'font.size': 15})

In [6]:
### TYPING

from typing import Callable, Tuple
Array = util.Array
f32 = util.f32
f64 = util.f64

## Utils

### Helper functions

#### Logistic function

In [7]:
# Logistic function 
def logistic(x,gamma,k):
    return 1./(1.+np.exp(-gamma*(x-k)))

### State Dataclass

In [8]:
@jax_dataclasses.dataclass
class CellState:
    '''
    Dataclass containing the system state.

    Attributes:
    
    position: The current position of the particles. An ndarray of floats with
              shape [n, spatial_dimension].
    celltype: The cell type of each particle. An ndarray of integers in [0,1] with 
              shape [n, 1]
    radius:    Radius of each particle. Cells are born at cellRadBirth and grow up to cellRad
    chemical: Chemical concentration at location of each particle. An ndarray of 
              np.float32 with shape [n, numSigs] integers in [0,1] with shape [n, 1]
    key:      The current state of the random number generator.
    '''
    
    position: Array
    celltype: Array
    radius: Array 
    chemical: Array
    divrate: Array
    key: Array
    
    
@jax_dataclasses.dataclass
class SpaceFunc:
    '''
    Dataclass containing functions that handle space.
    '''
    
    displacement: space.DisplacementFn
    shift: space.ShiftFn

### Visualization

In [9]:
def draw_circles_ctype(state, ax=None, **kwargs):
    
    if None == ax:
        ax = plt.axes()
    
    #only usable for two cell types
    color = plt.cm.coolwarm(np.float32(state.celltype-1))

    for cell,radius,c in zip(state.position,state.radius,color):
        circle = plt.Circle(cell, radius=radius, color=c, alpha=.5, **kwargs)
        ax.add_patch(circle)
    
    
    ## calculate ax limits
    xmin = np.min(state.position[:,0])
    xmax = np.max(state.position[:,0])
    
    ymin = np.min(state.position[:,1])
    ymax = np.max(state.position[:,1])
    
    max_coord = max([xmax,ymax])+3
    min_coord = min([xmin,ymin])-3
    
    plt.xlim(min_coord,max_coord)
    plt.ylim(min_coord,max_coord)
    
    plt.xticks([])
    plt.yticks([])
    
    
    background_color = [56 / 256] * 3
    #ax.set_facecolor(background_color)    
    
    #ax.get_xaxis().set_visible(False)
    #ax.get_yaxis().set_visible(False)
    
    plt.gcf().patch.set_facecolor(background_color)
    plt.gcf().set_size_inches(6, 6)
    
    plt.tight_layout()
    
    return ax


def draw_circles_chem(state, chem=0, ax=None, edges=False, **kwargs):
    
    if None == ax:
        ax = plt.axes()
    
    chemical = np.float32(state.chemical[:,chem])    
    chemical = (chemical-chemical.min())/(chemical.max()-chemical.min())
        
    #only usable for two cell types
    if 0 == chem:
        color = plt.cm.YlGn(chemical)
    elif 1 == chem:
        color = plt.cm.BuPu(chemical)
    else:
        color = plt.cm.coolwarm(chemical)
        
    if edges:
        #only usable for two cell types
        ct_color = plt.cm.coolwarm(np.float32(state.celltype-1))

        for cell,radius,c,ctc in zip(state.position,state.radius,color,ct_color):
            circle = plt.Circle(cell, radius=radius, fc=c, ec=ctc, lw=2, alpha=.5, **kwargs)
            ax.add_patch(circle)
            
    else:
        for cell,radius,c in zip(state.position,state.radius,color):
            circle = plt.Circle(cell, radius=radius, fc=c, alpha=.5, **kwargs)
            ax.add_patch(circle)
            
    
    ## calculate ax limits
    xmin = np.min(state.position[:,0])
    xmax = np.max(state.position[:,0])
    
    ymin = np.min(state.position[:,1])
    ymax = np.max(state.position[:,1])
    
    max_coord = max([xmax,ymax])+3
    min_coord = min([xmin,ymin])-3
    
    plt.xlim(min_coord,max_coord)
    plt.ylim(min_coord,max_coord)
    
    plt.xticks([])
    plt.yticks([])
    
    
    background_color = [56 / 256] * 3
    #ax.set_facecolor(background_color)    
    
    #ax.get_xaxis().set_visible(False)
    #ax.get_yaxis().set_visible(False)
        
    plt.gcf().patch.set_facecolor(background_color)
    plt.gcf().set_size_inches(6, 6)
    
    plt.tight_layout()
    
    return ax


    
def draw_circles_divrate(state, ax=None, edges=False, **kwargs):
    
    if None == ax:
        ax = plt.axes()
    
    divrate = np.float32(state.divrate)    
    divrate = (divrate-divrate.min())/(divrate.max()-divrate.min())
        
    #only usable for two cell types
    color = plt.cm.coolwarm(divrate)
    
    if edges:
        #only usable for two cell types
        ct_color = plt.cm.coolwarm(np.float32(state.celltype-1))

        for cell,radius,c,ctc in zip(state.position,state.radius,color,ct_color):
            circle = plt.Circle(cell, radius=radius, fc=c, ec=ctc, lw=2, alpha=.5, **kwargs)
            ax.add_patch(circle)
            
    else:
        for cell,radius,c in zip(state.position,state.radius,color):
            circle = plt.Circle(cell, radius=radius, fc=c, alpha=.5, **kwargs)
            ax.add_patch(circle)
    
    
    
    
    ## calculate ax limits
    xmin = np.min(state.position[:,0])
    xmax = np.max(state.position[:,0])
    
    ymin = np.min(state.position[:,1])
    ymax = np.max(state.position[:,1])
    
    max_coord = max([xmax,ymax])+3
    min_coord = min([xmin,ymin])-3
    
    plt.xlim(min_coord,max_coord)
    plt.ylim(min_coord,max_coord)
    
    plt.xticks([])
    plt.yticks([])
    
    background_color = [56 / 256] * 3
    #ax.set_facecolor(background_color)    
    
    #ax.get_xaxis().set_visible(False)
    #ax.get_yaxis().set_visible(False)
        
    plt.gcf().patch.set_facecolor(background_color)
    plt.gcf().set_size_inches(6, 6)
    
    plt.tight_layout()
    
    return ax

# Model

## Diffusion

Wrong, but for the moment we use this.

1D diffusion with degradation, production and degradation continuous in time, use steady state formula.

$$ \frac{d^2}{dr^2} c(r) - K c (r) = 0$$

with solution:

$$c(r)= \frac{Source}{2 \sqrt{DK}} exp({-\sqrt{\frac{K}{D}} r)}$$

which gives (at x=0)

$$ c_{max} = \frac{Source}{2 \sqrt{DK}} $$

In [10]:
def diffuse_onechem(r,secRate,degRate,diffCoeff):
    '''
    NOTE: it is assumed that r is a pairwise distance
    '''
    diff = secRate/(2*np.sqrt(degRate*diffCoeff))*np.exp(-r*np.sqrt(degRate/diffCoeff))
    
    return diff


def diffuse_allchem(secretions, state, params, fspace):

    diff = energy.multiplicative_isotropic_cutoff(diffuse_onechem,
                                                  r_onset = params['r_onsetDiff'],
                                                  r_cutoff = params['r_cutoffDiff'])
    
    metric = space.metric(fspace.displacement)
    d = space.map_product(metric)
    
    #calculate all pairwise distances
    dist = d(state.position, state.position)
    
    # loop over all chemicals (vmap in future)
    new_chem = []
    for i in np.arange(params['n_chem']):
        
        c = diffuse_onechem(dist,
                            secretions[:,i],
                            params['degRate'][i],
                            params['diffCoeff'][i])
        
        c = c.sum(axis=1)
        
        #zero out concentration on empty sites
        c = np.where(state.celltype>0, c, 0.)
        
        c = np.reshape(c, (-1,1)) #make into column vector
        new_chem.append(c)
    
    new_chem = np.hstack(new_chem)
    
    return new_chem

## Secretion

The secretion of a chemical in response to the presence of all other chemicals is:

$$Secretion(x_i)= \mu_i^{MAX} \prod_{j=0}^{n Chem-1} \frac{1}{1+e^{-\gamma^s_{ij}(x_j-k^s_{ij})}}$$

NOTE: This response has the same parameter for every cell (no distinction between cell types). This implies that the two parameter matrices are to be intended as (chemical x chemical) and not as (cell_type x chemical)!

Therefore, element $k^s_{ij}$ for example is the k of the logistic that defines the contribution of chemical j to the secretion of chemical i.

In [11]:
# Function for the secretion rate of one chemical 
#as a function of concentration of all other chemicals
def _sec_onechem(chem, mumax, gammavec, kvec):  
    """
    Helper function.
    Calculates secretion rate of one chemicals by each cell 
    from the concentrations of the other chemicals.
    
    Arg:
    c : concetration of chemicals , an nCells x n_chem matrix
    mumax, gammavec, kvec : parameters for the logistic functions
    
    Returns:
    sec_onechem : 
        nCells x 1 array of secretion rates of one chemical by each cell
    """

    vmap_logistic = vmap(logistic, (1,0,0),(1))

    sec_onechem = mumax*np.prod(vmap_logistic(chem,gammavec,kvec), 
                                         axis=1, 
                                         dtype=np.float32)

    return sec_onechem

In [12]:
#(sec x conc) x (conc x cell)
# Function that returns secretion rates given current concentration
def sec_chemical(state, params):
    """
    Calculates secretion rate of chemicals by each cell.
    
    Args:
    state: current state
    params: dictionary with parameters
    
    Returns:
    secRates: secretion rates
    """


    sec_max = params['sec_max']
    sec_gamma = params['sec_gamma']
    sec_k = params['sec_k']

    
    #generalize secretion to n_chem cell types
    #each cell type secretes only one chemical
    
    sec_all = []
    
    for c in np.arange(params['n_chem'], dtype=np.int16):
                
        sec_onec = _sec_onechem(state.chemical, sec_max[c], sec_gamma[c,:], sec_k[c,:])
        
        #set sec to zero everywhere but where the secreting ctype is
        #ctype = 0 is empty cells
        sec_onec = np.where(state.celltype == c+1, sec_onec, 0.)
        
        #make into column vector
        sec_onec = np.reshape(sec_onec, (-1,1))
        
        sec_all.append(sec_onec)
        
    #(cells x chemical) matrix with secretion
    sec_all = np.concatenate(sec_all, axis=1)
    

    return sec_all

## Sec+Diff until steady state

Since secretion depends on the chemical concentration, after one sec+diff cycle we need to recalculate the new secretion values based on the new concentrations obtained after diffusing the old ones.

We keep alternating until we reach steady state.

In [13]:
#new version of Findcss

#non jittable due to the bool mask based on celltype
#substitute with simulation step index to sidestep masking (not sure works either)

def S_ss_chemfield(state, params, fspace=None, n_iter=5):
    '''
    Heuristically, steady state is reached after 5 iterations
    '''
    
    def _sec_diff_step(buff_state, i):
        
        #calculate new secretions
        sec = sec_chemical(buff_state, params)
        
        #calculate new chemical concentrations
        chemfield = diffuse_allchem(sec, buff_state, params, fspace)
        
        return jax_dataclasses.replace(buff_state, chemical=chemfield), 0.#, chemfield
    
    
    #buffer state for looping
    buff_state = CellState(*jax_dataclasses.unpack(state))
    
    iterations = np.arange(n_iter)
    
    buff_state, _ = lax.scan(_sec_diff_step, buff_state, iterations)
    #uncomment line below and comment line above for history
    #buff_state, chemfield = lax.scan(_sec_diff_step, buff_state, iterations)
    
    return buff_state #, chemfield

## Division rates

In [14]:
#define signature
Division_fn = Callable[[CellState,dict], np.array]

In [15]:
#DANGER: change of conventions for chemicals!!!!
#Now 
#species 0 produces chemical 0 and divides according to chemical 1
#species 1 produces chemical 1 and divides according to chemical 0


def div_chemical(state: CellState,
                params: dict,
                ) -> np.array:

    div_gamma = params['div_gamma']
    div_k = params['div_k']

    ### NOTE: possibility to extend this function to account for more complex interactions
    
    #calculate "rates"
    div1 = logistic(state.chemical[:,1],div_gamma[0],div_k[0])
    div2 = logistic(state.chemical[:,0],div_gamma[1],div_k[1])
    
    #create array with new divrates
    divrate = np.where(state.celltype==1,div1,div2)
    divrate = np.where(state.celltype==0,0,divrate)
    
    #cells cannot divide if they are too small
    #constants are arbitrary, change if you change cell radius
    divrate = divrate*logistic(state.radius+.06, 50, params['cellRad'])
    
    return divrate


def S_set_divrate(state, params, fspace=None):
    
    divrate = div_chemical(state,params)
    
    new_state = jax_dataclasses.replace(state, divrate=divrate)
    
    return new_state

## Cell division

In [16]:
def S_cell_division(state, params, fspace=None):
    
    cellRadBirth = params['cellRadBirth'] #easier to reuse
    
    #split key
    new_key, subkey_div, subkey_place = random.split(state.key,3)
    
    p = state.divrate/state.divrate.sum() #*logistic(radii,100,cellRad)
       
    #select cells that divides
    idx_dividing_cell = random.choice(subkey_div, a=len(p), p=p)
    
    #save logp for optimization purposes
    log_p = np.log(p[idx_dividing_cell])
    
    idx_new_cell = np.count_nonzero(state.celltype)
    
    ### POSITION OF NEW CELLS
    #note that cell positions will be symmetric so max is pi
    angle = random.uniform(subkey_place, minval=0., maxval=np.pi, dtype=np.float32)

    first_cell = np.array([np.cos(angle),np.sin(angle)])
    second_cell = np.array([-np.cos(angle),-np.sin(angle)])
    
    pos1 = state.position[idx_dividing_cell] + cellRadBirth*first_cell
    pos2 = state.position[idx_dividing_cell] + cellRadBirth*second_cell
    
    new_position = state.position.at[idx_dividing_cell].set(pos1)
    new_position = new_position.at[idx_new_cell].set(pos2)
    
    ### NEW RADII
    new_radius = state.radius.at[idx_dividing_cell].set(cellRadBirth)
    new_radius = new_radius.at[idx_new_cell].set(cellRadBirth)
    
    ### INHERIT CELLTYPE
    new_celltype = state.celltype.at[idx_new_cell].set(state.celltype[idx_dividing_cell])
    
    # INHERIT CHEMICAL AND DIVRATES
    # useless, recalculated right after, but just in case
    new_chemical = state.chemical.at[idx_new_cell].set(state.chemical[idx_dividing_cell])  
    new_divrate = state.divrate.at[idx_new_cell].set(state.divrate[idx_dividing_cell]) 
    
    #build new state of the system
    new_state = jax_dataclasses.replace(state, 
                                        position=new_position,
                                        radius=new_radius,
                                        celltype=new_celltype,
                                        chemical=new_chemical,
                                        divrate=new_divrate,
                                        key=new_key
                                       )
    
    return new_state, log_p

## Cell growth

In [17]:
def S_grow_cells(state, params, fspace=None):
    
    #constant growth
    new_radius = state.radius * np.exp(.1)
    
    #set max radius - try without for now
    new_radius = np.where(new_radius<params['cellRad'], new_radius, params['cellRad'])
    
    new_state = jax_dataclasses.replace(state, radius=new_radius)
    
    return new_state

## Mechanical interactions

In [18]:
def _generate_morse_params(state, params):
    '''
    Morse interaction params for each particle couple. 

    Returns:
      sigma_matrix: Distance between particles where the energy has a minimum.
      epsilon_matrix: Depth of Morse well. 
    '''

    epsilon_OneOne = params['eps_OneOne']
    epsilon_TwoTwo = params['eps_TwoTwo']
    epsilon_OneTwo = params['eps_OneTwo']


    #minimum energy when the two cells barely touch
    #minimum of the well is (approx) at the sum of the two radii (sure??)
    radii = np.array([state.radius]) 
    sigma_matrix = radii+radii.T

    #calculate epsilon (well depth) for each pair based on type
    celltypeOne = np.array([np.where(state.celltype==1,1,0)]) 
    celltypeTwo = np.array([np.where(state.celltype==2,1,0)]) 
    
    epsilon_matrix = np.outer(celltypeOne , celltypeOne)* epsilon_OneOne + \
                   np.outer(celltypeTwo , celltypeTwo)* epsilon_TwoTwo + \
                   np.outer(celltypeOne , celltypeTwo)* epsilon_OneTwo + \
                   np.outer(celltypeTwo, celltypeOne)* epsilon_OneTwo 

    return epsilon_matrix, sigma_matrix


In [19]:
def S_minimize_mech_sgd(state, params, fspace, dt=.001):
    
    epsilon_matrix, sigma_matrix = _generate_morse_params(state, params)
    
    energy_morse = energy.morse_pair(fspace.displacement,
                                     alpha=params['alpha'],
                                     epsilon=epsilon_matrix,
                                     sigma=sigma_matrix, 
                                     r_onset=params['r_onset'], 
                                     r_cutoff=params['r_cutoff'])
    
    init, apply = minimize.gradient_descent(energy_morse, fspace.shift, dt) # 0.001 is a timestep that seems to work.
    #apply = jit(apply)
 
    #@jit
    def scan_fn(opt_state, i):
        return apply(opt_state), 0.

    #initialize
    opt_state = init(state.position)
    #update
    n_steps = params['mech_relaxation_steps']
    opt_state, _ = lax.scan(scan_fn, opt_state, np.arange(n_steps))
    
    new_state = jax_dataclasses.replace(state, position=opt_state)
    
    return new_state

## Simulation

### Generate initial state

First create state with one cell, make it grow to a state with the initial number of cells and the set all the relevant parameters.

This process avoids explosions in the mechanical relaxation of the system, since positions are never too far from the equilibrium.

In [20]:
def _create_onecell_state(key, params):
    
    N = np.int16(params['ncells_init'])
    
    celltype = np.zeros(N, dtype=np.int16)
    celltype = celltype.at[0].set(1)
    
    radius = np.zeros(N, dtype=np.float32)
    radius = radius.at[0].set(params['cellRad'])
    
    position = np.zeros((N,2), dtype=np.float32)
    
    chemical = np.zeros((N,params['n_chem']), dtype=np.float64)
    
    divrate = np.zeros(N, dtype=np.float32)
    divrate = divrate.at[0].set(1.)
        
    onec_state = CellState(position, celltype, radius, chemical, divrate, key)
    
    return onec_state



def init_state_grow(key, params, fspace):
    
    N = np.int16(params['ncells_init'])
    
    n_init_tot = params['ncells_init']
    n_init_ones = params['n_ones_init']
    
    onecell_state = _create_onecell_state(key, params)
    
    def _init_add(onecell_state, i):
        
        onecell_state, _ = S_cell_division(onecell_state, params, fspace)
        onecell_state = S_grow_cells(onecell_state, params, fspace)
        onecell_state = S_minimize_mech_sgd(onecell_state, params, fspace)
    
        return onecell_state, 0.
    
    iterations = np.arange(params['ncells_init'])
    state, _ = lax.scan(_init_add, onecell_state, iterations)
    
    
    # divide cells in their subtypes arbitrarily
    celltype = np.zeros(N, dtype=np.int16)
    celltype = celltype.at[:n_init_ones].set(1)
    celltype = celltype.at[n_init_ones:n_init_tot].set(2)
    
    
    #set all cells to max radius and relax the system
    radius = np.zeros(N, dtype=np.float32)
    radius = radius.at[:n_init_tot].set(params['cellRad'])
    state = jax_dataclasses.replace(state, celltype=celltype, radius=radius)
    
    state = S_minimize_mech_sgd(state, params, fspace)
    
    # calculate consistent chemfield
    state = S_ss_chemfield(state, params, fspace)
    
    #calculate consistent division rates
    state = S_set_divrate(state, params, fspace)

    
    return state
    

### init-apply formulation

In [21]:
def simulation(fstep, params, fspace):
    
    n_ops = len(fstep)
    #fstep = iter(fstep)
    
    def sim_init(istate, key=None):
        '''
        If key is none use the key packed in initial state, else use the provided key.
        '''

        ### elongate data structures to account for cells to be added

        ncells_add = params['ncells_add']

        new_position = np.concatenate([istate.position, np.zeros((ncells_add,2))])
        new_chemical = np.concatenate([istate.chemical, np.zeros((ncells_add,params['n_chem']))])
        new_celltype = np.concatenate([istate.celltype, np.zeros(ncells_add)])
        new_radius = np.concatenate([istate.radius, np.zeros(ncells_add)])
        new_divrate = np.concatenate([istate.divrate, np.zeros(ncells_add)])

        if None != key:
            new_key = key
        else:
            new_key = istate.key

        new_istate = CellState(new_position, new_celltype, new_radius, new_chemical, new_divrate, new_key)

        return new_istate
    
    
    @eqx.filter_jit
    def sim_step(state):
                        
        #first step must always be cell division
        state, logp = fstep[0](state, params, fspace)
        
        for i in range(1,n_ops):
            state = fstep[i](state, params, fspace)
        
        return state, logp
    
    
#     @eqx.filter_jit
#     def sim_step(state):
                        
#         #first step must always be cell division
#         state, logp = fstep[0](state, params, fspace)
        
#         def scan_substeps(state, i):
#             state = fstep[i](state, params, fspace)
#             return state, None
            
#         ops = range(1,n_ops)
#         state, _ = lax.scan(scan_substeps, state, ops)
        
#         return state, logp
    
    return sim_init, sim_step

### Simulate trajectory

In [22]:
def sim_trajectory(istate, sim_init, sim_step, key=None):
    
    state = sim_init(istate, key)
    
    def scan_fn(state, i):
        state, logp = sim_step(state)
        return state, logp
    
    iterations = len(state.celltype)-len(istate.celltype)
    iterations = np.arange(iterations)
    state, logp = lax.scan(scan_fn, state, iterations)
    
    return state, logp

# Model Parameters

In [23]:
# Define parameters --blue particles are type 1, orange are type 2
# keep type casting to place vars in gpu memory

# Number of chemical signals
n_chem = 2


### CELL DIMENSIONS

#change constants in divrate calculation if cellRad != .5 
# always use python scalars
cellRad = .5
cellRadBirth = float(cellRad / np.sqrt(2))


### DIFFUSION

#possibly different diffusion and degradation for each chemical
diffCoeff = np.ones(n_chem) #np.array([1.,1.])
degRate = np.ones(n_chem) #np.array([1.,1.])

#diffusion cutoff
r_cutoffDiff = 5.*cellRad
r_onsetDiff = r_cutoffDiff - .5


### SECRETION

# sec rate that gives concentration 1 at source at SS
sec_max_unitary = 2*np.sqrt(diffCoeff*degRate)

sec_max = sec_max_unitary*np.ones((n_chem,), dtype=np.float32)
#sec_max = sec_max.at[0].set(10)

sec_gamma = 1.* np.ones((n_chem,n_chem), dtype=np.float32) #was 0.01
sec_k = np.zeros((n_chem,n_chem), dtype=np.float32) 


# GROWTH

div_gamma = 1.*np.ones(n_chem, dtype=np.float32)
div_k = 0.*np.ones(n_chem, dtype=np.float32)
#div_k = div_k.at[0].set(5.)


# MORSE POTENTIAL
# always use python scalars
alpha = 3.
eps_TwoTwo = 2. #orange
eps_OneOne = 2. #blue
eps_OneTwo = 2. 

#morse cutoff
r_cutoff = 4.*cellRad
r_onset = r_cutoff - .2


# number of gradient descent steps for Morse potential minimization
mech_relaxation_steps = 100


# Initialization and number of added cells. 
ncells_init = 10 #number of cells in the initial cluster
n_ones_init = 5 #number of type-1 cell in the initail cluster
ncells_add = 100
#ncells_tot =  np.int16(ncells_init+ncells_add)# final number of cells

assert ncells_init > n_ones_init

In [24]:
def _maybe_array(name, value, train_params):
    if train_params[name]:
        return np.array(value)
    else:
        return value

In [25]:
train_params = {
    'n_chem': False,
    
    'sec_max': False,
    'sec_gamma': False,
    'sec_k' : False,
    
    'div_gamma' : False,
    'div_k' : True,
    
    'cellRad' : False,
    'cellRadBirth' : False,
    
    'diffCoeff' : False,
    'degRate' : False,
    'r_onsetDiff' : False,
    'r_cutoffDiff' : False,
    
    'alpha': False, 
    'eps_TwoTwo': False, 
    'eps_OneOne' : False,
    'eps_OneTwo' : False,
    'r_onset' : False,
    'r_cutoff' : False,
    'mech_relaxation_steps' : False,
    
    'ncells_init' : False,
    'n_ones_init': False, 
    #'ncells_tot': ncells_tot,
    'ncells_add': False,
}

In [26]:
params = {
    'n_chem': n_chem,
    
    'sec_max': sec_max,
    'sec_gamma': sec_gamma,
    'sec_k' : sec_k,
    
    'div_gamma' : div_gamma,
    'div_k' : div_k,
    
    'cellRad' : cellRad,
    'cellRadBirth' : cellRadBirth,
    
    'diffCoeff' : diffCoeff,
    'degRate' : degRate,
    'r_onsetDiff' : r_onsetDiff,
    'r_cutoffDiff' : r_cutoffDiff,
    
    'alpha': _maybe_array('alpha', alpha, train_params), 
    'eps_TwoTwo': _maybe_array('eps_TwoTwo', eps_TwoTwo, train_params), 
    'eps_OneOne' : _maybe_array('eps_OneOne', eps_OneOne, train_params),
    'eps_OneTwo' : _maybe_array('eps_OneTwo', eps_OneTwo, train_params),
    'r_onset' : r_onset,
    'r_cutoff' : r_cutoff,
    'mech_relaxation_steps' : mech_relaxation_steps,
    
    'ncells_init' : ncells_init,
    'n_ones_init': n_ones_init, 
    #'ncells_tot': ncells_tot,
    'ncells_add': ncells_add,
}

In [27]:
key = random.PRNGKey(0)

# Simulation

In [28]:
# build initial state and space handling functions

fspace = SpaceFunc(*space.free())

key, init_key = split(key)
# generate initial state by growing from single cell
istate = init_state_grow(init_key, params, fspace)

In [29]:
# functions in this list will be executed in the given order
# at each simulation step
fstep = [
    S_cell_division,
    S_grow_cells,
    S_minimize_mech_sgd,
    S_ss_chemfield,
    S_set_divrate
]

sim_init, sim_step = simulation(fstep, params, fspace)

In [30]:
#initialize simulation state and advance one step

state = sim_init(istate)

state, logp = sim_step(state)

In [31]:
%%time

#OR run entire simulation
fstate, logps = sim_trajectory(istate, sim_init, sim_step)

CPU times: user 6.29 s, sys: 2.4 s, total: 8.7 s
Wall time: 6.14 s
