In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
from pathlib import Path
from compile_sweep_results import compile_sweep_results
import os

In [None]:
root = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/symmetry_breaking/sweep_results")
sweep_name = "sweep01_neutralization_v2"
fig_path = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/killifish/20250815")
os.makedirs(fig_path, exist_ok=True)

## Run simulation

In [None]:
from symmetry_breaking.models.NL_field_1D import NodalLeftyField1D
from symmetry_breaking.models.trackers import NodalROITracker
from symmetry_breaking.models.sweep import run_simulation_1D, make_1d_grid
from pde import ProgressTracker, PlotTracker

# hyperparams
dx = 5
L = 3500
T = 5 * 60 * 60
dt = 0.5 * dx ** 2 / 60 / 1.25 # Stability condition for diffusio

param_dict = {
         'K_NL': 667, # From lit
         'K_A': 667, # From lit
         'K_I': 10, # not in lit model
         'mu_L': 0.0002275846, # 5x larger than lit
         'N_amp': 0, # NA
         'N_sigma': 31.6227766017, # NA
         'sigma_N': 10, # 10x lit
         'sigma_L': 10, # 1000x (!!) lit 
         'D_N': 2*1.85, # 2X lit
         'D_L': 2*15.0, # 2x lit
         'n': 2, # lit
         'm': 1, # NA
         'p': 2, # lit
         'q': 2,
         'stoch_to_N':True, 
         'rate_N':1/(3500*60)/3, 
         'amp_median_N':500, 
         'amp_sigma_N':.00001,
         "sigma_x":31} # lit


sim_config = {
                "dx": dx,
                "L": L,
                "T": T,
                "dt": dt,
                "model_class": NodalLeftyField1D,
                "tracker_class": NodalROITracker,
            }

sim_inputs = (param_dict , sim_config, "") 

In [None]:
param_dict, sim_config, output_dir = sim_inputs

# Unpack sim config
dx = sim_config["dx"]
L = sim_config["L"]
T = sim_config["T"]
dt = sim_config["dt"]



# pde = NodalLeftyField1D(..., blip_logger=blip_logger)
# tracker = NodalROITracker(grid, ..., blip_logger=blip_logger)

interval=100
model_class = sim_config["model_class"]
tracker_class = sim_config["tracker_class"]

sigma_N_list = [0, 10]
sigma_L_list = [10, 10]
D_N_list = [2*1.85, 2*1.85]
D_L_list = [2*15, 2*15]
L_value_list = [0, 0]
result_list = []

for s in [0, 1]:

    params_temp = param_dict.copy()
    # pull info
    params_temp["sigma_N"] = sigma_N_list[s]
    params_temp["sigma_L"] = sigma_L_list[s]
    params_temp["D_N"] = D_N_list[s]
    params_temp["D_L"] = D_L_list[s]
    params_temp["L_value"] =L_value_list[s]
    
    # --- Setup grid and model ---
    blip_logger = []
    grid = make_1d_grid(length=L, dx=dx)
    model = model_class(**params_temp, blip_logger=blip_logger)
    state = model.get_state(grid)
    
    # roi = NodalROITracker(grid, interval=interval,
    #                       save_profiles=False,
    #                       store_every=300,      # set e.g. 100 if you want snapshots too
    #                       downsample=None,          # optional
    #                       dtype=np.float32)
    
    # --- Build a tracker collection ---
    # Progress bar updates every 'interval' steps
    progress = ProgressTracker(interval=interval)
    
    # Live plots (you can adjust scale, cmap, etc.)
    plots = PlotTracker(interval=interval, plot_args={"figsize": (6, 4)}, 
                        show=True, tight_layout=True)
    
    # If you want both at once, wrap them in a list
    trackers = [progress, tracker_class(grid, interval=interval, blip_logger=blip_logger)]
    
    # --- Run simulation ---
    state = model.solve(state, t_range=T, dt=dt, tracker=trackers)
    
    # --- Collect results from your custom tracker ---
    # result = {
    #     trackers[-1],  # last tracker is your custom one
    # }
    
    result_list.append(trackers[-1])

In [None]:
from tqdm import tqdm
from src.utilities.plot_functions import format_2d_plotly

offset = 1e-1
y0 = 2500
thresh = 150
pulse_mem = 300 

for r, tracker in enumerate(result_list):
    
    frame_path = fig_path / f"pulse_sim{r:02}_frames"
    os.makedirs(frame_path, exist_ok=True)

    profiles = tracker.get_profiles()
    x = profiles["x"]- 1750
    time_vec = profiles["times"]
    N = profiles["Activator"]      # numpy array
    L = profiles["Repressor"]      # numpy array
    rho = profiles.get("rho")

    nodal_pulses = tracker.get_pulses("Activator")
    pulse_x_vec = nodal_pulses["x"]
    pulse_t_vec = nodal_pulses["x"] 
    
    for t, time in enumerate(tqdm(time_vec)):

        ym = np.max([y0, 1.1*np.max(L[t,:]), 1.1*np.max(N[t,:])])
        
        fig = go.Figure()
        fig.add_traces(go.Scatter(
            x=x, 
            y=N[t, :], 
            mode="lines", 
            line=dict(width=5, color="#fc8d62"),  # Set2 blue
            fill="tozeroy",
            fillcolor="rgba(252, 141, 98, 0.3)",  # semi-transparent blue
            name="Nodal"
        ))
        fig.add_traces(go.Scatter(
            x=x, 
            y=L[t, :], 
            mode="lines", 
            line=dict(width=5, color="#8da0cb"),  # Set2 red
            fill="tozeroy",
            fillcolor="rgba(141, 160, 203, 0.6)",  # semi-transparent red
            name="Lefty"
        ))

        pulse_dt = time-pulse_t_vec 
        extant_pulse_indices = np.where((pulse_dt<=pulse_mem)&(pulse_dt>=0))[0]
        for e in extant_pulse_indices:
            fig.add_vline(x=pulse_x_vec[e], line_width=2, line_color="#fc8d62", layer="above")
        # define your bounds (in x units)
        # y0, y1 = -250, 250   # example
        # band_color = "rgba(200,200,200,0.25)"  # light gray with alpha
        # rep_indices = np.where(L[t,:] >= thresh)[0]
        # if np.any(rep_indices):
            
        #     xl = x[rep_indices[0]]
        #     xr = x[rep_indices[-1]]
            
        #     # vertical dashed lines
        #     fig.add_vline(x=xl, line_width=3, line_dash="dash", line_color="white", layer="above")
        #     fig.add_vline(x=xr, line_width=3, line_dash="dash", line_color="white", layer="above")
            
        #     # shaded region between y0 and y1 (use yref='paper' to span full height regardless of y-axis range)
        #     fig.add_shape(
        #         type="rect",
        #         x0=xl, x1=xr,
        #         y0=0, y1=1, yref="paper",
        #         fillcolor=band_color,
        #         line=dict(width=0),
        #         layer="below"
        #     )

        fig = format_2d_plotly(fig, axis_labels=["position (microns)", "concentration (nmol per micron squared)"], 
                               font_size=18)
        
        fig.update_layout(title=f"Nodal and Lefty concentrations ({np.round(time/3600,2)} hours)")
        fig.update_xaxes(range=[-1750, 1750])
        fig.update_yaxes(range=[0, ym])
        # fig.update_yaxes(range=[0, np.log10(5e4)], type="log")
        
        # fig.show()
        fig.write_image(frame_path / f"frame_{t:05}.png")
    
fig.show()

In [None]:
print(pulse_t_vec[extant_pulse_indices])
time

In [None]:
lefty_list = []
nodal_list = []
time_list = []
id_list = []

for r, tracker in enumerate(result_list):
    
    frame_path = fig_path / f"lefty_sim{r:02}_frames"
    os.makedirs(frame_path, exist_ok=True)

    profiles = tracker.get_profiles()
    x = profiles["x"]- 1750
    time_vec = profiles["times"]
    N = profiles["Activator"]      # numpy array
    L = profiles["Repressor"]      # numpy array
    rho = profiles.get("rho")

    N_left = np.mean(N[:, x<=-250], axis=1)
    N_c = np.mean(N[:, (x>=-250) & (x<=250)], axis=1)
    L_c = np.mean(L[:, (x>=-250) & (x<=250)], axis=1)
    nodal_list += list(N_c[1:])
    lefty_list += list(L_c[1:])
    time_list += list(time_vec[1:]/3600)
    id_list += list(np.tile(r, len(time_vec[1:])))

    
from matplotlib import pyplot as plt
plt.close('all')  
with plt.style.context("dark_background"):
    
    # cmap = mpl.cm.get_cmap('inferno')
    # cmap_trunc = mpl.colors.ListedColormap(cmap(np.linspace(0, 0.8, 256)))

    fig, ax = plt.subplots(figsize=(8, 6))
    # strm = ax.streamplot(
    #     X, Y, U, V,
    #     color="white",
    #     density=2
    # )
    
    #sc = Scatter overlay
    sc = ax.scatter(nodal_list, lefty_list, c=time_list, cmap="plasma", zorder=3)
    
    # Axis labels & title
    ax.set_xlabel("Relative Nodal concentration ($[N_c]$-$[N_p]$)")
    ax.set_ylabel("Central Lefty concentration ($L_c$)")
    ax.set_title("Patterning phase space")
    
    # Colorbar for velocity
    cbar = fig.colorbar(sc, ax=ax, label="hours")
    # plt.style.use("dark_background")
    # ax.legend()
    plt.show()
    
    fig.savefig(fig_path / "phase_example.png", dpi=300, bbox_inches="tight")

In [None]:
fig_path

### Re-run the exact same simulation but with no link between Lefty and Nodal 

In [None]:
frame_path

In [None]:
# hyperparams
dx = 5
L = 3500
T = 10 * 60 * 60
dt = 0.5 * dx ** 2 / 60 / 1.25 # Stability condition for diffusio

param_dict = {
         'K_NL': 138.9495494373,
         'K_A': 193.0697728883,
         # 'K_R': 19.3069772888,
         'K_I': 1e10, #19.3069772888,
         'mu_L': 0.0002275846,
         'N_amp': 517.9474679231,
         'N_sigma': 31.6227766017,
         'sigma_N': 10.0,
         'sigma_L': 10.0,
         'D0_N': 1.85,
         'D0_L': 15.0,
         'alpha_L': 0,
         'alpha_N': 0,
         'tau_rho': 3600,
         'n': 2,
         'm': 1,
         'p': 2,
         'q': 2}

static_params = {
                  "sigma_N": 10.0,  # Nodal auto-activation
                  "no_density_dependence": True,
                  "alpha_L": 0,
                  "alpha_N": 0,
                  "tau_rho": 3600,
                }

sim_config = {
                "dx": dx,
                "L": L,
                "T": T,
                "dt": dt,
                "model_class": NodalLeftyNeutralization1D,
                "tracker_class": NodalROITracker,
                "interval": 1000,
            }

sim_inputs = (param_dict | static_params, sim_config, "") 

In [None]:
25**2

In [None]:
param_dict, sim_config, output_dir = sim_inputs

# Unpack sim config
dx = sim_config["dx"]
L = sim_config["L"]
T = sim_config["T"]
dt = sim_config["dt"]
interval=300
model_class = sim_config["model_class"]
tracker_class = sim_config["tracker_class"]
# interval = sim_config.get("interval", 1000)

# --- Setup grid and model ---
grid = make_1d_grid(length=L, dx=dx)
model = model_class(**param_dict)
state = model.get_state(grid)

# roi = NodalROITracker(grid, interval=interval,
#                       save_profiles=False,
#                       store_every=300,      # set e.g. 100 if you want snapshots too
#                       downsample=None,          # optional
#                       dtype=np.float32)

# --- Build a tracker collection ---
# Progress bar updates every 'interval' steps
progress = ProgressTracker(interval=interval)

# Live plots (you can adjust scale, cmap, etc.)
# plots = PlotTracker(interval=interval, plot_args={"figsize": (6, 4)}, 
#                     show=True, tight_layout=True)

# If you want both at once, wrap them in a list
trackers = [progress, tracker_class(grid, interval=interval)]

# --- Run simulation ---
state = model.solve(state, t_range=T, dt=dt, tracker=trackers)

# --- Collect results from your custom tracker ---
result = {
    **param_dict,
    **trackers[-1].get_metrics(),  # last tracker is your custom one
}

In [None]:
profiles = trackers[-1].get_profiles()
x = profiles["x"]- 1750
time_vec = profiles["times"]
N = profiles["Activator"]      # numpy array
L = profiles["Repressor"]      # numpy array
rho = profiles.get("rho") 

In [None]:
frame_path = fig_path / f"{sim_id}_frames_no_LN"
os.makedirs(frame_path, exist_ok=True)

for t, time in enumerate(tqdm(time_vec)):
    fig = go.Figure()
    fig.add_traces(go.Scatter(
        x=x, 
        y=N[t, :] , 
        mode="lines", 
        line=dict(width=3, color="#8da0cb"),  # Set2 blue
        fill="tozeroy",
        fillcolor="rgba(141, 160, 203, 0.8)",  # semi-transparent blue
        name="Nodal"
    ))
    fig.add_traces(go.Scatter(
        x=x, 
        y=L[t, :], 
        mode="lines", 
        line=dict(width=3, color="#fc8d62"),  # Set2 red
        fill="tozeroy",
        fillcolor="rgba(252, 141, 98, 0.3)",  # semi-transparent red
        name="Lefty"
    ))
    
    fig = format_2d_plotly(fig, axis_labels=["position (microns)", "concentration (nmol per micron squared)"], 
                           font_size=18)
    fig.update_layout(title=f"Nodal and Lefty concentrations ({np.round(time/3600,2)} hours)")
    fig.update_xaxes(range=[-1500, 1500])
    fig.update_yaxes(range=[0, 50000])
    
    # fig.show()
    fig.write_image(frame_path / f"frame_{t:05}.png")

fig.show()

### Add random field