In [None]:
import numpy as np, json
from pathlib import Path
import matplotlib
# matplotlib.use("TkAgg")   # Non-interactive backend (saves figures to file)
from matplotlib import pyplot as plt
import plotly.express as px
import pandas as pd

## Load sweep results

In [None]:
root = Path("/Users/nick/Data/symmetry/sweeps/")
output_dir = root / "sweep00_jax_stable"


# Reload arrays
data = np.load(output_dir / "results.npz")
N = data["N"]
L = data["L"]

# Reload parameters
with open(output_dir / "params.json") as f:
    param_dicts = json.load(f)

# print(N.shape, L.shape, len(param_dicts))
param_df = pd.DataFrame.from_dict(param_dicts)
param_df.head()

In [None]:
sorted(param_df["kappa_NL"].unique())

## Plot example

In [None]:
sim_idx = 355  # <-- change this
sample_data = np.array(N[sim_idx])  # shape (time, space)

# imshow expects (rows, cols) = (time, space)

fig = px.imshow(
    sample_data.T,                # transpose so x=time, y=space
    aspect="auto",
    origin="lower",
    labels=dict(x="Time", y="Position", color="Intensity"),
    color_continuous_scale="viridis"
)
fig.show()

In [None]:
import plotly.graph_objects as go

plot_times = np.arange(0, 121,10)
x_grid=np.linspace(-1000,1000, 151)
fig = go.Figure()
for p in plot_times:
    fig.add_traces(go.Scatter(x=x_grid, y=sample_data[p]))
fig.show()

In [None]:
print(np.sqrt(1.85/1e-4))

test = sample_data[p] / np.sum(sample_data[p])
test_c = np.cumsum(test)
ind = np.where(test_c>(1-0.1587))[0]
x_grid[ind[0]]

### Question 1: what dictates survival/extintion of initial Nodal blip?
To start consider case when K_N = N_nl

In [None]:
x_samp = 61
t_samp = -1
K_A = 667

N_vec = N[:, t_samp, x_samp].ravel() / K_A
N_vec_e = N[:, t_samp, 0].ravel() / K_A
L_vec = L[:, t_samp, x_samp].ravel() / K_A
kappa_filter = param_df["kappa_NL"].astype(int)==1
# fig = px.scatter(x=N_vec[kappa_filter], y=L_vec[kappa_filter], color=np.log10(rat_vec[kappa_filter]))
fig = px.scatter(x=N_vec[kappa_filter], y=L_vec[kappa_filter], color=np.log10(param_df.loc[kappa_filter, "delta"]))
    #np.where(kappa_filter)[0])#np.log10(param_df.loc[kappa_filter, "a_amp"]+1e-6))
fig.show()

In [None]:
rat_vec = np.divide(L_vec, N_vec)
nz_ft = N_vec > 2
fig = px.scatter(x=param_df.loc[kappa_filter & nz_ft, "beta_r"], y=rat_vec[kappa_filter & nz_ft])
fig.show()

In [None]:
import numpy as np
from scipy.optimize import curve_fit
from tqdm import tqdm

def gauss(x, A, mu, sigma):
    return A * np.exp(-(x-mu)**2/(2*sigma**2))

In [None]:
peaked_inds = np.where((N_vec-N_vec_e)>1)[0]
t_samp = 0
amp_vec = np.empty_like(N_vec)
amp_vec[:] = np.nan
sigma_vec = np.empty_like(N_vec)
sigma_vec[:] = np.nan

for i in tqdm(peaked_inds):
    A = np.squeeze(N[i, t_samp, :])
    popt, _ = curve_fit(gauss, x_grid, A, p0=[A.max(), 0, 1.0])
    amp_vec[i], _, sigma_vec[i] = popt

In [None]:
fig = px.scatter(x=amp_vec, y=sigma_vec)
fig.show()

In [None]:
np.sum(~np.isnan(amp_vec))

In [None]:
px.line(x=x_grid, A