In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

from functools import partial
from typing import Callable, Tuple

import diffrax
import matplotlib.pyplot as plt
from jaxtyping import Array, ArrayLike

In [None]:
def reaction_diffusion_ode(
    t: ArrayLike, y: Array, args: Tuple[Callable, Callable, int, int]
) -> Array:
    reaction_ode, diffusion_ode, n, m = args

    reaction_dy = jax.vmap(
        partial(reaction_ode, u=jnp.zeros(2)), in_axes=0, out_axes=0
    )(y.reshape(n * m, 4)).reshape(n, m, 4)
    diffusion_dy = diffusion_ode(y)

    return reaction_dy + diffusion_dy


def diffusion_ode(y: Array) -> Array:
    # Derive kernel via finite-differences

    # Get distances from center element
    kernel_x = jnp.linspace(-1, 1, 3).reshape(1, -1).repeat(3, axis=0)
    kernel_y = jnp.linspace(-1, 1, 3).reshape(-1, 1).repeat(3, axis=1)
    kernel_d = -1 / jnp.sqrt(jnp.square(kernel_x) + jnp.square(kernel_y))
    kernel_d = kernel_d.at[1, 1].set(0)

    # Center element has weight equal to the sum of all other elements
    kernel_s = jnp.sum(kernel_d)
    kernel = kernel_d.at[1, 1].set(-kernel_s)

    # Convolve reaction area with diffusion kernel
    d = jnp.asarray([0.01, 0.1, 1.0, 1.0]) * 0.1
    y = jax.vmap(
        partial(jnp.pad, pad_width=((1, 1), (1, 1)), mode="edge"),
        in_axes=-1,
        out_axes=-1,
    )(y)
    j = -d * jax.vmap(
        partial(jax.scipy.signal.convolve, mode="valid"),
        in_axes=(-1, None),
        out_axes=-1,
    )(y, kernel)

    return j


def fibrosis_ode(y: Array, u: Array) -> Array:
    k = [None] * 14

    k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
    k[1] = 0.8  # lambda2=0.8/day
    k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
    k[3] = 1e6  # carrying capacity: 10^6 cells
    k[4] = 2  # growth factor degradation: gamma=2/day
    k[5] = (
        240 * 1440
    )  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
    k[6] = (
        470 * 1440
    )  # beta1=470 molecules/cell/min                                ---- beta_1
    k[7] = (
        70 * 1440
    )  # beta2=70 molecules/cell/min                                 ---- beta_2
    k[8] = (
        940 * 1440
    )  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
    k[9] = (
        510 * 1440
    )  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
    k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
    k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
    k[12] = 0.0  # 120 inflammation pulse
    k[13] = 1e6

    # PDGF antibody
    k_pdfg_ab = 1 * 1440  # 1 / (min * molecule)
    pdgf_ab_deg = -k_pdfg_ab * y[3] * u[0]

    # CSF1 antibody
    k_csf1_ab = 1 * 1440  # 1 / (min * molecule)
    csf1_ab_deg = -k_csf1_ab * y[2] * u[1]

    dy = [None] * 4

    dy[0] = y[0] * (
        k[0] * y[3] / (k[10] + y[3]) * (1 - y[0] / k[3]) - k[2]
    )  # Fibrobasts
    dy[1] = y[1] * (k[1] * y[2] / (k[11] + y[2]) - k[2]) + k[12]  # Mph
    dy[2] = (
        csf1_ab_deg + k[6] * y[0] - k[8] * y[1] * y[2] / (k[11] + y[2]) - k[4] * y[2]
    )  # CSF
    dy[3] = (
        pdgf_ab_deg
        + k[7] * y[1]
        + k[5] * y[0]
        - k[9] * y[0] * y[3] / (k[10] + y[3])
        - k[4] * y[3]
    )  # PDGF

    return jnp.stack(dy, axis=-1)


def zero_ode(y: Array, u: Array) -> Array:
    return jnp.zeros_like(y)

In [None]:
y0 = jnp.ones((16, 16, 4))
y0 = y0.at[8, 8, 1].set(1e8)

sol = diffrax.diffeqsolve(
    terms=diffrax.ODETerm(reaction_diffusion_ode),
    solver=diffrax.Kvaerno5(),
    t0=0.0,
    t1=100.0,
    dt0=0.1,
    y0=y0,
    args=(fibrosis_ode, diffusion_ode, 16, 16),
    saveat=diffrax.SaveAt(dense=True),
    stepsize_controller=diffrax.PIDController(
        rtol=1e-5, atol=1e-5, pcoeff=0.3, icoeff=0.3
    ),
    max_steps=10000,
)

In [None]:
sol.evaluate(100.0)[8, 8]

In [None]:
plt.figure()
plt.imshow(sol.evaluate(100.0)[..., 1], vmin=0, vmax=1e5)
plt.show()