In [1]:
# =======================================================================
#  PINN for EIS — Non-sweeping, TL physics
#  Model: Z(ω) = Rs + Zarc(Rp, Y0, n0) + Z_TL(r, y0, n1, L)
#  - Reads one folder of CSVs (parses C,T from names)
#  - Flexible header resolver for Frequency / Z' / −Z''
#  - Options: target normalization, loss rebalancing (MSE/Huber + weights)
#  - Physics priors: Arrhenius (Rs,Rp), θ-invariance across f, monotonic priors
#  - Optional teacher priors for TL parameters
#  - Saves full report in output_dir (plots, CSVs, model)
# =======================================================================

import os, re, json, math, random, warnings, gc, time
from pathlib import Path
from typing import Dict, List, Tuple
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from tqdm.auto import tqdm, trange
import matplotlib
matplotlib.rcParams["figure.dpi"] = 130
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.base import BaseEstimator, RegressorMixin

# -------------------
# Config (EDIT PATHS)
# -------------------
CONFIG: Dict = {
    "paths": {
        # <<< EDIT THESE >>>
        "input_root": "/Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/Single_frequencies_whole_spectrum/data",              # folder with many CSVs
        "fitparams_root": None,                                    # optional: folder with EIS_FitParams_*mM_*C.csv
        "output_dir": "/Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/Single_frequencies_whole_spectrum/PINN_report",            # where results go
    },
    "data": {
        "column_aliases": {
            "frequency": ["Frequency (Hz)", "frequency (hz)", "freq (hz)", "frequency", "f (hz)", "f_hz", "f"],
            "z_real":    ["Z' (Ω)", "Z' (ohm)", "z' (ω)", "z_real (Ω)", "zreal (Ω)", "re(z) (Ω)", "re(z)", "z_real", "zre"],
            "z_imag_neg":["-Z'' (Ω)", "-Z'' (ohm)", "-z'' (ω)", "-z_imag (Ω)", "-im(z) (Ω)", "-imag (Ω)", "-z_imag", "-imag", "-zim"],
            "z_imag_pos":["Z'' (Ω)", "Z'' (ohm)", "z'' (ω)", "z_imag (Ω)", "im(z) (Ω)", "imag (Ω)", "z_imag", "imag", "zim"],
        },
        # Legacy fallbacks (used only if aliases fail)
        "col_frequency": "Frequency (Hz)",
        "col_z_real":    "Z' (Ω)",
        "col_z_imag_neg":"-Z'' (Ω)",

        "frequency_filter_hz": None,         # None = use all f; or set single Hz to train single-frequency PINN
        "accept_nested_summary": True,       # allow */<conc>/*/<temp>/*/single_frequency_summary.csv
        "accept_flat_collected": True,       # allow arbitrary CSVs if columns match
        "read_csv_kwargs": {"encoding": "utf-8"},

        # Optional training-domain filter
        "train_filters": {"conc_mM_min": 5.0, "conc_mM_max": 20.0, "temp_C_min": None, "temp_C_max": None},
    },
    "targets": {"target_columns": ["Z_real", "Z_imag_neg"]},
    "split": {"test_size": 0.2, "random_state": 42, "group_round_C": 0, "group_round_T": 0},
    "progress": {"enable_bars": True, "show_file_scan": True},
    "plots": {"make_learning_curve": False, "learning_curve_train_sizes": [0.3, 0.6, 1.0], "show_inline": False},

    # ------------------- PINN knobs -------------------
    "pinn": {
        "teacher": {
            "use": False,                               # set True + provide fitparams_root for TL teacher priors
            "filename_glob": "EIS_FitParams_*mM_*C.csv",
            "weight": 0.15,
            # Log these teacher params (positive-only); do NOT log n0/n1
            "log_params": ["Rs","Rp","Y0","r","y0","L"],
            "param_col": "parameter",
            "value_col": "value",
        },
        "train": {
            "epochs": 5000, "lr": 2e-4, "width": 256, "depth": 4,
            "weight_decay": 1e-6, "batch_size": 512, "verbose": 1,
            "arrhenius_reg": 1e-4, "smooth_reg": 1e-5, "seed": 42, "device": "auto",

            # NEW: target normalization
            "targets_norm_enabled": False,          # True/False
            "targets_norm_method":  "standard",     # "standard" | "minmax"
            "targets_norm_clip":    True,           # for minmax, clip back to [min,max] on inverse

            # NEW: loss rebalancing / robust
            "loss_mode":   "mse",                   # "mse" | "huber"
            "huber_delta": 1.0,                     # δ for Huber
            "loss_weights": {"Z_real": 1.0, "Z_imag_neg": 1.0},  # per-target weights

            "extra": {
                "consistency_w": 1e-3, "group_round": 1e-6,
                # choose monotonicities if you know trends; default neutral except Rs:-1 vs T
                "mono_wC": 0.0, "mono_wT": 0.0,
                "mono_sign_C": {"Rs":0, "Rp":0, "Y0":0, "n0":0, "r":0, "y0":0, "n1":0, "L":0},
                "mono_sign_T": {"Rs":-1, "Rp":0, "Y0":0, "n0":0, "r":0, "y0":0, "n1":0, "L":0},
                "teacher_nn": True, "teacher_nn_sigma": 0.75
            }
        }
    },

    "extrapolation": {"range_conc_mM": (5.0, 20.0), "tag": "conc5_20"},
    "heatmap": {"C_min": 5.0, "C_max": 20.0, "C_points": 64, "T_min": 26.0, "T_max": 50.0, "T_points": 64}
}

# ----------------------
# Header resolution + C/T parsing
# ----------------------
CONC_RE = re.compile(r'(?P<val>[-+]?[0-9]*\.?[0-9]+)\s*(?P<unit>mM|M|uM|µM)', re.IGNORECASE)
TEMP_RE = re.compile(r'(?P<val>[-+]?[0-9]*\.?[0-9]+)\s*[cC]')
UNIT_SCALE = {"M": 1000.0, "mM": 1.0, "uM": 0.001, "µM": 0.001}

def parse_concentration_to_mM(text: str) -> float:
    if not isinstance(text, str): return np.nan
    m = CONC_RE.search(text)
    return float(m.group("val")) * UNIT_SCALE.get(m.group("unit"), 1.0) if m else np.nan

def parse_temperature_to_C(text: str) -> float:
    if not isinstance(text, str): return np.nan
    m = TEMP_RE.search(text)
    return float(m.group("val")) if m else np.nan

def _normalize(s: str) -> str:
    return re.sub(r"[\s\-\_\(\)\[\]\{\}\|'\"°Ωohmω]", "", s.strip().lower())

def _find_col_by_aliases(df: pd.DataFrame, aliases: List[str]) -> str | None:
    cols = list(df.columns)
    low_map = {c.lower(): c for c in cols}
    for a in aliases:
        if a.lower() in low_map: return low_map[a.lower()]
    norm_map = {_normalize(c): c for c in cols}
    for a in aliases:
        na = _normalize(a)
        if na in norm_map: return norm_map[na]
    return None

def resolve_eis_columns(df: pd.DataFrame) -> tuple[str, str, pd.Series]:
    al = CONFIG["data"]["column_aliases"]
    f_col  = _find_col_by_aliases(df, al["frequency"])
    zr_col = _find_col_by_aliases(df, al["z_real"])
    zi_neg = _find_col_by_aliases(df, al["z_imag_neg"])
    zi_pos = _find_col_by_aliases(df, al["z_imag_pos"])
    if f_col  is None and CONFIG["data"]["col_frequency"]   in df.columns: f_col  = CONFIG["data"]["col_frequency"]
    if zr_col is None and CONFIG["data"]["col_z_real"]      in df.columns: zr_col = CONFIG["data"]["col_z_real"]
    if zi_neg is None and CONFIG["data"]["col_z_imag_neg"]  in df.columns: zi_neg = CONFIG["data"]["col_z_imag_neg"]
    if f_col is None or zr_col is None or (zi_neg is None and zi_pos is None):
        raise KeyError(f"Could not resolve columns. Found={list(df.columns)}")
    if zi_neg is not None: zi_series = pd.to_numeric(df[zi_neg], errors="coerce")
    else:                  zi_series = -pd.to_numeric(df[zi_pos], errors="coerce")
    return f_col, zr_col, zi_series

# ----------------------
# File discovery (flat + nested)
# ----------------------
def discover_files(root: Path) -> List[Tuple[Path, str, str]]:
    found = []
    if CONFIG["data"]["accept_nested_summary"]:
        for p in root.rglob("single_frequency_summary.csv"):
            try:
                temperature_str = p.parent.name
                concentration_str = p.parent.parent.name
                found.append((p, concentration_str, temperature_str))
            except Exception:
                continue
    if CONFIG["data"]["accept_flat_collected"]:
        for p in root.rglob("*.csv"):
            name = p.name.lower()
            if name == "single_frequency_summary.csv": continue
            if any(tok in name for tok in ["metrics", "predictions_", "theta_grid", "model_report",
                                           "parity", "residual", "errhist", "learning", "training",
                                           "compiled_dataset", "rows_with_split"]):
                continue
            found.append((p, p.name, p.name))
    uniq, seen = [], set()
    for item in found:
        key = str(item[0])
        if key not in seen:
            seen.add(key); uniq.append(item)
    return uniq

def load_one_csv(path: Path) -> pd.DataFrame:
    return pd.read_csv(path, **CONFIG["data"]["read_csv_kwargs"])

def build_dataset(root: Path) -> pd.DataFrame:
    files = discover_files(root)
    if not files: raise FileNotFoundError(f"No CSV files found under: {root}")
    rows = []
    fsel = CONFIG["data"]["frequency_filter_hz"]
    it = tqdm(files, desc="Reading CSV files", unit="file") if CONFIG["progress"]["show_file_scan"] else files
    for csv_path, conc_hint, temp_hint in it:
        try:
            df = load_one_csv(csv_path)
        except Exception:
            continue
        try:
            f_col, zr_col, zi_neg_series = resolve_eis_columns(df)
        except Exception:
            continue
        tmp = pd.DataFrame({
            "frequency_Hz": pd.to_numeric(df[f_col], errors="coerce"),
            "Z_real":       pd.to_numeric(df[zr_col], errors="coerce"),
            "Z_imag_neg":   pd.to_numeric(zi_neg_series, errors="coerce"),
        })
        if fsel is not None:
            tmp = tmp.loc[np.isclose(tmp["frequency_Hz"].astype(float), float(fsel))]
        if tmp.empty: continue
        c_val = parse_concentration_to_mM(str(csv_path))
        t_val = parse_temperature_to_C(str(csv_path))
        if np.isnan(c_val): c_val = parse_concentration_to_mM(conc_hint) or parse_concentration_to_mM(csv_path.parent.name)
        if np.isnan(t_val): t_val = parse_temperature_to_C(temp_hint)    or parse_temperature_to_C(csv_path.parent.name)
        tmp["concentration_mM"] = c_val
        tmp["temperature_C"]    = t_val
        tmp["source_file"]      = str(csv_path)
        rows.append(tmp)
    if not rows:
        raise RuntimeError("Discovered files, but none had required columns / valid C,T.")
    data = pd.concat(rows, ignore_index=True)
    for c in ["concentration_mM","temperature_C","frequency_Hz","Z_real","Z_imag_neg"]:
        data[c] = pd.to_numeric(data[c], errors="coerce")
    data = data.dropna(subset=["concentration_mM","temperature_C","frequency_Hz","Z_real","Z_imag_neg"])
    fcfg = CONFIG["data"]["train_filters"]
    if any(v is not None for v in fcfg.values()):
        m = pd.Series(True, index=data.index)
        if fcfg["conc_mM_min"] is not None: m &= data["concentration_mM"] >= float(fcfg["conc_mM_min"])
        if fcfg["conc_mM_max"] is not None: m &= data["concentration_mM"] <= float(fcfg["conc_mM_max"])
        if fcfg["temp_C_min"]  is not None: m &= data["temperature_C"]    >= float(fcfg["temp_C_min"])
        if fcfg["temp_C_max"]  is not None: m &= data["temperature_C"]    <= float(fcfg["temp_C_max"])
        data = data.loc[m].copy()
    return data

# --------------------------
# Teacher parameter loader (TL names + aliases, optional)
# --------------------------
_C_RE = re.compile(r'EIS_FitParams_(?P<c>[-+]?[0-9]*\.?[0-9]+)mM_(?P<t>[-+]?[0-9]*\.?[0-9]+)C', re.I)

_TEACHER_KEY_ALIASES = {
    "Rs":"Rs","Rs (Ω)":"Rs",
    "Rp":"Rp","Rp (Ω)":"Rp",
    "Y0":"Y0","Y0_ZARC":"Y0","Y0_ZARC (Ω^-1 s^n0)":"Y0","Y0 (Ω^-1 s^n0)":"Y0",
    "n0":"n0","n0 (-)":"n0",
    "r":"r","r_line":"r","r_line (Ω/len)":"r","r (Ω/len)":"r",
    "y0":"y0","y0_line":"y0","y0_line (Ω^-1 s^n1 /len)":"y0","y0 (Ω^-1 s^n1/len)":"y0",
    "n1":"n1","n1 (-)":"n1",
    "L":"L","L (len)":"L"
}
def _normalize_teacher_key(k: str) -> str:
    k = str(k).strip()
    return _TEACHER_KEY_ALIASES.get(k, k)

def load_teacher_param_grid(root_dir: str, filename_glob: str, param_col="parameter", value_col="value") -> pd.DataFrame:
    if not root_dir: return pd.DataFrame()
    root = Path(root_dir)
    if not root.exists(): return pd.DataFrame()
    rows = []
    for fp in root.glob(filename_glob):
        m = _C_RE.search(fp.stem)
        if not m: continue
        c = float(m.group("c")); t = float(m.group("t"))
        try:
            dfp = pd.read_csv(fp)
        except Exception:
            try: dfp = pd.read_excel(fp)
            except Exception: continue
        cols = [str(x).strip() for x in dfp.columns]
        dfp.columns = cols
        if param_col not in cols or value_col not in cols:
            param_col, value_col = cols[0], cols[1]
        dmap = { _normalize_teacher_key(r[param_col]) : float(r[value_col]) for _, r in dfp.iterrows() }
        row = {"concentration_mM": c, "temperature_C": t}
        for k in ["Rs","Rp","Y0","n0","r","y0","n1","L"]:
            row[k] = dmap.get(k, np.nan)
        rows.append(row)
    grid = pd.DataFrame(rows)
    for c in ["concentration_mM","temperature_C","Rs","Rp","Y0","n0","r","y0","n1","L"]:
        if c in grid.columns: grid[c] = pd.to_numeric(grid[c], errors="coerce")
    return grid.dropna(subset=["concentration_mM","temperature_C"], how="any")

# --------------------------
# Metrics + plotting helpers
# --------------------------
def evaluate_predictions(y_true: np.ndarray, y_pred: np.ndarray, names: List[str]) -> Dict[str, float]:
    out = {}
    r2s, maes, rmses = [], [], []
    for j, name in enumerate(names):
        r2 = r2_score(y_true[:, j], y_pred[:, j])
        mae = mean_absolute_error(y_true[:, j], y_pred[:, j])
        rmse = math.sqrt(mean_squared_error(y_true[:, j], y_pred[:, j]))
        out[f"R2_{name}"] = r2; out[f"MAE_{name}"] = mae; out[f"RMSE_{name}"] = rmse
        r2s.append(r2); maes.append(mae); rmses.append(rmse)
    out["R2_mean"] = float(np.mean(r2s)); out["MAE_mean"] = float(np.mean(maes)); out["RMSE_mean"] = float(np.mean(rmses))
    return out

def _maybe_show():
    if CONFIG["plots"]["show_inline"]: plt.show()
    plt.close()

def parity_plot_and_csv(y_true, y_pred, target_name, img_path: Path, csv_path: Path):
    pd.DataFrame({"y_true": y_true.ravel(), "y_pred": y_pred.ravel()}).to_csv(csv_path, index=False)
    plt.figure()
    plt.scatter(y_true, y_pred, s=12, alpha=0.7)
    low = min(float(np.min(y_true)), float(np.min(y_pred))); high = max(float(np.max(y_true)), float(np.max(y_pred)))
    plt.plot([low, high], [low, high], linestyle="--")
    plt.xlabel(f"True {target_name}"); plt.ylabel(f"Pred {target_name}")
    plt.title(f"Parity: {target_name}")
    plt.tight_layout(); plt.savefig(img_path, dpi=180); _maybe_show()

def residual_plot_and_csv(y_true, y_pred, target_name, img_path: Path, csv_path: Path):
    residual = (y_pred - y_true).ravel()
    pd.DataFrame({"y_pred": y_pred.ravel(), "residual": residual}).to_csv(csv_path, index=False)
    plt.figure()
    plt.scatter(y_pred, residual, s=12, alpha=0.7)
    plt.axhline(0, linestyle="--")
    plt.xlabel(f"Pred {target_name}"); plt.ylabel("Residual (Pred - True)")
    plt.title(f"Residuals: {target_name}")
    plt.tight_layout(); plt.savefig(img_path, dpi=180); _maybe_show()

def _gauss_pdf(x, m, s):
    if s <= 0: return np.zeros_like(x)
    return (1.0 / (s * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((x - m)/s)**2)

def error_hist_plot_and_csv(y_true, y_pred, target_name, img_path: Path, bins: int, hist_csv_path: Path, fit_csv_path: Path):
    err = (y_pred - y_true).ravel()
    counts, edges = np.histogram(err, bins=bins)
    centers = 0.5*(edges[:-1] + edges[1:])
    bin_w = float(edges[1]-edges[0]) if len(edges) > 1 else 1.0
    mode_center = float(centers[np.argmax(counts)])
    med = float(np.median(err)); mad = float(np.median(np.abs(err - med)))
    sigma_rob = 1.4826*mad if mad > 0 else float(np.std(err, ddof=1))
    mean_bias = float(np.mean(err)); sigma_std = float(np.std(err, ddof=1)) if err.size > 1 else 0.0
    ci_low = mode_center - 1.96*sigma_rob; ci_high = mode_center + 1.96*sigma_rob
    pd.DataFrame({"bin_left": edges[:-1], "bin_right": edges[1:], "bin_center": centers, "count": counts}).to_csv(hist_csv_path, index=False)
    pd.DataFrame([{"mu_mode": mode_center, "sigma_robust": sigma_rob, "mean_bias": mean_bias, "sigma_std": sigma_std,
                   "ci95_low": ci_low, "ci95_high": ci_high, "n": int(len(err)), "bin_width": bin_w}]).to_csv(fit_csv_path, index=False)
    xg = np.linspace(edges[0], edges[-1], 600)
    gauss_counts = len(err) * bin_w * _gauss_pdf(xg, mode_center, sigma_rob)
    plt.figure(); plt.hist(err, bins=bins); plt.plot(xg, gauss_counts, linewidth=2)
    plt.xlabel("Error (Pred - True)"); plt.ylabel("Count"); plt.title(f"Error Distribution: {target_name}")
    plt.tight_layout(); plt.savefig(img_path, dpi=180); _maybe_show()

# --------------------------
# Physics core — TL model in torch (complex)
# --------------------------
def _j_like(x):
    return torch.complex(torch.zeros((), dtype=x.dtype, device=x.device),
                         torch.ones( (), dtype=x.dtype, device=x.device))

def torch_coth(z, eps=1e-12):
    sz = torch.sinh(z); cz = torch.cosh(z)
    small = torch.abs(sz) < eps
    out = torch.empty_like(z)
    out[~small] = cz[~small] / sz[~small]
    out[small] = 1.0/z[small] + z[small]/3.0   # series
    return out

def torch_zarc(Rp, Y0, n, w):
    j = _j_like(w)
    return 1.0 / (1.0/torch.clamp(Rp, min=1e-18) + torch.clamp(Y0, min=1e-18) * (j*w)**n)

def torch_tl_impedance(r, y0, n, L, w):
    j = _j_like(w)
    r_ = torch.clamp(r,  min=1e-18); y0_= torch.clamp(y0, min=1e-18)
    gamma = torch.sqrt(r_ * y0_ * (j*w)**n)
    Z0    = torch.sqrt(r_ / (y0_ * (j*w)**n))
    return Z0 * torch_coth(L * gamma)

def torch_impedance_rs_zarc_tl(omega, Rs, Rp, Y0, n0, r, y0, n1, L):
    Zarc = torch_zarc(Rp, Y0, n0, omega)
    Ztl  = torch_tl_impedance(r, y0, n1, L, omega)
    return Rs + Zarc + Ztl

# --------------------------
# Network θ(C,T) → [Rs, Rp, Y0, n0, r, y0, n1, L]
# --------------------------
class ThetaNet(nn.Module):
    def __init__(self, in_dim=2, width=64, depth=3, dtype=torch.float64):
        super().__init__()
        layers, d = [], in_dim
        for _ in range(depth):
            layers += [nn.Linear(d, width, dtype=dtype), nn.ReLU()]
            d = width
        self.backbone = nn.Sequential(*layers) if layers else nn.Identity()
        self.head = nn.Linear(d, 8, dtype=dtype)  # Rs, Rp, Y0, n0, r, y0, n1, L
        self.softplus = nn.Softplus(); self.sigmoid = nn.Sigmoid()
    def forward(self, Cn, Tn):
        h = self.backbone(torch.stack([Cn, Tn], dim=1))
        raw = self.head(h)
        Rs_r, Rp_r, Y0_r, n0_r, r_r, y0_r, n1_r, L_r = torch.unbind(raw, dim=1)
        eps = 1e-9
        Rs  = self.softplus(Rs_r)  + eps
        Rp  = self.softplus(Rp_r)  + eps
        Y0  = self.softplus(Y0_r)  + eps
        n0  = self.sigmoid(n0_r)   # (0,1)
        r   = self.softplus(r_r)   + eps
        y0  = self.softplus(y0_r)  + eps
        n1  = self.sigmoid(n1_r)   # (0,1)
        L   = self.softplus(L_r)   + eps
        return Rs, Rp, Y0, n0, r, y0, n1, L

# --------------------------
# EEC_PINN (TL) with normalization + loss options
# --------------------------
class EEC_PINN(BaseEstimator, RegressorMixin):
    def __init__(self, teacher_df=None, teacher_use=False, teacher_weight=0.15, teacher_log_params=None,
                 epochs=3500, lr=7.5e-4, width=64, depth=3, weight_decay=1e-6,
                 batch_size=512, verbose=1, arrhenius_reg=1e-4, smooth_reg=1e-5,
                 seed=42, device="auto",
                 consistency_w=1e-3, group_round=1e-6,
                 mono_wC=0.0, mono_wT=0.0,
                 mono_sign_C=None, mono_sign_T=None,
                 teacher_nn=True, teacher_nn_sigma=0.75,
                 # NEW:
                 targets_norm_enabled=False, targets_norm_method="standard", targets_norm_clip=True,
                 loss_mode="mse", huber_delta=1.0, loss_weights=None):
        # teacher/prior
        self.teacher_df = teacher_df
        self.teacher_use = bool(teacher_use)
        self.teacher_weight = float(teacher_weight)
        self.teacher_log_params = teacher_log_params or ["Rs","Rp","Y0","r","y0","L"]
        self.teacher_nn = bool(teacher_nn); self.teacher_nn_sigma=float(teacher_nn_sigma)
        # training
        self.epochs=int(epochs); self.lr=float(lr); self.width=int(width); self.depth=int(depth)
        self.weight_decay=float(weight_decay); self.batch_size=int(batch_size); self.verbose=int(verbose)
        self.arrhenius_reg=float(arrhenius_reg); self.smooth_reg=float(smooth_reg)
        self.seed=int(seed); self.device=device
        # priors
        self.consistency_w=float(consistency_w); self.group_round=float(group_round)
        self.mono_wC=float(mono_wC); self.mono_wT=float(mono_wT)
        self.mono_sign_C = mono_sign_C or {"Rs":0,"Rp":0,"Y0":0,"n0":0,"r":0,"y0":0,"n1":0,"L":0}
        self.mono_sign_T = mono_sign_T or {"Rs":-1,"Rp":0,"Y0":0,"n0":0,"r":0,"y0":0,"n1":0,"L":0}
        # targets + loss
        self.targets_norm_enabled=bool(targets_norm_enabled)
        self.targets_norm_method=str(targets_norm_method)
        self.targets_norm_clip=bool(targets_norm_clip)
        self.loss_mode=str(loss_mode)
        self.huber_delta=float(huber_delta)
        self.loss_weights = loss_weights or {"Z_real": 1.0, "Z_imag_neg": 1.0}

        # internals
        self._dtype=torch.float64; self._xmu=None; self._xstd=None; self._net=None
        self._teacher_map={}; self._teacher_CT=None; self._teacher_TH=None
        self._y_norm = {"enabled": False}
        self.loss_history = []

    def get_params(self, deep=True):
        return {
            "teacher_df": self.teacher_df, "teacher_use": self.teacher_use, "teacher_weight": self.teacher_weight,
            "teacher_log_params": self.teacher_log_params, "teacher_nn": self.teacher_nn, "teacher_nn_sigma": self.teacher_nn_sigma,
            "epochs": self.epochs, "lr": self.lr, "width": self.width, "depth": self.depth,
            "weight_decay": self.weight_decay, "batch_size": self.batch_size, "verbose": self.verbose,
            "arrhenius_reg": self.arrhenius_reg, "smooth_reg": self.smooth_reg,
            "seed": self.seed, "device": self.device,
            "consistency_w": self.consistency_w, "group_round": self.group_round,
            "mono_wC": self.mono_wC, "mono_wT": self.mono_wT,
            "mono_sign_C": self.mono_sign_C, "mono_sign_T": self.mono_sign_T,
            "targets_norm_enabled": self.targets_norm_enabled,
            "targets_norm_method": self.targets_norm_method,
            "targets_norm_clip": self.targets_norm_clip,
            "loss_mode": self.loss_mode, "huber_delta": self.huber_delta,
            "loss_weights": self.loss_weights
        }
    def set_params(self, **params):
        for k, v in params.items(): setattr(self, k, v)
        return self

    def _dev(self):
        return torch.device("cuda" if (self.device=="auto" and torch.cuda.is_available()) else (self.device if self.device!="auto" else "cpu"))
    def _std_CT(self, C, T):
        X = np.stack([C,T], axis=1)
        if self._xmu is None:
            self._xmu = X.mean(axis=0); self._xstd = X.std(axis=0) + 1e-12
        Xn = (X - self._xmu) / self._xstd
        return Xn[:,0], Xn[:,1]

    # --- target normalization (Z', -Z'') ---
    def _fit_y_norm(self, y):
        if not self.targets_norm_enabled:
            self._y_norm = {"enabled": False}; return
        y = np.asarray(y, float)
        mu = y.mean(axis=0); std = y.std(axis=0) + 1e-12
        y_min = y.min(axis=0); y_max = y.max(axis=0)
        self._y_norm = {"enabled": True, "method": self.targets_norm_method,
                        "mu": mu, "std": std, "min": y_min, "max": y_max}
    def _y_transform(self, y):
        if not self._y_norm.get("enabled", False): return y
        if self._y_norm["method"] == "standard":
            return (y - self._y_norm["mu"]) / self._y_norm["std"]
        elif self._y_norm["method"] == "minmax":
            rng = (self._y_norm["max"] - self._y_norm["min"]); rng[rng==0] = 1.0
            return (y - self._y_norm["min"]) / rng
        return y
    def _y_inverse(self, y_norm):
        if not self._y_norm.get("enabled", False): return y_norm
        if self._y_norm["method"] == "standard":
            return y_norm * self._y_norm["std"] + self._y_norm["mu"]
        elif self._y_norm["method"] == "minmax":
            rng = (self._y_norm["max"] - self._y_norm["min"]); rng[rng==0] = 1.0
            y = y_norm * rng + self._y_norm["min"]
            if self.targets_norm_clip:
                y = np.clip(y, self._y_norm["min"], self._y_norm["max"])
            return y
        return y_norm

    def _build_teacher_map(self):
        self._teacher_map = {}
        if not (self.teacher_use and (self.teacher_df is not None) and (not self.teacher_df.empty)):
            return
        def keyify(c,t): return (round(float(c),6), round(float(t),6))
        for _, r in self.teacher_df.iterrows():
            th = [r.get(k, np.nan) for k in ["Rs","Rp","Y0","n0","r","y0","n1","L"]]
            if not np.isnan(th).any():
                self._teacher_map[keyify(r["concentration_mM"], r["temperature_C"])] = th
    def _theta_teacher_for(self, C,T):
        return self._teacher_map.get((round(float(C),6), round(float(T),6)), None)

    def fit(self, X, y):
        torch.manual_seed(self.seed); np.random.seed(self.seed); random.seed(self.seed)
        self.loss_history = []
        dev=self._dev(); dtype=self._dtype
        C, T, f = X[:,0].astype(float), X[:,1].astype(float), X[:,2].astype(float)
        Zr, Zim = y[:,0].astype(float), y[:,1].astype(float)
        Cn, Tn = self._std_CT(C, T); w = 2*np.pi*f

        # fit target normalization on training y
        Y_train = np.column_stack([Zr, Zim]); self._fit_y_norm(Y_train)

        C_t=torch.tensor(Cn,dtype=dtype,device=dev); T_t=torch.tensor(Tn,dtype=dtype,device=dev)
        w_t=torch.tensor(w,dtype=dtype,device=dev)
        Zr_t=torch.tensor(Zr,dtype=dtype,device=dev); Zim_t=torch.tensor(Zim,dtype=dtype,device=dev)

        # teacher buffers
        if self.teacher_use and (self.teacher_df is not None):
            self._build_teacher_map()
            if self.teacher_nn and (not self.teacher_df.empty):
                Ct = self.teacher_df[["concentration_mM","temperature_C"]].to_numpy(float)
                Ctn, Ttn = self._std_CT(Ct[:,0], Ct[:,1])
                self._teacher_CT = torch.tensor(np.column_stack([Ctn,Ttn]), dtype=dtype, device=dev)
                self._teacher_TH = torch.tensor(self.teacher_df[["Rs","Rp","Y0","n0","r","y0","n1","L"]].to_numpy(float),
                                                dtype=dtype, device=dev)

        self._net = ThetaNet(in_dim=2, width=self.width, depth=self.depth, dtype=dtype).to(dev)
        opt = optim.Adam(self._net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        def _weighted_loss(diff_re, diff_im):
            if self.loss_mode == "huber":
                delta = self.huber_delta
                huber_re = torch.where(diff_re.abs() <= delta, 0.5*diff_re.pow(2), delta*(diff_re.abs()-0.5*delta))
                huber_im = torch.where(diff_im.abs() <= delta, 0.5*diff_im.pow(2), delta*(diff_im.abs()-0.5*delta))
                return self.loss_weights["Z_real"]*huber_re.mean() + self.loss_weights["Z_imag_neg"]*huber_im.mean()
            else:  # mse
                return self.loss_weights["Z_real"]*(diff_re.pow(2).mean()) + self.loss_weights["Z_imag_neg"]*(diff_im.pow(2).mean())

        N=len(C); nb=max(1, N//self.batch_size)
        loop=trange(self.epochs, disable=(self.verbose==0))
        for _ in loop:
            perm=torch.randperm(N, device=dev); epoch_loss=0.0
            for b in range(nb):
                sel=perm[b*self.batch_size:(b+1)*self.batch_size]
                Cb,Tb,wb = C_t[sel],T_t[sel],w_t[sel]
                Zr_b,Zim_b = Zr_t[sel], Zim_t[sel]

                Rs,Rp,Y0,n0,r,y0,n1,L = self._net(Cb,Tb)
                Zc = torch_impedance_rs_zarc_tl(wb, Rs,Rp,Y0,n0,r,y0,n1,L)
                yhat_re = Zc.real; yhat_im = -Zc.imag

                # targets normalization (apply on both yhat & ytrue)
                if self.targets_norm_enabled:
                    yhat_np = np.column_stack([yhat_re.detach().cpu().numpy(),
                                               yhat_im.detach().cpu().numpy()])
                    ytrue_np= np.column_stack([Zr_b.detach().cpu().numpy(),
                                               Zim_b.detach().cpu().numpy()])
                    yhat_n  = self._y_transform(yhat_np); ytrue_n = self._y_transform(ytrue_np)
                    diff_re = torch.tensor(yhat_n[:,0]-ytrue_n[:,0], dtype=dtype, device=dev)
                    diff_im = torch.tensor(yhat_n[:,1]-ytrue_n[:,1], dtype=dtype, device=dev)
                else:
                    diff_re = (yhat_re - Zr_b); diff_im = (yhat_im - Zim_b)

                loss = _weighted_loss(diff_re, diff_im)

                # --- physics-informed extras ---

                # exact teacher at same (C,T)
                if self.teacher_use and (self.teacher_weight > 0) and len(self._teacher_map) > 0:
                    Cbo = (Cb.cpu().numpy()*self._xstd[0] + self._xmu[0])
                    Tbo = (Tb.cpu().numpy()*self._xstd[1] + self._xmu[1])
                    idxs, th_list = [], []
                    for i in range(len(Cbo)):
                        th = self._theta_teacher_for(Cbo[i], Tbo[i])
                        if th is not None: idxs.append(i); th_list.append(th)
                    if idxs:
                        targ=torch.tensor(np.array(th_list),dtype=dtype,device=dev)
                        pred=torch.stack([Rs,Rp,Y0,n0,r,y0,n1,L],dim=1)[idxs]
                        names=["Rs","Rp","Y0","n0","r","y0","n1","L"]
                        log_mask=np.array([n in self.teacher_log_params for n in names],bool)
                        pred_s=pred.clone(); targ_s=targ.clone()
                        for j,islog in enumerate(log_mask):
                            if islog:
                                pred_s[:,j]=torch.log(pred_s[:,j]+1e-12)
                                targ_s[:,j]=torch.log(targ_s[:,j]+1e-12)
                        loss = loss + self.teacher_weight*torch.mean((pred_s - targ_s)**2)

                # nearest-teacher soft prior
                if self.teacher_use and (self.teacher_weight > 0) and (self._teacher_CT is not None):
                    Q = torch.stack([Cb, Tb], dim=1)
                    d2 = torch.cdist(Q, self._teacher_CT).pow(2)
                    wts = torch.softmax(-d2 / (2*self.teacher_nn_sigma**2), dim=1)
                    theta_pred = torch.stack([Rs,Rp,Y0,n0,r,y0,n1,L], dim=1)
                    theta_targ = wts @ self._teacher_TH
                    names = ["Rs","Rp","Y0","n0","r","y0","n1","L"]
                    log_mask = torch.tensor([n in self.teacher_log_params for n in names], device=dev)
                    pred_s = theta_pred.clone(); targ_s = theta_targ.clone()
                    pred_s[:, log_mask] = torch.log(pred_s[:, log_mask] + 1e-12)
                    targ_s[:, log_mask] = torch.log(targ_s[:, log_mask] + 1e-12)
                    loss = loss + 0.5*self.teacher_weight*torch.mean((pred_s - targ_s)**2)

                # θ-invariance across identical (C,T) within batch
                if self.consistency_w > 0:
                    with torch.no_grad():
                        keys = torch.round(torch.stack([Cb, Tb], dim=1) / self.group_round)
                        uniq, inv = torch.unique(keys, dim=0, return_inverse=True)
                    theta_b = torch.stack([Rs,Rp,Y0,n0,r,y0,n1,L], dim=1)
                    var_sum = 0.0
                    for gid in range(uniq.size(0)):
                        mask = (inv == gid)
                        if mask.sum() > 1:
                            var_sum = var_sum + theta_b[mask].var(dim=0, unbiased=False).mean()
                    loss = loss + self.consistency_w * var_sum

                # monotonic priors (optional)
                def _mono_pen(grad, sign):  # penalize wrong sign
                    return torch.relu((-sign) * grad).mean()
                if (self.mono_wC > 0) or (self.mono_wT > 0):
                    Cb_req = Cb.clone().detach().requires_grad_(self.mono_wC > 0)
                    Tb_req = Tb.clone().detach().requires_grad_(self.mono_wT > 0)
                    Rs2,Rp2,Y02,n02,r2,y02,n12,L2 = self._net(Cb_req, Tb_req)
                    thetas = {"Rs":Rs2,"Rp":Rp2,"Y0":Y02,"n0":n02,"r":r2,"y0":y02,"n1":n12,"L":L2}
                    if self.mono_wC > 0:
                        penC = 0.0
                        for name, th in thetas.items():
                            sgn = self.mono_sign_C.get(name, 0)
                            if sgn != 0:
                                gC = torch.autograd.grad(th.sum(), Cb_req, retain_graph=True, create_graph=False)[0]
                                penC = penC + _mono_pen(gC, sgn)
                        loss = loss + self.mono_wC * penC
                    if self.mono_wT > 0:
                        penT = 0.0
                        for name, th in thetas.items():
                            sgn = self.mono_sign_T.get(name, 0)
                            if sgn != 0:
                                gT = torch.autograd.grad(th.sum(), Tb_req, retain_graph=True, create_graph=False)[0]
                                penT = penT + _mono_pen(gT, sgn)
                        loss = loss + self.mono_wT * penT

                # Arrhenius + smoothness
                if self.arrhenius_reg > 0:
                    TK = (Tb * self._xstd[1] + self._xmu[1]) + 273.15
                    invTK = 1.0 / (TK + 1e-9)
                    def lin_resid(x, y):
                        X1 = torch.stack([torch.ones_like(x), x], dim=1)
                        beta, *_ = torch.linalg.lstsq(X1, y.unsqueeze(1))
                        yhat = (X1 @ beta).squeeze(1)
                        return torch.mean((y - yhat)**2)
                    loss = loss + self.arrhenius_reg * (lin_resid(invTK, torch.log(Rs)) + lin_resid(invTK, torch.log(Rp)))
                if self.smooth_reg > 0 and sel.numel() > 1:
                    params = torch.stack([Rs,Rp,Y0,n0,r,y0,n1,L], dim=1)
                    loss = loss + self.smooth_reg*torch.mean(params.var(dim=0))

                opt.zero_grad(); loss.backward(); opt.step()
                epoch_loss += float(loss.detach())
            if self.verbose: loop.set_description(f"PINN loss {epoch_loss/nb:.4g}")
            self.loss_history.append(float(epoch_loss/nb))
        return self

    def predict(self, X):
        dev=self._dev(); dtype=self._dtype
        C,T,f = X[:,0].astype(float), X[:,1].astype(float), X[:,2].astype(float)
        Cn,Tn = self._std_CT(C,T)
        with torch.no_grad():
            C_t=torch.tensor(Cn,dtype=dtype,device=dev)
            T_t=torch.tensor(Tn,dtype=dtype,device=dev)
            w_t=torch.tensor(2*np.pi*f,dtype=dtype,device=dev)
            Rs,Rp,Y0,n0,r,y0,n1,L = self._net(C_t,T_t)
            Zc = torch_impedance_rs_zarc_tl(w_t, Rs,Rp,Y0,n0,r,y0,n1,L)
            y = torch.stack([Zc.real, -Zc.imag], dim=1).cpu().numpy()
        return self._y_inverse(y) if self.targets_norm_enabled else y

    def predict_theta(self, C_grid, T_grid):
        dev=self._dev(); dtype=self._dtype
        C_flat=np.asarray(C_grid).ravel(); T_flat=np.asarray(T_grid).ravel()
        Cn,Tn=self._std_CT(C_flat,T_flat)
        with torch.no_grad():
            C_t=torch.tensor(Cn,dtype=dtype,device=dev); T_t=torch.tensor(Tn,dtype=dtype,device=dev)
            Rs,Rp,Y0,n0,r,y0,n1,L = self._net(C_t,T_t)
            def R(x): return x.cpu().numpy().reshape(C_grid.shape)
            return {"Rs":R(Rs),"Rp":R(Rp),"Y0":R(Y0),"n0":R(n0),"r":R(r),"y0":R(y0),"n1":R(n1),"L":R(L)}

# --------------------------
# Helpers: grouped split
# --------------------------
def make_group_ids(df: pd.DataFrame, c_col="concentration_mM", t_col="temperature_C") -> np.ndarray:
    rc = float(CONFIG["split"]["group_round_C"])
    rt = float(CONFIG["split"]["group_round_T"])
    Cg = np.round(df[c_col].astype(float) / rc) * rc if rc > 0 else df[c_col].astype(float)
    Tg = np.round(df[t_col].astype(float) / rt) * rt if rt > 0 else df[t_col].astype(float)
    return (Cg.astype(str) + "|" + Tg.astype(str)).to_numpy()

# --------------------------
# Train + evaluate + save
# --------------------------
def train_pinn_and_report():
    out_dir = Path(CONFIG["paths"]["output_dir"]).resolve()
    plots_dir = out_dir / "plots"; plots_dir.mkdir(parents=True, exist_ok=True)
    plots_data_dir = out_dir / "plots_data"; plots_data_dir.mkdir(parents=True, exist_ok=True)
    for sub in ["parity","residuals","hist","learning","training","split_info"]:
        (plots_dir/sub).mkdir(exist_ok=True); (plots_data_dir/sub).mkdir(exist_ok=True)

    with open(out_dir / "config_used.json","w",encoding="utf-8") as f: json.dump(CONFIG, f, indent=2, ensure_ascii=False)

    print("[1/5] Loading dataset…")
    data = build_dataset(Path(CONFIG["paths"]["input_root"]))
    data.to_csv(out_dir / "compiled_dataset.csv", index=False)
    print(f"   -> {len(data)} rows")

    feat_cols = ["concentration_mM","temperature_C","frequency_Hz"]
    targ_cols = CONFIG["targets"]["target_columns"]
    X = data[feat_cols].values
    y = data[targ_cols].values

    print("[2/5] Grouped train/test split (80/20) by unique (C,T)")
    groups = make_group_ids(data, "concentration_mM", "temperature_C")
    gss = GroupShuffleSplit(n_splits=1, test_size=CONFIG["split"]["test_size"],
                            random_state=CONFIG["split"]["random_state"])
    train_idx, test_idx = next(gss.split(X, y, groups=groups))
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    # save split diagnostics
    df_train = data.iloc[train_idx].copy(); df_train["split"]="train"
    df_test  = data.iloc[test_idx ].copy(); df_test["split"]="test"
    pd.concat([df_train, df_test], axis=0).to_csv(plots_data_dir / "split_info" / "rows_with_split.csv", index=False)
    ct_train = set((round(c,6), round(t,6)) for c,t in zip(df_train.concentration_mM, df_train.temperature_C))
    ct_test  = set((round(c,6), round(t,6)) for c,t in zip(df_test.concentration_mM,  df_test.temperature_C))
    with open(plots_data_dir / "split_info" / "ct_overlap_check.txt","w") as fh:
        fh.write(f"Unique (C,T) train: {len(ct_train)}\nUnique (C,T) test:  {len(ct_test)}\nOverlap count: {len(ct_train.intersection(ct_test))}\n")

    teacher_df = pd.DataFrame()
    if CONFIG["pinn"]["teacher"]["use"] and CONFIG["paths"]["fitparams_root"]:
        print("[3/5] Loading teacher parameters…")
        teacher_df = load_teacher_param_grid(
            CONFIG["paths"]["fitparams_root"],
            CONFIG["pinn"]["teacher"]["filename_glob"],
            CONFIG["pinn"]["teacher"]["param_col"],
            CONFIG["pinn"]["teacher"]["value_col"],
        )
        if teacher_df.empty: print("   (no teacher files found; continuing without)")

    print("[4/5] Training PINN…")
    p = CONFIG["pinn"]["train"]; extra = p["extra"]
    pinn = EEC_PINN(
        # teacher
        teacher_df=teacher_df if (CONFIG["pinn"]["teacher"]["use"] and not teacher_df.empty) else None,
        teacher_use=CONFIG["pinn"]["teacher"]["use"],
        teacher_weight=CONFIG["pinn"]["teacher"]["weight"],
        teacher_log_params=CONFIG["pinn"]["teacher"]["log_params"],
        teacher_nn=extra["teacher_nn"], teacher_nn_sigma=extra["teacher_nn_sigma"],
        # training
        epochs=p["epochs"], lr=p["lr"], width=p["width"], depth=p["depth"],
        weight_decay=p["weight_decay"], batch_size=p["batch_size"], verbose=p["verbose"],
        arrhenius_reg=p["arrhenius_reg"], smooth_reg=p["smooth_reg"],
        seed=p["seed"], device=p.get("device","auto"),
        # priors
        consistency_w=extra["consistency_w"], group_round=extra["group_round"],
        mono_wC=extra["mono_wC"], mono_wT=extra["mono_wT"],
        mono_sign_C=extra["mono_sign_C"], mono_sign_T=extra["mono_sign_T"],
        # NEW: normalization + loss
        targets_norm_enabled=p["targets_norm_enabled"],
        targets_norm_method=p["targets_norm_method"],
        targets_norm_clip=p["targets_norm_clip"],
        loss_mode=p["loss_mode"], huber_delta=p["huber_delta"],
        loss_weights=p["loss_weights"]
    )
    t0=time.time()
    pinn.fit(X_train, y_train)
    dt=time.time()-t0

    # save training loss
    loss_csv = plots_data_dir / "training" / "pinn_loss_history.csv"
    loss_png = plots_dir / "training" / "pinn_loss_history.png"
    epochs = np.arange(1, len(pinn.loss_history)+1)
    pd.DataFrame({"epoch": epochs, "loss": pinn.loss_history}).to_csv(loss_csv, index=False)
    plt.figure(); plt.plot(epochs, pinn.loss_history); plt.xlabel("Epoch"); plt.ylabel("Training loss")
    plt.title("PINN training loss vs epoch"); plt.tight_layout(); plt.savefig(loss_png, dpi=180); _maybe_show()

    if len(pinn.loss_history) >= 10:
        k = 25
        kernel = np.ones(k)/k
        smooth = np.convolve(pinn.loss_history, kernel, mode="valid")
        plt.figure(); plt.plot(np.arange(k, k+len(smooth)), smooth)
        plt.xlabel("Epoch"); plt.ylabel("Training loss (moving avg)")
        plt.title(f"PINN training loss (window={k})")
        plt.tight_layout(); plt.savefig(plots_dir / "training" / "pinn_loss_history_smooth.png", dpi=180); _maybe_show()

    print("[5/5] Evaluating… (grouped split)")
    y_pred = pinn.predict(X_test)
    metrics = evaluate_predictions(y_test, y_pred, targ_cols)
    pd.DataFrame([metrics|{"train_seconds":dt}]).to_csv(Path(CONFIG["paths"]["output_dir"]) / "metrics_test.csv", index=False)
    print("   Test metrics:", metrics)

    pd.DataFrame({
        "concentration_mM": X_test[:,0], "temperature_C": X_test[:,1], "frequency_Hz": X_test[:,2],
        "Z_real_true": y_test[:,0], "Z_imag_neg_true": y_test[:,1],
        "Z_real_pred": y_pred[:,0], "Z_imag_neg_pred": y_pred[:,1]
    }).to_csv(Path(CONFIG["paths"]["output_dir"]) / "test_predictions_pinn.csv", index=False)

    for j, tname in enumerate(targ_cols):
        parity_plot_and_csv(y_test[:,j], y_pred[:,j], tname,
            img_path = plots_dir / "parity" / f"pinn_parity_{tname}.png",
            csv_path = plots_data_dir / "parity" / f"pinn_parity_{tname}.csv")
        residual_plot_and_csv(y_test[:,j], y_pred[:,j], tname,
            img_path = plots_dir / "residuals" / f"pinn_residual_{tname}.png",
            csv_path = plots_data_dir / "residuals" / f"pinn_residual_{tname}.csv")
        error_hist_plot_and_csv(y_test[:,j], y_pred[:,j], tname,
            img_path = plots_dir / "hist" / f"pinn_errhist_{tname}.png",
            bins = 30,
            hist_csv_path = plots_data_dir / "hist" / f"pinn_errhist_{tname}_bins.csv",
            fit_csv_path  = plots_data_dir / "hist" / f"pinn_errhist_{tname}_fit.csv")

    # Save model (also store y-normalization to enable inverse at inference)
    torch.save({"state_dict": pinn._net.state_dict(), "xmu": pinn._xmu, "xstd": pinn._xstd,
                "train_config": CONFIG["pinn"]["train"], "y_norm": pinn._y_norm,
                "teacher_used": (pinn.teacher_df is not None and pinn.teacher_use)},
               Path(CONFIG["paths"]["output_dir"]) / "pinn_model.pt")
    return pinn, (X_train, X_test, y_train, y_test)

# --------------------------
# Extrapolation + Heatmaps
# --------------------------
def evaluate_extrapolation(pinn, min_c: float, max_c: float, tag: str = None):
    out_dir = Path(CONFIG["paths"]["output_dir"]).resolve()
    base = out_dir / "extrapolation_range" / (tag or f"conc{min_c:g}_{max_c:g}")
    (base / "plots").mkdir(parents=True, exist_ok=True)
    (base / "plots_data").mkdir(parents=True, exist_ok=True)

    tf = CONFIG["data"]["train_filters"].copy()
    CONFIG["data"]["train_filters"] = {"conc_mM_min": None, "conc_mM_max": None, "temp_C_min": None, "temp_C_max": None}
    try:
        data_all = build_dataset(Path(CONFIG["paths"]["input_root"]))
    finally:
        CONFIG["data"]["train_filters"] = tf

    m = (data_all["concentration_mM"] >= float(min_c)) & (data_all["concentration_mM"] <= float(max_c))
    eval_df = data_all.loc[m].copy()

    with open(base / "selection_info.txt", "w") as fh:
        fh.write(f"Found rows in [{min_c},{max_c}] mM: {len(eval_df)}\n")

    if eval_df.empty:
        print(f"No rows found in {min_c}-{max_c} mM."); return

    feat_cols = ["concentration_mM","temperature_C","frequency_Hz"]
    targ_cols = CONFIG["targets"]["target_columns"]
    X_eval = eval_df[feat_cols].astype(float).values
    y_true = eval_df[targ_cols].astype(float).values
    y_pred = pinn.predict(X_eval)

    pd.DataFrame({
        "concentration_mM": eval_df["concentration_mM"].values,
        "temperature_C": eval_df["temperature_C"].values,
        "frequency_Hz": eval_df["frequency_Hz"].values,
        "Z_real_true": y_true[:,0],
        "Z_imag_neg_true": y_true[:,1],
        "Z_real_pred": y_pred[:,0],
        "Z_imag_neg_pred": y_pred[:,1],
    }).to_csv(base / "predictions_pinn.csv", index=False)

    metrics = evaluate_predictions(y_true, y_pred, targ_cols)
    pd.DataFrame([metrics]).to_csv(base / "metrics_pinn.csv", index=False)
    print(f"Extrapolation rows: {len(eval_df)}  ->  metrics: {metrics}")

    for j, tname in enumerate(targ_cols):
        parity_plot_and_csv(y_true[:, j], y_pred[:, j], tname,
            img_path = base / "plots" / f"pinn_parity_{tname}.png",
            csv_path = base / "plots_data" / f"pinn_parity_{tname}.csv")
        residual_plot_and_csv(y_true[:, j], y_pred[:, j], tname,
            img_path = base / "plots" / f"pinn_residual_{tname}.png",
            csv_path = base / "plots_data" / f"pinn_residual_{tname}.csv")
        error_hist_plot_and_csv(y_true[:, j], y_pred[:, j], tname,
            img_path = base / "plots" / f"pinn_errhist_{tname}.png",
            bins = 30,
            hist_csv_path = base / "plots_data" / f"pinn_errhist_{tname}_bins.csv",
            fit_csv_path  = base / "plots_data" / f"pinn_errhist_{tname}_fit.csv")

def visualize_theta_heatmaps(pinn):
    out_dir = Path(CONFIG["paths"]["output_dir"]).resolve()
    heat_dir = out_dir / "theta_heatmaps"; heat_dir.mkdir(parents=True, exist_ok=True)
    heat_csv = out_dir / "theta_heatmaps_csv"; heat_csv.mkdir(parents=True, exist_ok=True)

    H = CONFIG["heatmap"]
    C_vals = np.linspace(H["C_min"], H["C_max"], H["C_points"])
    T_vals = np.linspace(H["T_min"], H["T_max"], H["T_points"])
    Cg, Tg = np.meshgrid(C_vals, T_vals, indexing="xy")

    grids = pinn.predict_theta(Cg, Tg)

    df_long = pd.DataFrame({
        "concentration_mM": np.repeat(C_vals, T_vals.size),
        "temperature_C":    np.tile(T_vals, C_vals.size),
        "Rs":   grids["Rs"].T.ravel(), "Rp": grids["Rp"].T.ravel(),
        "Y0":   grids["Y0"].T.ravel(), "n0": grids["n0"].T.ravel(),
        "r":    grids["r"].T.ravel(),  "y0": grids["y0"].T.ravel(),
        "n1":   grids["n1"].T.ravel(), "L":  grids["L"].T.ravel(),
    })
    df_long.to_csv(heat_csv / "theta_grid_long.csv", index=False)

    def _heat(A, title, fname, cmap="viridis"):
        plt.figure(figsize=(6,4.8))
        im = plt.imshow(A, origin="lower", aspect="auto",
                        extent=[C_vals.min(), C_vals.max(), T_vals.min(), T_vals.max()], cmap=cmap)
        cb = plt.colorbar(im); cb.set_label(title)
        plt.xlabel("Concentration (mM)"); plt.ylabel("Temperature (°C)")
        plt.title(f"{title} vs (C,T)")
        plt.tight_layout(); plt.savefig(heat_dir / fname, dpi=200); _maybe_show()

    _heat(grids["Rs"], "Rs (Ω)", "theta_Rs.png")
    _heat(grids["Rp"], "Rp (Ω)", "theta_Rp.png")
    _heat(grids["Y0"], "Y0 (Ω⁻¹ sⁿ⁰)", "theta_Y0.png")
    _heat(grids["n0"], "n0 (–)", "theta_n0.png")
    _heat(grids["r"],  "r (Ω/len)", "theta_r.png")
    _heat(grids["y0"], "y0 (Ω⁻¹ sⁿ¹/len)", "theta_y0.png")
    _heat(grids["n1"], "n1 (–)", "theta_n1.png")
    _heat(grids["L"],  "L (len)", "theta_L.png")

    print(f"Saved θ(C,T) heatmaps -> {heat_dir}")
    print(f"Saved θ(C,T) grid CSV -> {heat_csv / 'theta_grid_long.csv'}")

# ==========================
# Run (train → extrapolate → heatmaps)
# ==========================
pinn, _ = train_pinn_and_report()

lo, hi = CONFIG["extrapolation"]["range_conc_mM"]
evaluate_extrapolation(pinn, lo, hi, tag=CONFIG["extrapolation"]["tag"])

visualize_theta_heatmaps(pinn)

plt.close("all"); gc.collect()


[1/5] Loading dataset…


Reading CSV files:   0%|          | 0/260 [00:00<?, ?file/s]

   -> 20800 rows
[2/5] Grouped train/test split (80/20) by unique (C,T)
[4/5] Training PINN…


  0%|          | 0/5000 [00:00<?, ?it/s]

[5/5] Evaluating… (grouped split)
   Test metrics: {'R2_Z_real': 0.999918909851281, 'MAE_Z_real': 0.27602801049562403, 'RMSE_Z_real': 0.45518183985156613, 'R2_Z_imag_neg': 0.9996785797017844, 'MAE_Z_imag_neg': 0.09053714159454425, 'RMSE_Z_imag_neg': 0.1825440852910845, 'R2_mean': 0.9997987447765326, 'MAE_mean': 0.18328257604508413, 'RMSE_mean': 0.3188629625713253}


Reading CSV files:   0%|          | 0/260 [00:00<?, ?file/s]

Extrapolation rows: 20800  ->  metrics: {'R2_Z_real': 0.9999797372742756, 'MAE_Z_real': 0.08189447136066935, 'RMSE_Z_real': 0.2098234625677146, 'R2_Z_imag_neg': 0.9999182766237888, 'MAE_Z_imag_neg': 0.037848139273767414, 'RMSE_Z_imag_neg': 0.09083007713731185, 'R2_mean': 0.9999490069490322, 'MAE_mean': 0.05987130531721838, 'RMSE_mean': 0.15032676985251323}
Saved θ(C,T) heatmaps -> /Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/Single_frequencies_whole_spectrum/PINN_report/theta_heatmaps
Saved θ(C,T) grid CSV -> /Users/hosseinostovar/Desktop/BACKUP/Data_H2SO4_NPG/data/Single_frequencies_whole_spectrum/PINN_report/theta_heatmaps_csv/theta_grid_long.csv


1113