In [None]:
# calibrate_with_bistability_penalty.py
import os, numpy as np, pandas as pd
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

INPATH = "timeseries/combined_scfas_table_scored.csv"
OUTDIR = "mw_fit_out_bistable_enhanced"; os.makedirs(OUTDIR, exist_ok=True)

# === Target ranges with focus on bistability ===
TARGETS = {
    "u": 0.77,     "sd_u": 0.08,
    "K_u": 0.18,   "sd_K_u": 0.06,
    "gamma": 1.25, "sd_gamma": 0.20,
    "p_high": 3.20,"sd_p_high": 0.40,
    "rHP": 0.0344, "sd_rHP": 0.03,
}

H_COLS=["H_proxy_meta_smooth","H_proxy_meta"]
SCFA=["butyrate"]; MIN_ROWS=4
KQ=80.0; HILL_N=3
PENALTY=1e3

# Wider priors that favor bistability
PRIOR={"r0P":(0.32,0.12),"rHP":(0.07,0.06),"r0C":(0.28,0.12),"K_M":(1.0,0.4),
       "gamma":(0.85,0.35),"c":(0.12,0.08),\"d":(0.12,0.08),"g":(0.60,0.40),
       "u":(0.60,0.25),"K_u":(0.20,0.12),"p_low":(0.12,0.08),"p_high":(2.20,0.90),
       "H_on":(0.55,0.12),"H_off":(0.75,0.12),"tau_q":(5.0,3.0),"K_B":(0.20,0.12)}

# Bounds with wider hysteresis potential
def band(c, w_lo, w_hi): return c - w_lo, c + w_hi
uL,uU   = band(TARGETS["u"], 0.15, 0.15)
KuL,KuU = band(TARGETS["K_u"], 0.10, 0.10)
gaL,gaU = band(TARGETS["gamma"], 0.35, 0.30)
pHL,pHU = band(TARGETS["p_high"], 0.80, 0.60)

# Expanded bounds focusing on bistability regions
LBg=np.array([0.10, 0.00, 0.10, 0.40, gaL, 0.03, 0.02, 0.15, uL, KuL, 0.03, pHL, 0.45, 0.70, 2.0, 0.05])
UBg=np.array([0.60, 0.25, 0.60, 2.00, gaU, 0.25, 0.25, 1.50, uU, KuU, 0.30, pHU, 0.75, 0.95, 15.0, 0.40])

# Starting points biased toward bistability
STARTING_POINTS = [
    # Conservative bistability
    [0.32, 0.05, 0.28, 1.0, 1.10, 0.12, 0.08, 0.70, 0.77, 0.18, 0.12, 3.20, 0.55, 0.80, 5.5, 0.20],
    # Wider hysteresis
    [0.45, 0.02, 0.35, 1.2, 1.50, 0.08, 0.05, 0.90, 0.65, 0.15, 0.08, 3.50, 0.50, 0.85, 8.0, 0.15],
    # Strong memory effect
    [0.28, 0.08, 0.25, 0.8, 0.90, 0.15, 0.12, 0.50, 0.85, 0.22, 0.15, 2.80, 0.60, 0.90, 12.0, 0.25]
]

# --- Data prep (same as before) ---
df=pd.read_csv(INPATH)
Hcol=next((c for c in H_COLS if c in df.columns), None)
if Hcol is None: raise ValueError("Need H proxy col")
for c in SCFA:
    if c not in df.columns: raise ValueError(f"Missing {c}")
df=df[["subject_id","sample_id",Hcol]+SCFA].dropna(subset=["subject_id","sample_id"]).copy()
df["t_idx"]=df.groupby("subject_id").cumcount().astype(float)

def robust_z(s):
    x=s.astype(float).to_numpy(); m=np.isfinite(x)
    if m.sum()==0: import pandas as pd; return pd.Series(np.zeros_like(x), index=s.index)
    xm=x[m]; med=np.median(xm); mad=np.median(np.abs(xm-med))
    scale=mad if mad>1e-9 else (np.percentile(xm,75)-np.percentile(xm,25) or np.std(xm)+1e-9)
    import pandas as pd; return pd.Series((x-med)/(scale+1e-9), index=s.index)

df["B_z"]=df.groupby("subject_id")[SCFA[0]].transform(robust_z)
df["H_obs"]=df[Hcol].clip(0,1)

subs=[]
for sid,sub in df.groupby("subject_id"):
    sub=sub.sort_values("t_idx")
    if len(sub)<MIN_ROWS: continue
    t=sub["t_idx"].to_numpy(float)
    B=sub["B_z"].to_numpy(float); H=sub["H_obs"].to_numpy(float)
    mB=np.isfinite(B); mH=np.isfinite(H)
    if mB.sum()<3 or mH.sum()<3: continue
    def first(a,d):
        import numpy as np
        a=np.asarray(a,float); idx=np.where(np.isfinite(a))[0]
        return float(a[idx[0]]) if len(idx) else float(d)
    subs.append({"sid":sid,"t":t,"B":B,"H":H,"maskB":mB,"maskH":mH,
                 "nB":int(mB.sum()),"nH":int(mH.sum()),
                 "H0":float(np.clip(first(H,0.6),0,1)),
                 "B0":float(max(0.05, first(B,0.1)))})
if not subs: raise RuntimeError("no subjects")

# --- Model (same as before) ---
def q_inf(H,q,H_on,H_off,KQ):
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(t,y,p):
    P,C,H,B,q=y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = p
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    n = 3
    dH = gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off,KQ) - q)/tau
    return [dP,dC,H,dB,dq]

def simulate(ts,y0,p):
    try:
        sol=solve_ivp(lambda t,z: rhs(t,z,p),(ts[0],ts[-1]),y0,t_eval=ts,
                      rtol=1e-6,atol=1e-8,max_step=0.5)
        if not sol.success:
            T=len(ts); return np.vstack([np.full(T,np.nan)]*5)
        return sol.y
    except Exception:
        T=len(ts); return np.vstack([np.full(T,np.nan)]*5)

# per-subject obs maps
x0s=[]; LBs=[]; UBs=[]
for _ in subs: x0s += [1.0, 0.0, 1.0]; LBs += [0.6, -0.2, 0.8]; UBs += [1.6, 0.2, 1.2]

NAMES=["r0P","rHP","r0C","K_M","gamma","c","d","g","u","K_u","p_low","p_high","H_on","H_off","tau_q","K_B"]

def unpack(x):
    gpar=x[:len(NAMES)]
    triples=np.split(x[len(NAMES):], len(subs))
    return gpar, triples

W_B,W_H=0.6,1.2
WT=4.0

def bistability_penalty(gpar):
    """Encourage wider hysteresis window and bistability-friendly parameters"""
    H_on, H_off = gpar[12], gpar[13]
    hysteresis_width = H_off - H_on
    
    penalties = []
    
    # Penalize too narrow hysteresis (target width ~0.2)
    if hysteresis_width < 0.15:
        penalties.append((0.15 - hysteresis_width) * 50.0)
    
    # Penalize if thresholds are too extreme
    if H_on < 0.4 or H_on > 0.7:
        penalties.append(10.0)
    if H_off < 0.7 or H_off > 0.95:
        penalties.append(10.0)
        
    # Encourage reasonable memory timescale (2-10 hours)
    tau_q = gpar[14]
    if tau_q < 2.0 or tau_q > 15.0:
        penalties.append(5.0)
    
    return sum(penalties) if penalties else 0.0

def residuals(x):
    gpar, triples = unpack(x)
    
    # Hard constraint: H_off must be > H_on
    if not (gpar[13] > gpar[12]): 
        return np.ones(1000)*PENALTY
    
    res=[]
    for S,tr in zip(subs, triples):
        aB,b0,b1=tr
        ts=S["t"]; H0=np.clip(S["H0"],0,1); B0=max(0.05,S["B0"])
        P0=C0=0.12; q0=1.0 if H0 < 0.5*(gpar[12]+gpar[13]) else 0.0
        y0=[P0,C0,H0,B0,q0]
        Y=simulate(ts,y0,gpar)
        if np.any(~np.isfinite(Y)):
            res += [W_B*np.full(S["nB"],PENALTY), W_H*np.full(S["nH"],PENALTY)]
            continue
        P,C,H,B,q=Y
        Bh=aB*B; Hh=np.clip(b0 + b1*H,0,1)
        res += [W_B*(Bh[S["maskB"]] - S["B"][S["maskB"]]),
                W_H*(Hh[S["maskH"]] - S["H"][S["maskH"]])]
    
    # Baseline priors
    idx={nm:i for i,nm in enumerate(NAMES)}
    for nm,(mu,sd) in PRIOR.items():
        res.append(np.array([(gpar[idx[nm]]-mu)/(sd+1e-9)]))
    
    # Target priors
    res.append(WT*np.array([(gpar[idx["u"]]      - TARGETS["u"])/     (TARGETS["sd_u"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["K_u"]]    - TARGETS["K_u"])/   (TARGETS["sd_K_u"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["gamma"]]  - TARGETS["gamma"])/ (TARGETS["sd_gamma"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["p_high"]] - TARGETS["p_high"])/(TARGETS["sd_p_high"]+1e-9)]))
    res.append(WT*np.array([(gpar[idx["rHP"]]    - TARGETS["rHP"])/   (TARGETS["sd_rHP"]+1e-9)]))
    
    # Bistability penalty
    bistab_pen = bistability_penalty(gpar)
    if bistab_pen > 0:
        res.append(np.array([bistab_pen]))
    
    return np.concatenate(res)

# Try multiple starting points
best_fit = None
best_cost = np.inf

for i, start_point in enumerate(STARTING_POINTS):
    print(f"\n=== Trying starting point {i+1}/{len(STARTING_POINTS)} ===")
    
    x0 = np.concatenate([start_point, np.array(x0s, float)])
    LB = np.concatenate([LBg, np.array(LBs, float)])
    UB = np.concatenate([UBg, np.array(UBs, float)])
    
    fit = least_squares(residuals, x0, bounds=(LB, UB), verbose=1, 
                       max_nfev=1000, loss="soft_l1", f_scale=1.0)
    
    # Calculate hysteresis width for this fit
    gpar_hat, _ = unpack(fit.x)
    hyst_width = gpar_hat[13] - gpar_hat[12]
    
    print(f"Cost: {fit.cost:.4f}, Hysteresis width: {hyst_width:.3f}")
    
    if fit.cost < best_cost and hyst_width >= 0.1:  # Prefer solutions with reasonable hysteresis
        best_cost = fit.cost
        best_fit = fit
        print(f"New best fit with hysteresis width: {hyst_width:.3f}")

if best_fit is None:
    print("No satisfactory fit found, using best overall...")
    # Fall back to overall best regardless of hysteresis
    for i, start_point in enumerate(STARTING_POINTS):
        x0 = np.concatenate([start_point, np.array(x0s, float)])
        fit = least_squares(residuals, x0, bounds=(LB, UB), verbose=0, max_nfev=800)
        if fit.cost < best_cost:
            best_cost = fit.cost
            best_fit = fit

gpar_hat,_ = unpack(best_fit.x)
final_hyst_width = gpar_hat[13] - gpar_hat[12]

print(f"\n=== FINAL RESULTS ===")
print(f"Best cost: {best_cost:.4f}")
print(f"Hysteresis width: {final_hyst_width:.3f}")
print(f"H_on: {gpar_hat[12]:.3f}, H_off: {gpar_hat[13]:.3f}")
print("All parameters:", dict(zip(NAMES, gpar_hat)))

pd.Series(gpar_hat, index=NAMES).to_csv(os.path.join(OUTDIR,"fitted_global_params.csv"), header=False)

print("Saved:", OUTDIR)

In [None]:
# basins_enhanced_sampling.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

FIT = "mw_fit_out_bistable_enhanced/fitted_global_params.csv"  # Use your best fit
OUT = "mw_basins_enhanced"; os.makedirs(OUT, exist_ok=True)

# Load fitted globals
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[6])

print(f"Loaded parameters: d_fit = {d_fit:.4f}")
print(f"H_on = {p[12]:.3f}, H_off = {p[13]:.3f}")
print(f"Hysteresis width: {p[13] - p[12]:.3f}")

def q_inf(H,q,H_on,H_off):
    KQ = 80  # High gain for sharp switching
    th=(1-q)*H_on + q*H_off
    return 1.0/(1.0 + np.exp(-KQ*(H - th)))

def rhs(y,pvec,d_override=None):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    if d_override is not None:
        d = d_override
    pB = pL + (pH - pL)*np.clip(q,0,1)
    # ecology
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    # butyrate & host
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    n = 4  # Hill coefficient
    dH = gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H
    dq = (q_inf(H,q,H_on,H_off) - q)/tau
    return np.array([dP,dC,dH,dB,dq], float)

def relax(y0, T=500, d_val=None):
    """Relax to steady state with enhanced stability checking"""
    if d_val is None:
        d_val = d_fit
        
    try:
        sol = solve_ivp(lambda t,z: rhs(z,p,d_val), (0,T), y0, 
                       t_eval=np.linspace(0,T,1000), rtol=1e-6, atol=1e-8, max_step=0.5)
        if sol.success:
            # Check if we've reached steady state
            y_final = sol.y[:,-1]
            y_mid = sol.y[:,-100]  # Check 10% earlier
            change = np.linalg.norm(y_final - y_mid)
            
            if change < 1e-4:  # Stable
                return y_final
            else:
                # Try longer integration
                sol_long = solve_ivp(lambda t,z: rhs(z,p,d_val), (0,T*2), y0, 
                                   t_eval=[0, T*2], rtol=1e-6, atol=1e-8)
                return sol_long.y[:,-1] if sol_long.success else y_final
        else:
            return y0  # Return initial state if integration fails
    except:
        return y0  # Return initial state on error

# ---- Enhanced Basin Sampling ----
print("Computing enhanced basins of attraction...")

# Much denser sampling for clearer boundaries
Hs = np.linspace(0.05, 0.98, 80)   # Increased from 19 to 80
qs = np.linspace(0.0, 1.0, 80)     # Increased from 19 to 80

Z_H = np.zeros((len(Hs), len(qs)))
Z_B = np.zeros((len(Hs), len(qs))) 
Z_q = np.zeros((len(Hs), len(qs)))
Z_state = np.zeros((len(Hs), len(qs)))  # 0 = dysbiotic, 1 = healthy, 0.5 = intermediate

# Progress tracking
total = len(Hs) * len(qs)
completed = 0

for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        # Systematic initial conditions
        y0 = np.array([0.12, 0.12, H0, 0.10, q0], float)
        yss = relax(y0, T=600)  # Longer integration for stability
        
        Z_H[i,j] = yss[2]  # Final H
        Z_B[i,j] = yss[3]  # Final B
        Z_q[i,j] = yss[4]  # Final q
        
        # Classify state
        if yss[2] > 0.7:  # Healthy state
            Z_state[i,j] = 1.0
        elif yss[2] < 0.3:  # Dysbiotic state  
            Z_state[i,j] = 0.0
        else:  # Intermediate
            Z_state[i,j] = 0.5
            
        completed += 1
        if completed % 500 == 0:
            print(f"Progress: {completed}/{total} ({completed/total*100:.1f}%)")

# Save enhanced basin data
basin_data = []
for i,H0 in enumerate(Hs):
    for j,q0 in enumerate(qs):
        basin_data.append({
            'H0': H0, 'q0': q0, 
            'H_final': Z_H[i,j], 'B_final': Z_B[i,j], 'q_final': Z_q[i,j],
            'state': Z_state[i,j]
        })

pd.DataFrame(basin_data).to_csv(os.path.join(OUT, "enhanced_basins_detailed.csv"), index=False)

# ---- Enhanced Visualization ----
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Final H state
im1 = axes[0,0].imshow(Z_H, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0.1, vmax=0.9, cmap="RdYlBu_r")
axes[0,0].set_xlabel("Initial q"); axes[0,0].set_ylabel("Initial H")
axes[0,0].set_title("Final Host Health (H)")
plt.colorbar(im1, ax=axes[0,0])

# Final B state
im2 = axes[0,1].imshow(Z_B, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", cmap="viridis")
axes[0,1].set_xlabel("Initial q"); axes[0,1].set_ylabel("Initial H") 
axes[0,1].set_title("Final Butyrate (B)")
plt.colorbar(im2, ax=axes[0,1])

# State classification
im3 = axes[0,2].imshow(Z_state, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", cmap="RdYlBu", vmin=0, vmax=1)
axes[0,2].set_xlabel("Initial q"); axes[0,2].set_ylabel("Initial H")
axes[0,2].set_title("State Classification\n(Red=Dysbiotic, Blue=Healthy)")
plt.colorbar(im3, ax=axes[0,2])

# Final q state
im4 = axes[1,0].imshow(Z_q, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
           aspect="auto", vmin=0, vmax=1, cmap="coolwarm")
axes[1,0].set_xlabel("Initial q"); axes[1,0].set_ylabel("Initial H")
axes[1,0].set_title("Final Memory (q)")
plt.colorbar(im4, ax=axes[1,0])

# Basin boundary detection
from skimage import measure
contours = measure.find_contours(Z_state, 0.5)

axes[1,1].imshow(Z_H, origin="lower", extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
                aspect="auto", vmin=0.1, vmax=0.9, cmap="RdYlBu_r", alpha=0.7)
for contour in contours:
    # Convert contour coordinates to H,q space
    q_contour = qs[0] + (qs[-1]-qs[0]) * contour[:,1] / Z_H.shape[1]
    H_contour = Hs[0] + (Hs[-1]-Hs[0]) * (Z_H.shape[0]-contour[:,0]) / Z_H.shape[0]
    axes[1,1].plot(q_contour, H_contour, 'k-', linewidth=2, label='Basin Boundary')
axes[1,1].set_xlabel("Initial q"); axes[1,1].set_ylabel("Initial H")
axes[1,1].set_title("Basins with Boundary")

# Statistics
healthy_basin = np.sum(Z_state > 0.75) / Z_state.size
dysbiotic_basin = np.sum(Z_state < 0.25) / Z_state.size
intermediate = 1 - healthy_basin - dysbiotic_basin

axes[1,2].bar(['Dysbiotic', 'Intermediate', 'Healthy'], 
             [dysbiotic_basin, intermediate, healthy_basin], 
             color=['red', 'yellow', 'blue'])
axes[1,2].set_ylabel('Fraction of Phase Space')
axes[1,2].set_title('Basin Sizes')
axes[1,2].grid(True, alpha=0.3)

# Add text summary
summary_text = f"""Basin Analysis Summary:
- Healthy basin: {healthy_basin:.1%}
- Dysbiotic basin: {dysbiotic_basin:.1%}  
- Intermediate: {intermediate:.1%}
- Hysteresis: H_on={p[12]:.3f}, H_off={p[13]:.3f}
- d = {d_fit:.4f}"""

axes[1,2].text(0.5, -0.3, summary_text, transform=axes[1,2].transAxes, 
              fontsize=10, ha='center', va='top')

plt.suptitle(f"Enhanced Basin Analysis | Hysteresis Width: {p[13]-p[12]:.3f}", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(OUT, "enhanced_basins_analysis.png"), dpi=150, bbox_inches='tight')
plt.close()

print(f"\n=== BASIN ANALYSIS COMPLETE ===")
print(f"Healthy basin size: {healthy_basin:.1%}")
print(f"Dysbiotic basin size: {dysbiotic_basin:.1%}")
print(f"Intermediate region: {intermediate:.1%}")
print(f"Saved to: {OUT}")

In [None]:
# intervention_analysis_enhanced.py
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d

FIT = "mw_fit_out_bistable_enhanced/fitted_global_params.csv"
OUT = "mw_interventions_enhanced"; os.makedirs(OUT, exist_ok=True)

# Load parameters
g = pd.read_csv(FIT, index_col=0, header=None).squeeze("columns")
p = np.array([float(g[k]) for k in g.index.values], float)
d_fit = float(p[6])

print(f"Baseline d: {d_fit:.4f}")

def rhs(y, pvec, d_val=None, intervention=None, intervention_strength=0.0):
    P,C,H,B,q = y
    r0P,rHP,r0C,K_M,gamma,c,d,gH,u,K_u,pL,pH,H_on,H_off,tau,K_B = pvec.copy()
    
    if d_val is not None:
        d = d_val
    
    # Apply interventions
    if intervention == "anti_inflammatory":
        d = d * (1.0 - intervention_strength)  # Reduce inflammation
    elif intervention == "probiotic":
        # Increase butyrate production
        pL = pL * (1.0 + intervention_strength * 0.5)
        pH = pH * (1.0 + intervention_strength * 0.3)
    elif intervention == "postbiotic":  
        # Direct butyrate supplementation
        B_supplement = intervention_strength * 2.0
        # Add to butyrate dynamics
        pass  # Will handle in dB equation
    
    pB = pL + (pH - pL)*np.clip(q,0,1)
    dP = P*( r0P + rHP*H - c*pB - (P + gamma*C)/K_M )
    dC = C*( r0C           -        (C + gamma*P)/K_M )
    
    uptake = u*H*B/(K_u + B + 1e-9)
    dB = pB*P - uptake
    if intervention == "postbiotic":
        dB += intervention_strength * 0.1  # Direct butyrate supplementation
    
    n = 4
    dH = gH*(B**n/(K_B**n + B**n))*(1 - H) - d*H
    KQ = 80
    th = (1-q)*H_on + q*H_off
    q_inf = 1.0/(1.0 + np.exp(-KQ*(H - th)))
    dq = (q_inf - q)/tau
    
    return np.array([dP,dC,dH,dB,dq], float)

def calculate_intervention_metrics(H_trajectory, t_points, baseline_H, healthy_threshold=0.7):
    """Calculate comprehensive intervention efficacy metrics"""
    H_trajectory = np.array(H_trajectory)
    
    metrics = {}
    
    # 1. Recovery time (time to reach healthy state)
    healthy_mask = H_trajectory > healthy_threshold
    if np.any(healthy_mask):
        recovery_idx = np.argmax(healthy_mask)
        metrics['recovery_time'] = t_points[recovery_idx]
    else:
        metrics['recovery_time'] = np.inf
    
    # 2. Maximum effect
    metrics['max_effect'] = np.max(H_trajectory) - baseline_H
    
    # 3. Sustained effect (last 20% of simulation)
    sustained_window = int(0.2 * len(H_trajectory))
    metrics['sustained_effect'] = np.mean(H_trajectory[-sustained_window:]) - baseline_H
    
    # 4. Area under curve improvement
    baseline_auc = baseline_H * (t_points[-1] - t_points[0])
    intervention_auc = np.trapz(H_trajectory, t_points)
    metrics['auc_improvement'] = intervention_auc - baseline_auc
    
    # 5. Final state improvement
    metrics['final_improvement'] = H_trajectory[-1] - baseline_H
    
    # 6. Resilience (ability to stay in healthy state once reached)
    if metrics['recovery_time'] < np.inf:
        time_in_healthy = np.sum(healthy_mask) * (t_points[1] - t_points[0])
        total_time = t_points[-1] - metrics['recovery_time']
        metrics['resilience'] = time_in_healthy / total_time if total_time > 0 else 0
    else:
        metrics['resilience'] = 0
    
    return metrics

def simulate_intervention(intervention_type, strength, duration=100, y0=None, d_val=None):
    """Simulate an intervention with enhanced tracking"""
    if d_val is None:
        d_val = d_fit
    
    if y0 is None:
        # Start from dysbiotic state
        y0 = np.array([0.12, 0.12, 0.1, 0.05, 1.0])  # Low H, high q
    
    t_points = np.linspace(0, duration, 1000)
    H_values = []
    
    # Baseline phase (first 20% without intervention)
    baseline_duration = duration * 0.2
    t_baseline = np.linspace(0, baseline_duration, 200)
    
    try:
        # Baseline simulation
        sol_baseline = solve_ivp(
            lambda t,y: rhs(y, p, d_val), 
            (0, baseline_duration), y0, t_eval=t_baseline,
            rtol=1e-6, atol=1e-8, max_step=0.5
        )
        
        if sol_baseline.success:
            y_baseline = sol_baseline.y[:,-1]
            H_baseline = np.mean(sol_baseline.y[2,-50:])  # Average of last 50 points
            H_values.extend(sol_baseline.y[2].tolist())
        else:
            y_baseline = y0
            H_baseline = y0[2]
            H_values.extend([H_baseline] * len(t_baseline))
    except:
        y_baseline = y0
        H_baseline = y0[2]
        H_values.extend([H_baseline] * len(t_baseline))
    
    # Intervention phase
    t_intervention = np.linspace(baseline_duration, duration, 800)
    
    try:
        sol_intervention = solve_ivp(
            lambda t,y: rhs(y, p, d_val, intervention_type, strength),
            (baseline_duration, duration), y_baseline, t_eval=t_intervention,
            rtol=1e-6, atol=1e-8, max_step=0.5
        )
        
        if sol_intervention.success:
            H_values.extend(sol_intervention.y[2].tolist())
        else:
            H_values.extend([H_baseline] * len(t_intervention))
    except:
        H_values.extend([H_baseline] * len(t_intervention))
    
    # Calculate metrics
    all_t_points = np.concatenate([t_baseline, t_intervention])
    metrics = calculate_intervention_metrics(H_values, all_t_points, H_baseline)
    
    return {
        'time': all_t_points,
        'H': H_values,
        'metrics': metrics,
        'baseline_H': H_baseline
    }

# ---- Enhanced Intervention Comparison ----
print("Running enhanced intervention analysis...")

interventions = {
    'anti_inflammatory': [0.3, 0.5, 0.7],  # 30%, 50%, 70% reduction in inflammation
    'probiotic': [0.2, 0.4, 0.6],          # 20%, 40%, 60% increase in production
    'postbiotic': [0.1, 0.2, 0.3]          # Low, medium, high supplementation
}

results = {}
all_metrics = []

# Simulate all interventions
for interv_name, strengths in interventions.items():
    results[interv_name] = {}
    print(f"\nSimulating {interv_name}...")
    
    for strength in strengths:
        print(f"  Strength: {strength:.1f}")
        result = simulate_intervention(interv_name, strength, duration=100)
        results[interv_name][strength] = result
        
        # Store metrics for comparison
        for metric_name, metric_value in result['metrics'].items():
            all_metrics.append({
                'intervention': interv_name,
                'strength': strength,
                'metric': metric_name,
                'value': metric_value,
                'baseline_H': result['baseline_H']
            })

# Convert to DataFrame
metrics_df = pd.DataFrame(all_metrics)
metrics_df.to_csv(os.path.join(OUT, "intervention_metrics_detailed.csv"), index=False)

# ---- Enhanced Visualization ----
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: Time trajectories for medium strength
colors = {'anti_inflammatory': 'blue', 'probiotic': 'green', 'postbiotic': 'red'}
medium_strength = 0.4  # Use medium strength for comparison

for i, (interv_name, strengths) in enumerate(interventions.items()):
    if medium_strength in strengths:
        result = results[interv_name][medium_strength]
        axes[0,0].plot(result['time'], result['H'], 
                      color=colors[interv_name], linewidth=2, label=interv_name)
        
        # Mark recovery time if applicable
        rec_time = result['metrics']['recovery_time']
        if rec_time < np.inf:
            axes[0,0].axvline(rec_time, color=colors[interv_name], linestyle='--', alpha=0.7)

axes[0,0].axhline(0.7, color='black', linestyle=':', label='Healthy threshold')
axes[0,0].set_xlabel('Time')
axes[0,0].set_ylabel('Host Health (H)')
axes[0,0].set_title('Intervention Trajectories (Medium Strength)')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

# Plot 2: Metric comparison heatmap
metric_pivot = metrics_df.pivot_table(
    index=['intervention', 'strength'], 
    columns='metric', 
    values='value'
).reset_index()

# Focus on key metrics for visualization
key_metrics = ['recovery_time', 'max_effect', 'sustained_effect', 'resilience']
plot_data = metric_pivot[['intervention', 'strength'] + key_metrics]

# Normalize for heatmap (except recovery time which we want low)
heatmap_data = plot_data.copy()
for metric in key_metrics:
    if metric == 'recovery_time':
        # Invert so lower recovery time is better
        max_val = heatmap_data[metric].max()
        if max_val == np.inf:
            finite_vals = heatmap_data[metric][heatmap_data[metric] < np.inf]
            max_val = finite_vals.max() if len(finite_vals) > 0 else 100
            heatmap_data[metric] = heatmap_data[metric].replace(np.inf, max_val * 1.1)
        heatmap_data[metric] = max_val - heatmap_data[metric]
    # Normalize to 0-1
    min_val = heatmap_data[metric].min()
    max_val = heatmap_data[metric].max()
    if max_val > min_val:
        heatmap_data[metric] = (heatmap_data[metric] - min_val) / (max_val - min_val)

# Create heatmap
interv_strengths = []
for interv in interventions.keys():
    for strength in interventions[interv]:
        interv_strengths.append((interv, strength))

heatmap_matrix = np.zeros((len(interv_strengths), len(key_metrics)))
for i, (interv, strength) in enumerate(interv_strengths):
    for j, metric in enumerate(key_metrics):
        mask = (heatmap_data['intervention'] == interv) & (heatmap_data['strength'] == strength)
        if mask.any():
            heatmap_matrix[i,j] = heatmap_data.loc[mask, metric].values[0]

im = axes[0,1].imshow(heatmap_matrix, aspect='auto', cmap='viridis')
axes[0,1].set_xticks(range(len(key_metrics)))
axes[0,1].set_xticklabels(key_metrics, rotation=45)
axes[0,1].set_yticks(range(len(interv_strengths)))
axes[0,1].set_yticklabels([f"{i[0]}\n({i[1]})" for i in interv_strengths])
axes[0,1].set_title('Intervention Efficacy Heatmap\n(Higher = Better)')
plt.colorbar(im, ax=axes[0,1])

# Plot 3: Strength vs efficacy
strength_metrics = ['max_effect', 'sustained_effect']
for metric in strength_metrics:
    for interv_name in interventions.keys():
        interv_data = metrics_df[
            (metrics_df['intervention'] == interv_name) & 
            (metrics_df['metric'] == metric)
        ].sort_values('strength')
        if not interv_data.empty:
            axes[1,0].plot(interv_data['strength'], interv_data['value'], 
                          'o-', label=f'{interv_name} ({metric})', 
                          color=colors[interv_name], 
                          linestyle='--' if 'sustained' in metric else '-')

axes[1,0].set_xlabel('Intervention Strength')
axes[1,0].set_ylabel('Effect Size')
axes[1,0].set_title('Dose-Response Relationships')
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

# Plot 4: Overall intervention ranking
# Calculate composite score
composite_scores = []
for interv_name in interventions.keys():
    interv_metrics = metrics_df[metrics_df['intervention'] == interv_name]
    
    # Weight different metrics
    score = 0
    weights = {'recovery_time': -0.3, 'max_effect': 0.25, 
              'sustained_effect': 0.35, 'resilience': 0.1}
    
    for metric, weight in weights.items():
        metric_data = interv_metrics[interv_metrics['metric'] == metric]
        if not metric_data.empty:
            # Use medium strength for ranking
            medium_data = metric_data[metric_data['strength'] == medium_strength]
            if not medium_data.empty:
                value = medium_data['value'].values[0]
                if metric == 'recovery_time' and value == np.inf:
                    value = 1000  # Penalize non-recovery
                score += weight * value
    
    composite_scores.append({'intervention': interv_name, 'composite_score': score})

composite_df = pd.DataFrame(composite_scores).sort_values('composite_score', ascending=False)

axes[1,1].bar(composite_df['intervention'], composite_df['composite_score'], 
             color=[colors[i] for i in composite_df['intervention']])
axes[1,1].set_ylabel('Composite Efficacy Score')
axes[1,1].set_title('Overall Intervention Ranking')
axes[1,1].grid(True, alpha=0.3)

plt.suptitle(f'Enhanced Intervention Analysis | Baseline d = {d_fit:.4f}', fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(OUT, "enhanced_intervention_analysis.png"), dpi=150, bbox_inches='tight')
plt.close()

# ---- Save Summary Report ----
print("\n=== INTERVENTION ANALYSIS COMPLETE ===")

# Create summary table
summary_data = []
for interv_name in interventions.keys():
    medium_result = results[interv_name][medium_strength]
    metrics = medium_result['metrics']
    
    summary_data.append({
        'Intervention': interv_name,
        'Strength': medium_strength,
        'Recovery_Time': f"{metrics['recovery_time']:.1f}" if metrics['recovery_time'] < np.inf else "No recovery",
        'Max_Effect': f"{metrics['max_effect']:.3f}",
        'Sustained_Effect': f"{metrics['sustained_effect']:.3f}",
        'Resilience': f"{metrics['resilience']:.2f}",
        'Final_H': f"{medium_result['H'][-1]:.3f}"
    })

summary_df = pd.DataFrame(summary_data)
print("\nIntervention Efficacy Summary (Medium Strength):")
print(summary_df.to_string(index=False))

summary_df.to_csv(os.path.join(OUT, "intervention_summary.csv"), index=False)

print(f"\nSaved detailed results to: {OUT}")

In [4]:
# paper_analysis_main.py  (5-variable enhanced model version)

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
from scipy.integrate import solve_ivp
from scipy.optimize import root
import numpy.linalg as npl
import warnings
warnings.filterwarnings('ignore')

sns.set_palette("viridis")
plt.rcParams.update({
    'font.size': 12, 'axes.titlesize': 14, 'axes.labelsize': 12,
    'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10,
    'figure.titlesize': 16, 'figure.dpi': 300, 'savefig.dpi': 300,
    'savefig.bbox': 'tight'
})

OUTDIR = "paper_figures"
os.makedirs(OUTDIR, exist_ok=True)

# =============================================================================
# LOAD PARAMETERS (16-param enhanced model)
# =============================================================================
FIT_FILE = "mw_fit_out_bistable_enhanced/fitted_global_params.csv"
g = pd.read_csv(FIT_FILE, index_col=0, header=None).squeeze("columns")
p = np.array([float(g[k]) for k in g.index.values], float)

# Parameter order from basins_enhanced_sampling.py
r0P, rHP, r0C, K_M, gamma, c, d_fit, gH, u, K_u, pL, pH, H_on, H_off, tau, K_B = p

if H_off < H_on:
    print("⚠️ Warning: H_off < H_on detected — swapping for consistency.")
    H_on, H_off = H_off, H_on
    p[12], p[13] = H_on, H_off

print("=== GUT MICROBIOME BISTABILITY ANALYSIS (5-var model) ===")
print(f"Baseline inflammatory rate (d): {d_fit:.4f}")
print(f"H_on = {H_on:.3f}, H_off = {H_off:.3f}, width = {H_off-H_on:.3f}")

# =============================================================================
# DEFINE MODEL
# =============================================================================
def rhs_model(y, pvec, d_val=None):
    """Enhanced 5-variable host–microbiome model."""
    P, C, H, B, q = y
    (r0P, rHP, r0C, K_M, gamma, c, d, gH, u, K_u,
     pL, pH, H_on, H_off, tau, K_B) = pvec
    if d_val is not None:
        d = d_val
    pB = pL + (pH - pL) * np.clip(q, 0, 1)
    uptake = u * H * B / (K_u + B + 1e-9)
    dP = P * (r0P + rHP * H - c * pB - (P + gamma * C) / K_M)
    dC = C * (r0C - (C + gamma * P) / K_M)
    n = 4
    dH = gH * (B**n / (K_B**n + B**n)) * (1 - H) - d * H
    dB = pB * P - uptake
    th = (1 - q) * H_on + q * H_off
    KQ = 80
    q_inf = 1.0 / (1.0 + np.exp(-KQ * (H - th)))
    dq = (q_inf - q) / tau
    return np.array([dP, dC, dH, dB, dq])

# =============================================================================
# FIND EQUILIBRIA / STABILITY
# =============================================================================
def find_equilibria(pvec, d_val, n_seeds=25):
    seeds = []
    # healthy-like
    seeds += [np.array([0.3, 0.05, 0.8, 0.15, 0.1]),
              np.array([0.2, 0.1, 0.9, 0.18, 0.2])]
    # dysbiotic-like
    seeds += [np.array([0.05, 0.3, 0.15, 0.05, 0.9]),
              np.array([0.1, 0.1, 0.25, 0.08, 0.7])]
    # intermediates
    for _ in range(n_seeds):
        seeds.append(np.random.uniform([0,0,0,0,0],[0.5,0.5,1,0.3,1]))
    equilibria=[]
    for seed in seeds:
        try:
            sol = root(lambda y: rhs_model(y,pvec,d_val), seed, method='lm', tol=1e-5)
            if sol.success:
                y_eq=np.clip(sol.x,[0,0,0,0,0],[1,1,1,1,1])
                if not any(np.linalg.norm(y_eq-eq)<1e-3 for eq in equilibria):
                    equilibria.append(y_eq)
        except:
            pass
    return equilibria

def check_stability(y_eq, pvec, d_val):
    eps=1e-6; J=np.zeros((5,5)); f0=rhs_model(y_eq,pvec,d_val)
    for i in range(5):
        y2=y_eq.copy(); y2[i]+=eps
        J[:,i]=(rhs_model(y2,pvec,d_val)-f0)/eps
    eig=np.linalg.eigvals(J)
    return np.all(np.real(eig)<-1e-5), eig

# =============================================================================
# FIGURE 1: BIFURCATION AND HYSTERESIS
# =============================================================================
print("\nGenerating Figure 1: Bistability and Hysteresis...")
d_values=np.linspace(0.02,0.25,80)
records=[]
for d_val in d_values:
    eqs=find_equilibria(p,d_val)
    for eq in eqs:
        stable,eigs=check_stability(eq,p,d_val)
        records.append({'d':d_val,'H':eq[2],'B':eq[3],'q':eq[4],
                        'stable':stable,'max_eig':np.max(np.real(eigs))})
bif=pd.DataFrame(records)

if bif.empty or 'stable' not in bif.columns:
    print("⚠️ No equilibria found—skipping bifurcation plot.")
else:
    fig,ax=plt.subplots(figsize=(8,6))
    stable=bif[bif['stable']]; unstable=bif[~bif['stable']]
    if not stable.empty:
        ax.scatter(stable['d'],stable['H'],c='blue',s=25,alpha=0.7,label='Stable')
    if not unstable.empty:
        ax.scatter(unstable['d'],unstable['H'],c='red',s=25,alpha=0.6,marker='x',label='Unstable')
    ax.axvline(d_fit,color='k',ls='--',alpha=0.7,label=f'd={d_fit:.3f}')
    ax.axhline(H_on,color='green',ls=':',alpha=0.6,label='H_on')
    ax.axhline(H_off,color='purple',ls=':',alpha=0.6,label='H_off')
    ax.set_xlabel('Inflammatory rate (d)')
    ax.set_ylabel('Host Health (H)')
    ax.set_title('Figure 1A – Bifurcation Diagram (5-variable Model)')
    ax.legend(); ax.grid(True,alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"{OUTDIR}/Figure1_Bifurcation_Enhanced.png",dpi=300)
    plt.close()
    print("✓ Figure 1A saved")


=== GUT MICROBIOME BISTABILITY ANALYSIS (5-var model) ===
Baseline inflammatory rate (d): 0.1200
H_on = 0.550, H_off = 0.750, width = 0.200

Generating Figure 1: Bistability and Hysteresis...
✓ Figure 1A saved


In [8]:
# =============================================================================
# FIGURE 2: BASINS OF ATTRACTION AND CRITICAL TRANSITIONS
# =============================================================================
print("\nGenerating Figure 2: Basins of Attraction...")

def simulate_relax(pvec, y0, d_val, T=400):
    """Integrate until near steady state and return final values."""
    try:
        sol = solve_ivp(lambda t,y: rhs_model(y,pvec,d_val),
                        [0,T], y0, t_eval=np.linspace(0,T,800),
                        rtol=1e-6, atol=1e-8)
        if sol.success:
            yfin, yprev = sol.y[:,-1], sol.y[:,-20]
            if np.linalg.norm(yfin-yprev) < 1e-4:
                return yfin
            else:
                sol2 = solve_ivp(lambda t,y: rhs_model(y,pvec,d_val),
                                 [0,2*T], yfin, t_eval=[2*T],
                                 rtol=1e-6, atol=1e-8)
                return sol2.y[:,-1] if sol2.success else yfin
    except:
        pass
    return y0

def basin_map(pvec, d_val, Hs=np.linspace(0.05,0.95,60), qs=np.linspace(0,1,60)):
    """Compute 2-D basin classification grid."""
    ZH=np.zeros((len(Hs),len(qs)))
    Zstate=np.zeros_like(ZH)
    for i,H0 in enumerate(Hs):
        for j,q0 in enumerate(qs):
            y0=np.array([0.12,0.12,H0,0.10,q0])
            yfin=simulate_relax(pvec,y0,d_val)
            ZH[i,j]=yfin[2]
            if yfin[2]>0.7:  Zstate[i,j]=1
            elif yfin[2]<0.3:Zstate[i,j]=0
            else:             Zstate[i,j]=0.5
    return Hs,qs,ZH,Zstate

print("Computing basins...")
Hs,qs,ZH,Zstate=basin_map(p,d_fit)

fig,axes=plt.subplots(1,2,figsize=(13,6))
im1=axes[0].imshow(Zstate,origin="lower",
                   extent=[qs[0],qs[-1],Hs[0],Hs[-1]],
                   aspect="auto",cmap="RdYlBu",vmin=0,vmax=1)
axes[0].set_xlabel("Initial memory (q)")
axes[0].set_ylabel("Initial host health (H)")
axes[0].set_title("Figure 2A – Basins of Attraction")
plt.colorbar(im1,ax=axes[0],ticks=[0,0.5,1],label="State")
axes[0].grid(False)

# Critical-transition sweep
def resilience_vs_d(pvec, drange, npert=25):
    out=[]
    for dval in drange:
        healthy=0
        for _ in range(npert):
            pert=np.random.normal(0,0.05,5)
            y0=np.clip(np.array([0.2,0.15,0.8,0.15,0.1])+pert,0,1)
            sol=solve_ivp(lambda t,y: rhs_model(y,pvec,dval),[0,250],y0,
                          t_eval=np.linspace(0,250,600),rtol=1e-6,atol=1e-8)
            if sol.success and sol.y[2,-1]>0.7: healthy+=1
        out.append({'d':dval,'resilience':healthy/npert})
    return pd.DataFrame(out)

print("Computing resilience curve...")
d_range=np.linspace(0.04,0.20,25)
res_df=resilience_vs_d(p,d_range)
axes[1].plot(res_df['d'],res_df['resilience'],'o-',color='red',lw=2)
axes[1].axvline(d_fit,ls='--',c='k',alpha=0.7)
axes[1].set_xlabel("Inflammatory rate (d)")
axes[1].set_ylabel("Resilience (fraction recovering)")
axes[1].set_title("Figure 2B – Resilience Curve")
axes[1].grid(True,alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTDIR}/Figure2_Basins_Criticality_Enhanced.png",dpi=300)
plt.close()
print("✓ Figure 2 saved")

# =============================================================================
# FIGURE 3: INTERVENTION STRATEGIES
# =============================================================================
print("\nGenerating Figure 3: Intervention Strategies...")

def simulate_intervention(pvec, d_val, mode, strength, duration=180):
    """Run intervention starting from dysbiotic state."""
    y0=np.array([0.08,0.18,0.15,0.05,0.9])
    def rhs_int(t,y):
        dy=rhs_model(y,pvec,d_val)
        if mode=="anti_inflammatory":
            d_eff=d_val*(1-strength)
            dy=rhs_model(y,pvec,d_eff)
        elif mode=="probiotic":
            dy[0]+=strength*0.3*y[0]*(1-y[0])
        elif mode=="postbiotic":
            dy[3]+=strength*0.05
        return dy
    sol=solve_ivp(rhs_int,[0,duration],y0,
                  t_eval=np.linspace(0,duration,800),
                  rtol=1e-6,atol=1e-8)
    return sol

def metrics_from_sol(sol):
    H=sol.y[2]; t=sol.t
    base=np.mean(H[t<20]); healthy=H>0.7
    rec_time=np.argmax(healthy) if np.any(healthy) else np.inf
    return {
        'time_to_recovery':np.inf if rec_time==np.inf else t[rec_time]-20,
        'max_effect':np.max(H)-base,
        'sustained_effect':np.mean(H[int(0.75*len(H)):])-base
    }

modes=['anti_inflammatory','probiotic','postbiotic']
strengths=[0.3,0.5,0.7]
colors={'anti_inflammatory':'blue','probiotic':'green','postbiotic':'red'}

fig,axes=plt.subplots(2,2,figsize=(14,10))
# Panel A – trajectories
for m in modes:
    sol=simulate_intervention(p,d_fit,m,0.5)
    axes[0,0].plot(sol.t,sol.y[2],label=m.replace('_',' ').title(),
                   color=colors[m],lw=2)
axes[0,0].axhline(0.7,ls='--',c='k',alpha=0.6)
axes[0,0].set_xlabel("Time")
axes[0,0].set_ylabel("Host health (H)")
axes[0,0].set_title("Figure 3A – Intervention Trajectories")
axes[0,0].legend(); axes[0,0].grid(True,alpha=0.3)

# Panel B – dose response
rows = []
for m in modes:
    for s in strengths:
        sol = simulate_intervention(p, d_fit, m, s)
        met = metrics_from_sol(sol)
        if isinstance(met, dict):  # ensure metrics exist
            met['mode'] = m
            met['strength'] = s
            rows.append(met)

df = pd.DataFrame(rows)
print("Intervention metrics dataframe columns:", df.columns.tolist())

if 'mode' not in df.columns:
    print("⚠️ No valid intervention metrics found — skipping dose-response plot.")
else:
    for m in modes:
        sub = df[df['mode'] == m]
        if not sub.empty:
            axes[0,1].plot(sub['strength'], sub['max_effect'],
                           'o-', color=colors[m],
                           label=m.replace('_', ' ').title())
    axes[0,1].set_xlabel("Intervention strength")
    axes[0,1].set_ylabel("Max ΔH")
    axes[0,1].set_title("Figure 3B – Dose Response")
    axes[0,1].legend(); axes[0,1].grid(True, alpha=0.3)

# Panel C – success window across d
def success_window(pvec, mode, s, drange):
    data=[]
    for dval in drange:
        sol=simulate_intervention(pvec,dval,mode,s)
        data.append({'d':dval,'success':sol.y[2,-1]>0.7})
    dfw=pd.DataFrame(data)
    dfw['success_rate']=dfw['success'].astype(int)
    return dfw

drange=np.linspace(0.05,0.20,25)
for m in modes:
    dfw=success_window(p,m,0.5,drange)
    axes[1,0].plot(dfw.d,dfw.success_rate,'s-',color=colors[m],
                   label=m.replace('_',' ').title())
axes[1,0].axvline(d_fit,ls='--',c='k',alpha=0.7)
axes[1,0].set_xlabel("Inflammatory rate (d)")
axes[1,0].set_ylabel("Success rate")
axes[1,0].set_title("Figure 3C – Therapeutic Windows")
axes[1,0].legend(); axes[1,0].grid(True,alpha=0.3)

# Panel D – simple ranking
rank = df.groupby('mode', as_index=False)['max_effect'].mean().sort_values('max_effect')
if 'mode' in rank.columns and not rank.empty:
    axes[1,1].barh(rank['mode'].str.replace('_',' ').str.title(), rank['max_effect'],
                   color=[colors.get(m, 'gray') for m in rank['mode']])
    axes[1,1].set_xlabel("Mean Max ΔH")
    axes[1,1].set_title("Figure 3D – Efficacy Ranking")
else:
    axes[1,1].text(0.5, 0.5, "No data available", ha='center', va='center', fontsize=12)


plt.tight_layout()
plt.savefig(f"{OUTDIR}/Figure3_Interventions_Enhanced.png",dpi=300)
plt.close()
print("✓ Figure 3 saved")



Generating Figure 2: Basins of Attraction...
Computing basins...
Computing resilience curve...
✓ Figure 2 saved

Generating Figure 3: Intervention Strategies...
Intervention metrics dataframe columns: ['time_to_recovery', 'max_effect', 'sustained_effect', 'mode', 'strength']
✓ Figure 3 saved


In [7]:
# =============================================================================
# SUPPLEMENTARY FIGURE: PARAMETER SENSITIVITY ANALYSIS
# =============================================================================
print("\nGenerating Supplementary Figure: Parameter Sensitivity Analysis...")

def parameter_sensitivity(pvec, idx, label, sweep, d_val=d_fit, nsamples=30):
    """Sweep one parameter and estimate healthy-basin fraction."""
    results=[]
    for val in sweep:
        pmod=pvec.copy()
        pmod[idx]=val
        healthy=0; dys=0
        for _ in range(nsamples):
            H0=np.random.uniform(0.1,0.9)
            q0=np.random.uniform(0,1)
            y0=np.array([0.12,0.12,H0,0.10,q0])
            yfin=simulate_relax(pmod,y0,d_val)
            if yfin[2]>0.7: healthy+=1
            elif yfin[2]<0.3: dys+=1
        tot=max(1,healthy+dys)
        results.append({
            'param_value':val,
            'label':label,
            'healthy_frac':healthy/tot,
            'dys_frac':dys/tot,
            'bistability_index':min(healthy/tot,dys/tot)
        })
    return pd.DataFrame(results)

# Select representative parameters to test
param_sweeps=[
    (6,  "Inflammation rate d",       np.linspace(0.05,0.20,12)),
    (0,  "r₀ᴾ (max growth P)",        np.linspace(0.15,0.40,12)),
    (7,  "gᴴ (host gain from B)",     np.linspace(0.3,1.0,12)),
    (14, "τ (memory timescale)",      np.linspace(1.0,10.0,12))
]

all_df=[]
for idx,label,grid in param_sweeps:
    print(f"  sweeping {label}...")
    df=parameter_sensitivity(p,idx,label,grid)
    all_df.append(df)

sens=pd.concat(all_df,ignore_index=True)

# Plot results
fig,axes=plt.subplots(2,2,figsize=(12,9))
uniq=sens['label'].unique()
for k,label in enumerate(uniq):
    ax=axes[k//2,k%2]
    sub=sens[sens.label==label]
    ax.plot(sub['param_value'],sub['healthy_frac'],'b-o',lw=2,label='Healthy basin')
    ax.plot(sub['param_value'],sub['dys_frac'],'r-s',lw=2,label='Dysbiotic basin')
    ax.plot(sub['param_value'],sub['bistability_index'],'k--',lw=1.5,label='Bistability index')
    base_val=p[param_sweeps[k][0]]
    ax.axvline(base_val,ls='--',c='gray',alpha=0.7,label=f'Baseline {base_val:.2f}')
    ax.set_xlabel(label); ax.set_ylabel('Fraction of initial conditions')
    ax.set_title(f"Sensitivity: {label}")
    ax.legend(fontsize=8); ax.grid(True,alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTDIR}/Supplementary_Sensitivity_Enhanced.png",dpi=300)
plt.close()
print("✓ Supplementary Sensitivity figure saved")

# =============================================================================
# SUMMARY METRICS
# =============================================================================
print("\nSummary statistics (approximate):")
healthy_frac=np.mean(Zstate==1)
dys_frac=np.mean(Zstate==0)
print(f"Healthy basin fraction ≈ {healthy_frac:.2%}")
print(f"Dysbiotic basin fraction ≈ {dys_frac:.2%}")
print(f"Hysteresis width ≈ {H_off-H_on:.3f}")
print("Analysis complete. All outputs in:",OUTDIR)



Generating Supplementary Figure: Parameter Sensitivity Analysis...
  sweeping Inflammation rate d...
  sweeping r₀ᴾ (max growth P)...
  sweeping gᴴ (host gain from B)...
  sweeping τ (memory timescale)...
✓ Supplementary Sensitivity figure saved

Summary statistics (approximate):
Healthy basin fraction ≈ 0.00%
Dysbiotic basin fraction ≈ 100.00%
Hysteresis width ≈ 0.200
Analysis complete. All outputs in: paper_figures
