In [1]:
import jax
import jax.numpy as np

from typing import Any, Callable
from jaxtyping import Array

import equinox as eqx

In [3]:
key = jax.random.PRNGKey(0)

In [87]:
class Substates(eqx.Module):
    s : eqx.nn.StateIndex
    a : float
    f : Callable

    def __init__(self):
        state_init = dict([(1,np.zeros((1, 1))), (2,np.ones((1, 1)))])
        self.s = eqx.nn.StateIndex(state_init)
        self.a = 2.0
        self.f = lambda x: x + 1.0

    def __call__(self, x: Array) -> Array:
        return self.a * self.s(x)

In [88]:
ss, state = eqx.nn.make_with_state(Substates)()


#eqx.nn.State(ss).substate({1: 0})

In [4]:
N = 100

#define set of random positions in 2d
pos = 10*jax.random.uniform(key, (N, 2))

In [30]:
#define chemical degradations and update laplacian
K = .1*np.eye(N)

#diffusion coefficient
D = 2e-2

#define random production rates from exponential distribution
P = jax.random.uniform(key, (N,))

In [31]:
def ss_chemfield(pos, D, K, P):
    """
    Calculate the steady state chemical field for a given set of positions and parameters

    Parameters
    ----------
    pos : array
        Nx2 array of positions
    D : float
        Diffusion coefficient
    K : float
        Decay rate
    P : array
        Nx1 array of production rates

    Returns
    -------
    x : array
        Nx1 array of steady state concentrations
    """

    #calculate all pairwise distances using pure jax
    dists = np.sqrt(np.sum((pos[:, None, :] - pos[None, :, :])**2, axis=-1)) + np.diag(np.ones(N))

    #adjacency matrix
    diag_mask = np.ones((N, N)) - np.eye(N)
    A = (diag_mask)*(1/dists)**2

    #graph laplacian
    L = np.diag(np.sum(A, axis=0)) - A

    #update laplacian
    L = D*L + K

    #solve for steady state
    x = np.linalg.solve(L, P)

    return x

In [32]:
x = ss_chemfield(pos, D, K, P)

In [33]:
x

Array([3.9190333, 5.5560317, 4.945663 , 5.737023 , 3.8454523, 5.4835463,
       5.0753565, 4.443593 , 5.3435354, 4.5845647, 4.697924 , 4.529473 ,
       6.949196 , 6.030072 , 4.461426 , 6.1976566, 4.865279 , 3.7751825,
       5.070111 , 5.4399376, 3.816938 , 6.1518393, 5.331742 , 5.580075 ,
       4.9043703, 5.4377904, 5.190167 , 4.3016667, 6.622356 , 3.779549 ,
       5.817896 , 3.7215676, 5.0629745, 5.5592227, 4.958537 , 6.0261908,
       5.56019  , 4.257744 , 5.3171906, 5.061906 , 5.7816963, 4.9973135,
       5.989546 , 6.1399097, 6.73348  , 6.0398927, 5.138991 , 3.9823284,
       4.315284 , 5.9256234, 5.3551397, 6.018447 , 6.023216 , 5.254752 ,
       5.1014843, 4.7470174, 5.334473 , 5.46376  , 5.368791 , 4.7257895,
       5.5886235, 6.0267735, 5.493993 , 6.0267115, 5.5161595, 6.7781253,
       5.525284 , 5.1957827, 4.16076  , 6.42015  , 4.8727584, 5.68922  ,
       6.910315 , 5.4699125, 5.232961 , 4.2446747, 5.4115   , 6.52598  ,
       4.392853 , 4.932646 , 6.0187516, 5.217474 , 