In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from symmetry_breaking_JAX.models_2D.run import simulate_2d, build_initial_state_2d, run_2d
from symmetry_breaking_JAX.models_2D.geom_utils import make_grid
from symmetry_breaking_JAX.models_2D.param_classes import Params2D, BlipSet2D

In [None]:
T = 1000
n_save = 50

# 1. Define parameters and blips
params = Params2D(
    D_N=1.0, D_L=15.0,
    sigma_N=0.001, sigma_L=0.001,
    mu_N=1e-4, mu_L=1e-4,
    n=2, p=2, alpha=1.0,
    K_A=667.0, K_NL=670, K_I=1.0,
    Lx=500.0, Ly=500.0, dx=5.0, dy=5.0,
    geometry="disk",
)

blips = BlipSet2D.empty()  # no dynamic events for this test

In [None]:
# 2. Build grid
grid = make_grid(params)

# 3. Initialize state
y0 = build_initial_state_2d(
    params, grid,
    N_mode="gaussian",
    N_positions=jnp.array([[0, 0], [100, 100]]),
    N_amps=jnp.array([1000, 500]),
    N_sigmas=jnp.array([25, 25]),
    L_mode="constant",
    L_amp=0.0,
)


# Run the simulation

In [None]:
# 4. Integrate
ts, ys = run_2d(
    params,
    blips,
    grid,
    T=T,
    save_ts=jnp.linspace(0, T, n_save),
    y0=y0,
)

In [None]:
# unpack grid
X, Y, mask = grid.X, grid.Y, grid.mask
nx, ny = X.shape
nxy = nx * ny

# extract Nodal field over time
N_t = ys[:, :nxy].reshape(len(ts), nx, ny)
L_t = ys[:, nxy:2*nxy]         # shape (nt, nxy)
L_t = L_t.reshape(len(ts), *grid.X.shape)

In [None]:
import matplotlib.pyplot as plt
import time

%matplotlib tk

fig, ax = plt.subplots(figsize=(6, 5))
for i in range(len(ts)):
    ax.clear()
    ax.imshow(
        L_t[i],
        cmap="viridis",
        origin="lower",
        extent=[X.min(), X.max(), Y.min(), Y.max()],
        # vmin=0,
        # vmax=L_t.max(),
    )
    ax.set_title(f"t = {ts[i]:.2f}")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    plt.pause(0.05)   # force GUI update
    
    plt.show()



In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
for i in range(len(ts)):
    ax.clear()
    ax.imshow(
        N_t[i],
        cmap="viridis",
        origin="lower",
        extent=[X.min(), X.max(), Y.min(), Y.max()],
        # vmin=0,
        # vmax=L_t.max(),
    )
    ax.set_title(f"t = {ts[i]:.2f}")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    plt.pause(0.05)   # force GUI update
    
    plt.show()