In [9]:
# === JUPYTER CELL: Full pipeline with frequency-range and band-diversity knobs ===
import json, math, random, copy
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.optim as optim
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

# -----------------------
# PATHS 
# -----------------------
MODEL_PATH   = "/Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/single_frequency/Single_frequencies_whole_spectrum/PINN_report/pinn_model.pt"
DATA_ALL     = "/Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/single_frequency/Single_frequencies_whole_spectrum/PINN_report/compiled_dataset.csv"        # full spectra table (train+test)
DATA_TEST    = "/Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/single_frequency/Single_frequencies_whole_spectrum/PINN_report/test_predictions_pinn.csv"   # held-out test spectra (per-frequency rows)
OUT_DIR      = Path("/Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/single_frequency/Single_frequencies_whole_spectrum/inverse_reports/operando_piplines/high_f_1000-10000Hz")
OUT_DIR.mkdir(parents=True, exist_ok=True)
(OUT_DIR/"plots").mkdir(exist_ok=True)

# Targets
TARGET_C_MAE = 0.2   # mM
TARGET_T_MAE = 0.5   # °C

# Bounds (training domain)
C_MIN, C_MAX = 5.0, 20.0
T_MIN, T_MAX = 26.0, 50.0

# Candidate K values to try for top-K frequency sets
K_LIST = [1, 3, 6, 10]
TEST_ALL_K = True   # set False to keep "best only" behavior


# -----------------------
# Frequency selection knobs (NEW)
# -----------------------
# A) Allow only specific frequency windows; leave [] for whole range
#    Example to avoid sub-Hz entirely: ALLOWED_WINDOWS = [(1.0, 1e4)]
ALLOWED_WINDOWS = []   # [] => no restriction

# B) Enforce band-diversity quotas (round-robin pick). Leave [] to disable.
#    Example: one subHz, one 1–100, one 100–1000, one 1k–10k; remainder auto-filled.
BANDS = [
    {"name":"subHz",  "min":0.0,   "max":1.0,    "quota":0},
    {"name":"1_100",  "min":1.0,   "max":100.0,  "quota":0},
    {"name":"100_1k", "min":100.0, "max":1000.0, "quota":0},
    {"name":"1k_10k", "min":1000.0,"max":1e4,    "quota":20},
]

# -----------------------
# Forward physics (same as your training)
# -----------------------
def _j_like(x):
    return torch.complex(torch.zeros((), dtype=x.dtype, device=x.device),
                         torch.ones( (), dtype=x.dtype, device=x.device))

def torch_coth(z, eps=1e-12):
    sz = torch.sinh(z); cz = torch.cosh(z)
    small = torch.abs(sz) < eps
    out = torch.empty_like(z)
    out[~small] = cz[~small] / sz[~small]
    out[small] = 1.0/z[small] + z[small]/3.0
    return out

def torch_zarc(Rp, Y0, n, w):
    j = _j_like(w)
    return 1.0 / (1.0/torch.clamp(Rp, min=1e-18) + torch.clamp(Y0, min=1e-18) * (j*w)**n)

def torch_tl_impedance(r, y0, n, L, w):
    j = _j_like(w)
    r_ = torch.clamp(r,  min=1e-18); y0_= torch.clamp(y0, min=1e-18)
    gamma = torch.sqrt(r_ * y0_ * (j*w)**n)
    Z0    = torch.sqrt(r_ / (y0_ * (j*w)**n))
    return Z0 * torch_coth(L * gamma)

def torch_impedance_rs_zarc_tl(omega, Rs, Rp, Y0, n0, r, y0, n1, L):
    Zarc = torch_zarc(Rp, Y0, n0, omega)
    Ztl  = torch_tl_impedance(r, y0, n1, L, omega)
    return Rs + Zarc + Ztl

class ThetaNet(nn.Module):
    def __init__(self, in_dim=2, width=64, depth=3, dtype=torch.float64):
        super().__init__()
        layers, d = [], in_dim
        for _ in range(depth):
            layers += [nn.Linear(d, width, dtype=dtype), nn.ReLU()]
            d = width
        self.backbone = nn.Sequential(*layers) if layers else nn.Identity()
        self.head = nn.Linear(d, 8, dtype=dtype)  # Rs, Rp, Y0, n0, r, y0, n1, L
        self.softplus = nn.Softplus(); self.sigmoid = nn.Sigmoid()
    def forward(self, Cn, Tn):
        h = self.backbone(torch.stack([Cn, Tn], dim=1))
        raw = self.head(h)
        Rs_r, Rp_r, Y0_r, n0_r, r_r, y0_r, n1_r, L_r = torch.unbind(raw, dim=1)
        eps = 1e-9
        Rs  = self.softplus(Rs_r)  + eps
        Rp  = self.softplus(Rp_r)  + eps
        Y0  = self.softplus(Y0_r)  + eps
        n0  = self.sigmoid(n0_r)
        r   = self.softplus(r_r)   + eps
        y0  = self.softplus(y0_r)  + eps
        n1  = self.sigmoid(n1_r)
        L   = self.softplus(L_r)   + eps
        return Rs, Rp, Y0, n0, r, y0, n1, L

class PINNForward:
    def __init__(self, model_path, device="cpu"):
        # Support PyTorch 2.6+ change in default weights_only
        try:
            ckpt = torch.load(model_path, map_location=device, weights_only=False)
        except TypeError:
            ckpt = torch.load(model_path, map_location=device)
        self.xmu  = np.array(ckpt.get("xmu",  [0,0]), float)
        self.xstd = np.array(ckpt.get("xstd", [1,1]), float)
        tr = ckpt.get("train_config", {})
        width = int(tr.get("width", 64)); depth = int(tr.get("depth", 3))
        self.net = ThetaNet(in_dim=2, width=width, depth=depth, dtype=torch.float64).to(device)
        self.net.load_state_dict(ckpt["state_dict"])
        self.net.eval()
        self.device = torch.device(device)
        self.y_norm = ckpt.get("y_norm", {"enabled": False})
        self._dtype = torch.float64
    def predict_torch(self, C_t, T_t, w_t):
        Cn = (C_t - float(self.xmu[0])) / (float(self.xstd[0]) + 1e-12)
        Tn = (T_t - float(self.xmu[1])) / (float(self.xstd[1]) + 1e-12)
        Rs,Rp,Y0,n0,r,y0,n1,L = self.net(Cn, Tn)
        Zc = torch_impedance_rs_zarc_tl(w_t, Rs,Rp,Y0,n0,r,y0,n1,L)
        y = torch.stack([Zc.real, -Zc.imag], dim=1)  # [Z', -Z'']
        yn = self.y_norm
        if yn.get("enabled", False):
            method = yn.get("method","standard")
            if method == "standard":
                mu  = torch.tensor(yn["mu"],  dtype=self._dtype, device=self.device)
                std = torch.tensor(yn["std"], dtype=self._dtype, device=self.device)
                y = y*std + mu
            elif method == "minmax":
                y_min = torch.tensor(yn["min"], dtype=self._dtype, device=self.device)
                y_max = torch.tensor(yn["max"], dtype=self._dtype, device=self.device)
                y = y*(y_max - y_min) + y_min
        return y

# -----------------------
# Helpers: grouping spectra
# -----------------------
def norm_col(df, names, fallback=None):
    low = {c.lower(): c for c in df.columns}
    for n in names:
        if n.lower() in low: return low[n.lower()]
    return fallback

def read_full_dataset(path):
    df = pd.read_csv(path)
    fcol  = norm_col(df, ["frequency_Hz","frequency (Hz)","freq (hz)","f (hz)","f","f_hz"], None)
    zrcol = norm_col(df, ["Z_real","Z' (Ω)","z_real","zre","re(z)"], None)
    zinc  = norm_col(df, ["Z_imag_neg","-Z'' (Ω)","-z_imag","-imag","-zim"], None)
    if fcol is None or zrcol is None or zinc is None:
        raise RuntimeError("Could not resolve frequency/Z columns in compiled_dataset.csv")
    ccol = norm_col(df, ["concentration_mM","concentration (mm)","conc_mm","c_mm"], "concentration_mM")
    tcol = norm_col(df, ["temperature_C","temp_c","t_c"], "temperature_C")
    scol = norm_col(df, ["source_file","file","path"], "source_file")
    keep = ["frequency_Hz","Z_real","Z_imag_neg", ccol, tcol, scol]
    tmp = df[[fcol, zrcol, zinc, ccol, tcol, scol]].copy()
    tmp.columns = keep
    tmp = tmp.dropna(subset=["frequency_Hz","Z_real","Z_imag_neg"]).reset_index(drop=True)
    return tmp

def group_by_CT(df):
    groups = []
    for (c,t), sub in df.groupby(["concentration_mM","temperature_C"]):
        sub = sub.sort_values("frequency_Hz")
        groups.append({"id": f"C{float(c):g}_T{float(t):g}",
                       "C": float(c), "T": float(t),
                       "f": sub["frequency_Hz"].astype(float).to_numpy(),
                       "Zr": sub["Z_real"].astype(float).to_numpy(),
                       "Zim_neg": sub["Z_imag_neg"].astype(float).to_numpy()})
    return groups

# -----------------------
# Inverse solver (ALL freqs or subset)
# -----------------------
def sigmoid_to_range(x, lo, hi): return lo + (hi - lo) * torch.sigmoid(x)

def init_raw_from_guess(c0, t0, lo_c, hi_c, lo_t, hi_t):
    eps = 1e-6
    c0 = float(np.clip(c0, lo_c+eps, hi_c-eps)); t0 = float(np.clip(t0, lo_t+eps, hi_t-eps))
    invsig = lambda y, lo, hi: math.log((y-lo)/(hi-y))
    return torch.tensor([invsig(c0, lo_c, hi_c), invsig(t0, lo_t, hi_t)], dtype=torch.float64, requires_grad=True)

def invert_spectrum(forward, f, zr, zim_neg,
                    lo_c=C_MIN, hi_c=C_MAX, lo_t=T_MIN, hi_t=T_MAX,
                    restarts=10, steps_adam=300, steps_lbfgs=80,
                    wr=1.0, wi=1.0):
    device = forward.device
    f = np.asarray(f, float); zr = np.asarray(zr, float); zi = -np.asarray(zim_neg, float)
    if f.size == 0:
        return {"C": np.nan, "T": np.nan, "loss": np.nan, "se_C": np.nan, "se_T": np.nan}
    w = torch.tensor(2*np.pi*f, dtype=torch.float64, device=device)
    zr_t = torch.tensor(zr, dtype=torch.float64, device=device)
    zi_t = torch.tensor(zi, dtype=torch.float64, device=device)
    def loss_from(raw):
        C = sigmoid_to_range(raw[0], lo_c, hi_c)
        T = sigmoid_to_range(raw[1], lo_t, hi_t)
        y = forward.predict_torch(C.repeat(w.numel()), T.repeat(w.numel()), w)
        yzr, yzi = y[:,0], -y[:,1]
        sr = torch.clamp(zr_t.abs().median(), min=1e-9)
        si = torch.clamp(zi_t.abs().median(), min=1e-9)
        return wr*torch.mean(((yzr - zr_t)/sr)**2) + wi*torch.mean(((yzi - zi_t)/si)**2)
    best = {"loss": float("inf")}
    # multi-start: small grid + random
    grid_c = np.linspace(lo_c, hi_c, max(2, int(math.sqrt(restarts))))
    grid_t = np.linspace(lo_t, hi_t, max(2, int(math.sqrt(restarts))))
    seeds = [(float(c), float(t)) for c in grid_c for t in grid_t]
    while len(seeds) < restarts:
        seeds.append((random.uniform(lo_c, hi_c), random.uniform(lo_t, hi_t)))
    for c0,t0 in seeds[:restarts]:
        raw = init_raw_from_guess(c0, t0, lo_c, hi_c, lo_t, hi_t).to(device)
        # Adam
        opt = optim.Adam([raw], lr=0.08)
        for _ in range(steps_adam):
            opt.zero_grad(); L = loss_from(raw); L.backward(); opt.step()
        # LBFGS
        def closure():
            opt2.zero_grad(); L2 = loss_from(raw); L2.backward(); return L2
        opt2 = optim.LBFGS([raw], lr=1.0, max_iter=steps_lbfgs, line_search_fn="strong_wolfe")
        opt2.step(closure)
        with torch.no_grad():
            Lf = loss_from(raw).item()
            Cf = float(sigmoid_to_range(raw[0], lo_c, hi_c).cpu().numpy())
            Tf = float(sigmoid_to_range(raw[1], lo_t, hi_t).cpu().numpy())
        if Lf < best["loss"]:
            best = {"C": Cf, "T": Tf, "loss": Lf, "raw": raw.detach().cpu().numpy()}
    # crude SE via Gauss-Newton
    try:
        raw = torch.tensor(best["raw"], dtype=torch.float64, device=device, requires_grad=True)
        C = sigmoid_to_range(raw[0], lo_c, hi_c); T = sigmoid_to_range(raw[1], lo_t, hi_t)
        y = forward.predict_torch(C.repeat(w.numel()), T.repeat(w.numel()), w)
        yzr, yzi = y[:,0], -y[:,1]
        sr = torch.clamp(zr_t.abs().median(), min=1e-9); si = torch.clamp(zi_t.abs().median(), min=1e-9)
        res = torch.cat([(yzr - zr_t)/sr, (yzi - zi_t)/si], dim=0)
        J = []
        for i, var in enumerate([raw[0], raw[1]]):
            g = torch.autograd.grad(res, var, retain_graph=(i==0), allow_unused=False)[0].view(-1,1)
            J.append(g)
        J = torch.cat(J, dim=1).detach().cpu().numpy()
        JTJ = J.T @ J + 1e-10*np.eye(2)
        cov = np.linalg.inv(JTJ)
        rmse = float(torch.sqrt(torch.mean(res**2)).cpu().numpy())
        se_raw = np.sqrt(np.diag(cov)) * rmse
        s = 1/(1+np.exp(-best["raw"]))
        dC = s*(1-s)*(hi_c - lo_c); dT = s*(1-s)*(hi_t - lo_t)
        best["se_C"] = float(abs(dC[0])*se_raw[0]); best["se_T"] = float(abs(dT[1])*se_raw[1])
    except Exception:
        best["se_C"] = np.nan; best["se_T"] = np.nan
    return best

# -----------------------
# Load data & split into calibration vs test by (C,T)
# -----------------------
df_all  = read_full_dataset(DATA_ALL)
df_test = pd.read_csv(DATA_TEST)  # has: frequency_Hz, Z_real_true, Z_imag_neg_true, concentration_mM, temperature_C

# Normalize test column names
need = {"frequency_Hz":"frequency_Hz","Z_real_true":"Z_real_true","Z_imag_neg_true":"Z_imag_neg_true",
        "concentration_mM":"concentration_mM","temperature_C":"temperature_C"}
low = {c.lower(): c for c in df_test.columns}
for k in list(need):
    if k not in df_test.columns:
        for c in list(low.values()):
            if c.lower() == k.lower():
                need[k] = c
                break
df_test = df_test.rename(columns={need["frequency_Hz"]:"frequency_Hz",
                                  need["Z_real_true"]:"Z_real_true",
                                  need["Z_imag_neg_true"]:"Z_imag_neg_true",
                                  need["concentration_mM"]:"concentration_mM",
                                  need["temperature_C"]:"temperature_C"})

# Build test spectra (true Z from file)
test_groups = []
for (c,t), sub in df_test.groupby(["concentration_mM","temperature_C"]):
    sub = sub.sort_values("frequency_Hz")
    test_groups.append({
        "id": f"C{float(c):g}_T{float(t):g}",
        "C": float(c), "T": float(t),
        "f": sub["frequency_Hz"].astype(float).to_numpy(),
        "Zr": sub["Z_real_true"].astype(float).to_numpy(),
        "Zim_neg": sub["Z_imag_neg_true"].astype(float).to_numpy()
    })
test_CT = {(g["C"], g["T"]) for g in test_groups}

# Calibration = everything in compiled_dataset that is NOT in the test (by (C,T))
cal_df = df_all[~df_all.set_index(["concentration_mM","temperature_C"]).index.isin(test_CT)].copy()
cal_groups = group_by_CT(cal_df)

# Also capture the master frequency grid from data
FREQS_ALL = np.sort(df_all["frequency_Hz"].unique())
Path(OUT_DIR/"_meta.json").write_text(json.dumps({
    "n_calib_spectra": len(cal_groups),
    "n_test_spectra":  len(test_groups),
    "n_unique_freqs":  int(len(FREQS_ALL)),
    "freqs_min_max":   [float(FREQS_ALL.min()), float(FREQS_ALL.max())]
}, indent=2))

# -----------------------
# Frequency ranking by sensitivity (fast)
# -----------------------
DEV  = "cuda" if torch.cuda.is_available() else "cpu"
fwd  = PINNForward(MODEL_PATH, device=DEV)

def jacobian_scores(Cval, Tval, freqs):
    device = fwd.device
    w = torch.tensor(2*np.pi*freqs, dtype=torch.float64, device=device)
    C = torch.tensor(Cval, dtype=torch.float64, device=device, requires_grad=True)
    T = torch.tensor(Tval, dtype=torch.float64, device=device, requires_grad=True)
    y = fwd.predict_torch(C.repeat(w.numel()), T.repeat(w.numel()), w)
    Zr = y[:,0]; Zim = -y[:,1]
    rows=[]
    for i in range(len(freqs)):
        gC = torch.stack([torch.autograd.grad(Zr[i],C,retain_graph=True)[0],
                          torch.autograd.grad(Zim[i],C,retain_graph=True)[0]])
        gT = torch.stack([torch.autograd.grad(Zr[i],T,retain_graph=True)[0],
                          torch.autograd.grad(Zim[i],T,retain_graph=True)[0]])
        J  = torch.stack([gC,gT], dim=1).detach().cpu().numpy()  # 2x2
        JTJ = J.T @ J
        det = float(np.linalg.det(JTJ))
        # min eigenvalue helps conditioning; angle helps C vs T separation
        eigs = np.linalg.eigvalsh(JTJ)
        emin = float(np.min(eigs))
        c1, c2 = J[:,0], J[:,1]
        n1, n2 = np.linalg.norm(c1), np.linalg.norm(c2)
        angle = float(np.degrees(np.arccos(np.clip(np.dot(c1,c2)/(max(n1,1e-12)*max(n2,1e-12)),-1,1)))) if n1>0 and n2>0 else 0.0
        rows.append({"f": float(freqs[i]), "D": det, "E": emin, "angle": angle})
    return pd.DataFrame(rows)

# Score on a coarse grid of (C,T) to avoid biasing with test
C_grid = np.linspace(C_MIN, C_MAX, 6)
T_grid = np.linspace(T_MIN, T_MAX, 6)
blocks=[]
with tqdm(total=len(C_grid)*len(T_grid), desc="Sensitivity sweep (ranking)") as pbar:
    for C0 in C_grid:
        for T0 in T_grid:
            blocks.append(jacobian_scores(C0, T0, FREQS_ALL))
            pbar.update(1)
rank_df = pd.concat(blocks, ignore_index=True)
global_rank = (rank_df.groupby("f", as_index=False)
               .agg(D_mean=("D","mean"), E_mean=("E","mean"), angle_mean=("angle","mean"))
               .sort_values(["D_mean","E_mean","angle_mean"], ascending=[False,False,False])
               .reset_index(drop=True))
global_rank.to_csv(OUT_DIR/"global_frequency_ranking.csv", index=False)

# -----------------------
# Constrained frequency selection (NEW)
# -----------------------
def _apply_windows(df_rank, windows):
    if not windows: 
        return df_rank.copy()
    mask = np.zeros(len(df_rank), dtype=bool)
    fvals = df_rank["f"].values
    for (lo, hi) in windows:
        mask |= (fvals >= float(lo)) & (fvals <= float(hi))
    out = df_rank.loc[mask].reset_index(drop=True)
    return out

def _apply_bands_pick(df_rank, K, bands):
    """
    Round-robin pick under band quotas (highest-ranked per band).
    If quotas don't fill K, remaining slots are filled by overall ranking among unused.
    """
    if not bands:
        return df_rank.head(K)["f"].to_numpy()

    bands_local = copy.deepcopy(bands)  # don't mutate user quotas
    per = []
    used = set()
    for b in bands_local:
        sub = df_rank[(df_rank["f"] >= float(b["min"])) & (df_rank["f"] <= float(b["max"]))].copy().reset_index(drop=True)
        per.append({"band": b, "freqs": sub["f"].tolist(), "i": 0})

    picked = []
    # First pass: honor quotas in round-robin
    while len(picked) < K and any(p["i"] < len(p["freqs"]) and p["band"]["quota"] > 0 for p in per):
        for p in per:
            if len(picked) >= K:
                break
            q = p["band"]["quota"]
            # advance to next unused freq in this band
            while p["i"] < len(p["freqs"]) and p["freqs"][p["i"]] in used:
                p["i"] += 1
            if q > 0 and p["i"] < len(p["freqs"]):
                f = p["freqs"][p["i"]]
                picked.append(f); used.add(f)
                p["band"]["quota"] = q - 1
                p["i"] += 1

    # Second pass: fill remaining from overall rank among unused
    if len(picked) < K:
        for f in df_rank["f"].tolist():
            if f not in used:
                picked.append(f); used.add(f)
            if len(picked) >= K:
                break

    return np.array(picked[:K], float)

def choose_frequencies(global_rank, K, windows=None, bands=None, out_json_path=None):
    # 1) filter by windows (if any)
    rank2 = _apply_windows(global_rank, windows or [])
    if rank2.empty:
        raise RuntimeError("No candidate frequencies remain after applying ALLOWED_WINDOWS.")
    # 2) band quotas (if any), else top-K
    freqs = _apply_bands_pick(rank2, K, bands or []) if (bands and len(bands)>0) else rank2.head(K)["f"].to_numpy()
    # optional: save selection reasoning
    if out_json_path:
        Path(out_json_path).write_text(json.dumps({
            "K": int(K),
            "windows": windows or [],
            "bands": bands or [],
            "picked_Hz": list(map(float, freqs)),
            "n_candidates_after_windows": int(len(rank2))
        }, indent=2))
    return freqs

# -----------------------
# Helper: evaluate inverse MAE on a set of spectra using a selected freq set
# -----------------------
def eval_inverse_on_groups(groups, freq_set, desc="eval"):
    rows=[]
    with tqdm(groups, desc=desc) as pbar:
        for g in pbar:
            m = np.isin(g["f"], freq_set)
            f  = g["f"][m]; Zr = g["Zr"][m]; Zin = g["Zim_neg"][m]
            res = invert_spectrum(fwd, f, Zr, Zin)
            rows.append({"id": g["id"], "C_true": g["C"], "T_true": g["T"],
                         "C_pred": res["C"], "T_pred": res["T"],
                         "C_err": (res["C"]-g["C"]) if np.isfinite(res["C"]) else np.nan,
                         "T_err": (res["T"]-g["T"]) if np.isfinite(res["T"]) else np.nan,
                         "n_freq": int(len(f)), "loss": res["loss"],
                         "se_C": res.get("se_C", np.nan), "se_T": res.get("se_T", np.nan)})
    df = pd.DataFrame(rows)
    def mae(a): 
        a = np.asarray(a, float)
        return float(np.nanmean(np.abs(a))) if np.isfinite(a).any() else np.nan
    metrics = {"N": int(len(df)), "C_MAE": mae(df["C_err"]), "T_MAE": mae(df["T_err"])}
    return df, metrics

# -----------------------
# 1) Choose minimal K meeting targets on CALIBRATION
# -----------------------
best_choice = None
cal_results = []
for K in K_LIST:
    topK = choose_frequencies(
        global_rank, K,
        windows=ALLOWED_WINDOWS,   # [] ⇒ whole range
        bands=BANDS,               # [] ⇒ no band quotas
        out_json_path=OUT_DIR/f"selection_debug_top{K}.json"
    )
    df_cal, m_cal = eval_inverse_on_groups(cal_groups, topK, desc=f"Calibration inverse @ top-{K}")
    m_cal["K"] = K
    cal_results.append(m_cal)
    df_cal.to_csv(OUT_DIR/f"calibration_inverse_top{K}.csv", index=False)
    Path(OUT_DIR/f"calibration_metrics_top{K}.json").write_text(json.dumps(m_cal, indent=2))
    if (m_cal["C_MAE"] <= TARGET_C_MAE) and (m_cal["T_MAE"] <= TARGET_T_MAE) and (best_choice is None):
        best_choice = {"K": K, "freqs": topK, "cal_metrics": m_cal}

if best_choice is None:
    # pick the best trade-off (minimize weighted sum)
    scores = [(m["C_MAE"]/TARGET_C_MAE + m["T_MAE"]/TARGET_T_MAE, m["K"]) for m in cal_results]
    scores.sort()
    bestK = scores[0][1]
    best_choice = {"K": bestK,
                   "freqs": choose_frequencies(global_rank, bestK, windows=ALLOWED_WINDOWS, bands=BANDS),
                   "cal_metrics": [m for m in cal_results if m["K"]==bestK][0]}

Path(OUT_DIR/"chosen_frequency_set.json").write_text(json.dumps({
    "K": int(best_choice["K"]), "freqs_Hz": list(map(float, best_choice["freqs"])),
    "targets": {"C_MAE_mM": TARGET_C_MAE, "T_MAE_C": TARGET_T_MAE},
    "calibration_metrics": best_choice["cal_metrics"],
    "windows": ALLOWED_WINDOWS,
    "bands": BANDS
}, indent=2))

# -----------------------
# 2) Lock the chosen set and evaluate on TEST (held-out)
# -----------------------
# -----------------------
# 2) Evaluate on TEST (held-out)
# -----------------------
test_summ_rows = []

if TEST_ALL_K:
    # Evaluate test for every K in K_LIST (under same window/band constraints)
    for K in K_LIST:
        selK = choose_frequencies(
            global_rank, K,
            windows=ALLOWED_WINDOWS,
            bands=BANDS,
            out_json_path=OUT_DIR/f"selection_debug_TEST_top{K}.json"
        )
        df_test_invK, m_testK = eval_inverse_on_groups(test_groups, selK, desc=f"TEST inverse @ top-{K}")
        df_test_invK.to_csv(OUT_DIR/f"test_inverse_predictions_top{K}.csv", index=False)
        Path(OUT_DIR/f"test_inverse_metrics_top{K}.json").write_text(json.dumps({
            "K": int(K), "C_MAE_mM": m_testK["C_MAE"], "T_MAE_C": m_testK["T_MAE"], "N": m_testK["N"]
        }, indent=2))
        test_summ_rows.append({"K": int(K), "C_MAE_mM": m_testK["C_MAE"], "T_MAE_C": m_testK["T_MAE"], "N": m_testK["N"]})

    # Save a compact summary over all K
    pd.DataFrame(test_summ_rows).sort_values("K").to_csv(OUT_DIR/"test_inverse_metrics_summary.csv", index=False)

    # Also keep the “best” K products to preserve your current UX
    freq_set = best_choice["freqs"]
    df_test_inv, m_test = eval_inverse_on_groups(test_groups, freq_set, desc=f"TEST inverse @ top-{best_choice['K']}")
    df_test_inv.to_csv(OUT_DIR/"test_inverse_predictions_bestK.csv", index=False)
    Path(OUT_DIR/"test_inverse_metrics_bestK.json").write_text(json.dumps({
        "K": int(best_choice["K"]), "C_MAE_mM": m_test["C_MAE"], "T_MAE_C": m_test["T_MAE"], "N": m_test["N"]
    }, indent=2))

else:
    # Original behavior: only the best K
    freq_set = best_choice["freqs"]
    df_test_inv, m_test = eval_inverse_on_groups(test_groups, freq_set, desc=f"TEST inverse @ top-{best_choice['K']}")
    df_test_inv.to_csv(OUT_DIR/"test_inverse_predictions.csv", index=False)
    Path(OUT_DIR/"test_inverse_metrics.json").write_text(json.dumps({
        "K": int(best_choice["K"]), "C_MAE_mM": m_test["C_MAE"], "T_MAE_C": m_test["T_MAE"], "N": m_test["N"]
    }, indent=2))


# -----------------------
# 3) Bar plots (True vs Pred) on TEST
# -----------------------
if TEST_ALL_K:
    for K in K_LIST:
        dfK_path = OUT_DIR/f"test_inverse_predictions_top{K}.csv"
        if not Path(dfK_path).exists(): 
            continue
        dfK = pd.read_csv(dfK_path)
        labels = dfK["id"].tolist()
        x = np.arange(len(labels)); width = 0.38

        fig, ax = plt.subplots(figsize=(max(10, 0.4*len(labels)), 5))
        ax.bar(x - width/2, dfK["C_true"], width, label="True C")
        ax.bar(x + width/2, dfK["C_pred"], width, label="Pred C")
        ax.set_xticks(x); ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.set_ylabel("mM"); ax.set_title(f"TEST — True vs Pred C (top-{K} freqs)")
        ax.legend(); plt.tight_layout()
        plt.savefig(OUT_DIR/f"plots/bar_true_vs_pred_C_top{K}.png", dpi=180); plt.close(fig)

        fig, ax = plt.subplots(figsize=(max(10, 0.4*len(labels)), 5))
        ax.bar(x - width/2, dfK["T_true"], width, label="True T")
        ax.bar(x + width/2, dfK["T_pred"], width, label="Pred T")
        ax.set_xticks(x); ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.set_ylabel("°C"); ax.set_title(f"TEST — True vs Pred T (top-{K} freqs)")
        ax.legend(); plt.tight_layout()
        plt.savefig(OUT_DIR/f"plots/bar_true_vs_pred_T_top{K}.png", dpi=180); plt.close(fig)
else:
    labels = df_test_inv["id"].tolist()
    x = np.arange(len(labels)); width = 0.38

    fig, ax = plt.subplots(figsize=(max(10, 0.4*len(labels)), 5))
    ax.bar(x - width/2, df_test_inv["C_true"], width, label="True C")
    ax.bar(x + width/2, df_test_inv["C_pred"], width, label="Pred C")
    ax.set_xticks(x); ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("mM"); ax.set_title(f"TEST — True vs Pred C (top-{best_choice['K']} freqs)")
    ax.legend(); plt.tight_layout()
    plt.savefig(OUT_DIR/"plots/bar_true_vs_pred_C.png", dpi=180); plt.close(fig)

    fig, ax = plt.subplots(figsize=(max(10, 0.4*len(labels)), 5))
    ax.bar(x - width/2, df_test_inv["T_true"], width, label="True T")
    ax.bar(x + width/2, df_test_inv["T_pred"], width, label="Pred T")
    ax.set_xticks(x); ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("°C"); ax.set_title(f"TEST — True vs Pred T (top-{best_choice['K']} freqs)")
    ax.legend(); plt.tight_layout()
    plt.savefig(OUT_DIR/"plots/bar_true_vs_pred_T.png", dpi=180); plt.close(fig)


# -----------------------
# 4) Save a tiny runtime inverse module + config
# -----------------------
RUNTIME_PATH = OUT_DIR/"ct_inverse_runtime.py"
RUNTIME_PATH.write_text(f"""
import json, math, numpy as np, torch, torch.nn as nn, torch.optim as optim

def _j_like(x):
    return torch.complex(torch.zeros((), dtype=x.dtype, device=x.device),
                         torch.ones( (), dtype=x.dtype, device=x.device))
def torch_coth(z, eps=1e-12):
    sz = torch.sinh(z); cz = torch.cosh(z)
    small = torch.abs(sz) < eps
    out = torch.empty_like(z)
    out[~small] = cz[~small] / sz[~small]
    out[small] = 1.0/z[small] + z[small]/3.0
    return out
def torch_zarc(Rp, Y0, n, w):
    j = _j_like(w)
    return 1.0 / (1.0/torch.clamp(Rp, min=1e-18) + torch.clamp(Y0, min=1e-18) * (j*w)**n)
def torch_tl_impedance(r, y0, n, L, w):
    j = _j_like(w)
    r_ = torch.clamp(r,  min=1e-18); y0_= torch.clamp(y0, min=1e-18)
    gamma = torch.sqrt(r_ * y0_ * (j*w)**n)
    Z0    = torch.sqrt(r_ / (y0_ * (j*w)**n))
    return Z0 * torch_coth(L * gamma)
def torch_impedance_rs_zarc_tl(omega, Rs, Rp, Y0, n0, r, y0, n1, L):
    Zarc = torch_zarc(Rp, Y0, n0, omega)
    Ztl  = torch_tl_impedance(r, y0, n1, L, omega)
    return Rs + Zarc + Ztl

class ThetaNet(nn.Module):
    def __init__(self, in_dim=2, width=64, depth=3, dtype=torch.float64):
        super().__init__()
        layers, d = [], in_dim
        for _ in range(depth):
            layers += [nn.Linear(d, width, dtype=dtype), nn.ReLU()]
            d = width
        self.backbone = nn.Sequential(*layers) if layers else nn.Identity()
        self.head = nn.Linear(d, 8, dtype=dtype)
        self.softplus = nn.Softplus(); self.sigmoid = nn.Sigmoid()
    def forward(self, Cn, Tn):
        h = self.backbone(torch.stack([Cn, Tn], dim=1))
        raw = self.head(h)
        Rs_r, Rp_r, Y0_r, n0_r, r_r, y0_r, n1_r, L_r = torch.unbind(raw, dim=1)
        eps = 1e-9
        Rs  = self.softplus(Rs_r)  + eps
        Rp  = self.softplus(Rp_r)  + eps
        Y0  = self.softplus(Y0_r)  + eps
        n0  = self.sigmoid(n0_r)
        r   = self.softplus(r_r)   + eps
        y0  = self.softplus(y0_r)  + eps
        n1  = self.sigmoid(n1_r)
        L   = self.softplus(L_r)   + eps
        return Rs, Rp, Y0, n0, r, y0, n1, L

class ForwardPINN:
    def __init__(self, model_path, device='cpu'):
        try:
            ckpt = torch.load(model_path, map_location=device, weights_only=False)
        except TypeError:
            ckpt = torch.load(model_path, map_location=device)
        self.xmu  = np.array(ckpt.get('xmu',[0,0]), float)
        self.xstd = np.array(ckpt.get('xstd',[1,1]), float)
        tr = ckpt.get('train_config', {{}})
        width = int(tr.get('width',64)); depth = int(tr.get('depth',3))
        self.net = ThetaNet(in_dim=2, width=width, depth=depth, dtype=torch.float64).to(device)
        self.net.load_state_dict(ckpt['state_dict']); self.net.eval()
        self.device = torch.device(device)
        self.y_norm = ckpt.get('y_norm', {{'enabled': False}})
        self._dtype = torch.float64
    def predict_torch(self, C_t, T_t, w_t):
        Cn = (C_t - float(self.xmu[0])) / (float(self.xstd[0]) + 1e-12)
        Tn = (T_t - float(self.xmu[1])) / (float(self.xstd[1]) + 1e-12)
        Rs,Rp,Y0,n0,r,y0,n1,L = self.net(Cn, Tn)
        Zc = torch_impedance_rs_zarc_tl(w_t, Rs,Rp,Y0,n0,r,y0,n1,L)
        y = torch.stack([Zc.real, -Zc.imag], dim=1)
        yn = self.y_norm
        if yn.get('enabled', False):
            method = yn.get('method','standard')
            if method == 'standard':
                mu  = torch.tensor(yn['mu'],  dtype=self._dtype, device=self.device)
                std = torch.tensor(yn['std'], dtype=self._dtype, device=self.device)
                y = y*std + mu
            elif method == 'minmax':
                y_min = torch.tensor(yn['min'], dtype=self._dtype, device=self.device)
                y_max = torch.tensor(yn['max'], dtype=self._dtype, device=self.device)
                y = y*(y_max - y_min) + y_min
        return y

def _sigmoid_to_range(x, lo, hi): return lo + (hi - lo) * torch.sigmoid(x)

class CTInverseEstimator:
    def __init__(self, model_path, chosen_freqs, c_bounds=({C_MIN}, {C_MAX}), t_bounds=({T_MIN}, {T_MAX}), device='cpu'):
        self.forward = ForwardPINN(model_path, device=device)
        self.freqs = np.array(chosen_freqs, float)
        self.c_bounds = c_bounds; self.t_bounds = t_bounds
    def estimate(self, freqs, Zr, Zim, restarts=8, steps_adam=250, steps_lbfgs=80, wr=1.0, wi=1.0):
        freqs = np.asarray(freqs, float); Zr=np.asarray(Zr,float); Zim=np.asarray(Zim,float)
        m = np.isin(freqs, self.freqs)
        f = freqs[m]; zr = Zr[m]; zin = -Zim[m]  # convert to -Z'' convention internally
        if len(f)==0: return {{'ok': False, 'reason':'no overlap with chosen frequencies'}}
        device = self.forward.device
        w = torch.tensor(2*np.pi*f, dtype=torch.float64, device=device)
        zr_t = torch.tensor(zr, dtype=torch.float64, device=device)
        zi_t = torch.tensor(zin, dtype=torch.float64, device=device)
        lo_c,hi_c = self.c_bounds; lo_t,hi_t = self.t_bounds
        def loss_from(raw):
            C = _sigmoid_to_range(raw[0], lo_c, hi_c); T = _sigmoid_to_range(raw[1], lo_t, hi_t)
            y = self.forward.predict_torch(C.repeat(w.numel()), T.repeat(w.numel()), w)
            yzr, yzi = y[:,0], -y[:,1]
            sr = torch.clamp(zr_t.abs().median(), min=1e-9)
            si = torch.clamp(zi_t.abs().median(), min=1e-9)
            return wr*torch.mean(((yzr - zr_t)/sr)**2) + wi*torch.mean(((yzi - zi_t)/si)**2)
        def init_raw(c0,t0):
            eps=1e-6
            c0=float(np.clip(c0, lo_c+eps, hi_c-eps)); t0=float(np.clip(t0, lo_t+eps, hi_t-eps))
            invsig=lambda y, lo, hi: math.log((y-lo)/(hi-y))
            return torch.tensor([invsig(c0,lo_c,hi_c), invsig(t0,lo_t,hi_t)], dtype=torch.float64, requires_grad=True)
        grid_c = np.linspace(lo_c, hi_c, max(2, int(math.sqrt(restarts))))
        grid_t = np.linspace(lo_t, hi_t, max(2, int(math.sqrt(restarts))))
        seeds = [(float(c),float(t)) for c in grid_c for t in grid_t]
        while len(seeds)<restarts:
            seeds.append((np.random.uniform(lo_c,hi_c), np.random.uniform(lo_t,hi_t)))
        best={{'loss': float('inf')}}
        for c0,t0 in seeds[:restarts]:
            raw = init_raw(c0,t0).to(device)
            opt = optim.Adam([raw], lr=0.08)
            for _ in range(steps_adam):
                opt.zero_grad(); L=loss_from(raw); L.backward(); opt.step()
            def closure():
                opt2.zero_grad(); L2=loss_from(raw); L2.backward(); return L2
            opt2 = optim.LBFGS([raw], lr=1.0, max_iter=steps_lbfgs, line_search_fn='strong_wolfe')
            opt2.step(closure)
            with torch.no_grad():
                Lf = loss_from(raw).item()
                Cf = float(_sigmoid_to_range(raw[0], lo_c, hi_c).cpu().numpy())
                Tf = float(_sigmoid_to_range(raw[1], lo_t, hi_t).cpu().numpy())
            if Lf<best['loss']:
                best={{'loss':Lf,'C_pred':Cf,'T_pred':Tf,'raw':raw.detach().cpu().numpy()}}
        try:
            raw = torch.tensor(best['raw'], dtype=torch.float64, device=device, requires_grad=True)
            C=_sigmoid_to_range(raw[0], lo_c, hi_c); T=_sigmoid_to_range(raw[1], lo_t, hi_t)
            y=self.forward.predict_torch(C.repeat(w.numel()), T.repeat(w.numel()), w)
            yzr,yzi = y[:,0], -y[:,1]
            sr=torch.clamp(zr_t.abs().median(),min=1e-9); si=torch.clamp(zi_t.abs().median(),min=1e-9)
            res=torch.cat([(yzr-zr_t)/sr,(yzi-zi_t)/si],dim=0)
            J=[]
            for i,var in enumerate([raw[0],raw[1]]):
                g=torch.autograd.grad(res,var,retain_graph=(i==0),allow_unused=False)[0].view(-1,1)
                J.append(g)
            J=np.concatenate([j.detach().cpu().numpy() for j in J],axis=1)
            JTJ=J.T@J+1e-10*np.eye(2); cov=np.linalg.inv(JTJ)
            rmse=float(torch.sqrt(torch.mean(res**2)).cpu().numpy())
            se_raw=np.sqrt(np.diag(cov))*rmse
            s=1/(1+np.exp(-best['raw'])); dC=s*(1-s)*(hi_c-lo_c); dT=s*(1-s)*(hi_t-lo_t)
            seC=float(abs(dC[0])*se_raw[0]); seT=float(abs(dT[1])*se_raw[1])
        except Exception:
            seC=float('nan'); seT=float('nan')
        return {{'ok': True, 'C_pred':best['C_pred'], 'T_pred':best['T_pred'],
                 'se_C':seC, 'se_T':seT, 'used_freqs':list(map(float,f)),
                 'loss':best['loss']}}
""")

# Save a tiny JSON config with the chosen freqs and bounds
Path(OUT_DIR/"ct_inverse_runtime_config.json").write_text(json.dumps({
    "model_path": MODEL_PATH,
    "chosen_freqs_Hz": list(map(float, freq_set)),
    "C_bounds_mM": [C_MIN, C_MAX],
    "T_bounds_C":  [T_MIN, T_MAX],
    "windows": ALLOWED_WINDOWS,
    "bands": BANDS
}, indent=2))

print("\n=== DONE ===")
print("Chosen frequency set (Hz):", list(map(float, freq_set)))
print("Calibration metrics:", best_choice["cal_metrics"])
print("Test metrics:", {"C_MAE_mM": m_test["C_MAE"], "T_MAE_C": m_test["T_MAE"]})
print("Artifacts saved in:", str(OUT_DIR))


Sensitivity sweep (ranking):   0%|          | 0/36 [00:00<?, ?it/s]

Calibration inverse @ top-1:   0%|          | 0/166 [00:00<?, ?it/s]

Calibration inverse @ top-3:   0%|          | 0/166 [00:00<?, ?it/s]

Calibration inverse @ top-6:   0%|          | 0/166 [00:00<?, ?it/s]

Calibration inverse @ top-10:   0%|          | 0/166 [00:00<?, ?it/s]

TEST inverse @ top-1:   0%|          | 0/42 [00:00<?, ?it/s]

TEST inverse @ top-3:   0%|          | 0/42 [00:00<?, ?it/s]

TEST inverse @ top-6:   0%|          | 0/42 [00:00<?, ?it/s]

TEST inverse @ top-10:   0%|          | 0/42 [00:00<?, ?it/s]

TEST inverse @ top-3:   0%|          | 0/42 [00:00<?, ?it/s]


=== DONE ===
Chosen frequency set (Hz): [1120.0386846932463, 1267.006317736862, 8053.13803910887]
Calibration metrics: {'N': 166, 'C_MAE': 0.23179246465847766, 'T_MAE': 2.028279560927099, 'K': 3}
Test metrics: {'C_MAE_mM': 0.3046269289635919, 'T_MAE_C': 2.5793108670117184}
Artifacts saved in: /Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/single_frequency/Single_frequencies_whole_spectrum/inverse_reports/operando_piplines/high_f_1000-10000Hz
