<a href="https://colab.research.google.com/github/fmottes/jax-morph/blob/dev/Ramya/04_NN_response_function.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Notebook with basic NN response function - it can create a forward simulation and you can take gradients with respect to it.

# Imports

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

: 

In [None]:
import numpy as onp
import matplotlib.pyplot as plt
import jax.numpy as np

In [None]:
%%capture
#!pip install jax
#!pip install jax-md
#!pip install optax
#!pip install equinox==0.7.1
#!pip install --upgrade dm-haiku

In [None]:
#%%capture
#!git clone https://gianthawk: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@github.com/fmottes/jax-morph.git
#!git checkout dev
%cd ..

In [None]:
from jax import random, vmap, tree_leaves, lax
from jax_md import space, quantity, util
import jax_md.dataclasses as jdc
from jax_md import space


########## IMPORT JAX-MORPH FUNCTIONS ##########
################################################

from jax_morph.datastructures import SpaceFunc
from jax_morph.utils import _maybe_array, logistic

from jax_morph.simulation import simulation, sim_trajectory

# IMPORT STATE-CHANGING FUNCTIONS
from jax_morph.division_and_growth.cell_division import S_cell_division
from jax_morph.division_and_growth.cell_growth import S_grow_cells

from jax_morph.mechanics.morse import S_mech_morse_relax
from jax_morph.cell_internals.stress import S_set_stress
from jax_morph.chemicals.secdiff import S_ss_chemfield

from jax_morph.cell_internals.divrates import S_set_divrate, div_nn
from jax_morph.cell_internals.secretion import sec_nn
from jax_morph.cell_internals.grad_estimate import S_chemical_gradients
from jax_morph.cell_internals.hidden_state import hidden_state_nn, S_hidden_state

from jax_morph.initial_states import init_state_grow

from jax_morph.visualization import draw_circles_ctype, draw_circles_chem, draw_circles_divrate
from Ramya.mech_homogeneous_growth.chemical import S_fixed_chemfield


########## IMPORT PLOTTING UTILITIES ##########
###############################################
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})

from functools import partial
import equinox as eqx
import haiku as hk
from tqdm import tqdm


In [None]:
# The warnings annoy me :(
import warnings
warnings.filterwarnings('ignore')

In [None]:
# For saving data
import pickle
from pathlib import Path
import os
ROOT_DIR = '../data/'

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
def get_state(t, traj):
    state = CellState(position=traj[0].position[t], celltype=traj[0].celltype[t], radius=traj[0].radius[t], chemical=traj[0].chemical[t], chemgrad=traj[0].chemgrad[t], field=traj[0].field[t], stress=traj[0].stress[t], divrate=traj[0].divrate[t], divangle=traj[0].divangle,hidden_state=traj[0].hidden_state[t], key=traj[0].key)
    return state

# Params

In [None]:
#@title Define Params
# Define parameters --blue particles are type 1, orange are type 2
# keep type casting to place vars in gpu memory

# Number of chemical signals
n_chem = 2


### CELL DIMENSIONS
cellRad = .5
cellRadBirth = float(cellRad / np.sqrt(2))


### DIFFUSION

# No diffusion or secretion in my simulation - only external chemical field over positions
diffCoeff = np.ones(n_chem) 
degRate = np.ones(n_chem) 

# diffusion cutoff
r_cutoffDiff = 5.*cellRad
r_onsetDiff = r_cutoffDiff - .5

# CHEMICAL FIELD
chem_max = 100.0
chem_k = 2.0
chem_gamma = 0.4

### SECRETION

# sec rate that gives concentration 1 at source at SS
#sec_max_unitary = 2*np.sqrt(diffCoeff*degRate)

sec_max = np.ones((n_chem,), dtype=np.float32)
#sec_max = sec_max.at[0].set(10) 
#secreted_by_ctypes = np.ones((n_chem, 1))
#ctype_sec_chem = np.ones((1, n_chem))
ctype_sec_chem = np.identity((n_chem))

# GROWTH
# MORSE POTENTIAL
# always use python scalars
alpha = 2.7
epsilon = 3.
eps_OneOne = 3.
eps_OneTwo = 3.
eps_TwoTwo = 3.

# morse cutoff
r_cutoff = 2.1*cellRad
r_onset = r_cutoff - .2

# number of gradient descent steps for Morse potential minimization
mech_relaxation_steps = 10
# Initialization and number of added cells. 
ncells_init = n_chem #number of cells in the initial cluster
n_ones_init = 100 #number of type-1 cell in the initail cluster
ncells_add = 100

hidden_state_size = 10


In [None]:
#@title Define trainable params
train_params = {
    'n_chem': False,
    'n_dim': False,
    'sec_max': True,
    'ctype_sec_chem': False,

    
    'cellRad' : False,
    'cellRadBirth' : False,
    
    'diffCoeff' : True,
    'degRate' : False,
    'r_onsetDiff' : False,
    'r_cutoffDiff' : False,
    
    'alpha': False, 
    'epsilon': False,
    'eps_OneOne': False,
    'eps_OneTwo': False,
    'eps_TwoTwo': False,
    'r_onset' : False,
    'r_cutoff' : False,
    'mech_relaxation_steps' : False,
    
    'ncells_init' : False,
    'n_ones_init': False, 
    'ncells_add': False,

    'chem_max': False, 
    'chem_k': False,
    'chem_gamma': False,

    'hidden_state_size': False
}

In [None]:
#@title Initialize params
params = {
    'n_chem': n_chem,
    'n_dim': 2,
    'sec_max': sec_max,
    'ctype_sec_chem' : ctype_sec_chem,
    
    'cellRad' : cellRad,
    'cellRadBirth' : cellRadBirth,
    
    'diffCoeff' : diffCoeff,
    'degRate' : degRate,
    'r_onsetDiff' : r_onsetDiff,
    'r_cutoffDiff' : r_cutoffDiff,
    
    'alpha': _maybe_array('alpha', alpha, train_params), 
    'epsilon':  _maybe_array('epsilon', epsilon, train_params),
    'eps_OneOne': _maybe_array('eps_OneOne', eps_OneOne, train_params),
    'eps_OneTwo': _maybe_array('eps_OneTwo', eps_OneTwo, train_params),
    'eps_TwoTwo': _maybe_array('eps_TwoTwo', eps_TwoTwo, train_params),
    'r_onset' : r_onset,
    'r_cutoff' : r_cutoff,
    'mech_relaxation_steps' : mech_relaxation_steps,
    
    'ncells_init' : ncells_init,
    'n_ones_init': n_ones_init, 
    'ncells_add': ncells_add,

    'chem_max': chem_max,
    'chem_k': chem_k,
    'chem_gamma': chem_gamma,

    'hidden_state_size':  hidden_state_size,

}

# Create cell state

In [None]:
# decorator MUST be jax_md.dataclass instead of dataclasses.dataclass
# to make dataclass compatible with jax tree operations
@jdc.dataclass
class CellState:
    '''
    Dataclass containing the system state.

    STATE
    -----

    '''

    # STATE
    position:   util.Array
    celltype:   util.Array
    radius:     util.Array
    chemical:   util.Array
    chemgrad:   util.Array
    field:     util.Array
    stress:   util.Array
    hidden_state: util.Array
    divrate:    util.Array
    divangle:   util.Array
    key:        util.Array


    @classmethod
    def default_init(cls, n_dim=2, n_chem=1, hidden_size=10):
        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).
        

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.
        n_chem: int
            Number of chemical species.

        Returns
        -------
        CellState
        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'
        assert n_chem > 0 and isinstance(n_chem, int), 'n_chem must be a positive integer'
        
        defaultstate = {
            'position'  :   np.empty(shape=(0, n_dim),              dtype=np.float32),
            'celltype'  :   np.empty(shape=(0,),                    dtype=np.int8),
            'radius'    :   np.empty(shape=(0,),                    dtype=np.float32),
            'chemical'  :   np.empty(shape=(0, n_chem),             dtype=np.float32),
            'chemgrad'  :   np.empty(shape=(0, int(n_dim*n_chem)),  dtype=np.float32),
            'field'   :      np.empty(shape=(0,),                   dtype=np.float32),
            'stress'  :   np.empty(shape=(0,),                      dtype=np.float32), 
            'hidden_state' : np.empty(shape=(0, hidden_size),       dtype=np.float32),
            'divrate'   :   np.empty(shape=(0,),                    dtype=np.float32),
            'divangle'  :   np.empty(shape=(0,),                  dtype=np.float32),
            'key'       :   None,
        }


        return cls(**defaultstate)

In [None]:
#!pip install memory_profiler
%load_ext memory_profiler

# Forward simulation

In [None]:
def S_cell_division(state, params, fspace=None, divangle_sigma=0.1):#, ST_grad=False):
    '''
    Performs one cell division with probability proportional to the current state divrates.
    '''

    def _divide(): 
        cellRadBirth = params['cellRadBirth'] #easier to reuse
        #split key
        new_key, subkey_div, subkey_place = random.split(state.key,3)
        
        p = state.divrate/state.divrate.sum()

        
        ### DOESN'T WORK SINCE ALL NUMBERS IN GRAD CALCULATION ARE CONVERTED TO FLOAT32
        ### SO OBV USELESS TO INDEX ARRAYS
        # # straight-through estimator set grad of sampling to 1
        # if ST_grad:
        #     def _sample_ST(p, subkey):
        #         #select cells that divides
        #         idx_dividing_cell = random.choice(subkey, a=len(p), p=p)
        #         zero = np.sum(p - jax.lax.stop_gradient(p))
        #         return zero + jax.lax.stop_gradient(idx_dividing_cell)
            
        #     idx_dividing_cell = _sample_ST(p, subkey_div).astype(np.int32)
        # else:
        idx_dividing_cell = random.choice(subkey_div, a=len(p), p=p)
        idx_new_cell = np.count_nonzero(state.celltype)
        
        ### POSITION OF NEW CELLS
        #note that cell positions will be symmetric so max is pi
        #angle = random.uniform(subkey_place, minval=0., maxval=np.pi, dtype=np.float32)
        divangle_mu = state.divangle[idx_dividing_cell]
        #divangle_sigma = state.divangle[idx_dividing_cell][1]
        #divangle_sigma=0.1
        # use differentiable clip
        z = random.normal(subkey_place)*divangle_sigma + divangle_mu
        angle = 2*np.tan(z) + np.pi
        #angle = random.normal(subkey_place)*divangle_sigma + lax.stop_gradient(divangle_mu) #*divangle_sigma + divangle_mu
        #angle = angle%(2*np.pi)
        #angle = random.uniform(subkey_place, minval=-4., maxval=4.)
        #angle = angle*np.pi
        
        #save logp for optimization purposes
        # probability of angle
        # periodic between 0 and 2pi using modulus
        p_angle = 1/(angle*np.sqrt(8*np.pi*divangle_sigma**2))*np.exp(-(np.arctan(angle) - 2*divangle_mu - np.pi)**2/(8*divangle_sigma**2))
        #p_angle = 1/(np.sqrt(2*np.pi)*divangle_sigma)*np.exp(-(angle-divangle_mu)**2/(2*divangle_sigma**2))
        log_p = np.log(p[idx_dividing_cell]) + np.log(p_angle)
        # fix small standard deviation and only optimize mean
        first_cell = np.array([np.cos(angle),np.sin(angle)])
        second_cell = np.array([-np.cos(angle),-np.sin(angle)])
        
        pos1 = state.position[idx_dividing_cell] + cellRadBirth*first_cell
        pos2 = state.position[idx_dividing_cell] + cellRadBirth*second_cell
        
        
        new_fields = {}
        for field in jdc.fields(state):

            value = getattr(state, field.name)
            if 'position' == field.name:
                new_fields[field.name] = value.at[idx_dividing_cell].set(pos1).at[idx_new_cell].set(pos2)
            elif 'radius' == field.name:
                new_fields[field.name] = value.at[idx_dividing_cell].set(cellRadBirth).at[idx_new_cell].set(cellRadBirth)
            elif 'key' == field.name:
                new_fields[field.name] = new_key
            else:
                new_fields[field.name] = value.at[idx_new_cell].set(value[idx_dividing_cell])
        new_state = type(state)(**new_fields)
        #del state
        return new_state, log_p
    
    
    def _no_division():
        return state, 0.
    return jax.lax.cond(state.divrate.sum()>0, _divide, _no_division)


In [None]:
def prob(theta, b, sigma):
    return 1/(sigma*np.sqrt(2*np.pi))*np.exp(-0.5*(theta-b)**2/sigma**2)

In [None]:
def S_cell_division(state, params, fspace=None, divangle_sigma=0.1):#, ST_grad=False):
    '''
    Performs one cell division with probability proportional to the current state divrates.
    '''

    def _divide(): 
        cellRadBirth = params['cellRadBirth'] #easier to reuse
        #split key
        new_key, subkey_div, subkey_place = random.split(state.key,3)
        
        p = state.divrate/state.divrate.sum()

        
        ### DOESN'T WORK SINCE ALL NUMBERS IN GRAD CALCULATION ARE CONVERTED TO FLOAT32
        ### SO OBV USELESS TO INDEX ARRAYS
        # # straight-through estimator set grad of sampling to 1
        # if ST_grad:
        #     def _sample_ST(p, subkey):
        #         #select cells that divides
        #         idx_dividing_cell = random.choice(subkey, a=len(p), p=p)
        #         zero = np.sum(p - jax.lax.stop_gradient(p))
        #         return zero + jax.lax.stop_gradient(idx_dividing_cell)
            
        #     idx_dividing_cell = _sample_ST(p, subkey_div).astype(np.int32)
        # else:
        idx_dividing_cell = random.choice(subkey_div, a=len(p), p=p)
        idx_new_cell = np.count_nonzero(state.celltype)
        
        ### POSITION OF NEW CELLS
        #note that cell positions will be symmetric so max is pi
        #angle = random.uniform(subkey_place, minval=0., maxval=np.pi, dtype=np.float32)
        divangle_mu = state.divangle[idx_dividing_cell]
        #divangle_sigma = state.divangle[idx_dividing_cell][1]
        #divangle_sigma=0.1
        # use differentiable clip
        angle = random.normal(subkey_place)*divangle_sigma + lax.stop_gradient(divangle_mu) #*divangle_sigma + divangle_mu
        #angle = angle%(2*np.pi)
        #angle = random.uniform(subkey_place, minval=-4., maxval=4.)
        #angle = angle*np.pi
        
        #save logp for optimization purposes
        # probability of angle
        # periodic between 0 and 2pi using modulus
        p_angle = 1/(np.sqrt(2*np.pi)*divangle_sigma)*np.exp(-(angle-divangle_mu)**2/(2*divangle_sigma**2))
        log_p = np.log(p[idx_dividing_cell]) + np.log(p_angle)
        # fix small standard deviation and only optimize mean
        first_cell = np.array([np.cos(angle),np.sin(angle)])
        second_cell = np.array([-np.cos(angle),-np.sin(angle)])
        
        pos1 = state.position[idx_dividing_cell] + cellRadBirth*first_cell
        pos2 = state.position[idx_dividing_cell] + cellRadBirth*second_cell
        
        
        new_fields = {}
        for field in jdc.fields(state):

            value = getattr(state, field.name)
            if 'position' == field.name:
                new_fields[field.name] = value.at[idx_dividing_cell].set(pos1).at[idx_new_cell].set(pos2)
            elif 'radius' == field.name:
                new_fields[field.name] = value.at[idx_dividing_cell].set(cellRadBirth).at[idx_new_cell].set(cellRadBirth)
            elif 'key' == field.name:
                new_fields[field.name] = new_key
            else:
                new_fields[field.name] = value.at[idx_new_cell].set(value[idx_dividing_cell])
        new_state = type(state)(**new_fields)
        #del state
        return new_state, log_p
    
    
    def _no_division():
        return state, 0.
    return jax.lax.cond(state.divrate.sum()>1e-2, _divide, _no_division)


In [None]:
import jax_md
import jax
def init_state_grow(key, empty_state, params, fspace, n_cells=5):
    '''
    Initialize an empty state with a single cell and grow to a given number of cells.

    NOTE: empty_state must include the following fields for this initialization method to work correctly:
    - position
    - celltype
    - radius
    - divrate

    All other fields are initialized to zero. All cells are set to the same radius and celltype.
    '''

    assert n_cells > 0, 'Must initialize at least one cell.'

    # elongate datastructures to the accomodate the initial number of cells
    new_fields = {}
    for field in jdc.fields(empty_state):

        if field.name == 'key':
            new_fields[field.name] = key

        else:
            #retrieve the value of the field
            value = getattr(empty_state, field.name)

            if jax_md.util.is_array(value):

                if len(value.shape) > 0:
                    shape = (n_cells,)+(value.shape[1:])
                    new_fields[field.name] = np.zeros(shape, dtype=value.dtype)
                    
                else:
                    new_fields[field.name] = value
            else:
                new_fields[field.name] = value


    state = type(empty_state)(**new_fields)


    # initialize the first cell
    celltype = state.celltype.at[0].set(1)
    radius = state.radius.at[0].set(params['cellRad'])
    divrate = state.divrate.at[0].set(1.)
    divangle = state.divangle.at[0].set(0.0)
    state = jdc.replace(state, celltype=celltype, radius=radius, divrate=divrate, divangle=divangle)

    
    # add one cell at a time and relax the system
    def _init_add(state, i):
        state, _    = S_cell_division(state, params, fspace)
        state       = S_grow_cells(state, params, fspace)
        state       = S_mech_morse_relax(state, params, fspace)
        return state, 0.
    
    iterations = np.arange(n_cells-1)
    state, _ = jax.lax.scan(_init_add, state, iterations)
    
    
    #set all cells to max radius and relax the system
    radius = np.ones_like(state.radius)*params['cellRad']
    state = jdc.replace(state, radius=radius)
    
    state = S_mech_morse_relax(state, params, fspace)

    #set key to None to signal possibly inconsistent state
    state = jdc.replace(state, key=None)
    
    return state

In [None]:
# build space handling function and initial state
key = random.PRNGKey(20)
fspace = SpaceFunc(*space.free())

N_CELLS_INIT = params['ncells_init']



#generate empty data structure with correct shapes 
istate = CellState.default_init(n_dim=params['n_dim'], 
                                n_chem=params['n_chem'],
                                hidden_size=params['hidden_state_size']
                                )

# populate initial state by growing from single cell
key, init_key = random.split(key)
istate = init_state_grow(init_key, istate, params, fspace, N_CELLS_INIT)
for i in range(n_chem):
    istate = jdc.replace(istate, celltype=istate.celltype.at[i].set(i + 1))
#istate = jdc.replace(istate, celltype=istate.celltype.at[istate.celltype.shape[0]//2:].set(2))

In [None]:
#randomly initialize hidden states
from jax.nn import softplus
key, init_key = random.split(key)
# hidden_regulation_init = 5*jax.random.normal(init_key, shape=istate.hidden_state.shape)
hidden_state_init = softplus(10*(random.uniform(init_key, shape=istate.hidden_state.shape)*2 - 1))
istate = jdc.replace(istate, hidden_state=hidden_state_init)

In [None]:
# randomly initialize chemical species
key, init_key = random.split(key)
ichem = random.uniform(init_key, istate.chemical.shape)*params['sec_max']
istate = jdc.replace(istate, chemical=ichem)

In [None]:
#hidden neurons per layer
from jax.nn import tanh, softplus
HID_HIDDEN = [64] 

#input fields to the network
use_state_fields = CellState(position=      False, 
                             celltype=      False, 
                             radius=            False, 
                             chemical=          True,
                             chemgrad=          True,
                             field=             False,
                             stress=            False,
                             divrate=           False,
                             divangle=          False,
                             hidden_state=      False,
                             key=           False
                            )


# init nn functions
hid_init, hid_nn_apply = hidden_state_nn(params,
                                         train_params,
                                         HID_HIDDEN,
                                         use_state_fields,
                                         train=True,
                                         transform_mlp_out=softplus,
                                         )


key, init_key = random.split(key)
params, train_params = hid_init(istate, init_key)

In [None]:
# GENERATE DIVISION FUNCTION WITH NEURAL NETWORK
from jax.nn import softplus, leaky_relu, sigmoid, initializers
from jax import tree_map, tree_leaves
def div_nn(params, 
           train_params=None, 
           n_hidden=3,
           use_state_fields=None,
           train=True,
           transform_mlp_out=sigmoid,
           transform_fwd=None,
           w_init=None,
           b_init=None,
           params_name="divrate_fn",
          ):
    
    if use_state_fields is None:
        raise ValueError('Input fields flags must be passed explicitly as a CellState dataclass.')
    
    if type(n_hidden) == np.int_ or type(n_hidden) == int:
        n_hidden = [int(n_hidden)]

    if transform_mlp_out is None:
        transform_mlp_out = lambda x: x
        
    if transform_fwd is None:
        transform_fwd = lambda state,divrate: divrate
    
    def _div_nn(in_fields, w_init=w_init, b_init=b_init):
        mlp = hk.nets.MLP(n_hidden+[1],
                          activation=leaky_relu,
                          activate_final=False,
                          w_init=w_init,
                          b_init=b_init,
                         )
        
        out = mlp(in_fields)
        out = transform_mlp_out(out)

        return out

    _div_nn = hk.without_apply_rng(hk.transform(_div_nn))


    
    def init(state, key):
        
        in_fields = np.hstack([f if len(f.shape)>1 else f[:,np.newaxis] for f in tree_leaves(eqx.filter(state, use_state_fields))])
        input_dim = in_fields.shape[1]
            
        p = _div_nn.init(key, np.zeros(input_dim))
        
        #add to param dict
        params[params_name] = p
        
        # no need to update train_params when generating initial state
        if type(train_params) is dict:
            
            #set trainability flag
            train_p = tree_map(lambda x: train, p)

            train_params[params_name] = train_p
        
            return params, train_params
            
        else:
            return params
            
        
    def fwd(state, params):
        in_fields = np.hstack([f if len(f.shape)>1 else f[:,np.newaxis] for f in tree_leaves(eqx.filter(state, use_state_fields))])
        x = _div_nn.apply(params[params_name], in_fields).flatten()
        div_output = transform_fwd(state, x)
        div_output = div_output*logistic(state.radius+.06, 50, params['cellRad'])
        div_output = np.where(state.celltype<1.,0.,div_output)
        return div_output
    return init, fwd

In [None]:
# GENERATE DIVISION FUNCTION WITH NEURAL NETWORK
from jax.nn import softplus, leaky_relu, sigmoid, initializers
from jax import tree_map, tree_leaves
def divangle_nn(params, 
           train_params=None, 
           n_hidden=3,
           use_state_fields=None,
           train=True,
           transform_mlp_out=sigmoid,
           transform_fwd=None,
           w_init=None,
           b_init=None,
           params_name="divrate_fn",
          ):
    
    if use_state_fields is None:
        raise ValueError('Input fields flags must be passed explicitly as a CellState dataclass.')
    
    if type(n_hidden) == np.int_ or type(n_hidden) == int:
        n_hidden = [int(n_hidden)]

    if transform_mlp_out is None:
        transform_mlp_out = lambda x: x
        
    if transform_fwd is None:
        transform_fwd = lambda state,divrate: divrate
    
    def _div_nn(in_fields, w_init=w_init, b_init=b_init):
        mlp = hk.nets.MLP(n_hidden+[1],
                          activation=leaky_relu,
                          activate_final=False,
                          w_init=w_init,
                          b_init=b_init,
                          #with_bias=False,
                         )
        
        out = mlp(in_fields)
        out = transform_mlp_out(out)

        return out

    _div_nn = hk.without_apply_rng(hk.transform(_div_nn))


    
    def init(state, key):
        
        in_fields = np.hstack([f if len(f.shape)>1 else f[:,np.newaxis] for f in tree_leaves(eqx.filter(state, use_state_fields))])
        input_dim = in_fields.shape[1]
            
        p = _div_nn.init(key, np.zeros(input_dim))
        
        #add to param dict
        params[params_name] = p
        
        # no need to update train_params when generating initial state
        if type(train_params) is dict:
            
            #set trainability flag
            train_p = tree_map(lambda x: train, p)

            train_params[params_name] = train_p
        
            return params, train_params
            
        else:
            return params
            
        
    def fwd(state, params):
        in_fields = np.hstack([f if len(f.shape)>1 else f[:,np.newaxis] for f in tree_leaves(eqx.filter(state, use_state_fields))])
        x = _div_nn.apply(params[params_name], in_fields).flatten()
        div_output = transform_fwd(state, x)
        #div_output = div_output*logistic(state.radius+.06, 50, params['cellRad'])
        div_output = np.where(state.celltype<1.,0.,div_output)
        return div_output
    return init, fwd

In [None]:
#hidden neurons per layer
from jax.nn import sigmoid, initializers
DIVRATE_HIDDEN = []
DIVANGLE_HIDDEN = []
key, divrate_key, divangle_key = random.split(key, 3)
#input fields to the network
use_state_fields_div = CellState(position=   False, celltype=   False, radius=     False, chemical=     False,field=      False,stress=    False,chemgrad=   False,hidden_state= True,divrate=    False, divangle=  False,key=        False)
div_init, div_nn_apply = div_nn(params,
                                train_params,
                                DIVRATE_HIDDEN,
                                use_state_fields_div,
                                train=True,
                                transform_fwd=None,
                                w_init=hk.initializers.Constant(0.0),
                                b_init=None)
#initialize network parameters
params, train_params = div_init(istate, divrate_key)
divangle_init, divangle_nn_apply = divangle_nn(params,
                                train_params,
                                DIVANGLE_HIDDEN,
                                use_state_fields_div,
                                train=True,
                                transform_fwd=lambda state,divangle: 2*np.pi*divangle,
                                w_init=None,
                                b_init=None,
                                params_name="divangle_fn",)
params, train_params = divangle_init(istate, divangle_key)

In [None]:
#hidden neurons per layer
SEC_HIDDEN = []


#input fields to the network
use_state_fields_sec = CellState(position=   False, 
                             celltype=   False, 
                             radius=     False, 
                             chemical=      True,
                             chemgrad=   True,
                             field=      False,
                             stress=   False,
                             divrate=    False,
                             divangle=   False,
                             hidden_state= False, 
                             key=        False
                            )


# init nn functions
sec_init, sec_nn_apply = sec_nn(params,
                                train_params,
                                SEC_HIDDEN,
                                use_state_fields_sec,
                                train=True)


#initialize network parameters
key, init_key = random.split(key)
params, train_params = sec_init(istate, init_key)

In [None]:
import jax_md.dataclasses as jax_dataclasses
def S_set_divrate_divangle(state, params, fspace=None, divrate_fn=None, divangle_fn=None):
    
    if None == divrate_fn:
        raise(ValueError('Need to pass a valid function for the calculation of the division rates.'))
    if None == divangle_fn:
        raise(ValueError('Need to pass a valid function for the calculation of the division angles.'))
    divrates = divrate_fn(state, params)
    divangles = divangle_fn(state, params)
    new_state = jax_dataclasses.replace(state, divrate=divrates)
    new_state = jax_dataclasses.replace(new_state, divangle=divangles)
    return new_state

In [None]:
# functions in this list will be executed in the given order
# at each simulation step

fstep = [
    # ENV CHANGES
    partial(S_cell_division, divangle_sigma=0.3),
    S_grow_cells,
    partial(S_mech_morse_relax, morse_eps_sigma='twotypes'),
    partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),

    # SENSING
    #chemicals sensed directly
    S_chemical_gradients,
    S_fixed_chemfield,
    S_set_stress,
    # INTERNAL (HIDDEN) STATE
    #no hidden state in this case
    # INTERNAL (HIDDEN) STATE
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    # POLICIES
    partial(S_set_divrate_divangle, divrate_fn=eqx.filter_jit(div_nn_apply), divangle_fn=eqx.filter_jit(divangle_nn_apply)),
]


sim_init, sim_step = simulation(fstep, params, fspace)

In [None]:
key = random.PRNGKey(5)
sim_init, sim_step = simulation(fstep, params, fspace)
fstate, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True) 

In [None]:
draw_circles_ctype(istate)

In [None]:
draw_circles_ctype(fstate)

In [None]:
def sim(key, b_val, param, batch_size):
    batch_subkeys = random.split(key, batch_size)
    param["divangle_fn"]["mlp/~/linear_1"]["b"] = b_val
    sim_init, sim_step = simulation(fstep, param, fspace)
    sim_traj_vmap = vmap(sim_trajectory, (None, None, None, None, 0))
    #fstate, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True) 
    fstates, trajs = sim_traj_vmap(istate, sim_init, sim_step, 100, batch_subkeys)
    vmap_metric_fn = vmap(position_sum_of_squares, (0,))
    return np.average(vmap_metric_fn(fstates)), np.average(np.sum(trajs, axis=1))
    #return position_sum_of_squares(fstate), np.sum(traj[1])
sim_vmapped = vmap(sim, (0, None,  None, None))
param_vmapped = vmap(sim, (0, 0, None, None))
#sim_vmapped_batched = vmap(vmap(sim, (0, None, None, None)), (0, None, None, None))
#param_vmapped_batched = vmap(vmap(sim, (None, 0, None, None)), (0, None, None, None))

In [None]:
%%time
keys = random.split(key, 50)
#batch_keys = random.split(key, 20)
bs = np.atleast_2d(np.linspace(0, 2*np.pi, 50)).T
losses_sim, lprob_sim = sim_vmapped(keys, np.array([np.pi]), params.copy(), 1)
losses_param, lprob_param = param_vmapped(keys, bs, params.copy(), 1)
losses_sim_50, lprob_sim_50 = sim_vmapped(keys, np.array([np.pi]), params.copy(), 50)
losses_param_50, lprob_param_50 = param_vmapped(keys, bs, params.copy(), 50)

In [None]:
plt.hist(losses_sim, label=r'b = $\pi$', alpha=0.6, edgecolor='black', bins=np.linspace(5, 10, 30));
plt.hist(losses_sim_50, label=r'b = $\pi$, batch 20', alpha=0.6, edgecolor='black', bins=np.linspace(5, 10, 30));

plt.hist(losses_param, label=r'b $\in$ [0, $2\pi$]', alpha=0.6, edgecolor='black', bins=np.linspace(5, 10, 30));
plt.hist(losses_param_50, label=r'b $\in$ [0, $2\pi$], batch 20', alpha=0.6, edgecolor='black', bins=np.linspace(5, 10, 30));

plt.xlabel("Loss")
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left');

In [None]:
plt.plot(np.linspace(0, 2*np.pi, 50), losses_param, label='1')
plt.plot(np.linspace(0, 2*np.pi, 50), losses_param_50, label='50')
plt.xlabel("b value (rads)")
plt.ylabel("loss");
plt.legend(title="batch size");
# vertical dotted black lines at pi/2 and 3pi/2
plt.axvline(1.1, color='black', linestyle='--')


In [None]:
plt.plot(np.linspace(0, 2*np.pi, 50), lprob_param, label='1')
plt.plot(np.linspace(0, 2*np.pi, 50), lprob_param_50, label='50')
plt.xlabel("b value (rads)")
plt.ylabel("log prob");
plt.legend(title="batch size");

In [None]:
state = get_state(100, traj)
draw_circles_chem(state, 1);    

In [None]:
def draw_circles_divangle(state, idx=0, probability=False, colorbar=True, ax=None, cm=plt.cm.coolwarm, grid=False, labels=False, edges=False, cm_edges=plt.cm.coolwarm, **kwargs):
    
    if None == ax:
        ax = plt.axes()
    
    alive_cells = state.celltype > 0

    divrate = state.divangle[alive_cells]
    divrate = (divrate-divrate.min()+1e-20)/(divrate.max()-divrate.min()+1e-20)
        
    color = cm(divrate)
    
    if edges:
        #only usable for two cell types
        ct_color = cm_edges(np.float32(state.celltype-1)[alive_cells])

        for cell,radius,c,ctc in zip(state.position[alive_cells],state.radius[alive_cells],color,ct_color):
            circle = plt.Circle(cell, radius=radius, fc=c, ec=ctc, lw=2, alpha=.5, **kwargs)
            ax.add_patch(circle)
            
    else:
        # 
        for i,(cell,radius,c) in enumerate(zip(state.position[alive_cells],state.radius[alive_cells],color)):
            circle = plt.Circle(cell, radius=radius, fc=c, alpha=.5, **kwargs)
            ax.add_patch(circle)
            if labels:
                ax.text(*cell, str(i), horizontalalignment='center', verticalalignment='center')



    
    #show colorbar
    if colorbar:    
        sm = plt.cm.ScalarMappable(cmap=cm, norm=plt.Normalize(vmin=state.divangle[alive_cells].min(), vmax=state.divangle[alive_cells].max()))
        sm._A = []
        if idx == 0:
            cbar_text = 'Division Angle mean'
        else:
            cbar_text = 'Division Angle std'
        cbar = plt.colorbar(sm, shrink=0.7, alpha=.5) # rule of thumb
        cbar.set_label(cbar_text, labelpad=20)
    
    ## calculate ax limits
    xmin = np.min(state.position[:,0][alive_cells])
    xmax = np.max(state.position[:,0][alive_cells])
    
    ymin = np.min(state.position[:,1][alive_cells])
    ymax = np.max(state.position[:,1][alive_cells])
    
    max_coord = max([xmax,ymax])+3
    min_coord = min([xmin,ymin])-3
    
    plt.xlim(min_coord,max_coord)
    plt.ylim(min_coord,max_coord)
    

    
    #scale x and y in the same way
    ax.set_aspect('equal', adjustable='box')

    #white bg color for ax
    ax.set_facecolor([1,1,1])

    if grid:
        ax.grid(alpha=.2)
    else:
        #remove axis spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

        plt.xticks([])
        plt.yticks([])


    background_color = [56 / 256] * 3        
    plt.gcf().patch.set_facecolor(background_color)
    plt.gcf().patch.set_alpha(0)
    plt.gcf().set_size_inches(8, 8)
    return plt.gcf(), ax

In [None]:
# plotting mu
draw_circles_divangle(fstate);    

In [None]:
position_sum_of_squares(fstate)

# Optimization

In [None]:
import optax
from jax import value_and_grad
from jax_morph.optimization.losses import loss, avg_loss
from jax_morph.optimization.state_metrics import cv_divrates

In [None]:
from jax import tree_map
def train(key,
          params, train_params, 
          EPOCHS, 
          EPISODES_PER_UPDATE, 
          EPISODES_PER_EVAL, 
          LEARNING_RATE, 
          rloss, 
          sloss, 
          fstep, 
          fspace, 
          istate,
          normalize_grads=True,
          ):

    p, hp = eqx.partition(params, train_params)

    # init optimizer
    optimizer = optax.adam(LEARNING_RATE)
    opt_state = optimizer.init(p)


    #--------------------------------------------
    #store loss at initial params and calc grad 

    key, *batch_subkeys = random.split(key, EPISODES_PER_UPDATE+1)
    batch_subkeys = np.array(batch_subkeys)

    ll, grads = value_and_grad(avg_loss)(p, hp, rloss, batch_subkeys, fspace=fspace, fstep=fstep, istate=istate)


    key, *eval_subkeys = random.split(key, EPISODES_PER_EVAL+1)
    eval_subkeys = np.array(eval_subkeys)

    l = avg_loss(p, hp, sloss, eval_subkeys, fstep=fstep, fspace=fspace, istate=istate)
    print(float(l))
    #store initial params and loss
    loss_t = [float(l)]
    params_t = [p]
    grads_t = [grads]
    

    #--------------------------------------------


    for t in range(EPOCHS):
        #generate batch of random keys
        key, *batch_subkeys = random.split(key, EPISODES_PER_UPDATE+1)
        batch_subkeys = np.array(batch_subkeys)
        #normalize grads
        if normalize_grads:
            grads = tree_map(lambda x: x/(np.linalg.norm(x)+1e-10), grads)
        # sgd step
        updates, opt_state = optimizer.update(grads, opt_state, p)

        p = eqx.apply_updates(p, updates)
    
        #clip diffCoeff if trained
        if None != p['diffCoeff']:
            p['diffCoeff'] = np.clip(p['diffCoeff'],.2)
    
        #estimate actual avg loss
        key, *eval_subkeys = random.split(key, EPISODES_PER_EVAL+1)
        eval_subkeys = np.array(eval_subkeys)

        l = avg_loss(p, hp, sloss, eval_subkeys, fstep=fstep, fspace=fspace, istate=istate)
    
        # surrogate loss and grads
        ll, grads = value_and_grad(avg_loss)(p, hp, rloss, batch_subkeys, fstep=fstep, fspace=fspace, istate=istate)


        #store
        loss_t += [float(l)]
        params_t += [p]
        grads_t += [grads]
        print(float(l))
    
            
    print(float(l))
    return loss_t, params_t, grads_t

In [None]:
EPOCHS = 300
EPISODES_PER_UPDATE = 100
EPISODES_PER_EVAL = 100

LEARNING_RATE = 1e-3


############## define loss parameters
METRIC_FN = cv_divrates
TARGET_METRIC = 0.

LAMBDA = 0. #.01

In [None]:
def entropy_fn(state):
    return np.sum(state.divrate*np.log(state.divrate + 1e-12))

In [None]:
def position_sum_of_squares(state, coordinate=0):

    alive = state.celltype > 0

    m = np.sum((state.position[:, coordinate] * alive)**2)

    m = m / np.sum(alive)

    return m

In [None]:
def save_data(params_tt, loss_tt, grads_tt, PATH_NAME, file_name, lr):
    p = Path(ROOT_DIR + PATH_NAME)
    if not p.exists():
        os.mkdir(ROOT_DIR + PATH_NAME)

    pickle.dump(params_tt, open(ROOT_DIR + PATH_NAME + f'params_tt/{file_name}_lr{lr}_epochs{EPOCHS}_episodes{EPISODES_PER_UPDATE}_{HID_HIDDEN}_hidden', 'wb'))
    pickle.dump(loss_tt, open(ROOT_DIR + PATH_NAME + f'loss_tt/{file_name}_lr{lr}_epochs{EPOCHS}_episodes{EPISODES_PER_UPDATE}_{HID_HIDDEN}_hidden', 'wb'))
    pickle.dump(grads_tt, open(ROOT_DIR + PATH_NAME + f'grads_tt/{file_name}_lr{lr}_epochs{EPOCHS}_episodes{EPISODES_PER_UPDATE}_{HID_HIDDEN}_hidden', 'wb'))


In [None]:
rloss = eqx.filter_jit(partial(loss, metric_fn=position_sum_of_squares, REINFORCE=True, metric_type='cost', GAMMA=0.90))
sloss = eqx.filter_jit(partial(loss, metric_fn=position_sum_of_squares, metric_type='cost', REINFORCE=False, GAMMA=0.90))
loss_t, params_t, grads_t = train(key, params, train_params, EPOCHS, EPISODES_PER_UPDATE, EPISODES_PER_EVAL, LEARNING_RATE, rloss, sloss, fstep, fspace, istate)

In [None]:
rloss = eqx.filter_jit(partial(loss, metric_fn=position_sum_of_squares, REINFORCE=True, metric_type='cost', GAMMA=0.90))
sloss = eqx.filter_jit(partial(loss, metric_fn=position_sum_of_squares, metric_type='cost', REINFORCE=False, GAMMA=0.80))
loss_80_2, params_80_2, grads_80_2 = train(key, params, train_params, EPOCHS, EPISODES_PER_UPDATE, EPISODES_PER_EVAL, LEARNING_RATE, rloss, sloss, fstep, fspace, istate)

In [None]:
save_data(params_80_2, loss_80_2, grads_80_2, "optimizations/divangle/", "gamma80_2")

In [None]:
# Plot loss for each gamma value
#plt.plot(loss_10, label='gamma=0.10')
plt.plot(loss_80, label='gamma=0.80', color='green')
plt.plot(loss_80_2, color='green', linestyle='dashed')
plt.plot(loss_90, label='gamma=0.90', color='black')
plt.plot(loss_90_2, color='black', linestyle='dashed')
plt.plot(loss_99, label='gamma=0.99', color='maroon')
plt.plot(loss_99_2, color='maroon', linestyle='dashed')
plt.title('Loss')
plt.xlabel('Epoch')
# Plot for 200 epochs
plt.xlim(0, 200)
plt.legend();

In [None]:
plt.plot(loss_t)
plt.title('Loss')
plt.xlabel('Epoch')

In [None]:
opt_params = eqx.combine(params_t[250], params)
sim_init, sim_step = simulation(fstep, opt_params, fspace)
fstate_opt, traj_opt = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True) 

In [None]:
plt.hist(fstate_opt.divangle, label='opt');
plt.hist(fstate.divangle, label='init');
plt.legend()
plt.xlabel("divangle");

In [None]:
draw_circles_divangle(fstate_opt);

In [None]:
draw_circles_chem(fstate_opt, 1);

In [None]:
draw_circles_chem(fstate_opt);

In [None]:
#transform hidden states with pca to 2d
from sklearn.decomposition import PCA
pca = PCA(n_components=2)

# pca_hidden = pca.fit_transform(np.log(fstate.hidden_state + 1e-40))
pca_hidden = pca.fit_transform(fstate_opt.hidden_state)


plt.scatter(pca_hidden[:,0], pca_hidden[:,1], c=fstate_opt.divrate, cmap='coolwarm', alpha=.5)
plt.xlabel('PCA 1')
plt.ylabel('PCA 2');

In [None]:
def mask_metric(mask_fn=None, reward=3., penalty=-1.):
        
    def metric(state):
        
        alive = state.celltype > 0
        n = np.sum(alive)
        
        mask = mask_fn(state.position)

        m = np.sum(np.where(mask, reward, penalty)*alive)

        # penalize asymmetric growth
        m += -.5*np.abs(np.sum(state.position[:, 0] * alive))
        # penalize unequal cell types
        #m += -.5*np.abs(np.sum(alive) - np.sum(np.where(state.celltype == 1, 1, 0)))
        return m
    
    return metric
    
    return metric

def v_mask(pos):
    '''
    Constrain cell growth in a V shape.
    '''
    return (pos[:,1]+1.5 > .5*np.abs(pos[:,0])) * (pos[:,1]+1.5 < 2.5+.5*np.abs(pos[:,0])) * (pos[:,1]>0.)

def heart_mask(pos):
    return np.power(pos[:, 0], 2) + np.power((pos[:, 1] - np.power(pos[:, 0], 2.0/3.0)), 2) < 35.0

In [None]:
v_loss = mask_metric(v_mask)
heart_loss = mask_metric(heart_mask)

In [None]:
#LEARNING_RATE = 5e-3
EPOCHS = 1000
EPISODES_PER_UPDATE = 100
EPISODES_PER_EVAL = 100

In [None]:
LEARNING_RATE = optax.piecewise_constant_schedule(2e-3, {300: .2})

In [None]:
def mask_metric(mask_fn=None, reward=3., penalty=-1.):
        
    def metric(state):
        
        alive = state.celltype > 0
        n = np.sum(alive)
        
        mask = mask_fn(state.position)

        m = np.sum(np.where(mask, reward, penalty)*alive)

        # penalize asymmetric growth
        m += -.5*np.abs(np.sum(state.position[:, 0] * alive))
        # penalize number of cells more than 60 
        m += -.5*np.heaviside(np.sum(alive) - 60.0, 0)*np.power(np.sum(alive) - 60, 2)
        #m += -.5*np.abs(np.sum(alive) - np.sum(np.where(state.celltype == 1, 1, 0)))
        return m
    
    return metric
    
    return metric

def v_mask(pos):
    '''
    Constrain cell growth in a V shape.
    '''
    return (pos[:,1]+1.5 > .5*np.abs(pos[:,0])) * (pos[:,1]+1.5 < 2.5+.5*np.abs(pos[:,0])) * (pos[:,1]>0.)

def heart_mask(pos):
    return np.power(pos[:, 0], 2) + np.power((pos[:, 1] - np.power(pos[:, 0], 2.0/3.0)), 2) < 35.0

In [None]:
v_loss = mask_metric(v_mask)

In [None]:
losses = [v_loss(get_state(i, traj_opt)) for i in range(100)]

In [None]:
@eqx.filter_jit
@eqx.filter_vmap(default=None, kwargs=dict(sim_key=0))
def loss(params, 
         hyper_params,
         fstep,
         fspace,
         istate,
         sim_key=None,
         metric_fn=None,
         metric_type='reward',
         REINFORCE=True,
         GAMMA=.99,
         ncells_add=None
         ):
    '''
    Reinforce loss on trajectory (with discounting). Rewards are differences in successive state metrics.

    If REINFORCE=False, then the loss is just the state measure on the final state.

    GAMMA is the discount factor for the calculation of the returns.

    If metric_type='reward', it is maximized, if metric_type='cost', it is minimized.

    '''

    #simulation length
    ncells_add = hyper_params['ncells_add'] if ncells_add is None else ncells_add
    
    def _sim_trajectory(istate, sim_init, sim_step, ncells_add, key=None):

        state = sim_init(istate, ncells_add, key)

        def scan_fn(state, i):
            state, logp = sim_step(state)
            measure = metric_fn(state)
            return state, (logp, measure)


        iterations = np.arange(ncells_add)
        fstate, aux = jax.lax.scan(scan_fn, state, iterations)

        return fstate, aux

    # merge params dicts
    all_params = eqx.combine(params, hyper_params)

    #forward pass - simulation
    sim_init, sim_step = simulation(fstep, all_params, fspace)
    _, (logp, measures) = _sim_trajectory(istate, sim_init, sim_step, ncells_add, sim_key)

    
    if REINFORCE:
        
        def _returns_rec(rewards):
            Gs=[]
            G=0
            for r in rewards[::-1]:
                G = r+G*GAMMA
                Gs.append(G)

            return np.array(Gs)[::-1]
        
        
        measures = np.append(np.array([metric_fn(istate)]),measures)
        
        if metric_type=='reward':
            rewards = np.diff(measures)
        elif metric_type=='cost':
            rewards = -np.diff(measures)


        returns = _returns_rec(rewards)
        # want structure to stop growing
        returns = returns
        # standardizing returns helps with convergence
        returns = (returns-returns.mean())/(returns.std()+1e-8)
        loss = -np.sum(logp*jax.lax.stop_gradient(returns))

        return loss

    else:
        return measures[-1]


In [None]:
0.1/100.0

In [None]:
EPOCHS = 100

In [None]:
train_params["divangle_fn"] = {'mlp/~/linear_0': {'b': True, 'w': True}}
train_params["divrate_fn"] = {'mlp/~/linear_0': {'b': True, 'w': True}}
rloss = eqx.filter_jit(partial(loss, metric_fn=v_loss, REINFORCE=True, metric_type='reward', GAMMA=0.90))
sloss = eqx.filter_jit(partial(loss, metric_fn=v_loss, metric_type='reward', REINFORCE=False, GAMMA=0.90))
loss_tn, params_tn, grads_tn = train(key, params, train_params, EPOCHS, EPISODES_PER_UPDATE, EPISODES_PER_EVAL, LEARNING_RATE, rloss, sloss, fstep, fspace, istate)

In [None]:
train_params["divangle_fn"] = {'mlp/~/linear_0': {'b': False, 'w': False}}
train_params["divrate_fn"] = {'mlp/~/linear_0': {'b': True, 'w': True}}
rloss = eqx.filter_jit(partial(loss, metric_fn=v_loss, REINFORCE=True, metric_type='reward', GAMMA=0.90))
sloss = eqx.filter_jit(partial(loss, metric_fn=v_loss, metric_type='reward', REINFORCE=False, GAMMA=0.90))
loss_tn_rate, params_tn_rate, grads_tn_rate = train(key, params, train_params, EPOCHS, EPISODES_PER_UPDATE, EPISODES_PER_EVAL, LEARNING_RATE, rloss, sloss, fstep, fspace, istate)

In [None]:
train_params["divangle_fn"] = {'mlp/~/linear_0': {'b': True, 'w': True}}
train_params["divrate_fn"] = {'mlp/~/linear_0': {'b': False, 'w': False}}
rloss = eqx.filter_jit(partial(loss, metric_fn=v_loss, REINFORCE=True, metric_type='reward', GAMMA=0.90))
sloss = eqx.filter_jit(partial(loss, metric_fn=v_loss, metric_type='reward', REINFORCE=False, GAMMA=0.90))
loss_tn_angle, params_tn_angle, grads_tn_angle = train(key, params, train_params, EPOCHS, EPISODES_PER_UPDATE, EPISODES_PER_EVAL, LEARNING_RATE, rloss, sloss, fstep, fspace, istate)

In [None]:
save_data(params_tn_angle, loss_tn_angle, grads_tn_angle, "optimizations/divangle/", "vshape_2type_divangle", "schedule_0.002_0.0004")
save_data(params_tn, loss_tn, grads_tn, "optimizations/divangle/", "vshape_2type_divangle_divrate", "schedule_0.002_0.0004")
save_data(params_tn_rate, loss_tn_rate, grads_tn_rate, "optimizations/divangle/", "vshape_2type_divrate", "schedule_0.002_0.0004")

In [None]:
%ls ../data/optimizations/divangle/loss_tt

In [None]:
with open('../data/optimizations/divangle/loss_tt/vshape_2type_divangle_lrschedule_0.002_0.0004_epochs1000_episodes100_[64]_hidden', 'rb') as handle:
    loss_t_angle = pickle.load(handle)
with open('../data/optimizations/divangle/loss_tt/vshape_2type_divrate_lrschedule_0.002_0.0004_epochs1000_episodes100_[64]_hidden', 'rb') as handle:
    loss_t_rate = pickle.load(handle)
with open('../data/optimizations/divangle/loss_tt/vshape_2type_divangle_divrate_lrschedule_0.002_0.0004_epochs1000_episodes100_[64]_hidden', 'rb') as handle:
    loss_t = pickle.load(handle)


In [None]:
plt.plot(loss_tn, label="divangle + divrate", color='black', linestyle='solid');
plt.plot(loss_tn_rate, label="divrate only", color='maroon', linestyle='dashed');
plt.plot(loss_tn_angle, label="divangle only", color='darkblue', linestyle='dotted');
plt.ylabel("reward")
plt.xlabel("update step");
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left');

In [None]:
plt.plot(loss_tn)

In [None]:
opt_params = eqx.combine(params_tn[200], params)
sim_init, sim_step = simulation(fstep, opt_params, fspace)
fstate_opt, traj_opt = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True) 

In [None]:
draw_circles_divrate(get_state(100, traj_opt));

In [None]:
draw_circles_divangle(fstate_opt);

In [None]:
plt.hist(fstate_opt.divangle[fstate_opt.celltype == 1], color='maroon', label="type 1", alpha=0.7, bins=np.linspace(2.0, 5.0, 30));
plt.hist(fstate_opt.divangle[fstate_opt.celltype == 2], color='green', label="type 2", alpha=0.7, bins=np.linspace(2.0, 5.0, 30));
plt.xlabel("division angle")
plt.legend();

In [None]:
draw_circles_divrate(fstate_opt);

In [None]:
draw_circles_chem(fstate_opt, 1);

In [None]:
from jax import config
config.update("jax_debug_nans", True)

In [None]:
plt.plot(np.array(loss_t), 'r', label='reinforce')
plt.plot(np.array(loss_tn), 'g', label='simple')
plt.grid(alpha=.2)
plt.xlabel('Training Steps')
plt.ylabel('CV division rates');
plt.legend()

In [None]:
train_params["divangle_fn"] = {'mlp/~/linear_0': {'b': True, 'w': True}}
train_params["divrate_fn"] = {'mlp/~/linear_0': {'b': True, 'w': True}}
rloss = eqx.filter_jit(partial(loss, metric_fn=heart_loss, REINFORCE=True, metric_type='reward', GAMMA=0.90))
sloss = eqx.filter_jit(partial(loss, metric_fn=heart_loss, metric_type='reward', REINFORCE=False, GAMMA=0.90))
loss_tn, params_tn, grads_tn = train(key, params, train_params, EPOCHS, EPISODES_PER_UPDATE, EPISODES_PER_EVAL, LEARNING_RATE, rloss, sloss, fstep, fspace, istate)

In [None]:
plt.plot(loss_tn)
# plot vertical dashed line at 398
plt.axvline(x=390, color='r', linestyle='--')


In [None]:
def draw_circles_ctype(state, ax=None, cm=plt.cm.coolwarm, grid=False, **kwargs):
    
    if None == ax:
        ax = plt.axes()

    alive_cells = state.celltype > 0
    ntypes = np.max(state.celltype)
    #only usable for two cell types
    color = cm(np.float32(state.celltype-1)[alive_cells]/ntypes)

    for cell,radius,c in zip(state.position[alive_cells],state.radius[alive_cells],color):
        circle = plt.Circle(cell, radius=radius, color=c, alpha=.5, **kwargs)
        ax.add_patch(circle)
    
    
    ## calculate ax limits
    xmin = np.min(state.position[:,0][alive_cells])
    xmax = np.max(state.position[:,0][alive_cells])
    
    ymin = np.min(state.position[:,1][alive_cells])
    ymax = np.max(state.position[:,1][alive_cells])
    
    max_coord = max([xmax,ymax])+3
    min_coord = min([xmin,ymin])-3
    
    plt.xlim(min_coord,max_coord)
    plt.ylim(min_coord,max_coord)
    

    #scale x and y in the same way
    ax.set_aspect('equal', adjustable='box')

    #white bg color for ax
    ax.set_facecolor([1,1,1])

    if grid:
        ax.grid(alpha=.2)
    else:
        #remove axis spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

        plt.xticks([])
        plt.yticks([])


    background_color = [56 / 256] * 3        
    plt.gcf().patch.set_facecolor(background_color)
    plt.gcf().patch.set_alpha(0)

    plt.gcf().set_size_inches(8, 8)
    
    return plt.gcf(), ax

In [None]:
opt_params = eqx.combine(params_tn[390], params)
sim_init, sim_step = simulation(fstep, opt_params, fspace)
fstate_opt, traj_opt = sim_trajectory(istate, sim_init, sim_step, 100, key, history=False) 

In [None]:
draw_circles_ctype(fstate_opt);
mask_fn = heart_mask

#xlim = plt.gca().get_xlim()
xlim = (-10, 10)
xx, yy = np.meshgrid(np.linspace(*xlim,50),np.linspace(*xlim,50))

pos = np.vstack((xx.flatten(),yy.flatten())).T

#scatter pos colored by V_mask value as heatmap
plt.scatter(pos[:,0], pos[:,1], c=mask_fn(pos), cmap=plt.cm.gray_r, s=70, alpha=.1);

# Optimize with different conditions

In [None]:
# Load losses
with open('../data/optimizations/homogeneous_growth/loss_tt/div_and_sec_stress_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    loss_tt_stress = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/loss_tt/div_and_sec_chem_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    loss_tt_chem = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/loss_tt/div_and_sec_chemical_chemgrad_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    loss_tt_chem_stress = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/loss_tt/div_and_sec_chemical_chemgrad_stress_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    loss_tt_chem_stress_chemgrad = pickle.load(handle)

In [None]:
# Load gradients
with open('../data/optimizations/homogeneous_growth/grads_tt/div_and_sec_stress_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    grads_tt_stress = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/grads_tt/div_and_sec_chem_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    grads_tt_chem = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/grads_tt/div_and_sec_chemical_chemgrad_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    grads_tt_chem_stress = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/grads_tt/div_and_sec_chemical_chemgrad_stress_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    grads_tt_chem_stress_chemgrad = pickle.load(handle)

In [None]:
# Load params
with open('../data/optimizations/homogeneous_growth/params_tt/div_and_sec_stress_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    params_tt_stress = pickle.load(handle)
with open('../data/optimizations/homogeneous_growth/params_tt/div_and_sec_chemical_chemgrad_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
    params_tt_chem_stress = pickle.load(handle)
#with open('../data/optimizations/homogeneous_growth/params_tt/div_and_sec_chemical_chemgrad_stress_lr0.01_epochs100_episodes30_[8]_hidden', 'rb') as handle:
#    params_tt_chem_stress_chemgrad = pickle.load(handle)

In [None]:
from matplotlib.lines import Line2D
# Plot losses in loss_tt_chem
plt.plot(np.array(loss_tt_stress).T, label='stress', c='green');
plt.plot(np.array(loss_tt_chem).T, label='chemicals', c='darkorange');
plt.plot(np.array(loss_tt_chem_stress).T, label='chemicals and chemgrad', c='maroon');
plt.plot(np.array(loss_tt_chem_stress_chemgrad).T, label='chemicals, stress and chemgrad', c='darkblue');
plt.xlabel("grad descent step")
plt.ylabel("loss")
# Make legend with three color lines and labels
plt.legend([Line2D([0], [0], color='green', lw=4),
                Line2D([0], [0], color='maroon', lw=4),
                Line2D([0], [0], color='darkblue', lw=4)], 
                ['stress', 'chem + stress', 'chem + stress + chemgrad'], bbox_to_anchor=(1.1, 1.05));

In [None]:
# Make range of N numbers from 10th to 90th percentile of value
def make_range(N, traj_vals):
    return np.linspace(np.percentile(traj_vals.flatten(), 10), np.percentile(traj_vals.flatten(), 90), N) 

In [None]:
from jax.nn import leaky_relu
n_hidden = HID_HIDDEN
#if transform_mlp_out is None:
#    transform_mlp_out = lambda x: x
def _hidden_nn(in_fields):
    mlp = hk.nets.MLP(n_hidden+[params['hidden_state_size']],
                        activation=leaky_relu,
                        w_init=hk.initializers.Orthogonal(),
                        activate_final=False
                        )
        
    out = mlp(in_fields)
    out = transform_mlp_out(out)

    return out
_hidden_nn = hk.without_apply_rng(hk.transform(_hidden_nn))

In [None]:
n_hidden_div = []
from jax.nn import leaky_relu, softplus, sigmoid
transform_mlp_out = sigmoid
def _div_nn(in_fields):
    mlp = hk.nets.MLP(n_hidden_div+[3],
                        activation=leaky_relu,
                      activate_final=False
                        )
        
    out = mlp(in_fields)
    out = transform_mlp_out(out)

    return out

_div_nn = hk.without_apply_rng(hk.transform(_div_nn))

In [None]:
transform_mlp_out=sigmoid
n_hidden_sec=[]
def _sec_nn(in_fields):
    mlp = hk.nets.MLP(n_hidden_sec+[params['n_chem']],
                        activation=leaky_relu,
                        activate_final=False
                        )
        
    out = mlp(in_fields)
    out = transform_mlp_out(out)
    return out
_sec_nn = hk.without_apply_rng(hk.transform(_sec_nn))

In [None]:
# Plot div_nn output for different values of chemicals and stresses
def get_div_output(key, params_to_use, traj, type='stress'):
    #fstate, traj = run_simulation(key, params_to_use.copy(), train_params, n_hidden_div=div_hidden, n_hidden_sec=sec_hidden, use_state_fields_div=use_state_fields_div, use_state_fields_sec=use_state_fields_sec)
    #chem1_range = make_range(20, fstate.chemical[-1, :, 0])
    #chem2_range = make_range(20, fstate.chemical[-1, :, 1])
    if type=='stress':
        stress_range = make_range(20, traj[0].stress[:])
        input_data = np.reshape(stress_range, (20, 1))
        plotting_data = stress_range
    elif type=='chem':
        chem1 = np.median(traj[0].chemical[:, :, 1].flatten())
        chem2 = make_range(20, traj[0].chemical[:, :, 0])
        stress_range = make_range(20, traj[0].stress[:])
        x, y = np.meshgrid(chem2, stress_range)
        plotting_data = (x, y)
        input_data = np.vstack((np.repeat(chem1, 400), x.flatten(), np.repeat(chem1_x, 400), y.flatten())).T
    else:
        chem1 = np.median(traj[0].chemical[:, :, 1].flatten())
        chem2 = make_range(20, traj[0].chemical[:, :, 0])
        chem1_x = np.median(traj[0].chemgrad[:, :, 0])
        chem1_y = np.median(traj[0].chemgrad[:, :, 1])
        chem2_x = np.median(traj[0].chemgrad[:, :, 2])
        chem2_y = np.median(traj[0].chemgrad[:, :, 3])
        stress_range = make_range(20, traj[0].stress[:])
        x, y = np.meshgrid(chem2, stress_range)
        plotting_data = (x, y)
        input_data = np.vstack((np.repeat(chem1, 400), x.flatten(), np.repeat(chem1_x, 400), np.repeat(chem1_y, 400), np.repeat(chem2_x, 400), np.repeat(chem2_y, 400), y.flatten())).T

    #input_data = np.vstack(y.flatten()).T
    #input_data = np.vstack((np.repeat(chem1, 400), x.flatten(), y.flatten())).T
    _ = _hidden_nn.init(key, np.ones(input_data.shape))
    hidden_output = _hidden_nn.apply(params_to_use["hidden_fn"].copy(), input_data)
    _ = _div_nn.init(key, np.ones(np.array(hidden_output).shape))
    div_output = _div_nn.apply(params_to_use["div_fn"].copy(), hidden_output).flatten()
    return plotting_data, div_output

In [None]:
# Plot div_nn output for different values of chemicals and stresses
def get_sec_output(key, params_to_use, traj, type='stress'):
    #fstate, traj = run_simulation(key, params_to_use.copy(), train_params, n_hidden_div=div_hidden, n_hidden_sec=sec_hidden, use_state_fields_div=use_state_fields_div, use_state_fields_sec=use_state_fields_sec)
    #chem1_range = make_range(20, fstate.chemical[-1, :, 0])
    #chem2_range = make_range(20, fstate.chemical[-1, :, 1])
    if type=='stress':
        stress_range = make_range(20, traj[0].stress[:])
        input_data = np.reshape(stress_range, (20, 1))
        plotting_data = stress_range
    elif type=='chem':
        chem1 = np.median(traj[0].chemical[:, :, 1].flatten())
        chem2 = make_range(20, traj[0].chemical[:, :, 0])
        stress_range = make_range(20, traj[0].stress[:])
        x, y = np.meshgrid(chem2, stress_range)
        plotting_data = (x, y)
        input_data = np.vstack((np.repeat(chem1, 400), x.flatten(), np.repeat(chem1_x, 400), y.flatten())).T  
    else:     
        chem1 = np.median(traj[0].chemical[:, :, 1].flatten())
        chem2 = make_range(20, traj[0].chemical[:, :, 1])
        chem1_x = np.median(traj[0].chemgrad[:, :, 0])
        chem1_y = np.median(traj[0].chemgrad[:, :, 1])
        chem2_x = np.median(traj[0].chemgrad[:, :, 2])
        chem2_y = np.median(traj[0].chemgrad[:, :, 3])
        stress_range = make_range(20, traj[0].stress[:])
        x, y = np.meshgrid(chem2, stress_range)
        input_data = np.vstack((np.repeat(chem1, 400), x.flatten(), np.repeat(chem1_x, 400), np.repeat(chem1_y, 400), np.repeat(chem2_x, 400), np.repeat(chem2_y, 400), y.flatten())).T
        plotting_data = (x, y)
    _ = _hidden_nn.init(key, np.ones(input_data.shape))
    hidden_output = _hidden_nn.apply(params_to_use["hidden_fn"], input_data)
    _ = _sec_nn.init(key, np.ones(np.array(hidden_output).shape))
    sec_output = _sec_nn.apply(params_to_use["sec_fn"].copy(), hidden_output).flatten()   
    return plotting_data, sec_output

In [None]:
fstate.hidden_state.shape

In [None]:
params_t["hidden_fn"]['mlp/~/linear_0']['w'].shape

In [None]:
## check result of entropy optimization
# Plot different growth functions learned using only stress
params_t = eqx.combine(params_tn[-1], params)
# Update hidden state network for running simulation
use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=False,divrate=False,hidden_state=False, divangle=False, key=False)
hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
_, key = random.split(key)
_, _ = hid_init(istate, key)
fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
]    
sim_init, sim_step = simulation(fstep,params_t, fspace)
state, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True)
x, div_output = get_div_output(key, params_t, traj, type='stress')
plt.plot(x, div_output, color='gray');
plt.xlabel("stress")
plt.ylabel("division output")
plt.title("optimized, only stress");

In [None]:
## ONLY STRESS
# Plot different growth functions learned using only stress
for i, params_t in enumerate(params_tt_stress):
    params_t = eqx.combine(params_t[-1], params)
    # Update hidden state network for running simulation
    use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=False,chemgrad=False,field=False,stress=True,divrate=False,hidden_state=False,key=False)
    hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
    _, key = random.split(key)
    _, _ = hid_init(istate, key)
    fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
    ]    
    sim_init, sim_step = simulation(fstep,params_t, fspace)
    state, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True)
    x, div_output = get_div_output(key, params_t, traj, type='stress')
    plt.plot(x, div_output, color='gray');
plt.xlabel("stress")
plt.ylabel("division output")
plt.title("optimized, only stress");

In [None]:
for i, params_t in enumerate(params_tt_stress):
    params_t = eqx.combine(params_t[-1], params)
    # Update hidden state network for running simulation
    use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=False,chemgrad=False,field=False,stress=True,divrate=False,hidden_state=False,key=False)
    hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
    _, key = random.split(key)
    _, _ = hid_init(istate, key)
    fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
    ]    
    sim_init, sim_step = simulation(fstep,params_t, fspace)
    fstate, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True)
    x, sec_output = get_sec_output(key, params_t, traj, type='stress')
    # Chemical 1
    plt.plot(x, sec_output.reshape(20, 2)[:, 0], color='maroon');
    # Chemical 2
    plt.plot(x, sec_output.reshape(20, 2)[:, 1], color='navy');
plt.xlabel("stress")
plt.ylabel("secretion of chem 1")
plt.title("optimized, only stress");

In [None]:
# CHEMICALS + CHEMGRAD
# Make ten subplots plotting the growth function for every set of parameters in params_tt_stress
fig, axs = plt.subplots(2, 5, figsize=(20, 8))
axs = axs.flatten()
for i, params_t in enumerate(params_tt_stress):
    params_t = eqx.combine(params_t[-1], params)
    x, div_output = get_div_output(key, params_t, traj, type='stress')
    for j in range(20):
        axs[i].plot(y.T[j], div_output.reshape(20, 20).T[j], c=plt.cm.viridis(j/20));
        axs[i].set_xlabel("stress")
    axs[i].set_ylabel("division output")
    sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(vmin=x[0][0], vmax=x[0][-1]))
    sm._A = []
    cbar = plt.colorbar(sm, shrink=0.7, alpha=.5, ax=axs[i]) # rule of thumb
    cbar.set_label('Chemical', labelpad=20)

In [None]:
# STRESS + CHEMICALS + CHEMGRAD
# Make ten subplots plotting the growth function for every set of parameters in params_tt_stress
fig, axs = plt.subplots(2, 5, figsize=(25, 8))
axs = axs.flatten()
for i, params_t in enumerate(params_tt_chem_stress_chemgrad):
    params_t = eqx.combine(params_t[-1], params)
    use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=True,divrate=False,hidden_state=False,key=False)
    hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
    _, key = random.split(key)
    _, _ = hid_init(istate, key)
    fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
    ]    
    sim_init, sim_step = simulation(fstep,params_t, fspace)
    state, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True)
    loss = np.power(np.std(state.divrate)/np.mean(state.divrate), 2)
    (x, y), div_output = get_div_output(key, params, traj, type='chemgrad')
    for j in range(20):
        axs[i].plot(y.T[j], div_output.reshape(20, 20).T[j], c=plt.cm.viridis(j/20));
        axs[i].set_title(f"loss: {loss:.2f}")
fig.supxlabel('stress')
fig.supylabel('division output')
sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(vmin=x[0][0], vmax=x[0][-1]))
sm._A = []
cbar = plt.colorbar(sm, shrink=0.7, alpha=.5, ax=axs[i]) # rule of thumb
cbar.set_label('Chemical', labelpad=20)
fig.suptitle("optimized, stress + chemicals + chemgrad");
plt.tight_layout()

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(25, 8))
axs = axs.flatten()
for i, params_t in enumerate(params_tt_chem_stress_chemgrad):
    params_t = eqx.combine(params_t[-1], params)
    # Update hidden state network for running simulation
    use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=True,divrate=False,hidden_state=False,key=False)
    hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
    _, key = random.split(key)
    _, _ = hid_init(istate, key)
    fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
    ]    
    sim_init, sim_step = simulation(fstep,params_t, fspace)
    fstate, traj = sim_trajectory(istate, sim_init, sim_step, 100, key, history=True)
    (x, y), sec_output = get_sec_output(key, params_t.copy(), traj, type='chemgrad')
    # Chemical 1
    for j in range(20):
        axs[i].plot(y.T[j], sec_output.reshape(20, 20, 2)[:, :, 0].T[j, :], c=plt.cm.viridis(j/20));
    # Chemical 2
    for j in range(20):
        axs[i].plot(y.T[j], sec_output.reshape(20, 20, 2)[:, :, 1].T[j, :], c=plt.cm.coolwarm(j/20));
fig.supxlabel("stress")
fig.supylabel("secretion, chemical 1")
fig.suptitle("optimized, stress + chem + chemgrad");
plt.tight_layout()
sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(vmin=x[0][0], vmax=x[0][-1]))
sm2 = plt.cm.ScalarMappable(cmap=plt.cm.coolwarm, norm=plt.Normalize(vmin=x[1][0], vmax=x[1][-1]))

sm._A = []
sm2._A = []

cbar = plt.colorbar(sm, shrink=0.7, alpha=.5) # rule of thumb
cbar2 = plt.colorbar(sm2, shrink=0.7, alpha=.5) # rule of thumb

cbar.set_label('Chemical 1', labelpad=10)
cbar2.set_label('Chemical 2', labelpad=10)

In [None]:
params_t = eqx.combine(params_tt_chem_stress[0][-1], params)
# Update hidden state network for running simulation
use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=False,divrate=False,hidden_state=False,key=False)
hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh)
_, key = random.split(key)
_, _ = hid_init(istate, key)
fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
    ]    
sim_init, sim_step = simulation(fstep,params_t, fspace)
fstate, traj = sim_trajectory(istate, sim_init, sim_step, 200, key, history=True)

In [None]:
draw_circles_chem(fstate);

In [None]:
draw_circles_chem(fstate, chem=1);

# Generalization to more timesteps

In [None]:
colors = ['green', 'maroon', 'darkblue']
fields = [CellState(position=False, celltype=False, radius=False, chemical=False,chemgrad=False,field=False,stress=True,divrate=False,hidden_state=False,key=False),
          CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=False,divrate=False,hidden_state=False,key=False),
          CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=True,divrate=False,hidden_state=False,key=False)]
for k, params_list in enumerate([params_tt_stress, params_tt_chem_stress, params_tt_chem_stress_chemgrad]):
    for i, params_t in enumerate(params_list):
        params_t = eqx.combine(params_t[-1], params)
        # Update hidden state network for running simulation
        hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,fields[k],train=True,transform_mlp_out=tanh,)
        _, key = random.split(key)
        _, _ = hid_init(istate, key)
        fstep = [
        S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
        S_chemical_gradients,S_fixed_chemfield,S_set_stress,
        partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
        partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
        ]    
        sim_init, sim_step = simulation(fstep,params_t, fspace)
        fstate, traj = sim_trajectory(istate, sim_init, sim_step, 200, key, history=True)
        loss = np.power(np.std(traj[0].divrate, axis=1)/np.mean(traj[0].divrate, axis=1), 2)
        plt.plot(loss, color=colors[k]);
plt.xlabel("time")
plt.ylabel("loss")
plt.axvline(x=100, color='black', linestyle='--')
# Make legend with three color lines and labels
plt.legend([Line2D([0], [0], color='green', lw=4),
                Line2D([0], [0], color='maroon', lw=4),
                Line2D([0], [0], color='darkblue', lw=4)], 
                ['stress', 'chem + chemgrad', 'chem + stress + chemgrad'], bbox_to_anchor=(1.1, 1.05));

# Generalization to other chemical fields

In [None]:
def change_field(chem_b, params_t, istate, key):
    params_t["chem_gamma"] = chem_b
    # Update hidden state network for running simulation
    hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
    fstep = [
    S_cell_division, S_grow_cells, partial(S_mech_morse_relax, morse_eps_sigma='twotypes'), partial(S_ss_chemfield, sec_fn=sec_nn_apply, n_iter=3),
    S_chemical_gradients,S_fixed_chemfield,S_set_stress,
    partial(S_hidden_state, dhidden_fn=eqx.filter_jit(hid_nn_apply), state_decay=.0),
    partial(S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply))
    ] 
    _, key = random.split(key)
    _, _ = hid_init(istate, key)
    sim_init, sim_step = simulation(fstep,params_t, fspace)
    fstate, traj = sim_trajectory(istate, sim_init, sim_step, 100, key)
    return np.power(np.std(fstate.divrate)/np.mean(fstate.divrate), 2)
change_field_vmap = vmap(vmap(change_field, (0, None, None, None)), (None,0, None, None))
bs = np.array([np.power(10, x) for x in np.linspace(-3, 0, 20)])

In [None]:
from jax.tree_util import tree_flatten, tree_unflatten
def tree_stack(trees):
    """Takes a list of trees and stacks every corresponding leaf.
    For example, given two trees ((a, b), c) and ((a', b'), c'), returns
    ((stack(a, a'), stack(b, b')), stack(c, c')).
    Useful for turning a list of objects into something you can feed to a
    vmapped function.
    """
    leaves_list = []
    treedef_list = []
    for tree in trees:
        leaves, treedef = tree_flatten(tree)
        leaves_list.append(leaves)
        treedef_list.append(treedef)

    grouped_leaves = zip(*leaves_list)
    result_leaves = [np.stack(l) for l in grouped_leaves]
    return treedef_list[0].unflatten(result_leaves)

In [None]:
%%time
# Vary the steepness of the chemical field gradient and see if learned mechanisms generalize.
key = random.PRNGKey(0)
use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=False,chemgrad=False,field=False,stress=True,divrate=False,hidden_state=False,key=False)
params_ts = [eqx.combine(params_tt_stress[i][-1], params) for i in range(10)]
p_stack = tree_stack(params_ts)
# Update hidden state network for running simulation
#hid_init, hid_nn_apply = hidden_state_nn(params_t.copy(),train_params,HID_HIDDEN,use_state_fields,train=True,transform_mlp_out=tanh,)
#_, key = random.split(key)
#_, _ = hid_init(istate, key)
losses_stress = change_field_vmap(bs, p_stack, istate, key)

In [None]:
%%time
# Vary the steepness of the chemical field gradient and see if learned mechanisms generalize.
use_state_fields = CellState(position=False, celltype=False, radius=False, chemical=True,chemgrad=True,field=False,stress=False,divrate=False,hidden_state=False,key=False)
params_t = eqx.combine(params_tt_chem_stress[0][-1], params)
# Update hidden state network for running simulation
losses_chem = change_field_vmap(bs, params_t, istate, key)

In [None]:
plt.plot(bs, losses_stress, label='stress')
plt.plot(bs, losses_chem, label='chem')
# Make dotted black vertical line at x=0.4
plt.axvline(x=0.4, color='black', linestyle='--')
plt.xlabel("chem_b")
plt.ylabel("loss");
plt.xscale ('log');
plt.legend();

## Disruptions in field

In [None]:
from jax_md import dataclasses

In [None]:
def perturb_field(state, params, fspace):
    # Add high stress at a random patch 
    x1, x2, y1, y2 = 2.0, 4.5, 2.0, 4.5
    mask = ((state.position[:, 0] > x1) & (state.position[:, 0] < x2)) \
    & ((state.position[:, 1] > y1) & (state.position[:, 1] < y2))
    field_vals = np.where(mask, 50.0, state.field)
    new_state = dataclasses.replace(state, field=field_vals)
    return new_state

In [None]:
# Initial state with perturbation.
perturbed_state = perturb_field(istate, params, fspace)
visualization.draw_circles(perturbed_state, perturbed_state.field, min_val=0.0, max_val=50.0);

In [None]:
fstep_perturb = [
    cell_division.S_cell_division,
    cell_growth.S_grow_cells,
    mechanical.S_mechmin_twotypes,
    partial(secdiff.S_ss_chemfield, sec_fn=sec_nn_apply),
    chemical.S_fixed_chemfield,
    perturb_field,
    divrates.S_set_stress,
    partial(divrates_chem.S_set_divrate, divrate_fn=eqx.filter_jit(div_nn_apply)), 
]

In [None]:
# Simulation with unoptimized parameters
sim_init, sim_step = simulation.simulation(fstep_perturb, params, fspace)
fstate, state_all = simulation.sim_trajectory(perturbed_state, sim_init, sim_step, key=key, history=True)
visualization.draw_circles(fstate, fstate.divrate, min_val=0.0, max_val=0.25);

In [None]:
visualization.draw_circles(fstate, fstate.chemical[:, 0], min_val=0.0, max_val=0.3);

In [None]:
visualization.draw_circles(fstate, fstate.chemical[:, 1], min_val=0.0, max_val=15);

In [None]:
visualization.draw_circles(fstate, fstate.stress, min_val=-240, max_val=1.00);

In [None]:
# Simulation with optimized parameters
opt_params = eqx.combine(params_t[-1], params)
sim_init, sim_step = simulation.simulation(fstep_perturb, opt_params, fspace)
fstate_opt, state_all_opt = simulation.sim_trajectory(perturbed_state, sim_init, sim_step, key=key, history=True, ncells_add=200)
visualization.draw_circles(fstate_opt, fstate_opt.divrate, min_val=0.0, max_val=0.25);

In [None]:
visualization.draw_circles(fstate_opt, fstate_opt.chemical[:, 0], min_val=0.0, max_val=0.3);

In [None]:
visualization.draw_circles(fstate_opt, fstate_opt.chemical[:, 1], min_val=0.0, max_val=15);

In [None]:
visualization.draw_circles(fstate_opt, fstate_opt.stress, min_val=-240, max_val=1.00);

# Noise in chemical field