In [None]:
# 09_mw_microtune_showcase_progress.py
import os, sys, numpy as np, pandas as pd, matplotlib.pyplot as plt, numpy.linalg as npl
from scipy.optimize import root
from tqdm import tqdm
from mw_model_core import load_params, find_equilibria, relax_to_ss, rhs, jac_fd
from mw_model_constants import FIT_PATH, N_HILL as N_HILL_BASE, KQ as KQ_BASE, D_OVERRIDE

OUT = "results/microtune_showcase"
os.makedirs(OUT, exist_ok=True)
p_fit = load_params()
d_fit = float(p_fit[6])

# ------------ parameter grids ------------
N_HILL_grid = [max(2, int(N_HILL_BASE - 1)), N_HILL_BASE, N_HILL_BASE + 1, N_HILL_BASE + 2]
N_HILL_grid = sorted(set([n for n in N_HILL_grid if 2 <= n <= 7]))
# print(N_HILL_grid)
N_HILL_grid = [3,4,5]
KQ_grid = [80,100] #sorted(set([KQ_BASE, 100, 120, 140, 150]))[:5]

p_high_scales = [0.95, 1.05, 1.15]
rHP_scales = [0.9, 1.0, 1.1]
K_u_scales = [ 0.9, 1.0, 1.1]
gamma_scales = [1.0, 1.1, 0.9]

# ------------ helpers ------------
def eqs_with_params(p, KQ_local, N_HILL_local):
    return find_equilibria(p, KQ_local=KQ_local, N_HILL_local=N_HILL_local)

def basin_fraction_high(p, KQ_local, N_HILL_local, H_cut=0.5, grid=20, T=900):
    Hs = np.linspace(0.10, 0.95, grid)
    qs = np.linspace(0.00, 1.00, grid)
    hi, tot = 0, 0
    for H0 in Hs:
        for q0 in qs:
            y0 = np.array([0.10, 0.10, H0, 0.08, q0], float)
            yss, _ = relax_to_ss(p, y0, T=T, KQ_local=KQ_local, N_HILL_local=N_HILL_local)
            hi += int(yss[2] >= H_cut)
            tot += 1
    return hi / tot if tot > 0 else 0.0

def bifurcation_slice_H_vs_d(p, KQ_local, N_HILL_local, tag):
    ds = np.linspace(0.7 * d_fit, 1.4 * d_fit, 120)
    seeds = [
        np.array([0.12, 0.12, 0.30, 0.10, 1.0]),
        np.array([0.05, 0.20, 0.90, 0.12, 0.0]),
        np.array([0.30, 0.08, 0.55, 0.10, 0.6]),
        np.array([0.15, 0.15, 0.65, 0.12, 0.4]),
    ]
    rows = []
    for d in tqdm(ds, desc=f"Bifurcation {tag}", leave=False):
        p2 = p.copy()
        p2[6] = d
        for wi, s in enumerate(seeds):
            sol = root(lambda yy: rhs(yy, p2, KQ_local, N_HILL_local), s, method="hybr")
            if not sol.success:
                continue
            y = np.array([
                max(0, sol.x[0]), max(0, sol.x[1]),
                np.clip(sol.x[2], 0, 1.2),
                max(0, sol.x[3]), np.clip(sol.x[4], 0, 1.2)
            ], float)
            lam = np.real(npl.eigvals(jac_fd(lambda z: rhs(z, p2, KQ_local, N_HILL_local), y)))
            rows.append({"d": d, "H": float(y[2]), "stable": bool(np.max(lam) < 0), "seed": wi})
    if not rows:
        return
    df = pd.DataFrame(rows)
    plt.figure(figsize=(7.2, 5.0))
    for st, mk in [(True, "o"), (False, "x")]:
        sub = df[df["stable"] == st]
        if len(sub):
            plt.scatter(sub["d"], sub["H"], s=16, marker=mk, alpha=0.7,
                        label=("stable" if st else "unstable"))
    plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
    plt.xlabel("d (1/h)")
    plt.ylabel("H*")
    plt.legend()
    plt.grid(True, ls=":")
    plt.tight_layout()
    plt.savefig(os.path.join(OUT, f"bifurcation_{tag}.png"), dpi=180)
    plt.close()

# ------------ grid search ------------
total = len(N_HILL_grid) * len(KQ_grid) * len(p_high_scales) * len(rHP_scales) * len(K_u_scales) * len(gamma_scales)
cands = []

print(f"Running parameter grid ({total} combinations)...")
for nH in tqdm(N_HILL_grid, desc="N_HILL"):
    for KQv in tqdm(KQ_grid, desc=f"KQ", leave=False):
        for sh in p_high_scales:
            for sr in rHP_scales:
                for sku in K_u_scales:
                    for sg in gamma_scales:
                        p = p_fit.copy()
                        p[11] *= sh
                        p[1] *= sr
                        p[9] *= sku
                        p[4] *= sg

                        eqs = eqs_with_params(p, KQv, nH)
                        st = eqs[eqs["stable"] == True].sort_values("H")
                        if len(st) < 2:
                            bf = basin_fraction_high(p, KQv, nH, H_cut=0.6, grid=14)
                            cands.append({
                                "N_HILL": nH, "KQ": KQv, "p_high_scale": sh, "rHP_scale": sr,
                                "K_u_scale": sku, "gamma_scale": sg,
                                "n_stable": int(len(st)), "dH": 0.0, "basin_high_frac": bf,
                                "score": bf * 0.2
                            })
                        else:
                            dH = float(st["H"].iloc[-1] - st["H"].iloc[0])
                            bf = basin_fraction_high(p, KQv, nH, H_cut=0.6, grid=16)
                            penalty = abs(bf - 0.5)
                            score = dH - 0.3 * penalty
                            cands.append({
                                "N_HILL": nH, "KQ": KQv, "p_high_scale": sh, "rHP_scale": sr,
                                "K_u_scale": sku, "gamma_scale": sg,
                                "n_stable": int(len(st)), "dH": dH,
                                "basin_high_frac": bf, "score": score
                            })
                        if len(cands) % 10 == 0:
                            sys.stdout.write(f"\r → tested {len(cands)}/{total} combinations...")
                            sys.stdout.flush()

print(f"\nGrid search complete: {len(cands)} tested.\n")

# ------------ ranking ------------
cand_df = pd.DataFrame(cands).sort_values(
    ["n_stable", "score", "dH"], ascending=[False, False, False]
).reset_index(drop=True)
cand_df.to_csv(os.path.join(OUT, "microtune_candidates.csv"), index=False)
print(f"Saved ranking -> {os.path.join(OUT, 'microtune_candidates.csv')}")

# ------------ visualization ------------
top = cand_df[cand_df["n_stable"] >= 2].head(3)
print(f"\nPlotting top {len(top)} bistable parameter sets...\n")

for k, row in top.iterrows():
    p = p_fit.copy()
    p[11] *= row["p_high_scale"]
    p[1] *= row["rHP_scale"]
    p[9] *= row["K_u_scale"]
    p[4] *= row["gamma_scale"]
    tag = (f"n{int(row['N_HILL'])}_KQ{int(row['KQ'])}_ph{row['p_high_scale']:.2f}_"
           f"rhp{row['rHP_scale']:.2f}_Ku{row['K_u_scale']:.2f}_g{row['gamma_scale']:.2f}")

    print(f" → generating plots for {tag}")
    Hs = np.linspace(0.10, 0.95, 35)
    qs = np.linspace(0.0, 1.0, 35)
    Z = np.zeros((len(Hs), len(qs)))

    for i, H0 in enumerate(tqdm(Hs, desc=f"Basins {tag}", leave=False)):
        for j, q0 in enumerate(qs):
            y0 = np.array([0.10, 0.10, H0, 0.08, q0], float)
            yss, _ = relax_to_ss(p, y0, T=1100, KQ_local=row["KQ"], N_HILL_local=int(row["N_HILL"]))
            Z[i, j] = yss[2]

    plt.figure(figsize=(6.6, 5.2))
    plt.imshow(Z, origin="lower", extent=[qs[0], qs[-1], Hs[0], Hs[-1]],
               aspect="auto", vmin=0.0, vmax=1.0, cmap="viridis")
    plt.colorbar(label="final H*")
    plt.xlabel("q0")
    plt.ylabel("H0")
    plt.title(f"Basins | {tag}\nΔH*={row['dH']:.2f}, basin_high≈{row['basin_high_frac']:.2f}")
    plt.tight_layout()
    plt.savefig(os.path.join(OUT, f"basins_{tag}.png"))
    plt.close()

    bifurcation_slice_H_vs_d(p, int(row["KQ"]), int(row["N_HILL"]), tag)

print(f"\nAll plots saved in -> {OUT}")


Running parameter grid (486 combinations)...


N_HILL:   0%|                                             | 0/3 [00:00<?, ?it/s]
KQ:   0%|                                                 | 0/2 [00:00<?, ?it/s][A

 → tested 40/486 combinations...