In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from symmetry_breaking_JAX.models.JAX_NL_1D import run_1d, Params1D, BlipSet1D

In [None]:
# --------------------------------------------------
# Experiment config
# --------------------------------------------------
Lx = 1500.0   # µm
dx = 10.0
nx = int(Lx / dx) + 1
x = jnp.linspace(0.0, Lx, nx)

T = 30 * 3600  # 6 hours
n_save = 121
save_ts = jnp.linspace(0.0, T, n_save)

# Parameters (pick one set you want to test)
params = Params1D(
    D_N=1.85,
    D_L=1.85*10,
    sigma_N=1,
    sigma_L=1*5,
    mu_N=1e-4,
    mu_L=1e-4,
    n=2,
    p=2,
    alpha=1.0,
    K_A=667.0,
    K_NL=667*5,
    K_I=1.0,
    bc="periodic",
    sigma_t_direct=60.0
)

# --------------------------------------------------
# Choose seeding strategy
# --------------------------------------------------
mode = "manual"   # "init" or "induce"
n_spots = 100


In [None]:
print(np.sqrt(18.5/1e-4))
print(np.sqrt(1.85/1e-4))

In [None]:
if mode == "init":
    # Initial Gaussians
    rng = np.random.default_rng()  # uses OS entropy
    centers    = rng.uniform(0.0, Lx, size=n_spots)   
    sigmas = 25.0 * jnp.ones_like(centers)
    amps = 50 * jnp.ones_like(centers)

    N0 = sum(amp * jnp.exp(-0.5 * ((x - c)/s)**2) for amp, c, s in zip(amps, centers, sigmas))
    L0 = jnp.zeros_like(x)
    rho0 = jnp.ones_like(x) * 0.1
    F_N0 = jnp.zeros_like(x)
    F_L0 = jnp.zeros_like(x)
    y0 = jnp.concatenate([N0, L0, rho0, F_N0, F_L0])

    blips = BlipSet1D.empty()

elif mode == "induce":
    # No initial spots, but add timed blips
    blips = BlipSet1D.empty()
    # rng = jax.random.PRNGKey(1)
    # rng_t, rng_x = jax.random.split(rng) 
    # times = jax.random.uniform(rng_t, (n_spots,), minval=1, maxval=T)
    # xs = jax.random.uniform(rng, (n_spots,), minval=0.0, maxval=Lx)
    rng = np.random.default_rng()  # uses OS entropy
    times = rng.uniform(0.0, T, size=n_spots)
    xs    = rng.uniform(0.0, Lx, size=n_spots)

    amps = 200.0 * jnp.ones_like(xs)
    sigx = 25.0 * jnp.ones_like(xs)

    blips = BlipSet1D(
        N_dir_times=times, N_dir_x=xs, N_dir_amp=amps, N_dir_sigx=sigx,
        N_imp_times=jnp.array([]), N_imp_x=jnp.array([]), N_imp_amp=jnp.array([]), N_imp_sigx=jnp.array([]),
        L_dir_times=jnp.array([]), L_dir_x=jnp.array([]), L_dir_amp=jnp.array([]), L_dir_sigx=jnp.array([]),
        L_imp_times=jnp.array([]), L_imp_x=jnp.array([]), L_imp_amp=jnp.array([]), L_imp_sigx=jnp.array([]),
    )
    y0 = None  # use default initializer inside run_1d
    
elif mode == "manual":
    # No initial spots, but add timed blips
    blips = BlipSet1D.empty()

    times = np.array([0.0, 0.0])
    xs    = np.array([375.0, 1000.0])

    amps = 666667.0 * jnp.ones_like(xs)
    sigx = 25.0 * jnp.ones_like(xs)

    blips = BlipSet1D(
        N_dir_times=times, N_dir_x=xs, N_dir_amp=amps, N_dir_sigx=sigx,
        N_imp_times=jnp.array([]), N_imp_x=jnp.array([]), N_imp_amp=jnp.array([]), N_imp_sigx=jnp.array([]),
        L_dir_times=jnp.array([]), L_dir_x=jnp.array([]), L_dir_amp=jnp.array([]), L_dir_sigx=jnp.array([]),
        L_imp_times=jnp.array([]), L_imp_x=jnp.array([]), L_imp_amp=jnp.array([]), L_imp_sigx=jnp.array([]),
    )
    y0 = None  # use default initializer inside run_1d

In [None]:
# --------------------------------------------------
# Run simulation
# --------------------------------------------------
x, ts, ys = run_1d(params, blips, T, nx, save_ts, y0_override=y0)

N_traj = ys[:, 0:nx]
L_traj = ys[:, nx:2*nx]

# --------------------------------------------------
# Quick visualization
# --------------------------------------------------
plt.figure(figsize=(8, 5))
for i in range(0, len(ts), max(1, len(ts)//10)):
    plt.plot(x, N_traj[i] / 667, label=f"t={ts[i]/3600:.1f} h")
plt.xlabel("x (µm)")
plt.ylabel("Nodal concentration")
plt.title("Spot test trajectories")
plt.legend()
plt.show()

In [None]:
import numpy as np
from scipy.optimize import curve_fit
from tqdm import tqdm

def gauss(x, A, mu, sigma):
    return A * np.exp(-(x-mu)**2/(2*sigma**2))

A = np.array(N_traj[-1])
x_mu = np.sum(np.multiply(A, x)) / np.sum(A)
A_var = np.sum(np.multiply((x-x_mu)**2, A)) / np.sum(A)
print(np.sqrt(A_var))
print(x_mu)

In [None]:
gauss(400, 0, 300)

In [None]:
import plotly.express as px

fig = px.imshow(np.log(1+N_traj[0:25, :]))
fig.show()

In [None]:
fig = px.scatter(x=xs, y=times)
fig.show()