In [None]:
!pip install diffrax

In [None]:
!pip install optax



# 2D reaction diffusion model

In [None]:
from typing import Tuple, Callable
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from jaxtyping import Array, Float

jax.config.update("jax_enable_x64", True)

# Define the 2D grid for the reaction-diffusion system
class SpatialDiscretisation(eqx.Module):
    x0: float = eqx.field(static=True)
    x_final: float = eqx.field(static=True)
    y0: float = eqx.field(static=True)
    y_final: float = eqx.field(static=True)
    vals: Tuple[Float[Array, "nx ny"], Float[Array, "nx ny"]]  # vals[0] for u, vals[1] for v

    @classmethod
    def discretise_fn(cls, x0: float, x_final: float, y0: float, y_final: float, nx: int, ny: int, fn: Callable):
        x = jnp.linspace(x0, x_final, nx)
        y = jnp.linspace(y0, y_final, ny)
        xx, yy = jnp.meshgrid(x, y, indexing="ij")
        u, v = fn(xx, yy)
        return cls(x0, x_final, y0, y_final, (u, v))

    @property
    def δx(self):
        return (self.x_final - self.x0) / (self.vals[0].shape[0] - 1)

    @property
    def δy(self):
        return (self.y_final - self.y0) / (self.vals[0].shape[1] - 1)

def laplacian_2d(y: SpatialDiscretisation) -> Tuple[jnp.ndarray, jnp.ndarray]:
    # Compute the 2D Laplacian for both u and v fields with no-flux boundary conditions
    u, v = y.vals

    # Compute the Laplacian for the interior points
    u_xx = (jnp.roll(u, shift=1, axis=0) + jnp.roll(u, shift=-1, axis=0) - 2 * u) / (y.δx**2)
    u_yy = (jnp.roll(u, shift=1, axis=1) + jnp.roll(u, shift=-1, axis=1) - 2 * u) / (y.δy**2)
    v_xx = (jnp.roll(v, shift=1, axis=0) + jnp.roll(v, shift=-1, axis=0) - 2 * v) / (y.δx**2)
    v_yy = (jnp.roll(v, shift=1, axis=1) + jnp.roll(v, shift=-1, axis=1) - 2 * v) / (y.δy**2)

    # Combine x and y second derivatives
    Δu = u_xx + u_yy
    Δv = v_xx + v_yy

    # Apply no-flux boundary conditions by setting boundary derivatives to zero
    Δu = Δu.at[0, :].set(Δu[1, :])  # Top boundary
    Δu = Δu.at[-1, :].set(Δu[-2, :])  # Bottom boundary
    Δu = Δu.at[:, 0].set(Δu[:, 1])  # Left boundary
    Δu = Δu.at[:, -1].set(Δu[:, -2])  # Right boundary

    Δv = Δv.at[0, :].set(Δv[1, :])  # Top boundary
    Δv = Δv.at[-1, :].set(Δv[-2, :])  # Bottom boundary
    Δv = Δv.at[:, 0].set(Δv[:, 1])  # Left boundary
    Δv = Δv.at[:, -1].set(Δv[:, -2])  # Right boundary

    return Δu, Δv

# Define model
def vector_field(t, y: SpatialDiscretisation, args) -> SpatialDiscretisation:
    u, v = y.vals
    Δu, Δv = laplacian_2d(y)
    du_dt = Du * Δu + alpha - u + u**2 * v
    dv_dt = Dv * Δv + beta - u**2 * v
    return SpatialDiscretisation(y.x0, y.x_final, y.y0, y.y_final, (du_dt, dv_dt))

def initial_conditions(x, y):
    u = jnp.ones((x.shape[0], y.shape[1]))  # Set u to all 1s
    key = random.PRNGKey(0)
    # Add slight random perturbations to v to break symmetry and initiate pattern formation
    v = 0.05 * random.uniform(key, shape=(x.shape[0], y.shape[1]), minval=0.0, maxval=1.0)
    return u, v


# Parameters for the Schnakenberg model
Du, Dv = 0.1, 1.0  # Diffusion coefficients for u and v
alpha, beta = 0.1, 0.8  # Reaction parameters

# Spatial and temporal discretisation parameters
x0, x_final = -50, 50
y0, y_final = -50, 50
nx, ny = 100, 100  # Number of grid points in x and y directions
dx = (x_final - x0)/(nx-1)
dy = dx
t0, t_final = 0, 1000.0
δt = 0.01
# Initial condition
y0 = SpatialDiscretisation.discretise_fn(x0, x_final, y0, y_final, nx, ny, initial_conditions)
u0 = y0.vals[0]
v0 = y0.vals[1]

# Plot initial conditions
fig, axs = plt.subplots(1, 2, figsize=(6, 3))
axs[0].imshow(u0,origin="lower",cmap="gray", vmin=0, vmax=1)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")
axs[0].set_title("Initial u concentration")
axs[1].imshow(v0,origin="lower",cmap="gray", vmin=0, vmax=1)
axs[1].set_xlabel("x")
axs[1].set_ylabel("y")
axs[1].set_title("Initial v concentration")
plt.tight_layout()
plt.show()

# Set up solver - use diffrax
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 100))
rtol = 1e-5
atol = 1e-5
stepsize_controller = diffrax.PIDController(
    pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol, dtmax=0.1
)
solver = diffrax.Tsit5()

# Simulation
term = diffrax.ODETerm(vector_field)
sol = diffrax.diffeqsolve(
    term,
    solver,
    t0=t0,
    t1=t_final,
    dt0=δt,
    y0=y0,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    max_steps=None,
)

# Extract u and v values from each SpatialDiscretisation object in sol.ys
u_vals = jax.vmap(lambda y: y.vals[0])(sol.ys)
v_vals = jax.vmap(lambda y: y.vals[1])(sol.ys)

# Plot simulation results
fig, (ax_u, ax_v) = plt.subplots(1, 2, figsize=(10, 5))
cax_u = ax_u.imshow(u_vals[0], origin="lower", cmap="gray", vmin=0, vmax=1)
cax_v = ax_v.imshow(v_vals[0], origin="lower", cmap="gray", vmin=0, vmax=1)
fig.colorbar(cax_u, ax=ax_u, label="u concentration")
fig.colorbar(cax_v, ax=ax_v, label="v concentration")
ax_u.set_title("2D Turing Pattern in u concentration")
ax_v.set_title("2D Turing Pattern in v concentration")

# Create animation for the evolution history
def update(frame):

    cax_u.set_data(u_vals[frame])
    cax_v.set_data(v_vals[frame])

    ax_u.set_title(f"2D Turing Pattern in u - Time Step {frame + 1}")
    ax_v.set_title(f"2D Turing Pattern in v - Time Step {frame + 1}")
    return [cax_u, cax_v]

ani = FuncAnimation(fig, update, frames=len(u_vals), interval=100, blit=True)
HTML(ani.to_jshtml())
