In [90]:
import jax
import jax.numpy as np

import numpy as onp

from typing import Any, Callable
from jaxtyping import Array

import equinox as eqx

import matplotlib.pyplot as plt

In [45]:
key = jax.random.PRNGKey(0)

# Cell State

In [149]:
import jax_md
#import abc

class BaseCellState(eqx.Module):
    '''
    Module containing the basic features of a system state.

    '''

    # METHODS
    displacement:   jax_md.space.DisplacementFn
    shift:          jax_md.space.ShiftFn

    # STATE
    position:   Array
    celltype:   Array
    radius:     Array
    divrate:    Array
    key:        Array


    def __init__(self, n_dim=2):
        '''
        Intializes a CellState with no cells (empty data structures, with correct shapes).

        Parameters
        ----------
        n_dim: int
            Number of spatial dimensions.

        '''

        assert n_dim == 2 or n_dim == 3, 'n_dim must be 2 or 3'

        disp, shift = jax_md.space.free()
        

        self.position  =   np.empty(shape=(0, n_dim),              dtype=np.float32)
        self.celltype  =   np.empty(shape=(0,),                    dtype=np.int8)
        self.radius    =   np.empty(shape=(0,),                    dtype=np.float32)
        self.divrate   =   np.empty(shape=(0,),                    dtype=np.float32)
        self.key       =   None

        self.displacement  =   disp
        self.shift         =   shift


In [150]:
class CellState(BaseCellState):
    chemical:   Array
    chemgrad:   Array
    hidden_state: Array

    def __init__(self, n_dim=2, n_chem=1, hidden_size=10):

        super().__init__(n_dim)
        
        self.chemical = np.empty(shape=(0, n_chem),             dtype=np.float32)
        self.chemgrad = np.empty(shape=(0, int(n_dim*n_chem)),  dtype=np.float32)
        self.hidden_state = np.empty(shape=(0, hidden_size),       dtype=np.float32)