In [None]:
# calibrate_hysteresis_from_scored_updated_v3.py >>>>>>>>> smooth hysteresis:: rhs_smooth and q_inf smooth 



# calibrate_hysteresis_from_scored_smooth.py
# ------------------------------------------------------------
# Smooth-hysteresis calibration to SCFA intensities and an H proxy.
# Fixes the variable-length residual issue by precomputing fixed masks.
# ------------------------------------------------------------

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ----------------- Config -----------------
INPATH  = "timeseries/combined_scfas_table_scored.csv"
OUTDIR  = "mw_fit_out_smooth"
os.makedirs(OUTDIR, exist_ok=True)

USE_H_COL  = "H_proxy_meta_smooth"   # fallback to "H_proxy_meta" if needed
USE_B_COLS = ["butyrate"]            # add other SCFAs if desired

MIN_POINTS_PER_SUBJECT = 4
PENALTY = 1e3                  # finite penalty for pathological cases
KQ      = 40.0                 # steepness of the smooth hysteretic switch sigma(k*(H-theta(q)))

# ----------------- Load & checks -----------------
df = pd.read_csv(INPATH)
needed = {"subject_id","sample_id"}
missing = [c for c in needed if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns in CSV: {missing}")

if USE_H_COL not in df.columns:
    alt = "H_proxy_meta"
    if alt in df.columns:
        print(f"[info] {USE_H_COL} not found; using {alt} instead.")
        USE_H_COL = alt
    else:
        raise ValueError("No usable H proxy column found (need H_proxy_meta_smooth or H_proxy_meta).")

for c in USE_B_COLS:
    if c not in df.columns:
        raise ValueError(f"SCFA column '{c}' not found in CSV.")

keep = ["subject_id","sample_id", USE_H_COL] + USE_B_COLS
df = df[keep].dropna(subset=["subject_id","sample_id"]).copy()

# ----------------- Time indexing -----------------
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

# ----------------- Robust scaling (MAD) -----------------
def robust_mad_scale(x: pd.Series) -> pd.Series:
    x_valid = x.dropna().astype(float)
    if len(x_valid) == 0:
        return pd.Series(np.zeros_like(x), index=x.index)
    med = np.median(x_valid)
    mad = np.median(np.abs(x_valid - med))
    if mad < 1e-9:
        q75, q25 = np.percentile(x_valid, [75, 25])
        iqr = q75 - q25
        scale = iqr if iqr > 1e-9 else (np.std(x_valid) + 1e-9)
    else:
        scale = mad
    return (x.astype(float) - med) / (scale + 1e-9)

for c in USE_B_COLS:
    df[c + "_z"] = df.groupby("subject_id")[c].transform(robust_mad_scale)

if len(USE_B_COLS) == 1:
    df["B_obs"] = df[USE_B_COLS[0] + "_z"]
else:
    zcols = [c + "_z" for c in USE_B_COLS]
    df["B_obs"] = df[zcols].mean(axis=1)

df["H_obs"] = df[USE_H_COL].clip(0, 1)

# ----------------- Pack per-subject series (with fixed masks) -----------------
def first_finite(x: np.ndarray, default: float) -> float:
    idx = np.where(np.isfinite(x))[0]
    if len(idx) > 0:
        return float(x[idx[0]])
    return float(default)

subjects = df["subject_id"].unique().tolist()
subs = []

for s in subjects:
    sub = df[df["subject_id"]==s].sort_values("t_idx").copy()
    t_all = sub["t_idx"].values.astype(float)
    B_all = sub["B_obs"].values.astype(float)
    H_all = sub["H_obs"].values.astype(float)

    # must have enough rows
    if sub.shape[0] < MIN_POINTS_PER_SUBJECT:
        continue

    # finite masks (FIXED over the entire optimization)
    maskB = np.isfinite(B_all)
    maskH = np.isfinite(H_all)
    if maskB.sum() < 3 or maskH.sum() < 3:
        continue

    # store
    subs.append({
        "sid": s,
        "t": t_all,                 # times to simulate at (same grid for model)
        "B_all": B_all, "H_all": H_all,
        "maskB": maskB, "maskH": maskH,
        "nB": int(maskB.sum()), "nH": int(maskH.sum()),
        # init values determined from first finite entries
        "H0": float(np.clip(first_finite(H_all, 0.5), 0, 1)),
        "B0": float(max(0.05, first_finite(B_all, 0.1))),
    })

if not subs:
    raise RuntimeError("No subject has enough finite points to fit (need ≥4 rows and ≥3 finite in both B & H).")

print(f"[info] Fitting {len(subs)} subjects...")
print("[info] Per-subject usable points (finite counts):")
for S in subs[:10]:
    print(f"  - {S['sid']}: B={S['nB']}, H={S['nH']} (of {len(S['t'])})")
if len(subs) > 10:
    print("  ...")

# ----------------- Smooth hysteretic model -----------------
# y = [M, H, B, q]
# p = [r_max,K_M,c,d,g,u,p_low,p_high,H_on,H_off,tau_q]

def q_inf_smooth(H, q, H_on, H_off, k=KQ):
    theta = (1.0 - q) * H_on + q * H_off   # moving threshold between H_on and H_off
    return 1.0 / (1.0 + np.exp(-k * (H - theta)))

def rhs_smooth(t, y, p):
    M, H, B, q = y
    r_max, K_M, c, d, g, u, pL, pH, H_on, H_off, tau = p
    pB = pL + (pH - pL)*np.clip(q, 0, 1)
    dM = (r_max - c*pB)*M*(1 - M/K_M)
    dH = g*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq   = (qinf - q) / tau
    return [dM, dH, dB, dq]

def simulate(ts, y0, p):
    y0 = np.array(y0, dtype=float)
    if not np.all(np.isfinite(y0)):
        T = len(ts)
        return np.vstack([np.full(T, np.nan)]*4)
    try:
        sol = solve_ivp(lambda t,y: rhs_smooth(t,y,p), (ts[0], ts[-1]), y0, t_eval=ts,
                        rtol=1e-6, atol=1e-8, max_step=0.5)
        if not sol.success:
            T = len(ts)
            return np.vstack([np.full(T, np.nan)]*4)
        return sol.y
    except Exception:
        T = len(ts)
        return np.vstack([np.full(T, np.nan)]*4)

# ----------------- Parameters -----------------
# global parameters (11)
LBg = np.array([0.1, 0.4, 0.02, 0.01, 0.05, 0.2, 0.0, 0.5, 0.2, 0.3, 0.5])
UBg = np.array([0.6, 1.5, 0.25, 0.5 , 2.0 , 1.2, 0.8, 4.0, 0.8, 0.95,24.0])
x0g = np.array([0.32, 1.0, 0.10, 0.12, 0.5, 0.6, 0.1, 2.5, 0.55, 0.80, 4.0])  # slightly higher H_off start

# per-subject linear links: Bhat = alpha_B * B ; Hhat = clip(beta0 + beta1*H, 0,1)
x0s, LBs, UBs = [], [], []
for _ in subs:
    x0s += [1.0, 0.0, 1.0]  # alpha_B, beta0_H, beta1_H
    LBs += [0.1, -0.5,  0.1]
    UBs += [5.0,  0.5,  2.0]

x0 = np.concatenate([x0g, np.array(x0s, dtype=float)])
LB = np.concatenate([LBg, np.array(LBs, dtype=float)])
UB = np.concatenate([UBg, np.array(UBs, dtype=float)])

def unpack(x):
    gpar = x[:11]
    spar = x[11:]
    triples = np.split(spar, len(subs))
    return gpar, triples

# ----------------- Residuals (fixed length) -----------------
def residuals(x):
    gpar, triples = unpack(x)

    # soft constraint to keep H_off > H_on
    if gpar[9] <= gpar[8]:
        return np.full(total_length, PENALTY, dtype=float)

    res_list = []   # collect fixed-length chunks in precomputed order

    for (S,tr) in zip(subs, triples):
        alpha_B, beta0, beta1 = tr
        ts = S["t"]

        # initial conditions (robust)
        H0 = float(np.clip(S["H0"], 0, 1))
        B0 = float(max(0.05, S["B0"]))
        M0 = 0.1
        # choose initial q based on H0 relative to mid-threshold to help solver
        q0 = 1.0 if H0 < (0.5*(gpar[8]+gpar[9])) else 0.0
        y0 = [M0, H0, B0, q0]

        Y = simulate(ts, y0, gpar)
        _, H, B, _ = Y

        # If integration failed -> fill with penalties of FIXED lengths
        if np.any(~np.isfinite(H)) or np.any(~np.isfinite(B)):
            res_list.append(np.full(S["nB"], PENALTY, dtype=float))
            res_list.append(np.full(S["nH"], PENALTY, dtype=float))
            continue

        # observation models evaluated at the same time grid, then masked
        Bhat = alpha_B * B
        Hhat = np.clip(beta0 + beta1*H, 0, 1)

        # apply FIXED masks
        b_res = (Bhat[S["maskB"]] - S["B_all"][S["maskB"]])
        h_res = (Hhat[S["maskH"]] - S["H_all"][S["maskH"]])

        # make sure we always push finite values
        if not np.all(np.isfinite(b_res)):
            b_res = np.where(np.isfinite(b_res), b_res, PENALTY)
        if not np.all(np.isfinite(h_res)):
            h_res = np.where(np.isfinite(h_res), h_res, PENALTY)

        # ensure fixed lengths (by construction)
        res_list.append(b_res.astype(float))
        res_list.append(h_res.astype(float))

    return np.concatenate(res_list)

# compute the total residual length ONCE so we can return fixed penalties if needed
total_length = sum(S["nB"] + S["nH"] for S in subs)

# ----------------- Fit -----------------
fit = least_squares(
    residuals, x0, bounds=(LB, UB),
    verbose=2, max_nfev=800,
    loss="soft_l1", f_scale=1.0,
)
gpar_hat, triples_hat = unpack(fit.x)

param_names = ["r_max","K_M","c","d","g","u","p_low","p_high","H_on","H_off","tau_q"]
pd.Series(gpar_hat, index=param_names).to_csv(os.path.join(OUTDIR, "fitted_global_params.csv"))

pd.DataFrame(
    [{"subject_id": S["sid"], "alpha_B": tr[0], "beta0_H": tr[1], "beta1_H": tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR, "fitted_subject_scales.csv"), index=False)

print("[info] Fitted global params:", dict(zip(param_names, gpar_hat)))

# ----------------- Diagnostics (first few subjects) -----------------
for S,tr in list(zip(subs, triples_hat))[:8]:
    alpha_B, beta0, beta1 = tr

    H0 = float(np.clip(S["H0"], 0, 1))
    B0 = float(max(0.05, S["B0"]))
    M0 = 0.1
    q0 = 1.0 if H0 < (0.5*(gpar_hat[8]+gpar_hat[9])) else 0.0
    y0 = [M0, H0, B0, q0]

    Y = simulate(S["t"], y0, gpar_hat)
    M,H,B,q = Y
    if np.any(~np.isfinite(H)) or np.any(~np.isfinite(B)):
        continue
    Bhat = alpha_B*B
    Hhat = np.clip(beta0 + beta1*H, 0, 1)

    maskB = S["maskB"]; maskH = S["maskH"]

    fig, ax = plt.subplots(2,1, figsize=(7,6), sharex=True)
    ax[0].plot(S["t"][maskB], Bhat[maskB], label="Model B (scaled)")
    ax[0].scatter(S["t"][maskB], S["B_all"][maskB], s=18, c="k", label="Obs B (z)")
    ax[0].set_ylabel("B intensity (scaled)")
    ax[0].legend(); ax[0].grid(True, ls=":")

    ax[1].plot(S["t"][maskH], Hhat[maskH], label="Model H")
    ax[1].scatter(S["t"][maskH], S["H_all"][maskH], s=18, c="k", label="H proxy")
    ax[1].axhline(gpar_hat[8], ls=":", c="gray", label="H_on")
    ax[1].axhline(gpar_hat[9], ls="--", c="gray", label="H_off")
    ax[1].set_xlabel("time index"); ax[1].set_ylabel("H")
    ax[1].legend(); ax[1].grid(True, ls=":")

    plt.tight_layout()
    plt.savefig(os.path.join(OUTDIR, f"diag_{S['sid']}.png"), dpi=180)
    plt.close()

print("✅ Done. See outputs in:", OUTDIR)


In [None]:
# bifurcation_and_basins_smooth.py
# ------------------------------------------------------------
# Uses the smooth-hysteresis model you just fitted to:
#   (A) continue equilibria vs d and test stability (Jacobian eigs)
#   (B) check if the baseline d lies in a bistable region
#   (C) draw a basin-of-attraction map at baseline
#   (D) (optional) repeat with weak positive feedback r_H
#
# Input:  mw_fit_out/fitted_global_params.csv  (from your smooth calibration)
# Output: mw_bif_smooth/
#   - branches.csv
#   - bifurcation_H_vs_d.png
#   - basins_heatmap.png
#   - hysteresis_sweep.png
#   - diagnosis.txt
# (and *_pf.* counterparts if POS_FEEDBACK=True)
# ------------------------------------------------------------

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import root

# ----------------- Config -----------------
FIT_CSV       = "mw_fit_out_smooth/fitted_global_params.csv"
OUTDIR        = "mw_bif_smooth"
os.makedirs(OUTDIR, exist_ok=True)

# Smooth switch sharpness (same k you used in calibration; 30–60 works)
KQ            = 40.0

# Continuation range around baseline d (you can widen if needed)
D_SPAN_FACTOR = (0.6, 1.5)      # explore from 0.6×d_fit to 1.5×d_fit
N_D_POINTS    = 80

# Basins grid at baseline
H_GRID = np.linspace(0.2, 0.95, 17)
Q_GRID = np.linspace(0.0, 1.0, 17)

# Optional: add weak positive feedback r_H into microbe growth
POS_FEEDBACK  = True    # <-- set True to test the r_H model
R_H_VALUE     = 0.2     # h^-1 contribution to growth per unit H (small)

# ----------------- Load fitted globals -----------------
g = pd.read_csv(FIT_CSV, index_col=0).squeeze("columns")
# params (smooth model without PF): [r_max,K_M,c,d,g,u,p_low,p_high,H_on,H_off,tau_q]
pars = np.array([
    float(g.get("r_max", 0.32)),
    float(g.get("K_M", 1.0)),
    float(g.get("c", 0.10)),
    float(g.get("d", 0.12)),
    float(g.get("g", 0.5)),
    float(g.get("u", 0.6)),
    float(g.get("p_low", 0.1)),
    float(g.get("p_high", 2.5)),
    float(g.get("H_on", 0.55)),
    float(g.get("H_off", 0.80)),   # you likely refit; read what’s in CSV
    float(g.get("tau_q", 4.0)),
], float)

# ----------------- Smooth hysteresis helpers -----------------
def q_inf_smooth(H, q, H_on, H_off, k=KQ):
    theta = (1.0 - q) * H_on + q * H_off
    return 1.0 / (1.0 + np.exp(-k * (H - theta)))

def rhs_smooth(y, p, d_override=None):
    """No positive feedback (baseline smooth model)."""
    M, H, B, q = y
    r_max,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau = p.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dM = (r_max - c*pB)*M*(1 - M/K_M)
    dH = gH*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq   = (qinf - q)/tau
    return np.array([dM,dH,dB,dq], float)

def rhs_pf(y, p, d_override=None, rH=R_H_VALUE):
    """With weak positive feedback r_H: r_eff = r_max + rH*H."""
    M, H, B, q = y
    r_max,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau = p.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r_max + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = gH*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq   = (qinf - q)/tau
    return np.array([dM,dH,dB,dq], float)

# choose which RHS to use
RHS = rhs_pf if POS_FEEDBACK else rhs_smooth

# ----------------- Utilities -----------------
def jacobian_fd(fun, y, p, d_val=None, eps=1e-7):
    f0 = fun(y, p, d_val)
    J = np.zeros((4,4))
    for i in range(4):
        y2 = y.copy(); y2[i] += eps
        J[:,i] = (fun(y2, p, d_val) - f0)/eps
    return J

def find_equilibrium(fun, p, d_val, guess):
    sol = root(lambda yy: fun(yy, p, d_val), guess, method="hybr")
    if not sol.success:
        return guess, False
    y = sol.x
    # project to physical range
    y = np.array([
        max(0.0, y[0]),
        np.clip(y[1], 0.0, 1.2),
        max(0.0, y[2]),
        np.clip(y[3], 0.0, 1.2),
    ], float)
    if not np.all(np.isfinite(y)):
        return guess, False
    return y, True

def relax_to_ss(fun, p, d_val, y0, T=240):
    sol = solve_ivp(lambda t,yy: fun(yy, p, d_val), (0,T), y0,
                    t_eval=np.linspace(0,T,900), rtol=1e-6, atol=1e-8, max_step=0.5)
    return sol.y[:,-1], sol

# ----------------- (A) Continuation in d -----------------
d_fit = float(pars[3])
d_vals = np.linspace(d_fit*D_SPAN_FACTOR[0], d_fit*D_SPAN_FACTOR[1], N_D_POINTS)

seeds = [
    np.array([0.2, 0.2, 0.05, 1.0]),   # low-H / q≈1
    np.array([0.2, 0.9, 0.10, 0.0]),   # high-H / q≈0
    np.array([0.6, 0.6, 0.20, 0.5]),   # mid
]

rows = []
for d in d_vals:
    for wi, y0 in enumerate(seeds):
        y_eq, ok = find_equilibrium(RHS, pars, d, y0)
        if ok:
            J = jacobian_fd(RHS, y_eq, pars, d_val=d)
            eigs = np.linalg.eigvals(J)
            stable = bool(np.max(np.real(eigs)) < 0)
            rows.append({"d": d, "H": float(y_eq[1]), "q": float(y_eq[3]),
                         "seed": wi, "stable": stable})
branches = pd.DataFrame(rows)
branches.to_csv(os.path.join(OUTDIR, "branches.csv"), index=False)

# plot branches
plt.figure(figsize=(7.2,5.2))
for wi in sorted(branches["seed"].unique()):
    sub = branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True, "o"), (False, "x")]:
    sub = branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6,
                label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H* at equilibrium")
plt.legend(); plt.grid(True, ls=":", alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "bifurcation_H_vs_d.png"), dpi=180)
plt.close()

# ----------------- (B) Is baseline d inside a bistable band? -----------------
near = branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
# count distinct stable points (tolerance to avoid duplicate seeds converging to same eq)
if not near.empty:
    Hs = np.sort(near.loc[near["stable"], "H"].values)
    distinct = 0
    if Hs.size:
        distinct = 1
        for i in range(1, len(Hs)):
            if abs(Hs[i] - Hs[i-1]) > 1e-3:  # distinct H*
                distinct += 1
else:
    distinct = 0
bistable = bool(distinct >= 2)

# ----------------- (C) Basin-of-attraction map at baseline -----------------
Z = np.zeros((len(H_GRID), len(Q_GRID)))
for i, H0 in enumerate(H_GRID):
    for j, q0 in enumerate(Q_GRID):
        y0 = np.array([0.2, H0, 0.1, q0], float)
        yss, _ = relax_to_ss(RHS, pars, d_fit, y0, T=300)
        Z[i, j] = yss[1]

plt.figure(figsize=(6.6,5.4))
plt.imshow(Z, origin="lower",
           extent=[Q_GRID[0], Q_GRID[-1], H_GRID[0], H_GRID[-1]],
           aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "basins_heatmap.png"), dpi=180)
plt.close()

# ----------------- (D) Dynamic hysteresis sweep (for comparison) -----------------
def sweep(fun, p, H0, q0, d_lo, d_hi, T=160):
    y = np.array([0.1, H0, 0.1, q0], float)
    grid = np.linspace(d_lo, d_hi, 28)
    fwd, bwd = [], []
    for d in grid:
        y, _ = relax_to_ss(fun, p, d, y, T=T)
        fwd.append((d, *y))
    for d in grid[::-1]:
        y, _ = relax_to_ss(fun, p, d, y, T=T)
        bwd.append((d, *y))
    return pd.DataFrame(fwd, columns=["d","M","H","B","q"]), \
           pd.DataFrame(bwd, columns=["d","M","H","B","q"])

fwd, bwd = sweep(RHS, pars, H0=0.8, q0=0.0,
                 d_lo=d_fit*D_SPAN_FACTOR[0], d_hi=d_fit*D_SPAN_FACTOR[1], T=180)

plt.figure(figsize=(7.6,4.2))
plt.plot(fwd["d"], fwd["H"], "-o", ms=3, label="forward (d↑)")
plt.plot(bwd["d"], bwd["H"], "-s", ms=3, label="backward (d↓)")
plt.axhline(pars[8], ls=":", c="gray", label="H_on")
plt.axhline(pars[9], ls="--", c="gray", label="H_off")
plt.xlabel("d (1/h)"); plt.ylabel("H* (dynamic SS)"); plt.legend()
plt.grid(True, ls=":", alpha=0.6); plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "hysteresis_sweep.png"), dpi=180)
plt.close()

# ----------------- Report -----------------
with open(os.path.join(OUTDIR, "diagnosis.txt"), "w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline (by H*): {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")
    f.write(f"Model used: {'smooth + positive feedback (r_H={R_H_VALUE})' if POS_FEEDBACK else 'smooth, no PF'}\n")

print("Saved results to:", OUTDIR)
print("Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# calibrate_hysteresis_from_scored_pf_prior.py
# ------------------------------------------------------------
# Calibrate smooth-hysteresis model with host->microbe positive feedback (r_H)
# and gentle priors; fixed-length residuals; corrected priors block.
# ------------------------------------------------------------

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ----------------- Paths & config -----------------
INPATH  = "timeseries/combined_scfas_table_scored.csv"
OUTDIR  = "mw_fit_out_pf"
os.makedirs(OUTDIR, exist_ok=True)

USE_H_COL  = "H_proxy_meta_smooth"     # fallback to "H_proxy_meta" if not found
USE_B_COLS = ["butyrate"]

MIN_POINTS_PER_SUBJECT = 4
KQ      = 40.0     # smooth switch steepness
PENALTY = 1e3      # finite penalty fallback (keeps residual length fixed)

# ---- Priors (edit here if desired) ----
H_OFF_TARGET   = 0.85
H_OFF_SD       = 0.05
R_EFF_H_REF    = 0.8
R_EFF_TARGET   = 0.35   # ~ your earlier r_max
R_EFF_SD       = 0.05
U_SD           = 0.25

# ----------------- Load data -----------------
df = pd.read_csv(INPATH)
needed = {"subject_id","sample_id"}
if missing := [c for c in needed if c not in df.columns]:
    raise ValueError(f"Missing required columns: {missing}")

if USE_H_COL not in df.columns:
    alt = "H_proxy_meta"
    if alt in df.columns:
        print(f"[info] {USE_H_COL} not found; using {alt} instead.")
        USE_H_COL = alt
    else:
        raise ValueError("No usable H proxy column (need H_proxy_meta_smooth or H_proxy_meta).")

for c in USE_B_COLS:
    if c not in df.columns:
        raise ValueError(f"SCFA column '{c}' not in data.")

keep = ["subject_id","sample_id", USE_H_COL] + USE_B_COLS
df = df[keep].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

# ----------------- Robust scaling -----------------
def robust_mad_scale(series: pd.Series) -> pd.Series:
    x = series.astype(float).to_numpy()
    m = np.isfinite(x)
    if m.sum() == 0:
        return pd.Series(np.zeros_like(x), index=series.index)
    xm = x[m]
    med = np.median(xm)
    mad = np.median(np.abs(xm - med))
    if mad < 1e-9:
        q75, q25 = np.percentile(xm, [75, 25])
        iqr = q75 - q25
        scale = iqr if iqr > 1e-9 else (np.std(xm) + 1e-9)
    else:
        scale = mad
    return pd.Series((x - med) / (scale + 1e-9), index=series.index)

for c in USE_B_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_mad_scale)

df["B_obs"] = df[USE_B_COLS[0] + "_z"] if len(USE_B_COLS)==1 else df[[c+"_z" for c in USE_B_COLS]].mean(axis=1)
df["H_obs"] = df[USE_H_COL].clip(0,1)

# ----------------- Pack per-subject with fixed masks -----------------
def first_finite(a, default):
    idx = np.where(np.isfinite(a))[0]
    return float(a[idx[0]]) if len(idx) else float(default)

subjects = df["subject_id"].unique().tolist()
subs = []
for sid in subjects:
    sub = df[df["subject_id"]==sid].sort_values("t_idx").copy()
    if len(sub) < MIN_POINTS_PER_SUBJECT:
        continue
    t  = sub["t_idx"].values.astype(float)
    H  = sub["H_obs"].values.astype(float)
    B  = sub["B_obs"].values.astype(float)
    mH = np.isfinite(H)
    mB = np.isfinite(B)
    if mH.sum() < 3 or mB.sum() < 3:
        continue
    subs.append({
        "sid": sid,
        "t": t,
        "H": H,
        "B": B,
        "maskH": mH, "maskB": mB,
        "nH": int(mH.sum()), "nB": int(mB.sum()),
        "H0": float(np.clip(first_finite(H, 0.6), 0, 1)),
        "B0": float(max(0.05, first_finite(B, 0.1))),
    })

if not subs:
    raise RuntimeError("No subject has enough finite points to fit.")

print(f"[info] Fitting {len(subs)} subjects...")
for S in subs[:10]:
    print(f"  - {S['sid']}: B={S['nB']} H={S['nH']} (of {len(S['t'])})")
if len(subs)>10: print("  ...")

# ----------------- Smooth hysteresis + positive feedback -----------------
# y=[M,H,B,q]
# global p = [r0, rH, K_M, c, d, g, u, p_low, p_high, H_on, H_off, tau_q]

def q_inf_smooth(H, q, H_on, H_off, k=KQ):
    theta = (1.0 - q)*H_on + q*H_off
    return 1.0 / (1.0 + np.exp(-k*(H - theta)))

def rhs_pf(t, y, p):
    M, H, B, q = y
    r0, rH, K_M, c, d, gH, u, pL, pH, H_on, H_off, tau = p
    pB  = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = gH*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq  = (qinf - q)/tau
    return [dM,dH,dB,dq]

def simulate(ts, y0, p):
    try:
        sol = solve_ivp(lambda t,y: rhs_pf(t,y,p), (ts[0], ts[-1]), y0,
                        t_eval=ts, rtol=1e-6, atol=1e-8, max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*4)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*4)

# ----------------- Parameter boxes -----------------
LBg = np.array([0.05, 0.00, 0.4, 0.02, 0.01, 0.05, 0.2, 0.0, 0.5, 0.2, 0.60, 0.5])
UBg = np.array([0.60, 0.20, 1.6, 0.25, 0.50, 2.00, 1.2, 0.8, 4.0, 0.8, 0.95, 24.0])
x0g = np.array([0.30, 0.08, 1.0, 0.10, 0.12, 0.55, 0.65, 0.10, 2.2, 0.55, 0.85, 4.0])

x0s, LBs, UBs = [], [], []
for _ in subs:
    x0s += [1.0, 0.0, 1.0]   # alpha_B, beta0_H, beta1_H
    LBs += [0.1, -0.5, 0.1]
    UBs += [5.0,  0.5,  2.0]

x0 = np.concatenate([x0g, np.array(x0s, float)])
LB = np.concatenate([LBg, np.array(LBs, float)])
UB = np.concatenate([UBg, np.array(UBs, float)])

def unpack(x):
    gpar = x[:12]
    triples = np.split(x[12:], len(subs))
    return gpar, triples

# ----------------- Fixed-length residuals + priors -----------------
N_PRIORS = 3  # H_off, r_eff at H=0.8, u regularization
total_len = sum(S["nB"] + S["nH"] for S in subs) + N_PRIORS

def residuals(x):
    gpar, triples = unpack(x)

    # enforce H_off > H_on
    if not (gpar[10] > gpar[9]):
        return np.full(total_len, PENALTY, float)

    res = []

    # data residuals
    for S, tr in zip(subs, triples):
        alpha_B, beta0, beta1 = tr
        ts = S["t"]
        H0 = float(np.clip(S["H0"], 0, 1))
        B0 = float(max(0.05, S["B0"]))
        M0 = 0.1
        q0 = 1.0 if H0 < 0.5*(gpar[9] + gpar[10]) else 0.0
        y0 = [M0, H0, B0, q0]

        Y = simulate(ts, y0, gpar)
        _, H, B, _ = Y

        if np.any(~np.isfinite(H)) or np.any(~np.isfinite(B)):
            res.append(np.full(S["nB"], PENALTY))
            res.append(np.full(S["nH"], PENALTY))
            continue

        Bhat = alpha_B * B
        Hhat = np.clip(beta0 + beta1*H, 0, 1)

        b_res = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h_res = (Hhat[S["maskH"]] - S["H"][S["maskH"]])

        b_res = np.where(np.isfinite(b_res), b_res, PENALTY)
        h_res = np.where(np.isfinite(h_res), h_res, PENALTY)

        res.append(b_res.astype(float))
        res.append(h_res.astype(float))

    # ---- Priors (3 numbers, fixed length) ----
    # Indices: r0=0, rH=1, u=6, H_on=9, H_off=10
    res_Hoff = (gpar[10] - H_OFF_TARGET) / (H_OFF_SD + 1e-9)
    r_eff    = gpar[0] + gpar[1] * R_EFF_H_REF
    res_reff = (r_eff - R_EFF_TARGET) / (R_EFF_SD + 1e-9)
    res_u    = (gpar[6] - 0.6) / (U_SD + 1e-9)

    res.append(np.array([res_Hoff, res_reff, res_u], float))

    return np.concatenate(res)

# ----------------- Fit -----------------
fit = least_squares(
    residuals, x0, bounds=(LB, UB),
    verbose=2, max_nfev=900,
    loss="soft_l1", f_scale=1.0
)
gpar_hat, triples_hat = unpack(fit.x)

param_names = ["r0","rH","K_M","c","d","g","u","p_low","p_high","H_on","H_off","tau_q"]
pd.Series(gpar_hat, index=param_names).to_csv(os.path.join(OUTDIR, "fitted_global_params.csv"))

pd.DataFrame(
    [{"subject_id": S["sid"], "alpha_B": tr[0], "beta0_H": tr[1], "beta1_H": tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR, "fitted_subject_scales.csv"), index=False)

print("[info] Fitted global params:", dict(zip(param_names, gpar_hat)))

# ----------------- Quick diagnostics -----------------
for S,tr in list(zip(subs, triples_hat))[:8]:
    alpha_B, beta0, beta1 = tr
    H0 = float(np.clip(S["H0"], 0, 1))
    B0 = float(max(0.05, S["B0"]))
    M0 = 0.1
    q0 = 1.0 if H0 < 0.5*(gpar_hat[9] + gpar_hat[10]) else 0.0
    y0 = [M0, H0, B0, q0]

    def rhs_for_plot(t, y): return rhs_pf(t, y, gpar_hat)
    try:
        sol = solve_ivp(rhs_for_plot, (S["t"][0], S["t"][-1]), y0,
                        t_eval=S["t"], rtol=1e-6, atol=1e-8, max_step=0.5)
        M,H,B,q = sol.y
    except Exception:
        continue

    Bhat = alpha_B*B
    Hhat = np.clip(beta0 + beta1*H, 0, 1)
    mB, mH = S["maskB"], S["maskH"]

    fig, ax = plt.subplots(2,1, figsize=(7,6), sharex=True)
    ax[0].plot(S["t"][mB], Bhat[mB], label="Model B (scaled)")
    ax[0].scatter(S["t"][mB], S["B"][mB], s=18, c="k", label="Obs B (z)")
    ax[0].set_ylabel("B intensity (scaled)"); ax[0].legend(); ax[0].grid(True, ls=":")

    ax[1].plot(S["t"][mH], Hhat[mH], label="Model H")
    ax[1].scatter(S["t"][mH], S["H"][mH], s=18, c="k", label="H proxy")
    ax[1].axhline(gpar_hat[9], ls=":", c="gray", label="H_on")
    ax[1].axhline(gpar_hat[10], ls="--", c="gray", label="H_off")
    ax[1].set_xlabel("time idx"); ax[1].set_ylabel("H"); ax[1].legend(); ax[1].grid(True, ls=":")

    plt.tight_layout(); plt.savefig(os.path.join(OUTDIR, f"diag_{S['sid']}.png"), dpi=180); plt.close()

print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_and_basins_pf.py
# ------------------------------------------------------------
# Uses the smooth-hysteresis model you just fitted to:
#   (A) continue equilibria vs d and test stability (Jacobian eigs)
#   (B) check if the baseline d lies in a bistable region
#   (C) draw a basin-of-attraction map at baseline
#   (D) (optional) repeat with weak positive feedback r_H
#
# Input:  mw_fit_out/fitted_global_params.csv  (from your smooth calibration)
# Output: mw_bif_smooth/
#   - branches.csv
#   - bifurcation_H_vs_d.png
#   - basins_heatmap.png
#   - hysteresis_sweep.png
#   - diagnosis.txt
# (and *_pf.* counterparts if POS_FEEDBACK=True)
# ------------------------------------------------------------

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import root

# ----------------- Config -----------------
FIT_CSV       = "mw_fit_out_pf/fitted_global_params.csv"
OUTDIR        = "mw_bif_pf"
os.makedirs(OUTDIR, exist_ok=True)

# Smooth switch sharpness (same k you used in calibration; 30–60 works)
KQ            = 40.0

# Continuation range around baseline d (you can widen if needed)
D_SPAN_FACTOR = (0.6, 1.5)      # explore from 0.6×d_fit to 1.5×d_fit
N_D_POINTS    = 80

# Basins grid at baseline
H_GRID = np.linspace(0.2, 0.95, 17)
Q_GRID = np.linspace(0.0, 1.0, 17)

# Optional: add weak positive feedback r_H into microbe growth
POS_FEEDBACK  = True    # <-- set True to test the r_H model
R_H_VALUE     = 0.05     # h^-1 contribution to growth per unit H (small)

# ----------------- Load fitted globals -----------------
g = pd.read_csv(FIT_CSV, index_col=0).squeeze("columns")
# params (smooth model without PF): [r_max,K_M,c,d,g,u,p_low,p_high,H_on,H_off,tau_q]
pars = np.array([
    float(g.get("r_max", 0.32)),
    float(g.get("K_M", 1.0)),
    float(g.get("c", 0.10)),
    float(g.get("d", 0.12)),
    float(g.get("g", 0.5)),
    float(g.get("u", 0.6)),
    float(g.get("p_low", 0.1)),
    float(g.get("p_high", 2.5)),
    float(g.get("H_on", 0.55)),
    float(g.get("H_off", 0.85)),   # you likely refit; read what’s in CSV
    float(g.get("tau_q", 4.0)),
], float)

# ----------------- Smooth hysteresis helpers -----------------
def q_inf_smooth(H, q, H_on, H_off, k=KQ):
    theta = (1.0 - q) * H_on + q * H_off
    return 1.0 / (1.0 + np.exp(-k * (H - theta)))

def rhs_smooth(y, p, d_override=None):
    """No positive feedback (baseline smooth model)."""
    M, H, B, q = y
    r_max,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau = p.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dM = (r_max - c*pB)*M*(1 - M/K_M)
    dH = gH*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq   = (qinf - q)/tau
    return np.array([dM,dH,dB,dq], float)

def rhs_pf(y, p, d_override=None, rH=R_H_VALUE):
    """With weak positive feedback r_H: r_eff = r_max + rH*H."""
    M, H, B, q = y
    r_max,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau = p.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r_max + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = gH*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq   = (qinf - q)/tau
    return np.array([dM,dH,dB,dq], float)

# choose which RHS to use
RHS = rhs_pf if POS_FEEDBACK else rhs_smooth

# ----------------- Utilities -----------------
def jacobian_fd(fun, y, p, d_val=None, eps=1e-7):
    f0 = fun(y, p, d_val)
    J = np.zeros((4,4))
    for i in range(4):
        y2 = y.copy(); y2[i] += eps
        J[:,i] = (fun(y2, p, d_val) - f0)/eps
    return J

def find_equilibrium(fun, p, d_val, guess):
    sol = root(lambda yy: fun(yy, p, d_val), guess, method="hybr")
    if not sol.success:
        return guess, False
    y = sol.x
    # project to physical range
    y = np.array([
        max(0.0, y[0]),
        np.clip(y[1], 0.0, 1.2),
        max(0.0, y[2]),
        np.clip(y[3], 0.0, 1.2),
    ], float)
    if not np.all(np.isfinite(y)):
        return guess, False
    return y, True

def relax_to_ss(fun, p, d_val, y0, T=240):
    sol = solve_ivp(lambda t,yy: fun(yy, p, d_val), (0,T), y0,
                    t_eval=np.linspace(0,T,900), rtol=1e-6, atol=1e-8, max_step=0.5)
    return sol.y[:,-1], sol

# ----------------- (A) Continuation in d -----------------
d_fit = float(pars[3])
d_vals = np.linspace(d_fit*D_SPAN_FACTOR[0], d_fit*D_SPAN_FACTOR[1], N_D_POINTS)

seeds = [
    np.array([0.2, 0.2, 0.05, 1.0]),   # low-H / q≈1
    np.array([0.2, 0.9, 0.10, 0.0]),   # high-H / q≈0
    np.array([0.6, 0.6, 0.20, 0.5]),   # mid
]

rows = []
for d in d_vals:
    for wi, y0 in enumerate(seeds):
        y_eq, ok = find_equilibrium(RHS, pars, d, y0)
        if ok:
            J = jacobian_fd(RHS, y_eq, pars, d_val=d)
            eigs = np.linalg.eigvals(J)
            stable = bool(np.max(np.real(eigs)) < 0)
            rows.append({"d": d, "H": float(y_eq[1]), "q": float(y_eq[3]),
                         "seed": wi, "stable": stable})
branches = pd.DataFrame(rows)
branches.to_csv(os.path.join(OUTDIR, "branches.csv"), index=False)

# plot branches
plt.figure(figsize=(7.2,5.2))
for wi in sorted(branches["seed"].unique()):
    sub = branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True, "o"), (False, "x")]:
    sub = branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6,
                label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H* at equilibrium")
plt.legend(); plt.grid(True, ls=":", alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "bifurcation_H_vs_d.png"), dpi=180)
plt.close()

# ----------------- (B) Is baseline d inside a bistable band? -----------------
near = branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
# count distinct stable points (tolerance to avoid duplicate seeds converging to same eq)
if not near.empty:
    Hs = np.sort(near.loc[near["stable"], "H"].values)
    distinct = 0
    if Hs.size:
        distinct = 1
        for i in range(1, len(Hs)):
            if abs(Hs[i] - Hs[i-1]) > 1e-3:  # distinct H*
                distinct += 1
else:
    distinct = 0
bistable = bool(distinct >= 2)

# ----------------- (C) Basin-of-attraction map at baseline -----------------
Z = np.zeros((len(H_GRID), len(Q_GRID)))
for i, H0 in enumerate(H_GRID):
    for j, q0 in enumerate(Q_GRID):
        y0 = np.array([0.2, H0, 0.1, q0], float)
        yss, _ = relax_to_ss(RHS, pars, d_fit, y0, T=300)
        Z[i, j] = yss[1]

plt.figure(figsize=(6.6,5.4))
plt.imshow(Z, origin="lower",
           extent=[Q_GRID[0], Q_GRID[-1], H_GRID[0], H_GRID[-1]],
           aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "basins_heatmap.png"), dpi=180)
plt.close()

# ----------------- (D) Dynamic hysteresis sweep (for comparison) -----------------
def sweep(fun, p, H0, q0, d_lo, d_hi, T=160):
    y = np.array([0.1, H0, 0.1, q0], float)
    grid = np.linspace(d_lo, d_hi, 28)
    fwd, bwd = [], []
    for d in grid:
        y, _ = relax_to_ss(fun, p, d, y, T=T)
        fwd.append((d, *y))
    for d in grid[::-1]:
        y, _ = relax_to_ss(fun, p, d, y, T=T)
        bwd.append((d, *y))
    return pd.DataFrame(fwd, columns=["d","M","H","B","q"]), \
           pd.DataFrame(bwd, columns=["d","M","H","B","q"])

fwd, bwd = sweep(RHS, pars, H0=0.8, q0=0.0,
                 d_lo=d_fit*D_SPAN_FACTOR[0], d_hi=d_fit*D_SPAN_FACTOR[1], T=180)

plt.figure(figsize=(7.6,4.2))
plt.plot(fwd["d"], fwd["H"], "-o", ms=3, label="forward (d↑)")
plt.plot(bwd["d"], bwd["H"], "-s", ms=3, label="backward (d↓)")
plt.axhline(pars[8], ls=":", c="gray", label="H_on")
plt.axhline(pars[9], ls="--", c="gray", label="H_off")
plt.xlabel("d (1/h)"); plt.ylabel("H* (dynamic SS)"); plt.legend()
plt.grid(True, ls=":", alpha=0.6); plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "hysteresis_sweep.png"), dpi=180)
plt.close()

# ----------------- Report -----------------
with open(os.path.join(OUTDIR, "diagnosis.txt"), "w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline (by H*): {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")
    f.write(f"Model used: {'smooth + positive feedback (r_H={R_H_VALUE})' if POS_FEEDBACK else 'smooth, no PF'}\n")

print("Saved results to:", OUTDIR)
print("Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# find_bistable_regions.py
# Scans a small neighborhood around your current fit to find parameter sets
# that yield TWO distinct stable equilibria at the baseline d (true bistability).
# Uses the smooth hysteretic switch (same as your current model) and optional PF.

import numpy as np, pandas as pd
from scipy.optimize import root
from numpy.linalg import eigvals

FIT = "mw_fit_out_pf/fitted_global_params.csv"   # or mw_fit_out/... if you're using the smooth-no-PF fit
OUT = "mw_scan_bistability.csv"

# ----- Model toggles -----
USE_PF = False       # True: dot M = (r0 + rH*H - c*pB) M (1 - M/KM); False: rH ignored, use r_max
KQ     = 40.0       # smooth switch steepness

g = pd.read_csv(FIT, index_col=0).squeeze("columns")

if USE_PF:
    # params: [r0, rH, K_M, c, d, g, u, p_low, p_high, H_on, H_off, tau_q]
    pars0 = np.array([
        float(g.get("r0", g.get("r_max", 0.30))),
        float(g.get("rH", 0.08)),
        float(g.get("K_M", 1.0)),
        float(g.get("c", 0.10)),
        float(g.get("d", 0.12)),
        float(g.get("g", 0.55)),
        float(g.get("u", 0.60)),
        float(g.get("p_low", 0.10)),
        float(g.get("p_high", 2.20)),
        float(g.get("H_on", 0.55)),
        float(g.get("H_off", 0.85)),
        float(g.get("tau_q", 4.0)),
    ], float)
else:
    # fall back to smooth-no-PF; map to the same slots (rH=0)
    pars0 = np.array([
        float(g.get("r_max", 0.35)),
        0.0,
        float(g.get("K_M", 1.0)),
        float(g.get("c", 0.10)),
        float(g.get("d", 0.12)),
        float(g.get("g", 0.55)),
        float(g.get("u", 0.60)),
        float(g.get("p_low", 0.10)),
        float(g.get("p_high", 2.20)),
        float(g.get("H_on", 0.55)),
        float(g.get("H_off", 0.85)),
        float(g.get("tau_q", 4.0)),
    ], float)

def q_inf_smooth(H,q,H_on,H_off,k=KQ):
    theta = (1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - theta)))

def rhs(y, p, d_override=None):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau = p.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = (r0 + rH*H)  # rH may be 0 if USE_PF=False
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = gH*B*(1 - H) - d*H
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H,q,H_on,H_off)
    dq   = (qinf - q)/tau
    return np.array([dM,dH,dB,dq], float)

def jac_fd(fun,y,p,dval,eps=1e-7):
    f0=fun(y,p,dval); J=np.zeros((4,4))
    for i in range(4):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,p,dval)-f0)/eps
    return J

def find_eq(p, dval, guess):
    sol = root(lambda yy: rhs(yy,p,dval), guess, method="hybr")
    if not sol.success: return None, False
    y = sol.x
    y = np.array([max(0,y[0]), np.clip(y[1],0,1.2), max(0,y[2]), np.clip(y[3],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# ---- Scan hyper-rectangle around the fit (tight first; enlarge if empty)
rH_grid    = [pars0[1]] if not USE_PF else np.linspace(max(0,pars0[1]-0.06), pars0[1]+0.06, 5)
u_grid     = np.linspace(max(0.3,pars0[6]-0.2), pars0[6]+0.2, 7)
Hoff_grid  = np.linspace(max(0.75,pars0[10]-0.08), min(0.95,pars0[10]+0.04), 7)
pH_grid    = np.linspace(max(1.2,pars0[8]-0.7), pars0[8]+0.7, 7)
c_grid     = np.linspace(max(0.05,pars0[3]-0.05), pars0[3]+0.05, 5)

d_fit = float(pars0[4])
seeds = [
    np.array([0.2, 0.25, 0.05, 1.0]),
    np.array([0.2, 0.90, 0.10, 0.0]),
    np.array([0.6, 0.60, 0.20, 0.5]),
]

rows=[]
for rH in rH_grid:
    for u in u_grid:
        for H_off in Hoff_grid:
            for pH in pH_grid:
                for c in c_grid:
                    p = pars0.copy()
                    p[1]=rH; p[6]=u; p[10]=H_off; p[8]=pH; p[3]=c
                    H_stables=[]
                    for y0 in seeds:
                        y, ok = find_eq(p, d_fit, y0)
                        if not ok: continue
                        J = jac_fd(rhs, y, p, d_fit)
                        st = (np.max(np.real(eigvals(J)))<0)
                        if st: H_stables.append(float(y[1]))
                    H_stables = np.array(sorted(H_stables))
                    # count distinct stable points by H* separation
                    distinct=0
                    if H_stables.size>0:
                        distinct=1
                        for i in range(1,len(H_stables)):
                            if abs(H_stables[i]-H_stables[i-1])>1e-3:
                                distinct+=1
                    bistable = (distinct>=2)
                    rows.append({
                        "rH":rH, "u":u, "H_off":H_off, "p_high":pH, "c":c,
                        "distinct_stable":distinct, "bistable":bistable
                    })

df = pd.DataFrame(rows)
df.to_csv(OUT, index=False)
print("Saved grid results ->", OUT)
print("Bistable combos at baseline d:", int(df["bistable"].sum()))
if df["bistable"].any():
    print(df[df["bistable"]].head(10))


## Hill host benefit + smooth hysteresis + weak positive feedback

In [None]:
# calibrate_hysteresis_from_scored_hill_pf.py
# -------------------------------------------------------------------
# Fits a Hill nonlinearity in host benefit, smooth hysteresis memory,
# and weak host->microbe positive feedback to your scored time series.
# Fixed-length residuals; gentle priors to avoid monostable extremes.
# -------------------------------------------------------------------

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ----------------- Paths & basic config -----------------
INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_hill_pf"
os.makedirs(OUTDIR, exist_ok=True)

H_COL_CANDIDATES = ["H_proxy_meta_smooth", "H_proxy_meta"]
SCFA_COLS = ["butyrate"]   # you can add more (will average z-scores)
MIN_ROWS = 4
KQ = 40.0                  # smooth switch steepness
PENALTY = 1e3              # finite penalty for failed integration

# ---- Priors (soft, data-friendly) ----
H_OFF_TARGET = 0.85; H_OFF_SD = 0.05
R_EFF_H_REF = 0.8;   R_EFF_SD = 0.05
U_TARGET = 0.6;      U_SD = 0.25
K_B_TARGET = 0.20;   K_B_SD = 0.08   # Hill half-saturation (in "B" units)
N_PRIORS = 4                           # total pseudo-residuals

# ----------------- Load data -----------------
df = pd.read_csv(INPATH)
if not {"subject_id","sample_id"}.issubset(df.columns):
    raise ValueError("CSV must contain 'subject_id' and 'sample_id'.")

Hcol = next((c for c in H_COL_CANDIDATES if c in df.columns), None)
if Hcol is None:
    raise ValueError("Need H_proxy_meta_smooth or H_proxy_meta in the CSV.")

for c in SCFA_COLS:
    if c not in df.columns:
        raise ValueError(f"Missing SCFA column '{c}' in CSV.")

keep = ["subject_id","sample_id", Hcol] + SCFA_COLS
df = df[keep].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

# robust (MAD) z-scoring per subject for SCFA intensities
def robust_z(series: pd.Series) -> pd.Series:
    x = series.astype(float).to_numpy()
    m = np.isfinite(x)
    if m.sum() == 0:
        return pd.Series(np.zeros_like(x), index=series.index)
    xm = x[m]
    med = np.median(xm)
    mad = np.median(np.abs(xm - med))
    if mad < 1e-9:
        q75, q25 = np.percentile(xm, [75, 25]); iqr = q75 - q25
        scale = iqr if iqr > 1e-9 else (np.std(xm) + 1e-9)
    else:
        scale = mad
    return pd.Series((x - med) / (scale + 1e-9), index=series.index)

for c in SCFA_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_z)

if len(SCFA_COLS) == 1:
    df["B_obs"] = df[SCFA_COLS[0] + "_z"]
else:
    df["B_obs"] = df[[c+"_z" for c in SCFA_COLS]].mean(axis=1)

df["H_obs"] = df[Hcol].clip(0,1)

# ----------------- Build per-subject series (fixed masks) -----------------
def first_finite(a, default):
    a = np.asarray(a, float)
    idx = np.where(np.isfinite(a))[0]
    return float(a[idx[0]]) if len(idx) else float(default)

subs = []
for sid, sub in df.groupby("subject_id"):
    sub = sub.sort_values("t_idx").copy()
    if len(sub) < MIN_ROWS:
        continue
    t = sub["t_idx"].values.astype(float)
    B = sub["B_obs"].values.astype(float)
    H = sub["H_obs"].values.astype(float)
    mB = np.isfinite(B)
    mH = np.isfinite(H)
    if mB.sum() < 3 or mH.sum() < 3:
        continue
    subs.append({
        "sid": sid,
        "t": t,
        "B": B, "H": H,
        "maskB": mB, "maskH": mH,
        "nB": int(mB.sum()), "nH": int(mH.sum()),
        "H0": float(np.clip(first_finite(H, 0.6), 0, 1)),
        "B0": float(max(0.05, first_finite(B, 0.1))),
    })

if not subs:
    raise RuntimeError("No subject passed the minimal data filters (need ≥4 rows, ≥3 finite B & H).")

print(f"[info] Fitting {len(subs)} subjects...")
for S in subs[:10]:
    print(f"  - {S['sid']}: B={S['nB']} H={S['nH']} (of {len(S['t'])})")
if len(subs) > 10: print("  ...")

# ----------------- Model: Hill host benefit + smooth hysteresis + PF -----------------
# y = [M, H, B, q]
# global params p = [r0, rH, K_M, c, d, g, u, p_low, p_high, H_on, H_off, tau_q, K_B]
NAMES = ["r0","rH","K_M","c","d","g","u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def q_inf_smooth(H, q, H_on, H_off, k=KQ):
    theta = (1.0 - q)*H_on + q*H_off
    return 1.0 / (1.0 + np.exp(-k * (H - theta)))

def dH_term_hill(B, H, gH, d, K_B, n=2):
    # Hill benefit in B (monotone, saturating), multiplied by (1-H)
    ben = gH * (B**n / (K_B**n + B**n)) * (1.0 - H)
    return ben - d*H

def rhs_hill_pf(t, y, p):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau,K_B = p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_term_hill(B, H, gH, d, K_B, n=2)
    dB = pB*M - u*H*B
    qinf = q_inf_smooth(H, q, H_on, H_off, k=KQ)
    dq   = (qinf - q)/tau
    return [dM,dH,dB,dq]

def simulate(ts, y0, p):
    try:
        sol = solve_ivp(lambda t,z: rhs_hill_pf(t,z,p),
                        (ts[0], ts[-1]), y0, t_eval=ts,
                        rtol=1e-6, atol=1e-8, max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*4)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*4)

# ----------------- Parameter boxes & initials -----------------
LBg = np.array([0.05, 0.00, 0.4, 0.02, 0.01, 0.05, 0.15, 0.0, 0.8, 0.2, 0.70, 0.5, 0.05])
UBg = np.array([0.60, 0.25, 1.8, 0.30, 0.60, 2.50, 1.20, 0.8, 4.0, 0.9,  0.95, 24.0, 0.40])
x0g = np.array([0.30, 0.08, 1.0, 0.10, 0.12, 0.55, 0.60, 0.10, 2.2, 0.55, 0.85, 4.0, 0.20])

# Optional per-subject linear observation maps (often ~identity; leave in for flexibility)
x0s, LBs, UBs = [], [], []
for _ in subs:
    x0s += [1.0, 0.0, 1.0]  # alpha_B, beta0_H, beta1_H
    LBs += [0.1, -0.5, 0.1]
    UBs += [5.0,  0.5,  2.0]

x0 = np.concatenate([x0g, np.array(x0s, float)])
LB = np.concatenate([LBg, np.array(LBs, float)])
UB = np.concatenate([UBg, np.array(UBs, float)])

def unpack(x):
    gpar = x[:len(NAMES)]
    triples = np.split(x[len(NAMES):], len(subs))
    return gpar, triples

# Fixed-length residual accounting
TOTLEN = sum(S["nB"] + S["nH"] for S in subs) + N_PRIORS

def residuals(x):
    gpar, triples = unpack(x)
    # enforce H_off > H_on; K_B > small
    if not (gpar[10] > gpar[9]) or (gpar[12] <= 0.02):
        return np.full(TOTLEN, PENALTY, float)

    res = []
    for S, tr in zip(subs, triples):
        alpha_B, beta0, beta1 = tr
        ts = S["t"]
        H0 = float(np.clip(S["H0"], 0, 1))
        B0 = float(max(0.05, S["B0"]))
        M0 = 0.1
        q0 = 1.0 if H0 < 0.5*(gpar[9]+gpar[10]) else 0.0
        y0 = [M0, H0, B0, q0]
        Y = simulate(ts, y0, gpar)
        _, Hm, Bm, _ = Y

        if np.any(~np.isfinite(Hm)) or np.any(~np.isfinite(Bm)):
            res += [np.full(S["nB"], PENALTY), np.full(S["nH"], PENALTY)]
            continue

        Bhat = alpha_B*Bm
        Hhat = np.clip(beta0 + beta1*Hm, 0, 1)

        b = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hhat[S["maskH"]] - S["H"][S["maskH"]])
        b = np.where(np.isfinite(b), b, PENALTY)
        h = np.where(np.isfinite(h), h, PENALTY)
        res += [b.astype(float), h.astype(float)]

    # --- Priors (4 scalars) ---
    r0, rH, u = gpar[0], gpar[1], gpar[6]
    H_on, H_off, tau, K_B = gpar[9], gpar[10], gpar[11], gpar[12]
    # 1) H_off prior
    rHoff = (H_off - H_OFF_TARGET) / (H_OFF_SD + 1e-9)
    # 2) effective growth at H=0.8 ~ target (keeps r0+rH*H in a plausible range)
    r_eff = r0 + rH * R_EFF_H_REF
    rReff = (r_eff - 0.35) / (R_EFF_SD + 1e-9)
    # 3) uptake u regularization (wide)
    rU = (u - U_TARGET) / (U_SD + 1e-9)
    # 4) Hill K_B prior (keeps half-sat in a plausible range for scaled B)
    rKB = (K_B - K_B_TARGET) / (K_B_SD + 1e-9)
    res.append(np.array([rHoff, rReff, rU, rKB], float))

    return np.concatenate(res)

fit = least_squares(residuals, x0, bounds=(LB, UB),
                    verbose=2, max_nfev=900, loss="soft_l1", f_scale=1.0)

gpar_hat, triples_hat = unpack(fit.x)
pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR, "fitted_global_params.csv"), header=False)
pd.DataFrame(
    [{"subject_id": S["sid"], "alpha_B": tr[0], "beta0_H": tr[1], "beta1_H": tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR, "fitted_subject_scales.csv"), index=False)

print("[info] Fitted globals:", dict(zip(NAMES, gpar_hat)))

# quick diagnostics (first few)
for S,tr in list(zip(subs, triples_hat))[:8]:
    alpha_B, beta0, beta1 = tr
    ts = S["t"]
    H0 = float(np.clip(S["H0"], 0, 1))
    B0 = float(max(0.05, S["B0"]))
    M0 = 0.1
    q0 = 1.0 if H0 < 0.5*(gpar_hat[9]+gpar_hat[10]) else 0.0
    y0 = [M0, H0, B0, q0]
    sol = solve_ivp(lambda t,z: rhs_hill_pf(t,z,gpar_hat),
                    (ts[0], ts[-1]), y0, t_eval=ts,
                    rtol=1e-6, atol=1e-8, max_step=0.5)
    M,H,B,q = sol.y
    Bhat = alpha_B*B
    Hhat = np.clip(beta0 + beta1*H, 0, 1)
    mB, mH = S["maskB"], S["maskH"]

    fig, ax = plt.subplots(2,1, figsize=(7,6), sharex=True)
    ax[0].plot(ts[mB], Bhat[mB], label="Model B (scaled)")
    ax[0].scatter(ts[mB], S["B"][mB], s=18, c="k", label="Obs B (z)")
    ax[0].legend(); ax[0].grid(True, ls=":")

    ax[1].plot(ts[mH], Hhat[mH], label="Model H")
    ax[1].scatter(ts[mH], S["H"][mH], s=18, c="k", label="H proxy")
    ax[1].axhline(gpar_hat[9], ls=":", c="gray", label="H_on")
    ax[1].axhline(gpar_hat[10], ls="--", c="gray", label="H_off")
    ax[1].legend(); ax[1].grid(True, ls=":")
    ax[1].set_xlabel("time"); ax[0].set_ylabel("B"); ax[1].set_ylabel("H")
    plt.tight_layout(); plt.savefig(os.path.join(OUTDIR, f"diag_{S['sid']}.png"), dpi=180); plt.close()

print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_and_basins_hill.py
# ------------------------------------------------------------
# Continuation in d; stability via Jacobian eigs; basins at baseline
# for the Hill + smooth hysteresis + PF model fitted above.
# ------------------------------------------------------------

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import root

FIT = "mw_fit_out_hill_pf/fitted_global_params.csv"
OUT = "mw_bif_hill"; os.makedirs(OUT, exist_ok=True)

KQ = 40.0
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0,rH,K_M,c,d,g,u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[4])

def q_inf(H,q,H_on,H_off,k=KQ):
    th = (1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B,H,gH,d,K_B,n=2):
    return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H

def rhs(y, pvec, d_override=None):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_hill(B,H,gH,d,K_B, n=2)
    dB = pB*M - u*H*B
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dM,dH,dB,dq], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((4,4))
    for i in range(4):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec, dval, guess):
    sol = root(lambda yy: rhs(yy, pvec, dval), guess, method="hybr")
    if not sol.success: return None, False
    y = sol.x
    y = np.array([max(0,y[0]), np.clip(y[1],0,1.2), max(0,y[2]), np.clip(y[3],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# (A) continuation in d
D_SPAN = (0.6*d_fit, 1.5*d_fit)
d_vals = np.linspace(D_SPAN[0], D_SPAN[1], 90)
seeds = [np.array([0.2,0.25,0.05,1.0]), np.array([0.2,0.9,0.1,0.0]), np.array([0.6,0.6,0.2,0.5])]

rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y, ok = find_eq(p, d, y0)
        if not ok: continue
        J=jac_fd(rhs,y,p,d)
        stable = (np.max(np.real(np.linalg.eigvals(J)))<0)
        rows.append({"d":d,"H":float(y[1]),"q":float(y[3]),"seed":wi,"stable":stable})
branches = pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.2))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

# (B) bistable at baseline?
near = branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs = np.sort(near.loc[near["stable"], "H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable = (distinct>=2)

# (C) basins at baseline
Hs = np.linspace(0.2, 0.95, 17)
qs = np.linspace(0.0, 1.0, 17)
Z = np.zeros((len(Hs), len(qs)))
def relax(y0,T=300):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,900),
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    return sol.y[:,-1]
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0=np.array([0.2,H0,0.1,q0],float)
        yss=relax(y0)
        Z[i,j]=yss[1]

plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]], aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# find_bistable_regions_hill.py  >>>> Optional
# Scan a small hyper-rectangle around the Hill+PF fit for true bistability at baseline d.

import numpy as np, pandas as pd
from scipy.optimize import root
import numpy.linalg as npl

FIT = "mw_fit_out_hill_pf/fitted_global_params.csv"
OUT = "mw_scan_bistability_hill.csv"

g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
p0 = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p0[4])
KQ = 40.0

def q_inf(H,q,H_on,H_off): th=(1-q)*H_on+q*H_off; return 1.0/(1.0+np.exp(-KQ*(H-th)))
def dH(B,H,gH,d,K_B,n=2): return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H
def rhs(y,p,dval=None):
    M,H,B,q=y
    r0,rH,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau,K_B=p.copy()
    if dval is not None: d=dval
    pB=pL+(pH-pL)*np.clip(q,0,1)
    r_eff=r0+rH*H
    return np.array([
        (r_eff - c*pB)*M*(1 - M/K_M),
        dH(B,H,gH,d,K_B,2),
        pB*M - u*H*B,
        (q_inf(H,q,H_on,H_off) - q)/tau
    ], float)

def jac_fd(fun,y,p,dval,eps=1e-7):
    f0=fun(y,p,dval); J=np.zeros((4,4))
    for i in range(4):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,p,dval)-f0)/eps
    return J

def find_eq(p,dval,guess):
    sol=root(lambda yy: rhs(yy,p,dval), guess, method="hybr")
    if not sol.success: return None, False
    y=sol.x
    y=np.array([max(0,y[0]), np.clip(y[1],0,1.2), max(0,y[2]), np.clip(y[3],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# small grids (tight first; widen if empty)
rH_grid   = np.linspace(max(0.0, p0[1]-0.06), p0[1]+0.06, 5)
u_grid    = np.linspace(max(0.2, p0[6]-0.25), p0[6]+0.25, 7)
Hoff_grid = np.linspace(max(0.75, p0[10]-0.08), min(0.95, p0[10]+0.04), 7)
pH_grid   = np.linspace(max(1.0, p0[8]-0.8), p0[8]+0.8, 7)
c_grid    = np.linspace(max(0.05, p0[3]-0.06), p0[3]+0.06, 5)
KB_grid   = np.linspace(max(0.08, p0[12]-0.10), min(0.40, p0[12]+0.10), 7)

seeds=[np.array([0.2,0.25,0.05,1.0]), np.array([0.2,0.90,0.10,0.0]), np.array([0.6,0.60,0.20,0.5])]

rows=[]
for rH in rH_grid:
    for u in u_grid:
        for H_off in Hoff_grid:
            for pH in pH_grid:
                for c in c_grid:
                    for KB in KB_grid:
                        p=p0.copy(); p[1]=rH; p[6]=u; p[10]=H_off; p[8]=pH; p[3]=c; p[12]=KB
                        Hs=[]
                        for y0 in seeds:
                            y, ok = find_eq(p, d_fit, y0)
                            if not ok: continue
                            J=jac_fd(rhs,y,p,d_fit)
                            if np.max(np.real(npl.eigvals(J)))<0:
                                Hs.append(float(y[1]))
                        Hs=np.array(sorted(Hs))
                        distinct=0
                        if Hs.size:
                            distinct=1
                            for i in range(1,len(Hs)):
                                if abs(Hs[i]-Hs[i-1])>1e-3:
                                    distinct+=1
                        rows.append({"rH":rH,"u":u,"H_off":H_off,"p_high":pH,"c":c,"K_B":KB,
                                     "distinct_stable":distinct,"bistable":distinct>=2})

df=pd.DataFrame(rows)
df.to_csv(OUT, index=False)
print("Saved ->", OUT, "| bistable rows:", int(df["bistable"].sum()))
if df["bistable"].any():
    print(df[df["bistable"]].head(10))


In [None]:
# calibrate_hysteresis_from_scored_hill_pf_constrained.py
# -------------------------------------------------------------------
# Hill host benefit + smooth hysteretic memory + weak host->microbe PF
# Refit with biologically plausible priors and tighter bounds to avoid
# degenerate edge solutions that destroy bistability.
# Fixed-length residuals. Save outputs to mw_fit_out_hill_pf_constrained/
# -------------------------------------------------------------------

import os, numpy as np, pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ----------------- Paths & config -----------------
INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_hill_pf_constrained"
os.makedirs(OUTDIR, exist_ok=True)

H_COLS = ["H_proxy_meta_smooth","H_proxy_meta"]
SCFA_COLS = ["butyrate"]   # can average multiple if you like

MIN_ROWS = 4
KQ = 40.0
PENALTY = 1e3

# ---- Priors (centers & widths; edit if you have better anchors) ----
PRIOR = {
    "d":      (0.12, 0.05),    # decay/inflammation
    "g":      (0.60, 0.30),    # host benefit gain
    "u":      (0.60, 0.20),    # uptake (linear model; Hill uptake can be added later)
    "p_low":  (0.10, 0.05),
    "c":      (0.10, 0.05),
    "rH":     (0.08, 0.05),    # weak PF
    "H_off":  (0.85, 0.05),
    "K_B":    (0.20, 0.08),    # Hill half-saturation in B units (z-scored per subject)
}

# ---- Bounds (tighter than before; still permissive) ----
LBg = np.array([0.08, 0.00, 0.6, 0.05, 0.03, 0.10, 0.25, 0.02, 0.9, 0.40, 0.75, 1.0, 0.08])
UBg = np.array([0.55, 0.20, 1.6, 0.22, 0.35, 1.20, 0.95, 0.35, 3.2, 0.70, 0.93, 12.0, 0.35])

# Initial guess near priors
x0g = np.array([
    0.30, 0.08, 1.0, 0.10, 0.12, 0.60, 0.60, 0.10, 2.0, 0.55, 0.85, 4.0, 0.20
], float)

# ----------------- Load & prep data -----------------
df = pd.read_csv(INPATH)
if not {"subject_id","sample_id"}.issubset(df.columns):
    raise ValueError("CSV must contain 'subject_id' and 'sample_id'.")

Hcol = next((c for c in H_COLS if c in df.columns), None)
if Hcol is None:
    raise ValueError("Need H_proxy_meta_smooth or H_proxy_meta in the CSV.")

for c in SCFA_COLS:
    if c not in df.columns:
        raise ValueError(f"Missing SCFA column '{c}'.")

keep = ["subject_id","sample_id",Hcol] + SCFA_COLS
df = df[keep].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

def robust_z(series: pd.Series) -> pd.Series:
    x = series.astype(float).to_numpy()
    m = np.isfinite(x)
    if m.sum()==0: return pd.Series(np.zeros_like(x), index=series.index)
    xm = x[m]; med = np.median(xm)
    mad = np.median(np.abs(xm - med))
    if mad < 1e-9:
        q75,q25 = np.percentile(xm,[75,25]); iqr = q75-q25
        scale = iqr if iqr>1e-9 else (np.std(xm)+1e-9)
    else:
        scale = mad
    return pd.Series((x - med)/(scale+1e-9), index=series.index)

for c in SCFA_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_z)

df["B_obs"] = df[SCFA_COLS[0] + "_z"] if len(SCFA_COLS)==1 else df[[c+"_z" for c in SCFA_COLS]].mean(axis=1)
df["H_obs"] = df[Hcol].clip(0,1)

subs=[]
for sid, sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx").copy()
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].values.astype(float)
    B=sub["B_obs"].values.astype(float)
    H=sub["H_obs"].values.astype(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({
        "sid":sid,
        "t":t,"B":B,"H":H,
        "maskB":mB,"maskH":mH,
        "nB":int(mB.sum()),"nH":int(mH.sum()),
        "H0":float(np.clip(first(H,0.6),0,1)),
        "B0":float(max(0.05, first(B,0.1))),
    })
if not subs:
    raise RuntimeError("No subject passed minimal data filters.")

print(f"[info] Fitting {len(subs)} subjects ...")

# ----------------- Model (Hill + PF + smooth memory) -----------------
# y = [M,H,B,q]
# p = [r0,rH,K_M,c,d,g,u,p_low,p_high,H_on,H_off,tau_q,K_B]
NAMES = ["r0","rH","K_M","c","d","g","u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def q_inf_smooth(H,q,H_on,H_off,k=KQ):
    th=(1.0-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B,H,gH,d,K_B,n=2):
    ben = gH*(B**n / (K_B**n + B**n))*(1.0 - H)
    return ben - d*H

def rhs(t,y,p):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau,K_B = p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_hill(B,H,gH,d,K_B, n=2)
    dB = pB*M - u*H*B
    dq = (q_inf_smooth(H,q,H_on,H_off) - q)/tau
    return [dM,dH,dB,dq]

def simulate(ts,y0,p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*4)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*4)

# Per-subject simple observation maps (kept; but we’ll weight H higher)
x0s=[]; LBs=[]; UBs=[]
for _ in subs:
    x0s += [1.0, 0.0, 1.0]   # alpha_B, beta0_H, beta1_H
    LBs += [0.2, -0.4, 0.5]
    UBs += [4.0,  0.4, 1.5]

x0 = np.concatenate([x0g, np.array(x0s,float)])
LB = np.concatenate([LBg, np.array(LBs,float)])
UB = np.concatenate([UBg, np.array(UBs,float)])

def unpack(x):
    gpar = x[:len(NAMES)]
    triples = np.split(x[len(NAMES):], len(subs))
    return gpar, triples

# residual weighting (downweight B a little to prevent d->0)
W_B = 0.6
W_H = 1.0

# fixed-length residual vector size (+ priors)
TOTLEN = sum(S["nB"] + S["nH"] for S in subs) + len(PRIOR)

def residuals(x):
    gpar, triples = unpack(x)
    # guards
    if not (gpar[10] > gpar[9]):  # H_off > H_on
        return np.full(TOTLEN, PENALTY)
    if gpar[12] <= 0.05:
        return np.full(TOTLEN, PENALTY)

    res=[]
    for S, tr in zip(subs, triples):
        aB, b0H, b1H = tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        M0=0.12
        q0 = 1.0 if H0 < 0.5*(gpar[9]+gpar[10]) else 0.0
        y0=[M0, H0, B0, q0]
        Y=simulate(ts, y0, gpar)
        _, Hm, Bm, _ = Y

        if np.any(~np.isfinite(Hm)) or np.any(~np.isfinite(Bm)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue

        Bhat = aB*Bm
        Hhat = np.clip(b0H + b1H*Hm, 0, 1)

        b = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hhat[S["maskH"]] - S["H"][S["maskH"]])

        b = np.where(np.isfinite(b), b, PENALTY)
        h = np.where(np.isfinite(h), h, PENALTY)

        res += [W_B*b.astype(float), W_H*h.astype(float)]

    # ---- Priors as pseudo-residuals (one per entry in PRIOR) ----
    prior_terms=[]
    index = {name:i for i,name in enumerate(NAMES)}
    for name,(mu,sd) in PRIOR.items():
        i = index[name]
        prior_terms.append( (gpar[i]-mu)/ (sd + 1e-9) )
    res.append(np.array(prior_terms, float))

    return np.concatenate(res)

fit = least_squares(residuals, x0, bounds=(LB,UB),
                    verbose=2, max_nfev=1000, loss="soft_l1", f_scale=1.0)

gpar_hat, triples_hat = unpack(fit.x)

pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)
pd.DataFrame(
    [{"subject_id": S["sid"], "alpha_B": tr[0], "beta0_H": tr[1], "beta1_H": tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR,"fitted_subject_scales.csv"), index=False)

print("[info] Fitted globals:", dict(zip(NAMES, gpar_hat)))

# quick small diagnostics (first few subjects)
for S,tr in list(zip(subs, triples_hat))[:8]:
    aB, b0, b1 = tr
    ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
    M0=0.12; q0 = 1.0 if H0 < 0.5*(gpar_hat[9]+gpar_hat[10]) else 0.0
    y0=[M0,H0,B0,q0]
    sol=solve_ivp(lambda t,z: rhs(t,z,gpar_hat),(ts[0],ts[-1]),y0,t_eval=ts,
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    M,H,B,q = sol.y
    Bhat=aB*B; Hhat=np.clip(b0 + b1*H,0,1)
    mB,mH=S["maskB"],S["maskH"]

    fig,ax=plt.subplots(2,1,figsize=(7,6),sharex=True)
    ax[0].plot(ts[mB],Bhat[mB],label="Model B (scaled)")
    ax[0].scatter(ts[mB],S["B"][mB],s=18,c="k",label="Obs B (z)")
    ax[0].legend(); ax[0].grid(True,ls=":")

    ax[1].plot(ts[mH],Hhat[mH],label="Model H")
    ax[1].scatter(ts[mH],S["H"][mH],s=18,c="k",label="H proxy")
    ax[1].axhline(gpar_hat[9],ls=":",c="gray",label="H_on")
    ax[1].axhline(gpar_hat[10],ls="--",c="gray",label="H_off")
    ax[1].legend(); ax[1].grid(True,ls=":")
    ax[1].set_xlabel("time"); ax[0].set_ylabel("B"); ax[1].set_ylabel("H")
    plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,f"diag_{S['sid']}.png"),dpi=170); plt.close()

print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_and_basins_hill.py >> constrained
# ------------------------------------------------------------
# Continuation in d; stability via Jacobian eigs; basins at baseline
# for the Hill + smooth hysteresis + PF model fitted above.
# ------------------------------------------------------------

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import root

FIT = "mw_fit_out_hill_pf_constrained/fitted_global_params.csv"
OUT = "mw_bif_hill"; os.makedirs(OUT, exist_ok=True)

KQ = 40.0
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0,rH,K_M,c,d,g,u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[4])

def q_inf(H,q,H_on,H_off,k=KQ):
    th = (1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B,H,gH,d,K_B,n=2):
    return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H

def rhs(y, pvec, d_override=None):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_hill(B,H,gH,d,K_B, n=2)
    dB = pB*M - u*H*B
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dM,dH,dB,dq], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((4,4))
    for i in range(4):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec, dval, guess):
    sol = root(lambda yy: rhs(yy, pvec, dval), guess, method="hybr")
    if not sol.success: return None, False
    y = sol.x
    y = np.array([max(0,y[0]), np.clip(y[1],0,1.2), max(0,y[2]), np.clip(y[3],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# (A) continuation in d
D_SPAN = (0.6*d_fit, 1.5*d_fit)
d_vals = np.linspace(D_SPAN[0], D_SPAN[1], 90)
seeds = [np.array([0.2,0.25,0.05,1.0]), np.array([0.2,0.9,0.1,0.0]), np.array([0.6,0.6,0.2,0.5])]

rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y, ok = find_eq(p, d, y0)
        if not ok: continue
        J=jac_fd(rhs,y,p,d)
        stable = (np.max(np.real(np.linalg.eigvals(J)))<0)
        rows.append({"d":d,"H":float(y[1]),"q":float(y[3]),"seed":wi,"stable":stable})
branches = pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.2))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

# (B) bistable at baseline?
near = branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs = np.sort(near.loc[near["stable"], "H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable = (distinct>=2)

# (C) basins at baseline
Hs = np.linspace(0.2, 0.95, 17)
qs = np.linspace(0.0, 1.0, 17)
Z = np.zeros((len(Hs), len(qs)))
def relax(y0,T=300):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,900),
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    return sol.y[:,-1]
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0=np.array([0.2,H0,0.1,q0],float)
        yss=relax(y0)
        Z[i,j]=yss[1]

plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]], aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


### saturable uptake to your Hill+smooth-memory+PF model, keep fixed-length residuals, and use tight, literature-plausible priors/bounds

In [None]:
# calibrate_hysteresis_from_scored_hill_satpf.py
# -------------------------------------------------------------------
# Model:
#   dM = (r0 + rH*H - c*pB)*M*(1 - M/K_M)
#   dH = g * (B^n/(K_B^n + B^n)) * (1 - H) - d*H         [Hill host benefit, n=2]
#   dB = pB*M - u*H * B/(K_u + B)                         [SATURABLE uptake]
#   dq = (q_inf(H,q) - q)/tau_q,  q_inf = sigmoid_k(H - [(1-q)H_on + q H_off])
#
# Robust calibration to scored time series (butyrate intensity z-scores, H proxy).
# Fixed-length residuals; tight priors/bounds from literature to avoid edge fits.
# -------------------------------------------------------------------

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ----------------- Paths & config -----------------
INPATH = "timeseries/combined_scfas_table_scored.csv"   # <- your scored file
OUTDIR = "mw_fit_out_hill_satpf"
os.makedirs(OUTDIR, exist_ok=True)

H_COLS = ["H_proxy_meta_smooth","H_proxy_meta"]
SCFA_COLS = ["butyrate"]        # can add more; will average per-subject z-scores

MIN_ROWS  = 4
KQ        = 40.0                # memory smoothness (sigmoid steepness)
PENALTY   = 1e3                 # finite penalty for failed sims to keep residual length fixed
HILL_N    = 2

# ---- Priors (center, sd) — edit if you have better anchors ----
PRIOR = {
    "d":      (0.12, 0.05),   # host decay/inflammation
    "g":      (0.60, 0.30),   # host benefit gain
    "u":      (0.60, 0.20),   # uptake scale
    "K_u":    (0.20, 0.08),   # uptake half-saturation (in B z-units)
    "K_B":    (0.20, 0.08),   # host benefit half-saturation
    "p_low":  (0.10, 0.05),
    "c":      (0.10, 0.05),
    "rH":     (0.08, 0.05),   # weak PF
    "H_off":  (0.85, 0.05),
}

# ---- Bounds (tighter, literature-plausible) ----
# p = [r0,rH,K_M,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
LBg = np.array([0.08, 0.00, 0.6, 0.06, 0.06, 0.15, 0.40, 0.10, 0.02, 1.2, 0.40, 0.80, 1.0, 0.10])
UBg = np.array([0.55, 0.20, 1.7, 0.18, 0.24, 1.50, 0.90, 0.35, 0.35, 3.0, 0.70, 0.90, 12.0, 0.35])

# Start near priors/baseline
x0g = np.array([0.30, 0.08, 1.0, 0.10, 0.12, 0.60, 0.60, 0.20, 0.10, 2.0, 0.55, 0.85, 4.0, 0.20], float)

# ----------------- Load & prep data -----------------
df = pd.read_csv(INPATH)
if not {"subject_id","sample_id"}.issubset(df.columns):
    raise ValueError("CSV must contain 'subject_id' and 'sample_id'.")

Hcol = next((c for c in H_COLS if c in df.columns), None)
if Hcol is None:
    raise ValueError("Need H_proxy_meta_smooth or H_proxy_meta in the CSV.")

for c in SCFA_COLS:
    if c not in df.columns:
        raise ValueError(f"Missing SCFA column '{c}'.")

keep = ["subject_id","sample_id",Hcol] + SCFA_COLS
df = df[keep].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

def robust_z(series: pd.Series) -> pd.Series:
    x=series.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: return pd.Series(np.zeros_like(x), index=series.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm - med))
    if mad < 1e-9:
        q75,q25=np.percentile(xm,[75,25]); iqr=q75-q25
        scale=iqr if iqr>1e-9 else (np.std(xm)+1e-9)
    else:
        scale=mad
    return pd.Series((x - med)/(scale+1e-9), index=series.index)

for c in SCFA_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_z)

df["B_obs"] = df[SCFA_COLS[0]+"_z"] if len(SCFA_COLS)==1 else df[[c+"_z" for c in SCFA_COLS]].mean(axis=1)
df["H_obs"] = df[Hcol].clip(0,1)

subs=[]
for sid, sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx").copy()
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].values.astype(float)
    B=sub["B_obs"].values.astype(float)
    H=sub["H_obs"].values.astype(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({
        "sid":sid, "t":t, "B":B, "H":H,
        "maskB":mB, "maskH":mH,
        "nB":int(mB.sum()), "nH":int(mH.sum()),
        "H0":float(np.clip(first(H,0.6),0,1)),
        "B0":float(max(0.05, first(B,0.1))),
    })
if not subs:
    raise RuntimeError("No subject passed minimal data filters.")

print(f"[info] Fitting {len(subs)} subjects ...")

# ----------------- Model (Hill + SAT uptake + PF + smooth memory) -----------------
# y=[M,H,B,q]; p=[r0,rH,K_M,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
NAMES = ["r0","rH","K_M","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def q_inf_smooth(H, q, H_on, H_off, k=KQ):
    th = (1.0 - q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B, H, gH, d, K_B, n=HILL_N):
    return gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H

def rhs(t, y, p):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_hill(B,H,gH,d,K_B)
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*M - uptake
    dq = (q_inf_smooth(H,q,H_on,H_off) - q)/tau
    return [dM,dH,dB,dq]

def simulate(ts, y0, p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6, atol=1e-8, max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*4)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*4)

# Per-subject observation maps (keep modest freedom; don’t let them dominate)
x0s=[]; LBs=[]; UBs=[]
for _ in subs:
    x0s += [1.0, 0.0, 1.0]     # alpha_B, beta0_H, beta1_H
    LBs += [0.5, -0.3, 0.7]
    UBs += [2.0,  0.3, 1.3]

x0 = np.concatenate([x0g, np.array(x0s,float)])
LB = np.concatenate([LBg, np.array(LBs,float)])
UB = np.concatenate([UBg, np.array(UBs,float)])

def unpack(x):
    gpar = x[:len(NAMES)]
    triples = np.split(x[len(NAMES):], len(subs))
    return gpar, triples

# weights — give H a little more pull than B (intensity/z is noisier)
W_B, W_H = 0.6, 1.0

# fixed residual length (+ priors)
TOTLEN = sum(S["nB"] + S["nH"] for S in subs) + len(PRIOR)

def residuals(x):
    gpar, triples = unpack(x)
    # guards
    if not (gpar[11] > gpar[10]):     # H_off > H_on
        return np.full(TOTLEN, PENALTY)
    if gpar[13] <= 0.06 or gpar[7] <= 0.06:  # K_B, K_u positive lower bounds
        return np.full(TOTLEN, PENALTY)

    res=[]
    for S, tr in zip(subs, triples):
        aB, b0H, b1H = tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05, S["B0"])
        M0=0.12; q0 = 1.0 if H0 < 0.5*(gpar[10]+gpar[11]) else 0.0
        y0=[M0,H0,B0,q0]
        Y=simulate(ts, y0, gpar)
        _, Hm, Bm, _ = Y
        if np.any(~np.isfinite(Hm)) or np.any(~np.isfinite(Bm)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue

        Bhat = aB*Bm
        Hhat = np.clip(b0H + b1H*Hm, 0, 1)

        b = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hhat[S["maskH"]] - S["H"][S["maskH"]])
        b = np.where(np.isfinite(b), b, PENALTY)
        h = np.where(np.isfinite(h), h, PENALTY)

        res += [W_B*b.astype(float), W_H*h.astype(float)]

    # priors — one scalar residual each
    index = {name:i for i,name in enumerate(NAMES)}
    prior_terms=[]
    for name,(mu,sd) in PRIOR.items():
        i = index[name]
        prior_terms.append((gpar[i]-mu)/(sd + 1e-9))
    res.append(np.array(prior_terms,float))

    return np.concatenate(res)

fit = least_squares(residuals, x0, bounds=(LB,UB),
                    verbose=2, max_nfev=1000, loss="soft_l1", f_scale=1.0)

gpar_hat, triples_hat = unpack(fit.x)
pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)
pd.DataFrame(
    [{"subject_id":S["sid"], "alpha_B":tr[0], "beta0_H":tr[1], "beta1_H":tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR,"fitted_subject_scales.csv"), index=False)

print("[info] Fitted globals:", dict(zip(NAMES, gpar_hat)))

# quick small diagnostics (first few)
for S,tr in list(zip(subs, triples_hat))[:8]:
    aB,b0,b1 = tr
    ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
    M0=0.12; q0=1.0 if H0 < 0.5*(gpar_hat[10]+gpar_hat[11]) else 0.0
    y0=[M0,H0,B0,q0]
    sol=solve_ivp(lambda t,z: rhs(t,z,gpar_hat),(ts[0],ts[-1]),y0,t_eval=ts,
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    M,H,B,q = sol.y
    Bhat=aB*B; Hhat=np.clip(b0 + b1*H,0,1)
    mB,mH=S["maskB"],S["maskH"]

    fig,ax=plt.subplots(2,1,figsize=(7,6),sharex=True)
    ax[0].plot(ts[mB],Bhat[mB],label="Model B (scaled)")
    ax[0].scatter(ts[mB],S["B"][mB],s=18,c="k",label="Obs B (z)")
    ax[0].legend(); ax[0].grid(True,ls=":")

    ax[1].plot(ts[mH],Hhat[mH],label="Model H")
    ax[1].scatter(ts[mH],S["H"][mH],s=18,c="k",label="H proxy")
    ax[1].axhline(gpar_hat[10],ls=":",c="gray",label="H_on")
    ax[1].axhline(gpar_hat[11],ls="--",c="gray",label="H_off")
    ax[1].legend(); ax[1].grid(True,ls=":")
    ax[1].set_xlabel("time"); ax[0].set_ylabel("B"); ax[1].set_ylabel("H")
    plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,f"diag_{S['sid']}.png"),dpi=170); plt.close()

print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_and_basins_hill_sat.py
# Continuation in d; stability via Jacobian eigs; basins at baseline
# for Hill host module + saturable uptake + PF + smooth memory.

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import root

FIT = "mw_fit_out_hill_satpf/fitted_global_params.csv"   # <- point to the new fit
OUT = "mw_bif_hill_sat"; os.makedirs(OUT, exist_ok=True)

KQ = 40.0
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0,rH,K_M,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[4])

def q_inf(H,q,H_on,H_off): th=(1-q)*H_on + q*H_off; return 1.0/(1.0 + np.exp(-KQ*(H - th)))
def dH_hill(B,H,gH,d,K_B,n=2): return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H

def rhs(y, pvec, d_override=None):
    M,H,B,q = y
    r0,rH,K_M,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_hill(B,H,gH,d,K_B,2)
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*M - uptake
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dM,dH,dB,dq], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((4,4))
    for i in range(4):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec, dval, guess):
    sol = root(lambda yy: rhs(yy, pvec, dval), guess, method="hybr")
    if not sol.success: return None, False
    y = sol.x
    y = np.array([max(0,y[0]), np.clip(y[1],0,1.2), max(0,y[2]), np.clip(y[3],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# (A) continuation in d
d_vals = np.linspace(0.6*d_fit, 1.6*d_fit, 90)
seeds = [np.array([0.2,0.25,0.05,1.0]), np.array([0.2,0.90,0.10,0.0]), np.array([0.6,0.60,0.20,0.5])]

rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y, ok = find_eq(p, d, y0)
        if not ok: continue
        J=jac_fd(rhs,y,p,d)
        stable = (np.max(np.real(np.linalg.eigvals(J)))<0)
        rows.append({"d":d,"H":float(y[1]),"q":float(y[3]),"seed":wi,"stable":stable})
branches = pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.0))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

# (B) bistable at baseline?
near = branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs = np.sort(near.loc[near["stable"], "H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable = (distinct>=2)

# (C) basins at baseline
Hs = np.linspace(0.2, 0.95, 17)
qs = np.linspace(0.0, 1.0, 17)
Z = np.zeros((len(Hs), len(qs)))

def relax(y0,T=320):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,900),
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    return sol.y[:,-1]

for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0=np.array([0.2,H0,0.1,q0],float)
        yss=relax(y0)
        Z[i,j]=yss[1]

plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# calibrate_hysteresis_from_scored_hill_satpf_strict.py
# Hill host benefit (n=3) + saturable uptake + smooth memory + weak PF
# Tighter, literature-plausible bounds/priors to avoid monostable corners.

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ----------------- Paths & config -----------------
INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_hill_satpf_strict"
os.makedirs(OUTDIR, exist_ok=True)

H_COLS = ["H_proxy_meta_smooth","H_proxy_meta"]
SCFA_COLS = ["butyrate"]

MIN_ROWS  = 4
KQ        = 80.0            # sharper memory
PENALTY   = 1e3
HILL_N    = 3               # stronger curvature in H benefit

# ---- Priors (center, sd) ----
PRIOR = {
    "r0":    (0.35, 0.08),
    "rH":    (0.08, 0.04),
    "d":     (0.12, 0.05),
    "g":     (0.60, 0.25),
    "u":     (0.60, 0.10),   # tighter around 0.6 to avoid u→upper bound
    "K_u":   (0.20, 0.06),
    "K_B":   (0.20, 0.06),
    "c":     (0.12, 0.04),
    "p_low": (0.12, 0.05),
    "p_high":(2.30, 0.50),
    "H_on":  (0.62, 0.08),
    "H_off": (0.86, 0.04),
    "tau_q": (5.00, 2.00),
}

# ---- Bounds (tighter; keep away from the corners you hit) ----
# p = [r0,rH,K_M,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
LBg = np.array([0.25, 0.02, 0.7, 0.08, 0.06, 0.20, 0.45, 0.12, 0.06, 1.6, 0.55, 0.80, 2.0, 0.10])
UBg = np.array([0.45, 0.15, 1.5, 0.18, 0.22, 1.20, 0.75, 0.30, 0.25, 3.0, 0.75, 0.92, 10.0, 0.35])

x0g = np.array([
    PRIOR["r0"][0], PRIOR["rH"][0], 1.0, PRIOR["c"][0], PRIOR["d"][0], PRIOR["g"][0],
    PRIOR["u"][0], PRIOR["K_u"][0], PRIOR["p_low"][0], PRIOR["p_high"][0],
    PRIOR["H_on"][0], PRIOR["H_off"][0], PRIOR["tau_q"][0], PRIOR["K_B"][0]
], float)

# ----------------- Load & prep data -----------------
df = pd.read_csv(INPATH)
if not {"subject_id","sample_id"}.issubset(df.columns):
    raise ValueError("CSV must contain 'subject_id' and 'sample_id'.")
Hcol = next((c for c in H_COLS if c in df.columns), None)
if Hcol is None:
    raise ValueError("Need H_proxy_meta_smooth or H_proxy_meta in the CSV.")
for c in SCFA_COLS:
    if c not in df.columns:
        raise ValueError(f"Missing SCFA column '{c}'.")

keep = ["subject_id","sample_id",Hcol] + SCFA_COLS
df = df[keep].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

def robust_z(s: pd.Series) -> pd.Series:
    x=s.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: return pd.Series(np.zeros_like(x), index=s.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm-med))
    if mad<1e-9:
        q75,q25=np.percentile(xm,[75,25]); iqr=q75-q25
        scale=iqr if iqr>1e-9 else (np.std(xm)+1e-9)
    else:
        scale=mad
    return pd.Series((x-med)/(scale+1e-9), index=s.index)

for c in SCFA_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_z)

df["B_obs"] = df[SCFA_COLS[0]+"_z"] if len(SCFA_COLS)==1 else df[[c+"_z" for c in SCFA_COLS]].mean(axis=1)
df["H_obs"] = df[Hcol].clip(0,1)

subs=[]
for sid, sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx").copy()
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].values.astype(float)
    B=sub["B_obs"].values.astype(float)
    H=sub["H_obs"].values.astype(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({
        "sid":sid, "t":t, "B":B, "H":H,
        "maskB":mB, "maskH":mH,
        "nB":int(mB.sum()), "nH":int(mH.sum()),
        "H0":float(np.clip(first(H,0.6),0,1)),
        "B0":float(max(0.05, first(B,0.1))),
    })
if not subs:
    raise RuntimeError("No subject passed minimal data filters.")
print(f"[info] Fitting {len(subs)} subjects...")

# ----------------- Model -----------------
# y=[M,H,B,q]; p=[r0,rH,K_M,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
NAMES = ["r0","rH","K_M","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def q_inf(H,q,H_on,H_off,k=KQ):
    th=(1.0-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B,H,gH,d,K_B,n=HILL_N):
    return gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H

def rhs(t,y,p):
    M,H,B,q=y
    r0,rH,K_M,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B=p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    r_eff = r0 + rH*H
    dM = (r_eff - c*pB)*M*(1 - M/K_M)
    dH = dH_hill(B,H,gH,d,K_B)
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*M - uptake
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return [dM,dH,dB,dq]

def simulate(ts,y0,p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*4)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*4)

# modest per-subject observation maps (don’t let them absorb dynamics)
x0s=[]; LBs=[]; UBs=[]
for _ in subs:
    x0s += [1.0, 0.0, 1.0]   # alpha_B, beta0_H, beta1_H
    LBs += [0.6, -0.2, 0.8]
    UBs += [1.6,  0.2, 1.2]

x0 = np.concatenate([x0g, np.array(x0s,float)])
LB = np.concatenate([LBg, np.array(LBs,float)])
UB = np.concatenate([UBg, np.array(UBs,float)])

def unpack(x):
    gpar = x[:len(NAMES)]
    triples = np.split(x[len(NAMES):], len(subs))
    return gpar, triples

# weights — give H more influence than B
W_B, W_H = 0.5, 1.3

TOTLEN = sum(S["nB"] + S["nH"] for S in subs) + len(PRIOR)

def residuals(x):
    gpar, triples = unpack(x)
    # guards
    if not (gpar[11] > gpar[10]):          # H_off > H_on
        return np.full(TOTLEN, PENALTY)
    if gpar[13] <= 0.08 or gpar[7] <= 0.08:  # K_B, K_u
        return np.full(TOTLEN, PENALTY)

    res=[]
    for S, tr in zip(subs, triples):
        aB, b0H, b1H = tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        M0=0.12; q0 = 1.0 if H0 < 0.5*(gpar[10]+gpar[11]) else 0.0
        y0=[M0,H0,B0,q0]
        Y=simulate(ts, y0, gpar)
        _, Hm, Bm, _ = Y
        if np.any(~np.isfinite(Hm)) or np.any(~np.isfinite(Bm)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue

        Bhat = aB*Bm
        Hhat = np.clip(b0H + b1H*Hm, 0, 1)

        b = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hhat[S["maskH"]] - S["H"][S["maskH"]])
        b = np.where(np.isfinite(b), b, PENALTY)
        h = np.where(np.isfinite(h), h, PENALTY)

        res += [W_B*b.astype(float), W_H*h.astype(float)]

    # priors — one scalar residual each
    index = {name:i for i,name in enumerate(NAMES)}
    prior_terms=[]
    for name,(mu,sd) in PRIOR.items():
        i=index[name]; prior_terms.append((gpar[i]-mu)/(sd + 1e-9))
    res.append(np.array(prior_terms,float))

    return np.concatenate(res)

fit = least_squares(residuals, x0, bounds=(LB,UB),
                    verbose=2, max_nfev=1200, loss="soft_l1", f_scale=1.0)

gpar_hat, triples_hat = unpack(fit.x)
pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)
pd.DataFrame(
    [{"subject_id":S["sid"], "alpha_B":tr[0], "beta0_H":tr[1], "beta1_H":tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR,"fitted_subject_scales.csv"), index=False)

print("[info] Fitted globals:", dict(zip(NAMES, gpar_hat)))

# quick diag for first few
for S,tr in list(zip(subs, triples_hat))[:8]:
    aB,b0,b1=tr
    ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
    M0=0.12; q0=1.0 if H0 < 0.5*(gpar_hat[10]+gpar_hat[11]) else 0.0
    y0=[M0,H0,B0,q0]
    sol=solve_ivp(lambda t,z: rhs(t,z,gpar_hat),(ts[0],ts[-1]),y0,t_eval=ts,
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    M,H,B,q=sol.y
    Bhat=aB*B; Hhat=np.clip(b0 + b1*H,0,1)
    mB,mH=S["maskB"],S["maskH"]
    fig,ax=plt.subplots(2,1,figsize=(7,6),sharex=True)
    ax[0].plot(ts[mB],Bhat[mB]); ax[0].scatter(ts[mB],S["B"][mB],s=18,c="k")
    ax[0].grid(True,ls=":")
    ax[1].plot(ts[mH],Hhat[mH]); ax[1].scatter(ts[mH],S["H"][mH],s=18,c="k")
    ax[1].axhline(gpar_hat[10],ls=":",c="gray"); ax[1].axhline(gpar_hat[11],ls="--",c="gray")
    ax[1].grid(True,ls=":"); ax[1].set_xlabel("time")
    plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,f"diag_{S['sid']}.png"),dpi=170); plt.close()

print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_and_basins_hill_sat_strict.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import root

FIT = "mw_fit_out_hill_satpf_strict/fitted_global_params.csv"
OUT = "mw_bif_hill_sat_strict"; os.makedirs(OUT, exist_ok=True)

KQ=80.0
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p=[r0,rH,K_M,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[4]); HILL_N=3

def q_inf(H,q,H_on,H_off): th=(1-q)*H_on+q*H_off; return 1.0/(1.0+np.exp(-KQ*(H-th)))
def dH(B,H,gH,d,K_B,n=HILL_N): return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H
def rhs(y,pvec,dval=None):
    M,H,B,q=y
    r0,rH,K_M,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B=pvec.copy()
    if dval is not None: d=dval
    pB=pL+(pH-pL)*np.clip(q,0,1)
    r_eff=r0+rH*H
    uptake=u*H*B/(K_u + B + 1e-9)
    return np.array([
        (r_eff - c*pB)*M*(1 - M/K_M),
        dH(B,H,gH,d,K_B),
        pB*M - uptake,
        (q_inf(H,q,H_on,H_off)-q)/tau
    ], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((4,4))
    for i in range(4):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec,dval,guess):
    sol=root(lambda yy: rhs(yy,pvec,dval), guess, method="hybr")
    if not sol.success: return None, False
    y=sol.x
    y=np.array([max(0,y[0]), np.clip(y[1],0,1.2), max(0,y[2]), np.clip(y[3],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# continuation in d
d_vals=np.linspace(0.7*d_fit, 1.6*d_fit, 110)
seeds=[np.array([0.2,0.25,0.05,1.0]),
       np.array([0.2,0.90,0.10,0.0]),
       np.array([0.6,0.60,0.20,0.5]),
       np.array([0.1,0.80,0.05,0.8])]
rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y,ok=find_eq(p,d,y0)
        if not ok: continue
        J=jac_fd(rhs,y,p,d)
        stable=(np.max(np.real(np.linalg.eigvals(J)))<0)
        rows.append({"d":d,"H":float(y[1]),"q":float(y[3]),"seed":wi,"stable":stable})
branches=pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.0))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

near=branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs=np.sort(near.loc[near["stable"],"H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable=(distinct>=2)

# basins at baseline
Hs=np.linspace(0.25,0.95,19)
qs=np.linspace(0.0,1.0,21)
Z=np.zeros((len(Hs),len(qs)))
def relax(y0,T=360):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,900),
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    return sol.y[:,-1]
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0=np.array([0.2,H0,0.1,q0],float)
        yss=relax(y0)
        Z[i,j]=yss[1]
plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


### two-guild (P,C) + Hill host + saturable uptake + smooth memory

In [None]:
# calibrate_guild_hill_satpf.py
# -------------------------------------------------------------
# Minimal two-guild ecology (Producer P makes butyrate; Competitor C)
# + Hill host benefit (n=2)
# + Saturable uptake of butyrate
# + Smooth "memory" q with H_on/H_off band
#
# Observables: stool butyrate intensity (per-subject z) and a host proxy H.
# We fit global params + per-subject linear observation maps for B and H.
# -------------------------------------------------------------

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# ---------- Config ----------
INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_guild_hill_satpf"
os.makedirs(OUTDIR, exist_ok=True)

H_COLS    = ["H_proxy_meta_smooth","H_proxy_meta"]
SCFA_COLS = ["butyrate"]

MIN_ROWS  = 4
KQ        = 60.0         # memory steepness
PENALTY   = 1e3
HILL_N    = 2            # Hill exponent in host benefit

# ---------- Priors (mu, sd) ----------
# p = [r0P, rHP, r0C, K_M, gamma, c, d, g, u, K_u, p_low, p_high, H_on, H_off, tau_q, K_B]
PRIOR = {
    "r0P":   (0.32, 0.08),
    "rHP":   (0.07, 0.04),    # weak positive feedback P growth vs H
    "r0C":   (0.28, 0.08),
    "K_M":   (1.00, 0.25),
    "gamma": (0.85, 0.25),    # cross-competition factor (P↔C)
    "c":     (0.12, 0.05),    # cost of production
    "d":     (0.12, 0.05),
    "g":     (0.60, 0.30),
    "u":     (0.60, 0.15),
    "K_u":   (0.20, 0.08),
    "p_low": (0.12, 0.06),
    "p_high":(2.20, 0.60),
    "H_on":  (0.60, 0.08),
    "H_off": (0.86, 0.04),
    "tau_q": (5.00, 2.00),
    "K_B":   (0.20, 0.08),
}

# ---------- Bounds (tight but realistic) ----------
LBg = np.array([0.18, 0.00, 0.15, 0.55, 0.40, 0.06, 0.06, 0.20, 0.45, 0.10, 0.06, 1.3, 0.50, 0.80, 1.0, 0.10])
UBg = np.array([0.46, 0.14, 0.40, 1.60, 1.40, 0.20, 0.22, 1.40, 0.85, 0.40, 0.28, 3.2, 0.74, 0.92, 10., 0.40])

x0g = np.array([
    PRIOR["r0P"][0], PRIOR["rHP"][0], PRIOR["r0C"][0], PRIOR["K_M"][0], PRIOR["gamma"][0],
    PRIOR["c"][0], PRIOR["d"][0], PRIOR["g"][0], PRIOR["u"][0], PRIOR["K_u"][0],
    PRIOR["p_low"][0], PRIOR["p_high"][0], PRIOR["H_on"][0], PRIOR["H_off"][0],
    PRIOR["tau_q"][0], PRIOR["K_B"][0]
], float)

# ---------- Load & prep data ----------
df = pd.read_csv(INPATH)
if not {"subject_id","sample_id"}.issubset(df.columns):
    raise ValueError("CSV must have subject_id, sample_id")
Hcol = next((c for c in H_COLS if c in df.columns), None)
if Hcol is None:
    raise ValueError("Need H_proxy_meta_smooth or H_proxy_meta")

for c in SCFA_COLS:
    if c not in df.columns:
        raise ValueError(f"Missing SCFA column {c}")

df = df[["subject_id","sample_id",Hcol]+SCFA_COLS].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

def robust_z(s: pd.Series) -> pd.Series:
    x=s.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: return pd.Series(np.zeros_like(x), index=s.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm-med))
    if mad<1e-9:
        q75,q25=np.percentile(xm,[75,25]); iqr=q75-q25
        scale=iqr if iqr>1e-9 else (np.std(xm)+1e-9)
    else:
        scale=mad
    return pd.Series((x-med)/(scale+1e-9), index=s.index)

for c in SCFA_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_z)
df["B_obs"] = df[SCFA_COLS[0]+"_z"] if len(SCFA_COLS)==1 else df[[c+"_z" for c in SCFA_COLS]].mean(axis=1)
df["H_obs"] = df[Hcol].clip(0,1)

subs=[]
for sid, sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx").copy()
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].values.astype(float)
    B=sub["B_obs"].values.astype(float)
    H=sub["H_obs"].values.astype(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({
        "sid":sid,
        "t":t,"B":B,"H":H,
        "maskB":mB,"maskH":mH,
        "nB":int(mB.sum()), "nH":int(mH.sum()),
        "H0":float(np.clip(first(H,0.6),0,1)),
        "B0":float(max(0.05, first(B,0.1))),
    })
if not subs:
    raise RuntimeError("No subject passed minimal filters.")
print(f"[info] Fitting {len(subs)} subjects...")

# ---------- Model ----------
# State y = [P, C, H, B, q]
# p = [r0P, rHP, r0C, K_M, gamma, c, d, g, u, K_u, p_low, p_high, H_on, H_off, tau_q, K_B]
NAMES = ["r0P","rHP","r0C","K_M","gamma","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def q_inf(H,q,H_on,H_off,k=KQ):
    th=(1.0-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B,H,gH,d,K_B,n=HILL_N):
    return gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H

def rhs(t,y,p):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = p

    pB = pL + (pH - pL)*np.clip(q,0,1)

    # gLV-like competition with shared K_M and cross-coefficient gamma
    # dP/dt = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    # dC/dt = C*( r0C           - (C + gamma*P)/K_M )
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )

    dH = dH_hill(B,H,gH,d,K_B)
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return [dP,dC,dH,dB,dq]

def simulate(ts,y0,p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*5)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*5)

# Per-subject observation maps (bounded so dynamics must explain signal)
# for B: B_hat = alpha_B * B
# for H: H_hat = beta0_H + beta1_H * H
x0s=[]; LBs=[]; UBs=[]
for _ in subs:
    x0s += [1.0, 0.0, 1.0]   # alpha_B, beta0_H, beta1_H
    LBs += [0.6, -0.2, 0.8]
    UBs += [1.6,  0.2, 1.2]

x0 = np.concatenate([x0g, np.array(x0s,float)])
LB = np.concatenate([LBg, np.array(LBs,float)])
UB = np.concatenate([UBg, np.array(UBs,float)])

def unpack(x):
    gpar = x[:len(NAMES)]
    triples = np.split(x[len(NAMES):], len(subs))
    return gpar, triples

W_B, W_H = 0.6, 1.2
TOTLEN = sum(S["nB"] + S["nH"] for S in subs) + len(PRIOR)

def residuals(x):
    gpar, triples = unpack(x)
    if not (gpar[13] > gpar[12]):  # H_off > H_on
        return np.full(TOTLEN, PENALTY)
    if gpar[9] <= 0.06 or gpar[15] <= 0.06:  # K_u, K_B positive
        return np.full(TOTLEN, PENALTY)

    res=[]
    for S, tr in zip(subs, triples):
        aB, b0H, b1H = tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        # modest initial microbes: producer + competitor
        P0 = 0.12
        C0 = 0.12
        q0 = 1.0 if H0 < 0.5*(gpar[12]+gpar[13]) else 0.0
        y0=[P0,C0,H0,B0,q0]
        Y=simulate(ts, y0, gpar)
        if np.any(~np.isfinite(Y)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue
        P,C,H,B,q = Y

        Bhat = aB*B
        Hhat = np.clip(b0H + b1H*H, 0, 1)
        b = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hhat[S["maskH"]] - S["H"][S["maskH"]])

        b = np.where(np.isfinite(b), b, PENALTY)
        h = np.where(np.isfinite(h), h, PENALTY)
        res += [W_B*b.astype(float), W_H*h.astype(float)]

    # priors as pseudo-residuals
    idx = {name:i for i,name in enumerate(NAMES)}
    prior_terms=[]
    for name,(mu,sd) in PRIOR.items():
        prior_terms.append( (gpar[idx[name]] - mu)/ (sd + 1e-9) )
    res.append(np.array(prior_terms,float))

    return np.concatenate(res)

fit = least_squares(residuals, x0, bounds=(LB,UB),
                    verbose=2, max_nfev=1200, loss="soft_l1", f_scale=1.0)

gpar_hat, triples_hat = unpack(fit.x)
pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)
pd.DataFrame(
    [{"subject_id":S["sid"], "alpha_B":tr[0], "beta0_H":tr[1], "beta1_H":tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR,"fitted_subject_scales.csv"), index=False)

print("[info] Fitted globals:", dict(zip(NAMES, gpar_hat)))

# quick diagnostics (first few)
for S,tr in list(zip(subs, triples_hat))[:8]:
    aB,b0,b1=tr
    ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
    P0=C0=0.12; q0=1.0 if H0 < 0.5*(gpar_hat[12]+gpar_hat[13]) else 0.0
    y0=[P0,C0,H0,B0,q0]
    sol=solve_ivp(lambda t,z: rhs(t,z,gpar_hat),(ts[0],ts[-1]),y0,t_eval=ts,
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    P,C,H,B,q=sol.y
    Bhat=aB*B; Hhat=np.clip(b0 + b1*H, 0,1)
    mB,mH=S["maskB"],S["maskH"]
    fig,ax=plt.subplots(2,1,figsize=(7,6),sharex=True)
    ax[0].plot(ts[mB],Bhat[mB]); ax[0].scatter(ts[mB],S["B"][mB],s=18,c="k"); ax[0].grid(True,ls=":")
    ax[1].plot(ts[mH],Hhat[mH]); ax[1].scatter(ts[mH],S["H"][mH],s=18,c="k"); ax[1].grid(True,ls=":")
    ax[1].axhline(gpar_hat[12],ls=":",c="gray"); ax[1].axhline(gpar_hat[13],ls="--",c="gray")
    ax[1].set_xlabel("time"); ax[0].set_ylabel("B"); ax[1].set_ylabel("H")
    plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,f"diag_{S['sid']}.png"),dpi=170); plt.close()

print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_basins_guild_hill_satpf.py
# Continuation in d and basins at baseline for the two-guild model.

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.optimize import root
from scipy.integrate import solve_ivp

FIT = "mw_fit_out_guild_hill_satpf/fitted_global_params.csv"
OUT = "mw_bif_guild_hill_satpf"; os.makedirs(OUT, exist_ok=True)

KQ=60.0; HILL_N=2

g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0P,rHP,r0C,K_M,gamma,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[6])

def q_inf(H,q,H_on,H_off): th=(1-q)*H_on+q*H_off; return 1.0/(1.0+np.exp(-KQ*(H-th)))
def dH_hill(B,H,gH,d,K_B,n=HILL_N): return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H

def rhs(y,pvec,d_override=None):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -          (C + gamma*P)/K_M )
    dH = dH_hill(B,H,gH,d,K_B)
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dP,dC,dH,dB,dq], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec, dval, guess):
    sol=root(lambda yy: rhs(yy,pvec,dval), guess, method="hybr")
    if not sol.success: return None, False
    y=sol.x
    y=np.array([max(0,y[0]), max(0,y[1]), np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# (A) continuation in d
d_vals=np.linspace(0.7*d_fit, 1.6*d_fit, 110)
seeds=[np.array([0.15,0.05,0.3,0.05,1.0]),
       np.array([0.05,0.20,0.9,0.10,0.0]),
       np.array([0.30,0.15,0.6,0.15,0.5]),
       np.array([0.05,0.05,0.5,0.05,0.8])]

rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y, ok = find_eq(p, d, y0)
        if not ok: continue
        J=jac_fd(rhs,y,p,d)
        stable=(np.max(np.real(np.linalg.eigvals(J)))<0)
        rows.append({"d":d,"H":float(y[2]),"q":float(y[4]),"seed":wi,"stable":stable})
branches=pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.0))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

# (B) bistable at baseline?
near=branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs=np.sort(near.loc[near["stable"],"H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable=(distinct>=2)

# (C) basins at baseline
Hs=np.linspace(0.2,0.95,17)
qs=np.linspace(0.0,1.0,17)
Z=np.zeros((len(Hs),len(qs)))

def relax(y0,T=360):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,1000),
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    return sol.y[:,-1]

for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0=np.array([0.12,0.12,H0,0.10,q0],float)
        yss=relax(y0)
        Z[i,j]=yss[2]

plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# scan_bistability_2d_guild.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.optimize import root
import numpy.linalg as npl

FIT = "mw_fit_out_guild_hill_satpf/fitted_global_params.csv"
OUT = "mw_scan2d_guild"; os.makedirs(OUT, exist_ok=True)

KQ = 60.0; HILL_N = 2

g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0P,rHP,r0C,K_M,gamma,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p0 = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p0[6])

def q_inf(H,q,H_on,H_off): th=(1-q)*H_on+q*H_off; return 1.0/(1.0+np.exp(-KQ*(H-th)))
def dH(B,H,gH,d,K_B,n=HILL_N): return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H
def rhs(y,pvec,d_override=None):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None: d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C - (C + gamma*P)/K_M )
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH_ = dH(B,H,gH,d,K_B)
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dP,dC,dH_,dB,dq], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec, dval, guess):
    sol = root(lambda yy: rhs(yy,pvec,dval), guess, method="hybr")
    if not sol.success: return None, False
    y = sol.x
    y = np.array([max(0,y[0]), max(0,y[1]), np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

def count_stable_at(pvec, dval, seeds):
    Hs=[]
    for y0 in seeds:
        y, ok = find_eq(pvec, dval, y0)
        if not ok: continue
        J = jac_fd(rhs, y, pvec, dval)
        lam_max = np.max(np.real(npl.eigvals(J)))
        if lam_max < 0:
            Hs.append(float(y[2]))
    Hs = np.array(sorted(Hs))
    if Hs.size == 0: return 0
    distinct = 1
    for i in range(1,len(Hs)):
        if abs(Hs[i]-Hs[i-1]) > 1e-3:
            distinct += 1
    return distinct

seeds = [
    np.array([0.12,0.12,0.30,0.08,1.0]),
    np.array([0.05,0.20,0.85,0.15,0.0]),
    np.array([0.30,0.05,0.55,0.10,0.7]),
    np.array([0.18,0.18,0.65,0.12,0.4]),
]

# --- planes to scan: (nameA, idxA, gridA), (nameB, idxB, gridB)
planes = []

# 1) (u, p_high)
u0, pH0 = p0[8], p0[11]
planes.append( ("u","p_high", 8,11,
                np.linspace(max(0.45,u0-0.25), min(0.85,u0+0.10), 13),
                np.linspace(max(1.4,pH0-0.6),  min(3.0,pH0+0.8),  13)) )

# 2) (gamma, p_high)
ga0 = p0[4]
planes.append( ("gamma","p_high", 4,11,
                np.linspace(max(0.5,ga0-0.3), min(1.3,ga0+0.3), 13),
                np.linspace(max(1.4,pH0-0.6), min(3.0,pH0+0.8), 13)) )

# 3) (K_u, p_high)
Ku0 = p0[9]
planes.append( ("K_u","p_high", 9,11,
                np.linspace(max(0.10,Ku0-0.10), min(0.40,Ku0+0.14), 13),
                np.linspace(max(1.4,pH0-0.6),  min(3.0,pH0+0.8),  13)) )

# 4) (rHP, p_high)
rHP0 = p0[1]
planes.append( ("rHP","p_high", 1,11,
                np.linspace(max(0.00,rHP0-0.05), min(0.15,rHP0+0.06), 13),
                np.linspace(max(1.4,pH0-0.6),   min(3.0,pH0+0.8),   13)) )

for (nameA,nameB,idxA,idxB,gridA,gridB) in planes:
    Z = np.zeros((len(gridA), len(gridB)), int)
    for i,a in enumerate(gridA):
        for j,b in enumerate(gridB):
            p = p0.copy()
            p[idxA] = a; p[idxB] = b
            Z[i,j] = count_stable_at(p, d_fit, seeds)
    df = pd.DataFrame(Z, index=np.round(gridA,3), columns=np.round(gridB,3))
    df.to_csv(os.path.join(OUT, f"stablecount_{nameA}_vs_{nameB}.csv"))
    plt.figure(figsize=(6.6,5.2))
    plt.imshow(Z, origin="lower",
               extent=[gridB[0], gridB[-1], gridA[0], gridA[-1]],
               aspect="auto", cmap="viridis", vmin=0, vmax=3)
    cbar=plt.colorbar(); cbar.set_label("# distinct stable eq at baseline d")
    plt.xlabel(nameB); plt.ylabel(nameA)
    plt.title(f"Distinct stable equilibria at baseline d={d_fit:.3f}\nplane: {nameA} vs {nameB}")
    plt.tight_layout(); plt.savefig(os.path.join(OUT, f"heatmap_{nameA}_vs_{nameB}.png"), dpi=180); plt.close()

print("Saved ->", OUT)


In [None]:
# calibrate_guild_hill_satpf_targeted.py
# As calibrate_guild_hill_satpf.py, but adds soft targets for a few params
# (from the 2D scan) to gently bias the optimizer toward a bistable pocket.

import os, numpy as np, pandas as pd
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_guild_hill_satpf_targeted"; os.makedirs(OUTDIR, exist_ok=True)

# --- choose targets after inspecting the heatmaps from scan_bistability_2d_guild.py ---
# Example targets (EDIT these based on your heatmaps):
TARGETS = {
    "rHP":   (0.117, 0.04),   # weak positive feedback, modest SD
    "u":     (0.725, 0.06),   # a bit lower than your fitted 0.85
    "K_u":   (0.212, 0.06),   # slightly higher saturation half-point
    "gamma": (1.063, 0.18),   # mildly stronger competition
    "p_high":(2.092, 0.30),   # more production headroom
}

# Everything else (data prep/model) mirrors the first two-guild script:
H_COLS    = ["H_proxy_meta_smooth","H_proxy_meta"]
SCFA_COLS = ["butyrate"]
MIN_ROWS  = 4
KQ        = 60.0
PENALTY   = 1e3
HILL_N    = 2

# Baseline priors (kept modest)
PRIOR = {
    "r0P":(0.32,0.08), "rHP":(0.07,0.04), "r0C":(0.28,0.08), "K_M":(1.00,0.25),
    "gamma":(0.85,0.25), "c":(0.12,0.05), "d":(0.12,0.05), "g":(0.60,0.30),
    "u":(0.60,0.15), "K_u":(0.20,0.08), "p_low":(0.12,0.06), "p_high":(2.20,0.60),
    "H_on":(0.60,0.08), "H_off":(0.86,0.04), "tau_q":(5.0,2.0), "K_B":(0.20,0.08)
}

# Bounds (same as earlier)
LBg = np.array([0.18,0.00,0.15,0.55,0.40,0.06,0.06,0.20,0.45,0.10,0.06,1.3,0.50,0.80,1.0,0.10])
UBg = np.array([0.46,0.14,0.40,1.60,1.40,0.20,0.22,1.40,0.85,0.40,0.28,3.2,0.74,0.92,10.,0.40])

x0g = np.array([0.32,0.07,0.28,1.0,0.85,0.12,0.12,0.60,0.60,0.20,0.12,2.20,0.60,0.86,5.0,0.20], float)

# --- load data (same as before; omitted plots to keep short) ---
df = pd.read_csv(INPATH)
Hcol = next((c for c in H_COLS if c in df.columns), None)
for c in SCFA_COLS:
    if c not in df.columns: raise ValueError(f"Missing SCFA column {c}")
df = df[["subject_id","sample_id",Hcol]+SCFA_COLS].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"] = df.groupby("subject_id").cumcount().astype(float)

def robust_z(s):
    import numpy as np, pandas as pd
    x=s.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: return pd.Series(np.zeros_like(x), index=s.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm-med))
    if mad<1e-9:
        q75,q25=np.percentile(xm,[75,25]); iqr=q75-q25
        scale=iqr if iqr>1e-9 else (np.std(xm)+1e-9)
    else:
        scale=mad
    return pd.Series((x-med)/(scale+1e-9), index=s.index)

for c in SCFA_COLS:
    df[c+"_z"] = df.groupby("subject_id")[c].transform(robust_z)
df["B_obs"] = df[SCFA_COLS[0]+"_z"]
df["H_obs"] = df[Hcol].clip(0,1)

subs=[]
for sid, sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx").copy()
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].values.astype(float)
    B=sub["B_obs"].values.astype(float)
    H=sub["H_obs"].values.astype(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({
        "sid":sid,"t":t,"B":B,"H":H,"maskB":mB,"maskH":mH,
        "nB":int(mB.sum()),"nH":int(mH.sum()),
        "H0":float(np.clip(first(H,0.6),0,1)),"B0":float(max(0.05, first(B,0.1)))
    })
if not subs: raise RuntimeError("No subject passed minimal filters.")

# --- model (same as before) ---
NAMES = ["r0P","rHP","r0C","K_M","gamma","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def q_inf(H,q,H_on,H_off,k=60.0):
    th=(1.0-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-k*(H - th)))

def dH_hill(B,H,gH,d,K_B,n=2):
    return gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H

def rhs(t,y,p):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C - (C + gamma*P)/K_M )
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH = dH_hill(B,H,gH,d,K_B)
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return [dP,dC,dH,dB,dq]

from scipy.integrate import solve_ivp

def simulate(ts,y0,p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*5)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*5)

# per-subject obs maps
x0s=[]; LBs=[]; UBs=[]
for _ in subs:
    x0s += [1.0, 0.0, 1.0]
    LBs += [0.6, -0.2, 0.8]
    UBs += [1.6,  0.2, 1.2]

x0 = np.concatenate([x0g, np.array(x0s,float)])
LB = np.concatenate([LBg, np.array(LBs,float)])
UB = np.concatenate([UBg, np.array(UBs,float)])

def unpack(x):
    gpar = x[:len(NAMES)]
    triples = np.split(x[len(NAMES):], len(subs))
    return gpar, triples

W_B, W_H = 0.6, 1.2
TOTLEN = sum(S["nB"] + S["nH"] for S in subs) + len(PRIOR) + len(TARGETS)

def residuals(x):
    gpar, triples = unpack(x)
    if not (gpar[13] > gpar[12]):
        return np.full(TOTLEN, 1e3)
    if gpar[9] <= 0.06 or gpar[15] <= 0.06:
        return np.full(TOTLEN, 1e3)

    res=[]
    for S, tr in zip(subs, triples):
        aB,b0H,b1H = tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        P0=C0=0.12; q0 = 1.0 if H0 < 0.5*(gpar[12]+gpar[13]) else 0.0
        y0=[P0,C0,H0,B0,q0]
        Y=simulate(ts,y0,gpar)
        if np.any(~np.isfinite(Y)):
            res += [W_B*np.full(S["nB"],1e3), W_H*np.full(S["nH"],1e3)]
            continue
        P,C,H,B,q = Y
        Bhat=aB*B; Hhat=np.clip(b0H + b1H*H, 0,1)
        b = (Bhat[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hhat[S["maskH"]] - S["H"][S["maskH"]])
        res += [W_B*b.astype(float), W_H*h.astype(float)]

    # baseline priors
    idx = {nm:i for i,nm in enumerate(NAMES)}
    for name,(mu,sd) in PRIOR.items():
        res.append( np.array([(gpar[idx[name]] - mu)/(sd + 1e-9)]) )

    # targeted priors (small set; stronger pull)
    for name,(mu,sd) in TARGETS.items():
        res.append( np.array([(gpar[idx[name]] - mu)/(sd + 1e-9)]) )

    return np.concatenate(res)

from scipy.optimize import least_squares
fit = least_squares(residuals, x0, bounds=(LB,UB),
                    verbose=2, max_nfev=1200, loss="soft_l1", f_scale=1.0)
gpar_hat, triples_hat = unpack(fit.x)

pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)
pd.DataFrame(
    [{"subject_id":S["sid"], "alpha_B":tr[0], "beta0_H":tr[1], "beta1_H":tr[2]}
     for S,tr in zip(subs, triples_hat)]
).to_csv(os.path.join(OUTDIR,"fitted_subject_scales.csv"), index=False)

print("[info] Fitted globals:", dict(zip(NAMES, gpar_hat)))
print("✅ Done. Outputs in:", OUTDIR)


In [None]:
# bifurcation_basins_guild_hill_satpf_TARGETED.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.optimize import root
from scipy.integrate import solve_ivp
import numpy.linalg as npl

FIT = "mw_fit_out_guild_hill_satpf_targeted/fitted_global_params.csv"
OUT = "mw_bif_guild_hill_satpf_targeted"; os.makedirs(OUT, exist_ok=True)

KQ=60.0; HILL_N=2
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[6])

def q_inf(H,q,H_on,H_off): th=(1-q)*H_on+q*H_off; return 1.0/(1.0+np.exp(-KQ*(H-th)))
def dH(B,H,gH,d,K_B,n=HILL_N): return gH*(B**n/(K_B**n + B**n))*(1-H) - d*H
def rhs(y,pvec,dval=None):
    P,C,H,B,q=y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B=pvec.copy()
    if dval is not None: d=dval
    pB=pL+(pH-pL)*np.clip(q,0,1)
    uptake=u*H*B/(K_u + B + 1e-9)
    return np.array([
        P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M ),
        C*( r0C           -        (C + gamma*P)/K_M ),
        dH(B,H,gH,d,K_B),
        pB*P - uptake,
        (q_inf(H,q,H_on,H_off)-q)/tau
    ], float)

def jac_fd(fun,y,pvec,dval,eps=1e-7):
    f0=fun(y,pvec,dval); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,pvec,dval)-f0)/eps
    return J

def find_eq(pvec,dval,guess):
    sol=root(lambda yy: rhs(yy,pvec,dval), guess, method="hybr")
    if not sol.success: return None, False
    y=sol.x
    y=np.array([max(0,y[0]), max(0,y[1]), np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# continuation in d
d_vals=np.linspace(0.7*d_fit, 1.6*d_fit, 120)
seeds=[np.array([0.15,0.05,0.3,0.05,1.0]),
       np.array([0.05,0.20,0.9,0.10,0.0]),
       np.array([0.30,0.15,0.6,0.15,0.5]),
       np.array([0.05,0.05,0.5,0.05,0.8])]
rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y,ok=find_eq(p,d,y0)
        if not ok: continue
        J=jac_fd(rhs,y,p,d)
        stable=(np.max(np.real(npl.eigvals(J)))<0)
        rows.append({"d":d,"H":float(y[2]),"q":float(y[4]),"seed":wi,"stable":stable})
branches=pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.0))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

# bistability at baseline?
near=branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs=np.sort(near.loc[near["stable"],"H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable=(distinct>=2)

# basins at baseline
from scipy.integrate import solve_ivp
def relax(y0,T=360):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,900),
                  rtol=1e-6,atol=1e-8,max_step=0.5)
    return sol.y[:,-1]

Hs=np.linspace(0.2,0.95,17)
qs=np.linspace(0.0,1.0,17)
Z=np.zeros((len(Hs),len(qs)))
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0=np.array([0.12,0.12,H0,0.10,q0],float)
        yss=relax(y0)
        Z[i,j]=yss[2]

plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0.5, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f}  |  Bistable? {'YES' if bistable else 'NO'}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# scan_bistability_guild_wide.py
# Two-guild model (P,C) + Hill host + saturable uptake + memory.
# Scans a wider pocket with more curvature: Hill n in {3,4}, KQ=100.
# Looks for >=2 distinct stable equilibria at baseline d.

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.optimize import root
import numpy.linalg as npl

FIT = "mw_fit_out_guild_hill_satpf/fitted_global_params.csv"
OUT = "mw_scan_guild_wide"; os.makedirs(OUT, exist_ok=True)

# --- load fitted globals ---
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0P,rHP,r0C,K_M,gamma,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p0 = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p0[6])

def q_inf(H,q,H_on,H_off,KQ):
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(y,pvec,dval,n,KQ):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if dval is not None: d = dval
    pB = pL + (pH - pL)*np.clip(q,0,1)
    # ecology
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C               - (C + gamma*P)/K_M )
    # host and butyrate
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH = gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off,KQ) - q)/tau
    return np.array([dP,dC,dH,dB,dq], float)

def jac_fd(fun,y,args,eps=1e-7):
    f0=fun(y,*args); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,*args)-f0)/eps
    return J

def find_eq(pvec,dval,guess,n,KQ):
    sol=root(lambda yy: rhs(yy,pvec,dval,n,KQ), guess, method="hybr")
    if not sol.success: return None, False
    y=sol.x
    y=np.array([max(0,y[0]), max(0,y[1]), np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

def count_stable(pvec,dval,n,KQ):
    seeds = [
        np.array([0.12,0.12,0.30,0.08,1.0]),
        np.array([0.05,0.20,0.85,0.15,0.0]),
        np.array([0.30,0.08,0.55,0.10,0.6]),
        np.array([0.18,0.18,0.65,0.12,0.4]),
    ]
    Hs=[]
    for y0 in seeds:
        y,ok = find_eq(pvec,dval,y0,n,KQ)
        if not ok: continue
        J = jac_fd(rhs,y,(pvec,dval,n,KQ))
        if np.max(np.real(npl.eigvals(J)))<0:
            Hs.append(float(y[2]))
    if not Hs: return 0
    Hs=np.sort(Hs)
    distinct=1
    for i in range(1,len(Hs)):
        if abs(Hs[i]-Hs[i-1])>1e-3:
            distinct+=1
    return distinct

# --- scan ranges (wider than before, still plausible) ---
grid = {
    "u":      np.linspace(max(0.50,p0[8]-0.25),  min(0.80,p0[8]+0.05),  13),
    "K_u":    np.linspace(max(0.10,p0[9]-0.10),  min(0.30,p0[9]+0.12),  11),
    "gamma":  np.linspace(max(0.70,p0[4]-0.20),  min(1.30,p0[4]+0.44),  13),
    "p_high": np.linspace(max(1.8,p0[11]-0.2),   3.2,                    15),
    "rHP":    np.linspace(max(0.00,p0[1]-0.06),  min(0.15,p0[1]+0.08),  11),
}
hill_list = [3,4]
KQ_list   = [80,100]
# also allow a small baseline d tweak (±20%) to see nearby pockets
d_list    = [d_fit, 0.9*d_fit, 1.1*d_fit]

rows=[]
for n in hill_list:
    for KQ in KQ_list:
        for d0 in d_list:
            for u in grid["u"]:
                for Ku in grid["K_u"]:
                    for ga in grid["gamma"]:
                        for pH in grid["p_high"]:
                            for rHP in grid["rHP"]:
                                p = p0.copy()
                                p[8]=u; p[9]=Ku; p[4]=ga; p[11]=pH; p[1]=rHP
                                sc = count_stable(p, d0, n, KQ)
                                rows.append({"n":n,"KQ":KQ,"d_eval":d0,"u":u,"K_u":Ku,"gamma":ga,"p_high":pH,"rHP":rHP,
                                             "stable_count":sc})

df=pd.DataFrame(rows)
os.makedirs(OUT, exist_ok=True)
out_csv=os.path.join(OUT,"scan_results.csv")
df.to_csv(out_csv,index=False)

# quick top summary
top=df.sort_values(["stable_count","p_high"],ascending=[False,False]).head(20)
top.to_csv(os.path.join(OUT,"top20.csv"),index=False)
print("Saved:", out_csv)
print("Top candidates:\n", top)


In [None]:
# calibrate_and_check_guild_target_from_scan.py
# Uses the same two-guild model and your data, adds soft targets from the scan,
# refits, then immediately rechecks bifurcation & basins at the fitted baseline d.

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares, root
import numpy.linalg as npl

DATA = "timeseries/combined_scfas_table_scored.csv"
FIT_OUT = "mw_fit_out_guild_hill_satpf_target_FROMSCAN"; os.makedirs(FIT_OUT, exist_ok=True)
BIF_OUT = "mw_bif_guild_target_FROMSCAN"; os.makedirs(BIF_OUT, exist_ok=True)

# === Paste a row from mw_scan_guild_wide/top20.csv with stable_count >= 2 ===
TARGETS = {
    # example centers; replace with your own from the scan:
    "n": 3,        # Hill exponent used only for bifurcation/check (not fitted)
    "KQ": 80,     # memory steepness used for check (not fitted)
    "u": 0.77,     "sd_u": 0.05,
    "K_u": 0.18,   "sd_K_u": 0.05,
    "gamma": 1.25, "sd_gamma": 0.15,
    "p_high": 3.20,"sd_p_high": 0.30,
    "rHP": 0.0344,   "sd_rHP": 0.04,
}

# === Model pieces (same as calibration script) ===
H_COLS = ["H_proxy_meta_smooth","H_proxy_meta"]
SCFA_COLS=["butyrate"]; MIN_ROWS=4
PENALTY=1e3; HILL_N=3  # keep n=2 for fitting; use TARGETS["n"] for the check
KQ_fit=60.0

# priors/bounds (same as earlier two-guild calibration)
PRIOR = {"r0P":(0.32,0.08),"rHP":(0.07,0.04),"r0C":(0.28,0.08),"K_M":(1.00,0.25),
         "gamma":(0.85,0.25),"c":(0.12,0.05),"d":(0.12,0.05),"g":(0.60,0.30),
         "u":(0.60,0.15),"K_u":(0.20,0.08),"p_low":(0.12,0.06),"p_high":(2.20,0.60),
         "H_on":(0.60,0.08),"H_off":(0.86,0.04),"tau_q":(5.0,2.0),"K_B":(0.20,0.08)}
LBg=np.array([0.18,0.00,0.15,0.55,0.40,0.06,0.06,0.20,0.45,0.10,0.06,1.3,0.50,0.80,1.0,0.10])
UBg=np.array([0.46,0.14,0.40,1.60,1.40,0.20,0.22,1.40,0.85,0.40,0.28,3.2,0.74,0.92,10.,0.40])
x0g=np.array([0.32,0.07,0.28,1.0,0.85,0.12,0.12,0.60,0.60,0.20,0.12,2.2,0.60,0.86,5.0,0.20],float)

def q_inf(H,q,H_on,H_off,KQ):
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs_fit(t,y,p):  # n=2, KQ_fit for calibration
    P,C,H,B,q=y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B=p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH = gH*(B**2/(K_B**2 + B**2))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off,KQ_fit) - q)/tau
    return [dP,dC,dH,dB,dq]

def simulate(ts,y0,p):
    from scipy.integrate import solve_ivp
    try:
        sol=solve_ivp(lambda t,z: rhs_fit(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*5)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*5)

# --- load data
df=pd.read_csv(DATA)
Hcol=next((c for c in H_COLS if c in df.columns), None)
for c in SCFA_COLS:
    if c not in df.columns: raise ValueError(f"Missing {c}")
df=df[["subject_id","sample_id",Hcol]+SCFA_COLS].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"]=df.groupby("subject_id").cumcount().astype(float)

def robust_z(s):
    x=s.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: import pandas as pd; return pd.Series(np.zeros_like(x), index=s.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm-med))
    scale = mad if mad>1e-9 else (np.percentile(xm,75)-np.percentile(xm,25) or np.std(xm)+1e-9)
    import pandas as pd; return pd.Series((x-med)/(scale+1e-9), index=s.index)

df["B_z"]=df.groupby("subject_id")[SCFA_COLS[0]].transform(robust_z)
df["H_obs"]=df[Hcol].clip(0,1)

subs=[]
for sid,sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx").copy()
    if len(sub)<4: continue
    t=sub["t_idx"].values.astype(float)
    B=sub["B_z"].values.astype(float)
    H=sub["H_obs"].values.astype(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        import numpy as np
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({"sid":sid,"t":t,"B":B,"H":H,"maskB":mB,"maskH":mH,
                 "nB":int(mB.sum()),"nH":int(mH.sum()),
                 "H0":float(np.clip(first(H,0.6),0,1)),
                 "B0":float(max(0.05, first(B,0.1)))})
if not subs: raise RuntimeError("no usable subjects")

# per-subject obs maps
x0s=[];LBs=[];UBs=[]
for _ in subs:
    x0s += [1.0, 0.0, 1.0]
    LBs += [0.6, -0.2, 0.8]
    UBs += [1.6,  0.2, 1.2]

x0=np.concatenate([x0g, np.array(x0s,float)])
LB=np.concatenate([LBg, np.array(LBs,float)])
UB=np.concatenate([UBg, np.array(UBs,float)])

NAMES=["r0P","rHP","r0C","K_M","gamma","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]
def unpack(x):
    gpar=x[:len(NAMES)]
    triples=np.split(x[len(NAMES):], len(subs))
    return gpar, triples

W_B,W_H=0.6,1.2
TOT = sum(S["nB"]+S["nH"] for S in subs) + len(PRIOR) + 5  # 5 targeted priors

def residuals(x):
    gpar,triples=unpack(x)
    if not (gpar[13] > gpar[12]): return np.full(TOT, PENALTY)
    if gpar[9] <= 0.06 or gpar[15] <= 0.06: return np.full(TOT, PENALTY)
    res=[]
    for S,tr in zip(subs, triples):
        aB,b0,b1 = tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        P0=C0=0.12; q0 = 1.0 if H0 < 0.5*(gpar[12]+gpar[13]) else 0.0
        y0=[P0,C0,H0,B0,q0]
        Y=simulate(ts,y0,gpar)
        if np.any(~np.isfinite(Y)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue
        P,C,H,B,q = Y
        Bh = aB*B; Hh=np.clip(b0 + b1*H,0,1)
        b = (Bh[S["maskB"]] - S["B"][S["maskB"]])
        h = (Hh[S["maskH"]] - S["H"][S["maskH"]])
        res += [W_B*b.astype(float), W_H*h.astype(float)]
    # baseline priors
    idx={nm:i for i,nm in enumerate(NAMES)}
    for nm,(mu,sd) in PRIOR.items():
        res.append( np.array([(gpar[idx[nm]]-mu)/(sd+1e-9)]) )
    # targeted priors
    res.append( np.array([(gpar[idx["u"]]      - TARGETS["u"])/     (TARGETS["sd_u"]+1e-9)]) )
    res.append( np.array([(gpar[idx["K_u"]]    - TARGETS["K_u"])/   (TARGETS["sd_K_u"]+1e-9)]) )
    res.append( np.array([(gpar[idx["gamma"]]  - TARGETS["gamma"])/ (TARGETS["sd_gamma"]+1e-9)]) )
    res.append( np.array([(gpar[idx["p_high"]] - TARGETS["p_high"])/(TARGETS["sd_p_high"]+1e-9)]) )
    res.append( np.array([(gpar[idx["rHP"]]    - TARGETS["rHP"])/   (TARGETS["sd_rHP"]+1e-9)]) )
    return np.concatenate(res)

fit = least_squares(residuals, x0, bounds=(LB,UB), verbose=2, max_nfev=1400, loss="soft_l1", f_scale=1.0)
gpar_hat,triples_hat=unpack(fit.x)
pd.Series(gpar_hat,index=NAMES).to_csv(os.path.join(FIT_OUT,"fitted_global_params.csv"), header=False)

print("[fit] globals:", dict(zip(NAMES, gpar_hat)))

# === Recheck bifurcation at the new fit using n=TARGETS['n'], KQ=TARGETS['KQ'] ===
def rhs_check(y,pvec,dval):
    P,C,H,B,q=y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if dval is not None: d=dval
    pB=pL+(pH-pL)*np.clip(q,0,1)
    uptake=u*H*B/(K_u + B + 1e-9)
    n=int(TARGETS["n"]); KQ=int(TARGETS["KQ"])
    dP=P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC=C*( r0C           -        (C + gamma*P)/K_M )
    dB=pB*P - uptake
    dH=gH*(B**n/(K_B**n + B**n))*(1-H) - d*H
    dq=(1.0/(1.0+np.exp(-KQ*((H - ((1-q)*H_on + q*H_off)))))-q)/tau
    return np.array([dP,dC,dH,dB,dq],float)

def jac_fd(fun,y,args,eps=1e-7):
    f0=fun(y,*args); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,*args)-f0)/eps
    return J

from scipy.optimize import root
def find_eq(pvec,dval,guess):
    sol=root(lambda yy: rhs_check(yy,pvec,dval), guess, method="hybr")
    if not sol.success: return None,False
    y=sol.x
    y=np.array([max(0,y[0]), max(0,y[1]), np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    return (y, True) if np.all(np.isfinite(y)) else (None, False)

d_fit_new=float(gpar_hat[6])
ds=np.linspace(0.7*d_fit_new, 1.6*d_fit_new, 120)
seeds=[np.array([0.12,0.12,0.3,0.1,1.0]), np.array([0.05,0.2,0.85,0.1,0.0]),
       np.array([0.3,0.05,0.55,0.1,0.6]), np.array([0.15,0.15,0.65,0.1,0.4])]

rows=[]
for dval in ds:
    for wi,y0 in enumerate(seeds):
        y,ok=find_eq(gpar_hat,dval,y0)
        if not ok: continue
        J=jac_fd(rhs_check,y,(gpar_hat,dval))
        stable=(np.max(np.real(npl.eigvals(J)))<0)
        rows.append({"d":dval,"H":float(y[2]),"q":float(y[4]),"seed":wi,"stable":stable})
branches=pd.DataFrame(rows)
branches.to_csv(os.path.join(BIF_OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.0))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit_new, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(BIF_OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

near=branches[np.isclose(branches["d"], d_fit_new, atol=1e-3)]
distinct=0
if not near.empty:
    Hs=np.sort(near.loc[near["stable"],"H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable=(distinct>=2)
with open(os.path.join(BIF_OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit_new:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")
print("Saved:", BIF_OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# bifurcation_guild_use_targets.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.optimize import root
import numpy.linalg as npl

FIT = "mw_fit_out_guild_hill_satpf/fitted_global_params.csv"
OUT = "mw_bif_guild_use_targets"; os.makedirs(OUT, exist_ok=True)

# <<< paste your chosen row from the scan here >>>
TARGETS = {
    "n": 3,    # Hill exponent for host benefit
    "KQ": 80,  # memory steepness
    "u": 0.77,
    "K_u": 0.18,
    "gamma": 1.25,
    "p_high": 3.20,
    "rHP": 0.0344,
    # optionally, if your top row used d_eval != d_fit, set it here:
    "d_override": None,   # e.g., 0.9 * d_fit  -> put a float to test
}

g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0P,rHP,r0C,K_M,gamma,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)

# Override with targets
idx = {"r0P":0,"rHP":1,"r0C":2,"K_M":3,"gamma":4,"c":5,"d":6,"g":7,"u":8,"K_u":9,
       "p_low":10,"p_high":11,"H_on":12,"H_off":13,"tau_q":14,"K_B":15}
p[idx["rHP"]]   = TARGETS["rHP"]
p[idx["u"]]     = TARGETS["u"]
p[idx["K_u"]]   = TARGETS["K_u"]
p[idx["gamma"]] = TARGETS["gamma"]
p[idx["p_high"]]= TARGETS["p_high"]

d_fit = float(p[6])
if TARGETS.get("d_override") is not None:
    d_fit = float(TARGETS["d_override"])

def q_inf(H,q,H_on,H_off,KQ):
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(y,pvec,dval):
    n   = int(TARGETS["n"]); KQ = int(TARGETS["KQ"])
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if dval is not None: d = dval
    pB = pL + (pH - pL)*np.clip(q,0,1)
    # ecology
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C               - (C + gamma*P)/K_M )
    # host & butyrate
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH = gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off,KQ) - q)/tau
    return np.array([dP,dC,dH,dB,dq], float)

def jac_fd(fun,y,args,eps=1e-7):
    f0=fun(y,*args); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,*args)-f0)/eps
    return J

def find_eq(pvec,dval,guess):
    sol=root(lambda yy: rhs(yy,pvec,dval), guess, method="hybr")
    if not sol.success: return None,False
    y=sol.x
    y=np.array([max(0,y[0]), max(0,y[1]), np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None,False
    return y,True

# continuation in d around (possibly overridden) baseline
ds=np.linspace(0.7*d_fit, 1.6*d_fit, 120)
seeds=[np.array([0.12,0.12,0.3,0.1,1.0]),
       np.array([0.05,0.2,0.9,0.1,0.0]),
       np.array([0.3,0.05,0.55,0.1,0.6]),
       np.array([0.15,0.15,0.65,0.1,0.4])]
rows=[]
for d in ds:
    for wi,y0 in enumerate(seeds):
        y,ok=find_eq(p,d,y0)
        if not ok: continue
        lam_max = np.max(np.real(npl.eigvals(jac_fd(rhs,y,(p,d)))))
        rows.append({"d":d,"H":float(y[2]),"seed":wi,"stable":(lam_max<0)})

branches=pd.DataFrame(rows)
os.makedirs(OUT, exist_ok=True)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

plt.figure(figsize=(7.2,5.0))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=22, marker=mk, alpha=0.6, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d (tested)")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

near=branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs=np.sort(near.loc[near["stable"],"H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable=(distinct>=2)
with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"d_tested = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at this d: {distinct}\n")
    f.write(f"Bistable at tested d? {'YES' if bistable else 'NO'}\n")
print("Saved:", OUT, "| Bistable at tested d? ", "YES" if bistable else "NO")


In [None]:
# calibrate_guild_hard_targets.py
import os, numpy as np, pandas as pd
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_guild_hard_targets"; os.makedirs(OUTDIR, exist_ok=True)

# === anchor around the pocket that gave 2 stable eq in the scan ===
TARGETS = {
    "u": 0.77,     "sd_u": 0.03,    # tighter
    "K_u": 0.18,   "sd_K_u": 0.03,
    "gamma": 1.25, "sd_gamma": 0.10,
    "p_high": 3.20,"sd_p_high": 0.20,
    "rHP": 0.0344, "sd_rHP": 0.02,
    # If your best pocket appeared at d_eval different from baseline, anchor d too:
    # "d": 0.054,   "sd_d": 0.006,
}

H_COLS=["H_proxy_meta_smooth","H_proxy_meta"]
SCFA=["butyrate"]; MIN_ROWS=4
KQ=80.0; HILL_N=3
PENALTY=1e3

# Baseline priors (kept modest)
PRIOR={"r0P":(0.32,0.08),"rHP":(0.07,0.04),"r0C":(0.28,0.08),"K_M":(1.0,0.25),
       "gamma":(0.85,0.25),"c":(0.12,0.05),"d":(0.12,0.05),"g":(0.60,0.30),
       "u":(0.60,0.15),"K_u":(0.20,0.08),"p_low":(0.12,0.06),"p_high":(2.20,0.60),
       "H_on":(0.60,0.08),"H_off":(0.86,0.04),"tau_q":(5.0,2.0),"K_B":(0.20,0.08)}

# Start from your previous two-guild bounds, but **shrink** key ones around TARGETS
def band(c, w_lo, w_hi): return c - w_lo, c + w_hi
uL,uU   = band(TARGETS["u"],   0.05, 0.05)
KuL,KuU = band(TARGETS["K_u"], 0.06, 0.06)
gaL,gaU = band(TARGETS["gamma"], 0.20, 0.15)
pHL,pHU = band(TARGETS["p_high"], 0.40, 0.30)
# other bounds (safe defaults)
LBg=np.array([0.18,0.00,0.15,0.55,gaL,0.06,0.06,0.20,uL,KuL,0.06,pHL,0.50,0.80,1.0,0.10])
UBg=np.array([0.46,0.14,0.40,1.60,gaU,0.20,0.22,1.40,uU,KuU,0.28,pHU,0.74,0.92,10.,0.40])
# Optional: also narrow d if you know the pocket needed d slightly lower/higher.
if "d" in TARGETS:
    LBg[6], UBg[6] = band(TARGETS["d"], 0.01, 0.01)

x0g=np.array([0.32,0.05,0.28,1.0,1.10,0.12,0.10,0.70, TARGETS["u"], TARGETS["K_u"],
              0.12, TARGETS["p_high"], 0.60,0.86,5.5,0.20], float)

# --- data prep (same as before; omitted plotting for brevity) ---
df=pd.read_csv(INPATH)
Hcol=next((c for c in H_COLS if c in df.columns), None)
if Hcol is None: raise ValueError("Need H proxy col")
for c in SCFA:
    if c not in df.columns: raise ValueError(f"Missing {c}")
df=df[["subject_id","sample_id",Hcol]+SCFA].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"]=df.groupby("subject_id").cumcount().astype(float)

def robust_z(s):
    x=s.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: import pandas as pd; return pd.Series(np.zeros_like(x), index=s.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm-med))
    scale=mad if mad>1e-9 else (np.percentile(xm,75)-np.percentile(xm,25) or np.std(xm)+1e-9)
    import pandas as pd; return pd.Series((x-med)/(scale+1e-9), index=s.index)

df["B_z"]=df.groupby("subject_id")[SCFA[0]].transform(robust_z)
df["H_obs"]=df[Hcol].clip(0,1)

subs=[]
for sid,sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx")
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].to_numpy(float)
    B=sub["B_z"].to_numpy(float); H=sub["H_obs"].to_numpy(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        import numpy as np
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({"sid":sid,"t":t,"B":B,"H":H,"maskB":mB,"maskH":mH,
                 "nB":int(mB.sum()),"nH":int(mH.sum()),
                 "H0":float(np.clip(first(H,0.6),0,1)),
                 "B0":float(max(0.05, first(B,0.1)))})
if not subs: raise RuntimeError("no subjects")

# --- model (same RHS as before for calibration; n=3, KQ=80) ---
def q_inf(H,q,H_on,H_off,KQ):
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(t,y,p):
    P,C,H,B,q=y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    n = 3
    dH = gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off,KQ) - q)/tau
    return [dP,dC,dH,dB,dq]

def simulate(ts,y0,p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*5)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*5)

# per-subject obs maps
x0s=[]; LBs=[]; UBs=[]
for _ in subs: x0s += [1.0, 0.0, 1.0]; LBs += [0.6, -0.2, 0.8]; UBs += [1.6, 0.2, 1.2]
x0=np.concatenate([x0g, np.array(x0s,float)])
LB=np.concatenate([LBg, np.array(LBs,float)])
UB=np.concatenate([UBg, np.array(UBs,float)])

NAMES=["r0P","rHP","r0C","K_M","gamma","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]
def unpack(x):
    gpar=x[:len(NAMES)]
    triples=np.split(x[len(NAMES):], len(subs))
    return gpar, triples

W_B,W_H=0.6,1.2
# heavier weights for target priors
WT=6.0

def residuals(x):
    gpar, triples = unpack(x)
    if not (gpar[13] > gpar[12]): return np.ones(1000)*PENALTY
    res=[]
    for S,tr in zip(subs, triples):
        aB,b0,b1=tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        P0=C0=0.12; q0=1.0 if H0 < 0.5*(gpar[12]+gpar[13]) else 0.0
        y0=[P0,C0,H0,B0,q0]
        Y=simulate(ts,y0,gpar)
        if np.any(~np.isfinite(Y)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue
        P,C,H,B,q=Y
        Bh=aB*B; Hh=np.clip(b0 + b1*H,0,1)
        res += [W_B*(Bh[S["maskB"]] - S["B"][S["maskB"]]),
                W_H*(Hh[S["maskH"]] - S["H"][S["maskH"]])]
    # modest baseline priors
    idx={nm:i for i,nm in enumerate(NAMES)}
    for nm,(mu,sd) in PRIOR.items():
        res.append(np.array([(gpar[idx[nm]]-mu)/(sd+1e-9)]))
    # HARD targets
    res.append(WT*np.array([(gpar[idx["u"]]      - TARGETS["u"])/     (TARGETS["sd_u"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["K_u"]]    - TARGETS["K_u"])/   (TARGETS["sd_K_u"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["gamma"]]  - TARGETS["gamma"])/ (TARGETS["sd_gamma"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["p_high"]] - TARGETS["p_high"])/(TARGETS["sd_p_high"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["rHP"]]    - TARGETS["rHP"])/   (TARGETS["sd_rHP"]+1e-9)]))
    if "d" in TARGETS:
        res.append(WT*np.array([(gpar[idx["d"]] - TARGETS["d"])/(TARGETS["sd_d"]+1e-9)]))
    return np.concatenate(res)

fit=least_squares(residuals, x0, bounds=(LB,UB), verbose=2, max_nfev=1500, loss="soft_l1", f_scale=1.0)
gpar_hat,_=unpack(fit.x)
pd.Series(gpar_hat,index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)
print("[info] globals:", dict(zip(NAMES, gpar_hat)))
print("Saved:", OUTDIR)


In [None]:
# bifurcation_basins_from_hard_fit.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.optimize import root
import numpy.linalg as npl
from scipy.integrate import solve_ivp

FIT = "mw_fit_out_guild_hard_targets/fitted_global_params.csv"   # <- your latest fit
OUT = "mw_bif_guild_hard_fit"; os.makedirs(OUT, exist_ok=True)

# Curvature / memory for the check (tune if needed)
N_HILL = 4
KQ     = 80

# Load fitted globals
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
# p = [r0P,rHP,r0C,K_M,gamma,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[6])

def q_inf(H,q,H_on,H_off):
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(y,pvec,d_override=None):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None:
        d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    # ecology
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    # butyrate & host
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH = gH*(B**N_HILL/(K_B**N_HILL + B**N_HILL))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dP,dC,dH,dB,dq], float)

def jac_fd(fun,y,args,eps=1e-7):
    f0=fun(y,*args); J=np.zeros((5,5))
    for i in range(5):
        y2=y.copy(); y2[i]+=eps
        J[:,i]=(fun(y2,*args)-f0)/eps
    return J

def find_eq(pvec,dval,guess):
    sol=root(lambda yy: rhs(yy,pvec,dval), guess, method="hybr")
    if not sol.success:
        return None, False
    y=sol.x
    y=np.array([max(0,y[0]), max(0,y[1]),
                np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)
    if not np.all(np.isfinite(y)): return None, False
    return y, True

# ---- Continuation in d
d_vals = np.linspace(0.7*d_fit, 1.6*d_fit, 140)
seeds = [
    np.array([0.12,0.12,0.30,0.10,1.0]),
    np.array([0.05,0.20,0.90,0.12,0.0]),
    np.array([0.30,0.08,0.55,0.10,0.6]),
    np.array([0.15,0.15,0.65,0.12,0.4]),
]
rows=[]
for d in d_vals:
    for wi,y0 in enumerate(seeds):
        y, ok = find_eq(p, d, y0)
        if not ok: continue
        lam_max = np.max(np.real(npl.eigvals(jac_fd(rhs, y, (p,d)))))
        rows.append({"d":d, "H":float(y[2]), "seed":wi, "stable":(lam_max<0)})

branches = pd.DataFrame(rows)
branches.to_csv(os.path.join(OUT,"branches.csv"), index=False)

# Plot bifurcation
plt.figure(figsize=(7.6,5.2))
for wi in sorted(branches["seed"].unique()):
    sub=branches[branches["seed"]==wi]
    plt.plot(sub["d"], sub["H"], ".", ms=3, alpha=0.7, label=f"seed{wi}")
for st, mk in [(True,"o"),(False,"x")]:
    sub=branches[branches["stable"]==st]
    plt.scatter(sub["d"], sub["H"], s=24, marker=mk, alpha=0.7, label=("stable" if st else "unstable"))
plt.axvline(d_fit, ls="--", c="gray", label="baseline d")
plt.xlabel("d (1/h)"); plt.ylabel("H*"); plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bifurcation_H_vs_d.png"), dpi=180); plt.close()

# Count distinct stable equilibria at baseline
near = branches[np.isclose(branches["d"], d_fit, atol=1e-3)]
distinct=0
if not near.empty:
    Hs = np.sort(near.loc[near["stable"],"H"].values)
    if Hs.size:
        distinct=1
        for i in range(1,len(Hs)):
            if abs(Hs[i]-Hs[i-1])>1e-3:
                distinct+=1
bistable = (distinct>=2)

# ---- Basins at baseline
def relax(y0,T=360):
    sol=solve_ivp(lambda t,z: rhs(z,p,d_fit),(0,T),y0,t_eval=np.linspace(0,T,900),
                  rtol=1e-6, atol=1e-8, max_step=0.5)
    return sol.y[:,-1]

Hs = np.linspace(0.20, 0.95, 19)
qs = np.linspace(0.0, 1.0, 19)
Z = np.zeros((len(Hs), len(qs)))
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0 = np.array([0.12,0.12,H0,0.10,q0], float)
        yss = relax(y0)
        Z[i,j] = yss[2]

plt.figure(figsize=(6.8,5.4))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0.4, vmax=1.0, cmap="viridis")
plt.colorbar(label="Final H (steady)")
plt.xlabel("initial q"); plt.ylabel("initial H")
plt.title(f"Basins at baseline d={d_fit:.3f} | Bistable? {'YES' if bistable else 'NO'}")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"basins_heatmap.png"), dpi=180); plt.close()

with open(os.path.join(OUT,"diagnosis.txt"),"w") as f:
    f.write(f"Baseline d = {d_fit:.5f}\n")
    f.write(f"Distinct stable equilibria at baseline: {distinct}\n")
    f.write(f"Bistable at baseline? {'YES' if bistable else 'NO'}\n")

print("Saved ->", OUT, "| Bistable at baseline? ", "YES" if bistable else "NO")


In [None]:
# export_equilibria_and_basins_fallback.py
# Robust equilibria export with basin-scan fallback for stiff/hysteretic pockets.

import os, numpy as np, pandas as pd, numpy.linalg as npl
from scipy.optimize import root
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

from mw_model_constants import FIT_PATH, N_HILL, KQ, D_OVERRIDE

OUT = "mw_eq_export"; os.makedirs(OUT, exist_ok=True)

# ----- Load fitted globals -----
g = pd.read_csv(FIT_PATH, index_col=0, header=None).squeeze("columns")
# p = [r0P,rHP,r0C,K_M,gamma,c,d,g,u,K_u,p_low,p_high,H_on,H_off,tau_q,K_B]
p = np.array([float(g[k]) for k in g.index.values], float)
if D_OVERRIDE is not None:
    p[6] = float(D_OVERRIDE)
d_fit = float(p[6])

# ----- Model -----
def q_inf(H,q,H_on,H_off):
    th = (1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(y, pvec):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec
    pB = pL + (pH - pL)*np.clip(q,0,1)
    # ecology
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    # butyrate & host
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    dH = gH*(B**N_HILL/(K_B**N_HILL + B**N_HILL))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dP,dC,dH,dB,dq], float)

def jac_fd(fun, y, args=(), eps=1e-6):
    f0 = fun(y, *args); J = np.zeros((5,5))
    for i in range(5):
        y2 = y.copy(); y2[i] += eps
        J[:, i] = (fun(y2, *args) - f0) / eps
    return J

def clamp_state(y):
    return np.array([max(0,y[0]), max(0,y[1]),
                     np.clip(y[2],0,1.2), max(0,y[3]), np.clip(y[4],0,1.2)], float)

# ----- Root-finding attempt (multi-start) -----
seeds = [
    np.array([0.12,0.12,0.30,0.10,1.0]),
    np.array([0.05,0.20,0.90,0.12,0.0]),
    np.array([0.30,0.08,0.55,0.10,0.6]),
    np.array([0.15,0.15,0.65,0.12,0.4]),
    np.array([0.25,0.05,0.75,0.15,0.8]),
    np.array([0.08,0.25,0.85,0.10,0.2]),
]
rows=[]
for s in seeds:
    sol = root(lambda yy: rhs(yy, p), s, method="hybr")
    if not sol.success: continue
    y = clamp_state(sol.x)
    if not np.all(np.isfinite(y)): continue
    lam = np.real(npl.eigvals(jac_fd(lambda z: rhs(z,p), y)))
    rows.append({"P":y[0],"C":y[1],"H":y[2],"B":y[3],"q":y[4],
                 "lam_max":float(np.max(lam)),"stable":bool(np.max(lam)<0)})

eqs = pd.DataFrame(rows).sort_values("H").reset_index(drop=True)

# De-duplicate by H
def dedup_by_H(df, tol=1e-3):
    if df.empty: return df
    kept=[df.iloc[0]]
    for i in range(1,len(df)):
        if abs(df.iloc[i]["H"] - kept[-1]["H"]) > tol:
            kept.append(df.iloc[i])
    return pd.DataFrame(kept).reset_index(drop=True)

eqs = dedup_by_H(eqs)

# If we already have ≥2 stables, save & plot basins and return
def save_all_and_exit(eqs_df, note="root-based"):
    eqs_df.to_csv(f"{OUT}/equilibria.csv", index=False)
    # Basins heatmap
    def relax(y0,T=900):
        sol=solve_ivp(lambda t,z: rhs(z,p),(0,T),y0,t_eval=np.linspace(0,T,1200),
                      rtol=1e-6, atol=1e-8, max_step=0.5)
        return sol.y[:,-1]
    Hs=np.linspace(0.15,0.95,25); qs=np.linspace(0.0,1.0,25)
    Z=np.zeros((len(Hs),len(qs)))
    for i,H0 in enumerate(Hs):
        for j,q0 in enumerate(qs):
            y0=np.array([0.12,0.12,H0,0.10,q0],float)
            Z[i,j]=relax(y0)[2]
    plt.figure(figsize=(6.6,5.2))
    plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]], aspect="auto", cmap="viridis")
    plt.colorbar(label="final H")
    plt.xlabel("q0"); plt.ylabel("H0")
    plt.title(f"Basins @ baseline (exported via {note})")
    plt.tight_layout(); plt.savefig(f"{OUT}/basins.png"), plt.close()
    print(f"Saved -> {OUT} | equilibria.csv + basins.png ({note})")

# Check if already bistable
if not eqs.empty and (eqs["stable"].sum() >= 2):
    save_all_and_exit(eqs, note="root-only")
    raise SystemExit

# ----- Fallback: basin scan to infer two attractors -----
print("[info] Root-finder recovered <2 stable eq. Trying basin-scan fallback...")
def relax(y0, T=1200):
    sol=solve_ivp(lambda t,z: rhs(z,p),(0,T),y0,t_eval=np.linspace(0,T,1400),
                  rtol=1e-6, atol=1e-8, max_step=0.5)
    return clamp_state(sol.y[:,-1])

# dense grid; increase if needed
Hs = np.linspace(0.12, 0.98, 29)
qs = np.linspace(0.0,  1.0,  29)
endpoints=[]
for H0 in Hs:
    for q0 in qs:
        y0 = np.array([0.12,0.12,H0,0.10,q0], float)
        yss= relax(y0, T=1400)
        endpoints.append(yss)
EP = np.array(endpoints)          # shape (N,5)
H_end = EP[:,2]
H_sorted = np.sort(H_end)

# detect a gap in final H to split into two clusters
gaps = np.diff(H_sorted)
if gaps.size == 0 or np.max(gaps) < 0.03:   # require at least ~0.03 separation
    # no clear two-basin structure found -> keep the monostable export (if any rows) and stop
    # still save whatever we had (could be monostable)
    if not eqs.empty:
        save_all_and_exit(eqs, note="root-only (monostable)")
    else:
        # fabricate a single 'equilibrium' by relaxing from mid initial
        y_mid = relax(np.array([0.12,0.12,0.5,0.10,0.5],float))
        lam = np.real(npl.eigvals(jac_fd(lambda z: rhs(z,p), y_mid)))
        df = pd.DataFrame([{"P":y_mid[0],"C":y_mid[1],"H":y_mid[2],"B":y_mid[3],"q":y_mid[4],
                            "lam_max":float(np.max(lam)),"stable":bool(np.max(lam)<0)}])
        save_all_and_exit(df, note="fabricated-single (no split)")
    raise SystemExit

# split by largest gap
k = np.argmax(gaps)
H_cut = 0.5*(H_sorted[k] + H_sorted[k+1])
low_idxs  = np.where(H_end <= H_cut)[0]
high_idxs = np.where(H_end >  H_cut)[0]
low_mean  = EP[low_idxs].mean(axis=0)
high_mean = EP[high_idxs].mean(axis=0)
# representatives nearest to cluster means
def nearest_idx(X, v):
    return np.argmin(np.sum((X - v)**2, axis=1))
i_low  = nearest_idx(EP[low_idxs],  low_mean)
i_high = nearest_idx(EP[high_idxs], high_mean)
y_low0  = EP[low_idxs][i_low]
y_high0 = EP[high_idxs][i_high]

# short refinement via root (optional but helpful)
def refine(y_start):
    try:
        sol = root(lambda yy: rhs(yy, p), y_start, method="hybr")
        if sol.success:
            y = clamp_state(sol.x)
        else:
            y = clamp_state(y_start)
    except Exception:
        y = clamp_state(y_start)
    lam = np.real(npl.eigvals(jac_fd(lambda z: rhs(z,p), y)))
    return y, float(np.max(lam))

y_low,  lam_low  = refine(y_low0)
y_high, lam_high = refine(y_high0)

eq_fallback = pd.DataFrame([
    {"P":y_low[0],"C":y_low[1],"H":y_low[2],"B":y_low[3],"q":y_low[4],
     "lam_max":lam_low,  "stable": bool(lam_low  < 0)},
    {"P":y_high[0],"C":y_high[1],"H":y_high[2],"B":y_high[3],"q":y_high[4],
     "lam_max":lam_high, "stable": bool(lam_high < 0)},
]).sort_values("H").reset_index(drop=True)

# (optional) attempt to locate an intermediate saddle by root from mid-H
mid_guess = 0.5*(y_low + y_high); mid_guess[2] = 0.5*(y_low[2] + y_high[2])
try:
    sol_mid = root(lambda yy: rhs(yy,p), mid_guess, method="hybr")
    if sol_mid.success:
        y_mid = clamp_state(sol_mid.x)
        lam_mid = np.real(npl.eigvals(jac_fd(lambda z: rhs(z,p), y_mid)))
        eq_fallback = pd.concat([eq_fallback,
                                 pd.DataFrame([{"P":y_mid[0],"C":y_mid[1],"H":y_mid[2],"B":y_mid[3],"q":y_mid[4],
                                                "lam_max":float(np.max(lam_mid)),"stable": bool(np.max(lam_mid)<0)}])],
                                ignore_index=True).sort_values("H").reset_index(drop=True)
except Exception:
    pass

save_all_and_exit(eq_fallback, note="basin-fallback")


In [None]:
# 07_mw_baseline_export.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import load_params, find_equilibria, relax_to_ss
from mw_model_constants import N_HILL, KQ

OUT = "results/baseline_export"; os.makedirs(OUT, exist_ok=True)
p = load_params()

# Equilibria
eqs = find_equilibria(p)
eqs.to_csv(f"{OUT}/equilibria.csv", index=False)

# Basins heatmap in (H0, q0)
Hs = np.linspace(0.15, 0.95, 25)
qs = np.linspace(0.0, 1.0, 25)
Z = np.zeros((len(Hs), len(qs)))
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        y0 = np.array([0.12,0.12,H0,0.10,q0], float)
        yss, _ = relax_to_ss(p, y0, T=900)
        Z[i,j] = yss[2]

plt.figure(figsize=(6.6,5.2))
plt.imshow(Z, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", cmap="viridis")
plt.colorbar(label="final H")
plt.xlabel("q0"); plt.ylabel("H0")
plt.title(f"Basins @ baseline | N_HILL={N_HILL}, KQ={KQ}")
plt.tight_layout(); plt.savefig(f"{OUT}/basins.png"), plt.close()

print("Saved ->", OUT)


In [None]:
# 01_mw_phase_diagrams.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import load_params, bistability_by_multistart

OUT = "results/phase_diagrams"; os.makedirs(OUT, exist_ok=True)
p0 = load_params()

def scan_pair(change_fn, grid_x, grid_y, label_x, label_y, tag, n_inits=80):
    """ change_fn(p_base, x, y) -> new_p """
    A = np.zeros((len(grid_y), len(grid_x)))  # 1 if bistable else 0
    rows=[]
    for j,y in enumerate(grid_y):
        for i,x in enumerate(grid_x):
            p = change_fn(p0.copy(), x, y)
            n_states, _ = bistability_by_multistart(p, n_inits=n_inits)
            A[j,i] = 1 if n_states >= 2 else 0
            rows.append({label_x:x, label_y:y, "bistable":bool(A[j,i])})
        print(f"[{tag}] row {j+1}/{len(grid_y)} done.")
    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(OUT, f"{tag}_grid.csv"), index=False)
    plt.figure(figsize=(6.6,5.2))
    plt.imshow(A, origin="lower",
               extent=[grid_x[0], grid_x[-1], grid_y[0], grid_y[-1]],
               aspect="auto", cmap="Blues")
    plt.colorbar(label="bistable (1) / mono (0)")
    plt.xlabel(label_x); plt.ylabel(label_y)
    plt.title(f"Bistability area: {tag} (fraction={A.mean():.2f})")
    plt.tight_layout(); plt.savefig(os.path.join(OUT, f"{tag}_heatmap.png")), plt.close()
    print(f"[{tag}] bistable fraction = {A.mean():.3f}")

# 1) (d, p_high)
dx = np.linspace(0.7*p0[6], 1.6*p0[6], 20)
py = np.linspace(0.6*p0[11], 1.4*p0[11], 20)
def ch1(p,x,y): p[6]=x; p[11]=y; return p
scan_pair(ch1, dx, py, "d", "p_high", "d_vs_p_high")

# 2) (gamma, p_high)
gx = np.linspace(0.6*p0[4], 1.6*p0[4], 20)
def ch2(p,x,y): p[4]=x; p[11]=y; return p
scan_pair(ch2, gx, py, "gamma", "p_high", "gamma_vs_p_high")

# 3) (N_HILL, KQ) – treat as integers; here: 2..6 and 60..140
nx = np.array([2,3,4,5,6], int)
kqx = np.array([60,80,100,120,140], int)
def ch3(p,x,y): return p  # p unchanged; we pass x,y via bistability_by_multistart kwargs
def scan_pair_hill_kq(nx, kqx, tag):
    A = np.zeros((len(kqx), len(nx)))
    rows=[]
    for j,KQv in enumerate(kqx):
        for i,nv in enumerate(nx):
            n_states, _ = bistability_by_multistart(p0.copy(), n_inits=60,
                                                    KQ_local=KQv, N_HILL_local=nv)
            A[j,i] = 1 if n_states >= 2 else 0
            rows.append({"N_HILL":int(nv), "KQ":int(KQv), "bistable":bool(A[j,i])})
        print(f"[{tag}] row {j+1}/{len(kqx)} done.")
    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(OUT, f"{tag}_grid.csv"), index=False)
    plt.figure(figsize=(6.6,5.2))
    plt.imshow(A, origin="lower", extent=[nx[0],nx[-1],kqx[0],kqx[-1]],
               aspect="auto", cmap="Blues")
    plt.colorbar(label="bistable (1) / mono (0)")
    plt.xlabel("N_HILL"); plt.ylabel("KQ")
    plt.title(f"Bistability area: {tag} (fraction={A.mean():.2f})")
    plt.tight_layout(); plt.savefig(os.path.join(OUT, f"{tag}_heatmap.png")), plt.close()
    print(f"[{tag}] bistable fraction = {A.mean():.3f}")

scan_pair_hill_kq(nx, kqx, "N_HILL_vs_KQ")


In [None]:
# 02_mw_attractor_counting.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import load_params, relax_to_ss

OUT = "results/attractor_counting"; os.makedirs(OUT, exist_ok=True)
p = load_params()

def many_inits(pvec, n=200, T=1200, tag="baseline"):
    rng = np.random.default_rng(1)
    box = ((0.05,0.30),(0.05,0.30),(0.10,0.95),(0.05,0.20),(0.0,1.0))
    def rand_in(box): return np.array([rng.uniform(a,b) for (a,b) in box], float)
    endpoints=[]; H_trajs=[]
    for k in range(n):
        y0 = rand_in(box)
        yss, sol = relax_to_ss(pvec, y0, T=T)
        endpoints.append(yss)
        H_trajs.append(sol.y[2,:])
    EP = np.array(endpoints); Hs = EP[:,2]
    H_sorted = np.sort(Hs); gaps = np.diff(H_sorted)
    n_states = 1 + np.sum(gaps >= 0.03)
    pd.DataFrame(EP, columns=["P","C","H","B","q"]).to_csv(os.path.join(OUT, f"{tag}_endpoints.csv"), index=False)
    # small panel of random trajectories
    plt.figure(figsize=(7,4))
    for i in np.linspace(0, len(H_trajs)-1, 30, dtype=int):
        plt.plot(H_trajs[i], alpha=0.4)
    plt.xlabel("time step"); plt.ylabel("H(t)")
    plt.title(f"Attractor test ({tag}) | distinct≈{n_states}")
    plt.tight_layout(); plt.savefig(os.path.join(OUT, f"{tag}_trajectories.png")), plt.close()
    print(f"[{tag}] distinct steady states (by gaps in H): ~{int(n_states)}")
    return int(n_states)

# Baseline
many_inits(p, n=200, T=1200, tag="baseline")

# Slight offsets around baseline (±10%)
p1 = p.copy(); p1[6] = 0.9*p[6]   # d down
p2 = p.copy(); p2[6] = 1.1*p[6]   # d up
many_inits(p1, n=150, tag="d_minus10")
many_inits(p2, n=150, tag="d_plus10")


In [None]:
# 03_mw_sensitivity_curves.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import bistability_by_multistart, load_params

OUT = "results/sensitivity_curves"; os.makedirs(OUT, exist_ok=True)
p0 = load_params()

def frac_bistable_vs_param(idx, vals, label, tag):
    fracs=[]
    for v in vals:
        p = p0.copy(); p[idx] = v
        n_states, _ = bistability_by_multistart(p, n_inits=80)
        fracs.append(1.0 if n_states>=2 else 0.0)
    df = pd.DataFrame({label: vals, "bistable": fracs})
    df.to_csv(os.path.join(OUT, f"{tag}.csv"), index=False)
    plt.figure(figsize=(6.8,4.2))
    plt.plot(vals, fracs, "-o")
    plt.xlabel(label); plt.ylabel("bistable fraction")
    plt.ylim(-0.05, 1.05); plt.grid(True, ls=":")
    plt.title(tag)
    plt.tight_layout(); plt.savefig(os.path.join(OUT, f"{tag}.png")), plt.close()
    print(f"[{tag}] mean={np.mean(fracs):.3f}")

# d (index 6)
d_vals = np.linspace(0.6*p0[6], 1.6*p0[6], 24)
frac_bistable_vs_param(6, d_vals, "d", "bistability_vs_d")

# gamma (index 4)
g_vals = np.linspace(0.6*p0[4], 1.6*p0[4], 24)
frac_bistable_vs_param(4, g_vals, "gamma", "bistability_vs_gamma")


In [None]:
# 04_mw_gamma_heterogeneity.py
import os, numpy as np, pandas as pd
from mw_model_core import load_params, bistability_by_multistart

OUT = "results/gamma_heterogeneity"; os.makedirs(OUT, exist_ok=True)
p0 = load_params()
rng = np.random.default_rng(7)

def experiment(mean_gamma, sd_gamma, trials=20):
    flags=[]
    for _ in range(trials):
        p = p0.copy()
        g_draw = float(np.clip(rng.normal(mean_gamma, sd_gamma), 0.2, 3.0))
        p[4] = g_draw
        n_states, _ = bistability_by_multistart(p, n_inits=60)
        flags.append(1 if n_states>=2 else 0)
    return np.mean(flags)

rows=[]
means = np.linspace(0.6*p0[4], 1.6*p0[4], 8)
sds   = np.linspace(0.0, 0.4, 6)
for m in means:
    for s in sds:
        frac = experiment(m, s, trials=30)
        rows.append({"mean_gamma":m, "sd_gamma":s, "bistable_frac":frac})
df = pd.DataFrame(rows)
df.to_csv(os.path.join(OUT, "gamma_heterogeneity_grid.csv"), index=False)
print("Saved ->", OUT)


In [None]:
# 05_mw_epistasis_analog.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import load_params, bistability_by_multistart

OUT = "results/epistasis_analog"; os.makedirs(OUT, exist_ok=True)
p0 = load_params()

# We'll vary: rHP (index 1) for Producer;
# and add an artificial "competitor host coupling" by shifting r0C (index 2) with k*H (sign analog).
# Implement sign analog by altering effective r0C during integration? Simpler: use a static proxy:
#   We emulate 'sign' by adding delta to r0C when H is high at steady state via a two-stage test.
#   For area maps, we approximate by direct parameter offsets to r0C: neutral (0), beneficial (+δ), costly (-δ).

def area_bistable_over_grid(deltaC, rHP_scale_grid, tag):
    # deltaC in {-dlt, 0, +dlt} approximates sign epistasis
    rows=[]; cnt=0; total=0
    for s in rHP_scale_grid:
        for ph in np.linspace(0.7*p0[11], 1.4*p0[11], 12):  # p_high axis
            p = p0.copy()
            p[1] = s * p0[1]      # rHP scaled
            p[11]= ph
            p[2] = p0[2] + deltaC # competitor base growth shift (proxy for sign)
            n_states, _ = bistability_by_multistart(p, n_inits=60)
            is_bi = (n_states >= 2)
            rows.append({"rHP_scale":s, "p_high":ph, "bistable":bool(is_bi)})
            cnt += int(is_bi); total += 1
    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(OUT, f"{tag}.csv"), index=False)
    frac = cnt/total if total>0 else 0.0
    print(f"[{tag}] bistable fraction={frac:.3f}")
    # quick plot
    S = sorted(df["rHP_scale"].unique())
    P = sorted(df["p_high"].unique())
    W = np.zeros((len(P), len(S)))
    for j,ph in enumerate(P):
        for i,s in enumerate(S):
            W[j,i] = 1.0 if bool(df[(df["rHP_scale"]==s)&(df["p_high"]==ph)]["bistable"].iloc[0]) else 0.0
    plt.figure(figsize=(6.4,5.0))
    plt.imshow(W, origin="lower", extent=[min(S), max(S), min(P), max(P)], aspect="auto", cmap="Blues")
    plt.colorbar(label="bistable")
    plt.xlabel("rHP scale"); plt.ylabel("p_high")
    plt.title(tag + f" (fraction={frac:.2f})")
    plt.tight_layout(); plt.savefig(os.path.join(OUT, f"{tag}.png")), plt.close()

rHP_scales = np.linspace(0.5, 1.8, 12)
dlt = 0.05

area_bistable_over_grid( 0.0, rHP_scales, "magnitude_only")
area_bistable_over_grid(+dlt, rHP_scales, "sign_beneficial_for_competitor")
area_bistable_over_grid(-dlt, rHP_scales, "sign_costly_for_competitor")


In [None]:
# 06_mw_metacommunity_mosaic.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import load_params, relax_to_ss

OUT = "results/metacommunity"; os.makedirs(OUT, exist_ok=True)
p = load_params()
rng = np.random.default_rng(11)

# Grid of patches
G = (12, 12)  # 12x12 patches
H_final = np.zeros(G)

for i in range(G[0]):
    for j in range(G[1]):
        # small jitter around two typical initials to populate both basins
        if (i+j) % 2 == 0:
            H0 = 0.25 + 0.05*rng.normal()
            q0 = 0.8  + 0.05*rng.normal()
        else:
            H0 = 0.75 + 0.05*rng.normal()
            q0 = 0.2  + 0.05*rng.normal()
        y0 = np.array([0.12,0.12, np.clip(H0,0.1,0.95), 0.10, np.clip(q0,0,1)], float)
        yss, _ = relax_to_ss(p, y0, T=1100)
        H_final[i,j] = yss[2]

np.save(os.path.join(OUT,"H_mosaic.npy"), H_final)
plt.figure(figsize=(6.2,5.2))
plt.imshow(H_final, origin="lower", cmap="viridis", vmin=0.1, vmax=1.0)
plt.colorbar(label="final H*")
plt.title("Metacommunity mosaic (independent patches)")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"mosaic.png")), plt.close()

print("Saved ->", OUT)


In [None]:
# 08_mw_bootstrap_robustness.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from mw_model_core import load_params, bistability_by_multistart

OUT = "results/bootstrap_robustness"; os.makedirs(OUT, exist_ok=True)
p0 = load_params()
rng = np.random.default_rng(23)

def area_fraction_dp(p_base, n_x=10, n_y=10, jitter=None):
    dx = np.linspace(0.8*p_base[6], 1.3*p_base[6], n_x)     # d
    ph = np.linspace(0.7*p_base[11], 1.3*p_base[11], n_y)   # p_high
    cnt=0; tot=0
    for x in dx:
        for y in ph:
            p = p_base.copy(); p[6]=x; p[11]=y
            n_states,_ = bistability_by_multistart(p, n_inits=50)
            cnt += int(n_states>=2); tot += 1
    return cnt / tot

def jitter_params(p, cv=0.08):
    """ multiplicative noise on selected entries: g(7), u(8), K_u(9), p_high(11) """
    idx = [7,8,9,11]
    p2 = p.copy()
    for k in idx:
        p2[k] = p2[k] * float(np.exp(rng.normal(0, cv)))
    return p2

B = 40
vals=[]
for b in range(B):
    pB = jitter_params(p0, cv=0.10)  # 10% lognormal CV
    frac = area_fraction_dp(pB, n_x=9, n_y=9)
    vals.append(frac)
    print(f"[bootstrap] {b+1}/{B}: {frac:.3f}")

pd.DataFrame({"area_fraction":vals}).to_csv(os.path.join(OUT, "bootstrap_area_fractions.csv"), index=False)
plt.figure(figsize=(5.0,4.2))
plt.plot(np.arange(1,B+1), vals, "-o", alpha=0.7)
plt.axhline(np.mean(vals), ls="--", c="gray", label=f"mean={np.mean(vals):.2f}")
plt.xlabel("bootstrap draw"); plt.ylabel("bistable area fraction")
plt.legend(); plt.grid(True, ls=":")
plt.tight_layout(); plt.savefig(os.path.join(OUT,"bootstrap_trace.png")), plt.close()
print("Saved ->", OUT)
