In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import plotly.express as px
from symmetry_breaking_JAX.models_2D.run import build_initial_state_2d, run_2d
from symmetry_breaking_JAX.models_2D.geom_utils import make_grid
from symmetry_breaking_JAX.models_2D.helpers_2D import single_run_2D
import numpy as np
from pathlib import Path
from PIL import Image
from matplotlib.cm import ScalarMappable

In [None]:
fig, ax = plt.subplots(figsize=(1, 5))  # thicker bar
fig.subplots_adjust(right=0.8)

norm = plt.Normalize(vmin=0, vmax=0.01*1e4)
cmap = plt.get_cmap("magma")

cb = fig.colorbar(
    ScalarMappable(norm=norm, cmap=cmap),
    cax=ax,
    orientation='vertical'
)

cb.set_label(r"density (cells per $100\,\mu m^2$)", fontsize=16, color='white', labelpad=5)


# Tick style
cb.ax.tick_params(
    colors='white',        # tick color
    labelsize=14,          # tick label size
    width=2,               # tick line width
    length=8,              # tick length
    direction='out'      # ticks go both directions
)
plt.show()

fig.savefig("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/killifish/20251008_symmetry/colorbar_magma.png", 
            dpi=300, bbox_inches="tight", transparent=True)

In [None]:
def load_png_field(path, invert=False, normalize=True):
    """
    Load a PNG as a grayscale field, using transparency as a mask.
    Wherever alpha == 0, the field is set to 0.

    Parameters
    ----------
    path : str or Path
        Path to PNG image file.
    invert : bool, optional
        If True, invert intensities (dark → high).
    normalize : bool, optional
        Rescale pixel intensities to [0, 1].

    Returns
    -------
    field : jnp.ndarray (float32)
        Grayscale field, shape (H, W)
    """
    img = Image.open(path).convert("RGBA")
    arr = np.array(img).astype(np.float32) / 255.0  # shape (H, W, 4)

    rgb = arr[..., :3]
    alpha = arr[..., 3]

    # Convert to luminance (perceptual grayscale)
    field = 0.2989 * rgb[..., 0] + 0.5870 * rgb[..., 1] + 0.1140 * rgb[..., 2]

    # Zero-out fully transparent pixels
    field = np.where(alpha > 0, field, 0.0)

    if invert:
        field = 1.0 - field

    if normalize:
        minv, maxv = field.min(), field.max()
        if maxv > minv:
            field = (field - minv) / (maxv - minv)
        else:
            field = np.zeros_like(field)

    return jnp.array(field, dtype=jnp.float32)

In [None]:
from skimage.filters import gaussian
from skimage.transform import resize

T = 3600 * 10
Lx = 1250
dx = 10
arr_size = tuple([int(Lx/dx) + 1, int(Lx/dx) + 1])

# Load array to use as initial state
field_dir = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/killifish/20251008_symmetry/20250716_well0000_density/")
image_list = sorted(field_dir.glob("*.png"))
image_ind = 27
field = load_png_field(image_list[image_ind])
field_norm = np.clip(field - 0.2, 0, 1)
field_norm = gaussian(field_norm, sigma=3, mode='nearest')
field_norm = field_norm / np.max(field_norm)

field_rs = resize(field_norm, arr_size)

px.imshow(field_rs)

In [None]:
T = 3600 * 10
n_save = 50

c_factor = 50
rate_factor = 4
# 1. Define parameters and blips
params0 = {
    "D_N": 1.85 / 4 * rate_factor,
    "D_L": 15*1.85 / 1.5 * rate_factor,
    "sigma_N": 10 / c_factor * rate_factor,
    "sigma_L": 16 / c_factor * rate_factor,
    "mu_N": 0.0001,
    "mu_L": 0.00005,
    "n": 2,
    "p": 2,
    "alpha": 1.0,
    "K_A": 667.0 / c_factor,
    "K_NL": 667 / c_factor / 5 / 8,
    "K_I": 1 / c_factor,
    "Lx": 1250,
    "Ly": 1250,
    "dx": 10,
    "dy": 10,
    "geometry": "rectangle",
    "bc": "periodic",
    # "N_positions": jnp.array([[0, 0]]),
    # "N_amps": jnp.array([1000 / c_factor]),
    # "N_sigmas": jnp.array([25]),
    "N_init_mode": "array",
    "N_init_array": field_rs,
    "N_scale_range": tuple([0, 2000]),
    "L_mode": "constant",
    "L_amp": 600,
}

grid0 = make_grid(params0)

In [None]:
# 1. Define parameters and blips
grid0 = make_grid(params0)

N_t0, L_t0 = single_run_2D(params0, grid0, T)

In [None]:
# x = grid0.X[0, :]
# fig = px.line(x=x, y=N_t0[-1][61, :])
# # fig.update_layout(yaxis=dict(range=[0, np.max(strip)]))
# fig.show()

In [None]:
fig = px.imshow(N_t0[-1], range_color=[0, 5000], color_continuous_scale="magma")
fig.show()

# Run the simulation

In [None]:
import matplotlib.pyplot as plt
import time
from tqdm import tqdm 

%matplotlib tk

ts = range(N_t0.shape[0])
out_dir = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/killifish/20251008_symmetry/pde_simulation/")
out_dir.mkdir(exist_ok=True, parents=True)

fig, ax = plt.subplots(figsize=(6, 5))
vmin = np.percentile(N_t0[-1], 0.1)
vmax = 5500  # or np.percentile(N_t0[-1], 99.5) * 0.85

for i, t in enumerate(tqdm(ts)):
    ax.clear()
    im = ax.imshow(
        N_t0[i],
        cmap="magma",
        origin="lower",
        vmin=vmin,
        vmax=vmax,
    )
    ax.set_title(f"t = {t:.2f}")
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    # Save frame
    out_path = out_dir / f"frame_{i:04d}.png"
    fig.savefig(out_path, dpi=300, bbox_inches="tight")

    # plt.pause(0.05)  # smaller pause if saving every frame

plt.close(fig)

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
for i in range(len(ts)):
    ax.clear()
    ax.imshow(
        L_t[i],
        cmap="viridis",
        origin="lower",
        extent=[X.min(), X.max(), Y.min(), Y.max()],
        # vmin=0,
        # vmax=L_t.max(),
    )
    ax.set_title(f"t = {ts[i]:.2f}")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    plt.pause(0.05)   # force GUI update
    
    plt.show()



In [None]:
np.max(L_t0[-1])

In [None]:
params