In [None]:
# EIS Single-Frequency ML Pipeline — v5 + Group-Safe Split + Flexible Z + Group-aware CV & Learning Curves
# ---------------------------------------------------------------------------------------------------------
# Changes vs previous cell:
# - RandomizedSearchCV now uses GroupKFold and fits with groups=...
# - Learning curves use GroupKFold and groups=...
# - Optional rounding for (C,T) groups: CONFIG["split"]["group_round_C"/"group_round_T"]

import os, re, json, math, random, warnings
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

from tqdm.auto import tqdm

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.dpi"] = 130

from sklearn.model_selection import GroupShuffleSplit, GroupKFold, RandomizedSearchCV, learning_curve
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, PolynomialFeatures, MinMaxScaler, FunctionTransformer
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error, make_scorer
from sklearn.multioutput import MultiOutputRegressor
from sklearn.base import clone, BaseEstimator, RegressorMixin

from sklearn.linear_model import Ridge, ElasticNet, HuberRegressor, TheilSenRegressor
from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
from sklearn.kernel_ridge import KernelRidge
from sklearn.cross_decomposition import PLSRegression

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, RationalQuadratic, DotProduct, WhiteKernel, ConstantKernel as C

import joblib
warnings.filterwarnings("ignore")

# Optional symbolic regression backends (auto-disabled if not installed)
_HAVE_PYSR = False
_HAVE_GPLEARN = False
try:
    from pysr import PySRRegressor
    _HAVE_PYSR = True
except Exception:
    try:
        from gplearn.genetic import SymbolicRegressor as GPLearnSR
        _HAVE_GPLEARN = True
    except Exception:
        pass

def arrhenius_features(X):
    """Input X columns: [concentration_mM, temperature_C, frequency_Hz]"""
    X = np.asarray(X)
    if X.ndim == 1:
        X = X.reshape(1, -1)
    C   = X[:, 0]
    T_C = X[:, 1]
    f   = X[:, 2]

    T_K     = T_C + 273.15
    inv_TK  = 1.0 / T_K
    logC    = np.log(np.clip(C, 1e-9, None))
    CxInvT  = C * inv_TK

    return np.column_stack([C, T_C, f, T_K, inv_TK, logC, CxInvT])

# =====================
# ====== CONFIG =======
# =====================
CONFIG: Dict = {
    "paths": {
        "input_root": "/Users/hosseinostovar/Desktop/BACKUP/Data/single_frequency_collected",
        "output_dir": "/Users/hosseinostovar/Desktop/BACKUP/Data/ml_outputs_singleFrequency_v5_groupsafe_flexZ_groupCV",
    },
    "data": {
        # Flexible header matching. Add your new header names here if needed.
        "column_aliases": {
            "frequency": [
                "Frequency (Hz)", "frequency (hz)", "freq (hz)", "frequency", "f (hz)"
            ],
            "z_real": [
                "Z' (Ω)", "Z' (ohm)", "z' (ω)", "z_real (Ω)", "zreal (Ω)", "re(z) (Ω)", "re(z)"
            ],
            # Already -Im(Z):
            "z_imag_neg": [
                "-Z'' (Ω)", "-Z'' (ohm)", "-z'' (ω)", "-z_imag (Ω)", "-im(z) (Ω)", "-imag (Ω)"
            ],
            # +Im(Z) (we will flip sign):
            "z_imag_pos": [
                "Z'' (Ω)", "Z'' (ohm)", "z'' (ω)", "z_imag (Ω)", "im(z) (Ω)", "imag (Ω)"
            ],
        },
        # Legacy exacts (used only if aliases fail)
        "col_frequency": "Frequency (Hz)",
        "col_z_real":    "Z' (Ω)",
        "col_z_imag_neg":"-Z'' (Ω)",

        "frequency_filter_hz": None,   # e.g., 1000.0 for single-frequency training
        "accept_nested_summary": True,
        "accept_flat_collected": True,
        "read_csv_kwargs": {"encoding": "utf-8"},

        # Optional training-domain filter
        "train_filters": {
            "conc_mM_min": 5.0, "conc_mM_max": 15.0,
            "temp_C_min": None, "temp_C_max": None
        },
    },
    "features": {
        "use_features": ["concentration_mM", "temperature_C", "frequency_Hz"],
        "poly_degrees": [2],
    },
    "feature_engineering": {"arrhenius": {"use": True}},
    "targets": {"target_columns": ["Z_real", "Z_imag_neg"]},
    "normalization": { "targets": {"enabled": False, "method": "standard"} },
    # NEW: group rounding for (C,T)
    "split": {"test_size": 0.2, "random_state": 42, "group_round_C": 0, "group_round_T": 0},
    # CV budget
    "cv": {"n_splits": 5, "n_iter": 25, "n_jobs": -1, "verbose": 1},
    "progress": {"enable_bars": True, "show_file_scan": True},
    "plots": {
        "save_all_models": True,
        "top_k_models_to_plot": 3,
        "make_learning_curve": True,
        "learning_curve_train_sizes": [0.2, 0.4, 0.6, 0.8, 1.0],
        "show_inline": False
    },
    "models": {
        "ridge": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler()), ("reg", Ridge())]),
            "param_distributions": {"reg__alpha": np.logspace(-3, 3, 50)},
        },
        "elasticnet": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler()), ("reg", ElasticNet(max_iter=8000))]),
            "param_distributions": {
                "reg__alpha": np.logspace(-4, 2, 50),
                "reg__l1_ratio": np.linspace(0.05, 0.95, 19),
            },
        },
        "pls": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler(with_mean=True, with_std=True)), ("reg", PLSRegression())]),
            "param_distributions": {"reg__n_components": list(range(1, 10))},
        },
        "poly_ridge": {
            "enabled": True,
            "pipeline_template": Pipeline([("poly", PolynomialFeatures(degree=2, include_bias=False)),
                                           ("scaler", StandardScaler()),
                                           ("reg", Ridge())]),
            "param_distributions": {"reg__alpha": np.logspace(-3, 3, 50)},
        },
        "huber": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler()), ("reg", MultiOutputRegressor(HuberRegressor()))]),
            "param_distributions": {
                "reg__estimator__epsilon": [1.1, 1.35, 1.5, 1.75],
                "reg__estimator__alpha": np.logspace(-6, -2, 5)
            },
        },
        "theilsen": {
            "enabled": True,
            "pipeline": Pipeline([
                ("scaler", StandardScaler()),
                ("reg", MultiOutputRegressor(
                    TheilSenRegressor(random_state=42, max_subpopulation=1e4)
                ))
            ]),
            "param_distributions": {
                "reg__estimator__max_subpopulation": [1e4, 2e4, 5e4],
                "reg__estimator__max_iter": [200, 500, 1000],
                "reg__estimator__tol": np.logspace(-4, -2, 3),
            },
        },
        "arrhenius_poly_ridge": {
            "enabled": True,
            "pipeline": Pipeline([
                ("fe", FunctionTransformer(arrhenius_features, validate=False, feature_names_out='one-to-one')),
                ("poly", PolynomialFeatures(degree=2, include_bias=False)),
                ("scaler", StandardScaler()),
                ("reg", Ridge())
            ]),
            "param_distributions": {"reg__alpha": np.logspace(-3, 3, 50)},
        },
        "svr_rbf": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler()), ("reg", MultiOutputRegressor(SVR(kernel="rbf")))]),
            "param_distributions": {
                "reg__estimator__C": np.logspace(-2, 3, 15),
                "reg__estimator__epsilon": np.logspace(-3, -0.3, 8),
                "reg__estimator__gamma": ["scale", "auto"] + list(np.logspace(-4, 1, 8)),
            },
        },
        "kernel_ridge_rbf": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler()), ("reg", KernelRidge(kernel="rbf"))]),
            "param_distributions": {"reg__alpha": np.logspace(-4, 1, 10), "reg__gamma": np.logspace(-4, 1, 10)},
        },
        "mlp_ann": {
            "enabled": True,
            "pipeline": Pipeline([("scaler", StandardScaler()), ("reg", MLPRegressor(max_iter=4000, random_state=42))]),
            "param_distributions": {
                "reg__hidden_layer_sizes": [(64,), (128,), (64,64), (128,64), (128,128)],
                "reg__alpha": np.logspace(-6, -2, 5),
                "reg__activation": ["relu", "tanh"],
                "reg__learning_rate_init": np.logspace(-4, -2, 3),
            },
        },
        "gpr": {
            "enabled": True,
            "pipeline": MultiOutputRegressor(GaussianProcessRegressor(
                kernel=(C(1.0, (1e-3, 1e3)) * RationalQuadratic(alpha=1.0, length_scale=10.0,
                        alpha_bounds=(1e-3, 1e3), length_scale_bounds=(1e-2, 1e3))
                        + C(1.0, (1e-3, 1e3)) * RBF(length_scale=10.0, length_scale_bounds=(1e-2, 1e3))
                        + DotProduct())
                + WhiteKernel(noise_level=1e-3, noise_level_bounds=(1e-8, 1e-1)),
                normalize_y=True, random_state=42
            )),
            "param_distributions": {
                "estimator__alpha": [1e-10, 1e-6, 1e-4, 1e-3],
                "estimator__normalize_y": [True, False],
            },
        },
        "random_forest": {
            "enabled": True,
            "pipeline": RandomForestRegressor(random_state=42),
            "param_distributions": {
                "n_estimators": list(range(200, 801, 100)),
                "max_depth": [None] + list(range(5, 41, 5)),
                "min_samples_split": [2, 5, 10],
                "min_samples_leaf": [1, 2, 4, 8],
                "max_features": ["sqrt", "log2"],
            },
        },
        "extra_trees": {
            "enabled": True,
            "pipeline": ExtraTreesRegressor(random_state=42),
            "param_distributions": {
                "n_estimators": list(range(300, 1001, 100)),
                "max_depth": [None] + list(range(5, 51, 5)),
                "min_samples_split": [2, 5, 10],
                "min_samples_leaf": [1, 2, 4, 8],
                "max_features": ["sqrt", "log2"],
                "bootstrap": [False],
            },
        },
        "gbr": {
            "enabled": True,
            "pipeline": MultiOutputRegressor(GradientBoostingRegressor(random_state=42)),
            "param_distributions": {
                "estimator__n_estimators": list(range(200, 801, 100)),
                "estimator__learning_rate": np.logspace(-3, -0.3, 10),
                "estimator__max_depth": list(range(2, 9)),
                "estimator__subsample": [0.7, 0.85, 1.0],
            },
        },
        "symbolic": {"enabled": True, "backend": "auto", "n_iter": 200},
    },
    "outputs": {"save_all_model_artifacts": True}
}

if not (_HAVE_PYSR or _HAVE_GPLEARN):
    CONFIG["models"]["symbolic"]["enabled"] = False

# =====================
# ===== UTILITIES =====
# =====================
CONC_RE = re.compile(r'(?P<val>[-+]?[0-9]*\.?[0-9]+)\s*(?P<unit>mM|M|uM|µM)', re.IGNORECASE)
TEMP_RE = re.compile(r'(?P<val>[-+]?[0-9]*\.?[0-9]+)\s*[cC]')
UNIT_SCALE = {"M": 1000.0, "mM": 1.0, "uM": 0.001, "µM": 0.001}

def parse_concentration_to_mM(text: str) -> float:
    if not isinstance(text, str): return np.nan
    m = CONC_RE.search(text)
    if not m: return np.nan
    val = float(m.group("val")); unit = m.group("unit")
    return val * UNIT_SCALE.get(unit, 1.0)

def parse_temperature_to_C(text: str) -> float:
    if not isinstance(text, str): return np.nan
    m = TEMP_RE.search(text)
    if not m: return np.nan
    return float(m.group("val"))

def _normalize(s: str) -> str:
    return re.sub(r"[\s\-\_\(\)\[\]\{\}\|'\"°Ωohmω]", "", s.strip().lower())

def _find_col_by_aliases(df: pd.DataFrame, aliases: List[str]) -> str | None:
    cols = list(df.columns)
    low_map = {c.lower(): c for c in cols}
    for a in aliases:
        if a.lower() in low_map:
            return low_map[a.lower()]
    norm_map = {_normalize(c): c for c in cols}
    for a in aliases:
        na = _normalize(a)
        if na in norm_map:
            return norm_map[na]
    return None

def resolve_eis_columns(df: pd.DataFrame) -> tuple[str, str, pd.Series]:
    """Return (freq_col_name, z_real_col_name, z_imag_neg_series) with -Im(Z) convention."""
    al = CONFIG["data"]["column_aliases"]
    f_col     = _find_col_by_aliases(df, al["frequency"])
    zr_col    = _find_col_by_aliases(df, al["z_real"])
    zi_neg    = _find_col_by_aliases(df, al["z_imag_neg"])
    zi_pos    = _find_col_by_aliases(df, al["z_imag_pos"])
    if f_col is None and CONFIG["data"]["col_frequency"] in df.columns:
        f_col = CONFIG["data"]["col_frequency"]
    if zr_col is None and CONFIG["data"]["col_z_real"] in df.columns:
        zr_col = CONFIG["data"]["col_z_real"]
    if zi_neg is None and CONFIG["data"]["col_z_imag_neg"] in df.columns:
        zi_neg = CONFIG["data"]["col_z_imag_neg"]
    if f_col is None or zr_col is None or (zi_neg is None and zi_pos is None):
        raise KeyError(f"Could not resolve columns. Found={list(df.columns)}")
    if zi_neg is not None:
        zi_series = pd.to_numeric(df[zi_neg], errors="coerce")
    else:
        zi_series = -pd.to_numeric(df[zi_pos], errors="coerce")
    return f_col, zr_col, zi_series

def discover_files(root: Path) -> List[Tuple[Path, str, str]]:
    discovered = []
    if CONFIG["data"]["accept_nested_summary"]:
        for p in root.rglob("single_frequency_summary.csv"):
            try:
                temperature_str = p.parent.name
                concentration_str = p.parent.parent.name
                discovered.append((p, concentration_str, temperature_str))
            except Exception:
                continue
    if CONFIG["data"]["accept_flat_collected"]:
        for p in root.rglob("single_frequency_*_*.csv"):
            name = p.name
            if name == "single_frequency_summary.csv" or not name.endswith(".csv"): continue
            core = name.replace("single_frequency_", "").replace(".csv", "")
            parts = core.split("_")
            if len(parts) >= 2:
                concentration_str = parts[0]; temperature_str = parts[1]
                discovered.append((p, concentration_str, temperature_str))
    unique = []; seen = set()
    for item in discovered:
        key = (str(item[0]), item[1], item[2])
        if key not in seen:
            seen.add(key); unique.append(item)
    return unique

def load_one_csv(path: Path) -> pd.DataFrame:
    kwargs = CONFIG["data"]["read_csv_kwargs"]
    return pd.read_csv(path, **kwargs)

def build_dataset(root: Path) -> pd.DataFrame:
    files = discover_files(root)
    if not files: raise FileNotFoundError(f"No CSV files found under: {root}")
    rows = []
    freq_filter = CONFIG["data"]["frequency_filter_hz"]

    it = tqdm(files, desc="Reading CSV files", unit="file") if CONFIG["progress"]["show_file_scan"] else files
    for csv_path, conc_str, temp_str in it:
        df = load_one_csv(csv_path)
        try:
            f_col, zr_col, zi_neg_series = resolve_eis_columns(df)
        except Exception:
            continue
        tmp = pd.DataFrame({
            "frequency_Hz": pd.to_numeric(df[f_col], errors="coerce"),
            "Z_real":       pd.to_numeric(df[zr_col], errors="coerce"),
            "Z_imag_neg":   pd.to_numeric(zi_neg_series, errors="coerce"),
        })
        if freq_filter is not None:
            tmp = tmp.loc[np.isclose(tmp["frequency_Hz"].astype(float), float(freq_filter))]
        if tmp.empty:
            continue
        tmp["concentration_mM"] = parse_concentration_to_mM(conc_str)
        tmp["temperature_C"] = parse_temperature_to_C(temp_str)
        tmp["source_file"] = str(csv_path)
        rows.append(tmp)

    if not rows: raise RuntimeError("Discovered files, but none had the required columns.")
    data = pd.concat(rows, ignore_index=True)

    for c in ["concentration_mM", "temperature_C", "frequency_Hz", "Z_real", "Z_imag_neg"]:
        data[c] = pd.to_numeric(data[c], errors="coerce")
    data = data.dropna(subset=["concentration_mM", "temperature_C", "frequency_Hz", "Z_real", "Z_imag_neg"])

    # Optional training-domain filter
    fcfg = CONFIG["data"]["train_filters"]
    if any(v is not None for v in fcfg.values()):
        m = pd.Series(True, index=data.index)
        if fcfg["conc_mM_min"] is not None: m &= data["concentration_mM"] >= float(fcfg["conc_mM_min"])
        if fcfg["conc_mM_max"] is not None: m &= data["concentration_mM"] <= float(fcfg["conc_mM_max"])
        if fcfg["temp_C_min"]  is not None: m &= data["temperature_C"] >= float(fcfg["temp_C_min"])
        if fcfg["temp_C_max"]  is not None: m &= data["temperature_C"] <= float(fcfg["temp_C_max"])
        data = data.loc[m].copy()

    return data

# =====================
# === Metrics & scorer
# =====================
def evaluate_predictions(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    out = {}
    r2s = []; maes = []; rmses = []
    for j, name in enumerate(["Z_real", "Z_imag_neg"]):
        r2 = r2_score(y_true[:, j], y_pred[:, j])
        mae = mean_absolute_error(y_true[:, j], y_pred[:, j])
        rmse = math.sqrt(mean_squared_error(y_true[:, j], y_pred[:, j]))
        out[f"R2_{name}"] = r2; out[f"MAE_{name}"] = mae; out[f"RMSE_{name}"] = rmse
        r2s.append(r2); maes.append(mae); rmses.append(rmse)
    out["R2_mean"] = float(np.mean(r2s))
    out["MAE_mean"] = float(np.mean(maes))
    out["RMSE_mean"] = float(np.mean(rmses))
    return out

def _rmse_mean(y_true, y_pred):
    return float(np.mean([np.sqrt(np.mean((y_true[:, j] - y_pred[:, j])**2)) for j in range(y_true.shape[1])]))

RMSE_MEAN_SCORER = make_scorer(_rmse_mean, greater_is_better=False)

def to_serializable(obj):
    if isinstance(obj, (str, int, float, bool)) or obj is None: return obj
    if isinstance(obj, (np.integer, np.floating)): return obj.item()
    if isinstance(obj, (list, tuple)): return [to_serializable(x) for x in obj]
    if isinstance(obj, dict): return {str(k): to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, np.ndarray): return obj.tolist()
    if isinstance(obj, BaseEstimator): return obj.__class__.__name__
    return repr(obj)

# ==========================
# === Target normalization
# ==========================
class YScaledRegressor(BaseEstimator, RegressorMixin):
    def __init__(self, estimator, method="standard"):
        self.estimator = estimator
        self.method = method
        self._yscaler = None
    def fit(self, X, y):
        y = np.asarray(y)
        if y.ndim == 1: y = y.reshape(-1, 1)
        if self.method == "standard":
            self._yscaler = StandardScaler()
        elif self.method == "minmax":
            self._yscaler = MinMaxScaler()
        else:
            raise ValueError(f"Unknown target normalization method: {self.method}")
        y_scaled = self._yscaler.fit_transform(y)
        self.estimator.fit(X, y_scaled)
        return self
    def predict(self, X):
        y_scaled_pred = self.estimator.predict(X)
        y_scaled_pred = np.asarray(y_scaled_pred)
        if y_scaled_pred.ndim == 1: y_scaled_pred = y_scaled_pred.reshape(-1, 1)
        return self._yscaler.inverse_transform(y_scaled_pred)

def maybe_wrap_target_normalization(estimator, param_distributions):
    norm_cfg = CONFIG.get("normalization", {}).get("targets", {})
    if not norm_cfg.get("enabled", False):
        return estimator, param_distributions
    wrapped = YScaledRegressor(estimator=estimator, method=norm_cfg.get("method", "standard"))
    remapped = {f"estimator__{k}": v for k, v in param_distributions.items()}
    return wrapped, remapped

# =====================
# ===== PLOTTING ======
# =====================
def _maybe_show():
    if CONFIG["plots"]["show_inline"]:
        plt.show()
    plt.close()

def _ensure_plot_dirs(base: Path):
    parity_dir = base / "parity"; parity_dir.mkdir(exist_ok=True, parents=True)
    resid_dir  = base / "residuals"; resid_dir.mkdir(exist_ok=True, parents=True)
    hist_dir   = base / "hist"; hist_dir.mkdir(exist_ok=True, parents=True)
    learn_dir  = base / "learning"; learn_dir.mkdir(exist_ok=True, parents=True)
    split_dir  = base / "split_info"; split_dir.mkdir(exist_ok=True, parents=True)
    return {"parity": parity_dir, "residuals": resid_dir, "hist": hist_dir, "learning": learn_dir, "split": split_dir}

def _ensure_data_dirs(base: Path):
    parity_dir = base / "parity"; parity_dir.mkdir(exist_ok=True, parents=True)
    resid_dir  = base / "residuals"; resid_dir.mkdir(exist_ok=True, parents=True)
    hist_dir   = base / "hist"; hist_dir.mkdir(exist_ok=True, parents=True)
    learn_dir  = base / "learning"; learn_dir.mkdir(exist_ok=True, parents=True)
    split_dir  = base / "split_info"; split_dir.mkdir(exist_ok=True, parents=True)
    return {"parity": parity_dir, "residuals": resid_dir, "hist": hist_dir, "learning": learn_dir, "split": split_dir}

def parity_plot_and_csv(y_true, y_pred, target_name, img_path: Path, csv_path: Path):
    pd.DataFrame({"y_true": y_true.ravel(), "y_pred": y_pred.ravel()}).to_csv(csv_path, index=False)
    plt.figure()
    plt.scatter(y_true, y_pred, s=12, alpha=0.7)
    low = min(float(np.min(y_true)), float(np.min(y_pred)))
    high = max(float(np.max(y_true)), float(np.max(y_pred)))
    plt.plot([low, high], [low, high], linestyle="--")
    plt.xlabel(f"True {target_name}")
    plt.ylabel(f"Pred {target_name}")
    plt.title(f"Parity: {target_name}")
    plt.tight_layout()
    plt.savefig(img_path, dpi=180)
    _maybe_show()

def residual_plot_and_csv(y_true, y_pred, target_name, img_path: Path, csv_path: Path):
    residual = (y_pred - y_true).ravel()
    pd.DataFrame({"y_pred": y_pred.ravel(), "residual": residual}).to_csv(csv_path, index=False)
    plt.figure()
    plt.scatter(y_pred, residual, s=12, alpha=0.7)
    plt.axhline(0, linestyle="--")
    plt.xlabel(f"Pred {target_name}")
    plt.ylabel("Residual (Pred - True)")
    plt.title(f"Residuals: {target_name}")
    plt.tight_layout()
    plt.savefig(img_path, dpi=180)
    _maybe_show()

def _gauss_pdf(x, m, s):
    if s <= 0: return np.zeros_like(x)
    return (1.0 / (s * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((x - m)/s)**2)

def error_hist_plot_and_csv(y_true, y_pred, target_name, img_path: Path, bins: int, hist_csv_path: Path, fit_csv_path: Path):
    err = (y_pred - y_true).ravel()
    counts, edges = np.histogram(err, bins=bins)
    centers = 0.5 * (edges[:-1] + edges[1:])
    bin_w = float(edges[1] - edges[0]) if len(edges) > 1 else 1.0
    mode_center = float(centers[np.argmax(counts)])

    med = float(np.median(err))
    mad = float(np.median(np.abs(err - med)))
    sigma_rob = 1.4826 * mad if mad > 0 else float(np.std(err, ddof=1))

    mean_bias = float(np.mean(err))
    sigma_std = float(np.std(err, ddof=1)) if err.size > 1 else 0.0
    skew = float(np.mean(((err - mean_bias) / (sigma_std if sigma_std>0 else 1))**3)) if sigma_std>0 else np.nan
    kurt_excess = float(np.mean(((err - mean_bias) / (sigma_std if sigma_std>0 else 1))**4) - 3.0) if sigma_std>0 else np.nan

    mu = mode_center
    sigma = sigma_rob if sigma_rob > 0 else (sigma_std if sigma_std > 0 else 1e-9)
    ci_low = mu - 1.96 * sigma; ci_high = mu + 1.96 * sigma

    pd.DataFrame({"bin_left": edges[:-1], "bin_right": edges[1:], "bin_center": centers, "count": counts}).to_csv(hist_csv_path, index=False)
    pd.DataFrame([{
        "mu_mode": mu, "sigma_robust": sigma,
        "mean_bias": mean_bias, "sigma_std": sigma_std,
        "ci95_low": ci_low, "ci95_high": ci_high,
        "skew": skew, "kurtosis_excess": kurt_excess,
        "n": int(len(err)), "bin_width": bin_w
    }]).to_csv(fit_csv_path, index=False)

    x_grid = np.linspace(edges[0], edges[-1], 600)
    gauss_counts_grid = len(err) * bin_w * _gauss_pdf(x_grid, mu, sigma)

    plt.figure()
    plt.hist(err, bins=bins)
    plt.plot(x_grid, gauss_counts_grid, linewidth=2)
    txt = (f"Gaussian (mode-aligned)\n"
           f"μ≈{mu:.4g}, σ≈{sigma:.4g}, 95%≈[{ci_low:.4g},{ci_high:.4g}]\n"
           f"mean(bias)={mean_bias:.4g}, skew={skew:.3g}, kurt_ex={kurt_excess:.3g}")
    plt.text(0.02, 0.98, txt, transform=plt.gca().transAxes, va="top", ha="left",
             bbox=dict(boxstyle="round", fc="white", ec="0.7", alpha=0.9))
    plt.xlabel("Error (Pred - True)"); plt.ylabel("Count"); plt.title(f"Error Distribution: {target_name}")
    plt.tight_layout(); plt.savefig(img_path, dpi=180); _maybe_show()

# --- Group-aware learning curve ---
def plot_learning_and_csv_grouped(estimator, X, y, groups, img_path: Path, csv_path: Path, train_sizes):
    groups = np.asarray(groups)
    unique_groups = np.unique(groups)
    n_groups = len(unique_groups)
    n_splits_eff = max(2, min(CONFIG["cv"]["n_splits"], n_groups))
    cv = GroupKFold(n_splits=n_splits_eff)
    ts, train_scores, test_scores = learning_curve(
        estimator, X, y,
        train_sizes=train_sizes,
        cv=cv,
        scoring=RMSE_MEAN_SCORER,
        groups=groups,
        n_jobs=CONFIG["cv"]["n_jobs"]
    )
    train_rmse = -np.mean(train_scores, axis=1); train_std  =  np.std(train_scores, axis=1)
    cv_rmse    = -np.mean(test_scores,  axis=1); cv_std     =  np.std(test_scores,  axis=1)

    pd.DataFrame({"train_size": ts, "train_rmse_mean": train_rmse, "train_rmse_std": train_std,
                  "cv_rmse_mean": cv_rmse, "cv_rmse_std": cv_std}).to_csv(csv_path, index=False)

    plt.figure()
    plt.plot(ts, train_rmse, marker="o", label="Train RMSE (mean)")
    plt.plot(ts, cv_rmse, marker="o", label="CV RMSE (mean)")
    plt.xlabel("Training Samples"); plt.ylabel("RMSE"); plt.legend(); plt.title("Learning Curve (Group-aware)")
    plt.tight_layout(); plt.savefig(img_path, dpi=180); _maybe_show()

# ============================
# ===== Group utilities  =====
# ============================
def make_group_ids(df: pd.DataFrame, c_col="concentration_mM", t_col="temperature_C"):
    rc = float(CONFIG["split"].get("group_round_C", 0) or 0)
    rt = float(CONFIG["split"].get("group_round_T", 0) or 0)
    C = pd.to_numeric(df[c_col], errors="coerce").values
    T = pd.to_numeric(df[t_col], errors="coerce").values
    if rc > 0: C = np.round(C/rc)*rc
    if rt > 0: T = np.round(T/rt)*rt
    return np.array([f"{c:.12g}|{t:.12g}" for c, t in zip(C, T)], dtype=object)

# ============================
# ===== Train / Predict ======
# ============================
# --- Group-aware randomized search ---
def random_search_grouped(estimator, param_distributions: dict, X, y, groups):
    groups = np.asarray(groups)
    unique_groups = np.unique(groups)
    n_groups = len(unique_groups)
    n_splits_eff = max(2, min(CONFIG["cv"]["n_splits"], n_groups))
    cv = GroupKFold(n_splits=n_splits_eff)

    rs = RandomizedSearchCV(
        estimator=estimator,
        param_distributions=param_distributions,
        n_iter=CONFIG["cv"]["n_iter"],
        scoring=RMSE_MEAN_SCORER,
        n_jobs=CONFIG["cv"]["n_jobs"],
        cv=cv,
        random_state=CONFIG["split"]["random_state"],
        verbose=CONFIG["cv"]["verbose"],
        refit=True,
        return_train_score=False,
    )
    rs.fit(X, y, groups=groups)   # <- crucial
    return rs

class SymbolicMultiOutput(BaseEstimator, RegressorMixin):
    def __init__(self, backend="auto", n_iter=200, random_state=42):
        self.backend = backend
        self.n_iter = n_iter
        self.random_state = random_state
        self.models_ = []
    def _make_single(self):
        if (self.backend == "pysr" or (self.backend == "auto" and _HAVE_PYSR)):
            return PySRRegressor(
                niterations=self.n_iter, maxsize=20,
                unary_operators=["sin", "cos", "exp", "log"],
                binary_operators=["+", "-", "*", "/"],
                loss="loss(x, y) = (x - y)^2", random_state=self.random_state,
                progress=False, verbosity=0,
            )
        elif (self.backend == "gplearn" or (self.backend == "auto" and _HAVE_GPLEARN)):
            return GPLearnSR(
                population_size=1000, generations=max(10, self.n_iter // 10),
                tournament_size=20, stopping_criteria=0.0, metric="mse",
                parsimony_coefficient=0.001, random_state=self.random_state,
                n_jobs=1, verbose=0,
            )
        else:
            raise RuntimeError("No symbolic regression backend available.")
    def fit(self, X, y):
        y = np.asarray(y)
        if y.ndim == 1: y = y.reshape(-1, 1)
        self.models_ = []
        for j in range(y.shape[1]):
            m = self._make_single()
            m.fit(X, y[:, j])
            self.models_.append(m)
        return self
    def predict(self, X):
        preds = []
        for m in self.models_:
            preds.append(np.asarray(m.predict(X)).reshape(-1, 1))
        return np.hstack(preds)

def train_all():
    np.random.seed(CONFIG["split"]["random_state"]); random.seed(CONFIG["split"]["random_state"])
    out_dir = Path(CONFIG["paths"]["output_dir"]); out_dir.mkdir(parents=True, exist_ok=True)
    plots_dir = out_dir / "plots"; plots_dir.mkdir(exist_ok=True)
    plots_data_dir = out_dir / "plots_data"; plots_data_dir.mkdir(exist_ok=True)
    models_dir = out_dir / "models"; models_dir.mkdir(exist_ok=True)

    img_dirs = _ensure_plot_dirs(plots_dir)
    csv_dirs = _ensure_data_dirs(plots_data_dir)

    with open(out_dir / "config_used.json", "w", encoding="utf-8") as f:
        json.dump(to_serializable(CONFIG), f, indent=2, ensure_ascii=False)

    print("[1/7] Building dataset …")
    data = build_dataset(Path(CONFIG["paths"]["input_root"]))
    data.to_csv(out_dir / "compiled_dataset.csv", index=False)
    print(f"   -> {len(data)} rows")

    feature_cols = CONFIG["features"]["use_features"]
    target_cols  = CONFIG["targets"]["target_columns"]
    X = data[feature_cols].values
    y = data[target_cols].values

    # --- Group-safe train/test split (by (C,T)) ---
    print("[2/7] Group-safe Train/Test split (by (C,T))")
    groups_all = make_group_ids(data, "concentration_mM", "temperature_C")
    gss = GroupShuffleSplit(n_splits=1, test_size=CONFIG["split"]["test_size"],
                            random_state=CONFIG["split"]["random_state"])
    train_idx, test_idx = next(gss.split(X, y, groups=groups_all))
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    groups_train, groups_test = groups_all[train_idx], groups_all[test_idx]

    print("[3/7] Training models (group-aware CV) …")
    tasks = []
    for name, spec in CONFIG["models"].items():
        if not spec.get("enabled", False): continue
        if name == "symbolic":
            if not (_HAVE_PYSR or _HAVE_GPLEARN):
                tqdm.write("   -> symbolic: skipped (no backend found)")
                continue
            tasks.append(("symbolic", "symbolic", None))
        elif "pipeline_template" in spec:
            for deg in CONFIG["features"]["poly_degrees"]:
                tasks.append((f"{name}_deg{deg}", name, deg))
        else:
            tasks.append((name, name, None))

    model_rows = []; all_searches = {}
    bar = tqdm(total=len(tasks), desc="Model families", unit="model") if CONFIG["progress"]["enable_bars"] else None
    best_name = None; best_score = None

    for label, name, deg in tasks:
        if name == "symbolic":
            est = SymbolicMultiOutput(backend=CONFIG["models"]["symbolic"]["backend"],
                                      n_iter=CONFIG["models"]["symbolic"]["n_iter"],
                                      random_state=CONFIG["split"]["random_state"])
            pipe = est
            dist = {}
            pipe.fit(X_train, y_train)
            class DummySearch:
                best_estimator_ = pipe
                best_params_ = {"backend": ("PySR" if _HAVE_PYSR else ("gplearn" if _HAVE_GPLEARN else "none")),
                                "n_iter": CONFIG["models"]["symbolic"]["n_iter"]}
                def predict(self, X): return pipe.predict(X)
            search = DummySearch()
        else:
            if "pipeline_template" in CONFIG["models"][name]:
                pipe = clone(CONFIG["models"][name]["pipeline_template"])
                if deg is not None: pipe.set_params(poly__degree=deg)
                dist = CONFIG["models"][name]["param_distributions"]
            else:
                pipe = CONFIG["models"][name]["pipeline"]
                dist = CONFIG["models"][name]["param_distributions"]
            pipe, dist = maybe_wrap_target_normalization(pipe, dist)
            search = random_search_grouped(pipe, dist, X_train, y_train, groups_train)

        all_searches[label] = search

        y_pred = search.predict(X_test)
        metrics = evaluate_predictions(y_test, y_pred)
        metrics["model"] = label
        metrics["best_params"] = str(getattr(search, "best_params_", {}))
        model_rows.append(metrics)

        pd.DataFrame({
            "concentration_mM": data.iloc[test_idx]["concentration_mM"].values,
            "temperature_C":    data.iloc[test_idx]["temperature_C"].values,
            "frequency_Hz":     data.iloc[test_idx]["frequency_Hz"].values,
            "Z_real_true":      y_test[:,0], "Z_imag_neg_true": y_test[:,1],
            "Z_real_pred":      y_pred[:,0], "Z_imag_neg_pred": y_pred[:,1]
        }).to_csv(out_dir / f"test_predictions_{label}.csv", index=False)

        if CONFIG["outputs"]["save_all_model_artifacts"]:
            joblib.dump(getattr(search, "best_estimator_", search), models_dir / f"{label}.joblib")

        current_score = -metrics["RMSE_mean"]
        if (best_score is None) or (current_score > best_score):
            best_score = current_score; best_name = label

        if bar: bar.update(1)
    if bar: bar.close()

    print("[4/7] Writing report")
    report_df = pd.DataFrame(model_rows).sort_values(by="RMSE_mean")
    report_df.to_csv(out_dir / "model_report.csv", index=False)

    best_search = all_searches[best_name]
    best_model = getattr(best_search, "best_estimator_", best_search)
    with open(out_dir / "feature_columns.json", "w", encoding="utf-8") as f: json.dump(feature_cols, f)
    with open(out_dir / "target_columns.json", "w", encoding="utf-8") as f: json.dump(target_cols, f)
    joblib.dump(best_model, out_dir / "best_model.joblib")
    with open(out_dir / "best_model_name.txt", "w") as f: f.write(best_name)

    print("[5/7] Plots + CSVs for models")
    if CONFIG["plots"]["save_all_models"]:
        models_to_plot = list(report_df["model"])
    else:
        models_to_plot = list(report_df["model"].head(CONFIG["plots"]["top_k_models_to_plot"]))

    for label in models_to_plot:
        est = getattr(all_searches[label], "best_estimator_", all_searches[label])
        y_pred = all_searches[label].predict(X_test)
        for j, tname in enumerate(target_cols):
            parity_plot_and_csv(y_test[:, j], y_pred[:, j], tname,
                img_path = img_dirs["parity"] / f"{label}_parity_{tname}.png",
                csv_path = csv_dirs["parity"] / f"{label}_parity_{tname}.csv")
            residual_plot_and_csv(y_test[:, j], y_pred[:, j], tname,
                img_path = img_dirs["residuals"] / f"{label}_residual_{tname}.png",
                csv_path = csv_dirs["residuals"] / f"{label}_residual_{tname}.csv")
            error_hist_plot_and_csv(y_test[:, j], y_pred[:, j], tname,
                img_path = img_dirs["hist"] / f"{label}_errhist_{tname}.png",
                bins = 30,
                hist_csv_path = csv_dirs["hist"] / f"{label}_errhist_{tname}_bins.csv",
                fit_csv_path  = csv_dirs["hist"] / f"{label}_errhist_{tname}_fit.csv")
        if CONFIG["plots"]["make_learning_curve"]:
            plot_learning_and_csv_grouped(est, X_train, y_train, groups_train,
                img_path = img_dirs["learning"] / f"{label}_learning_curve.png",
                csv_path = csv_dirs["learning"] / f"{label}_learning_curve.csv",
                train_sizes = CONFIG["plots"]["learning_curve_train_sizes"])

    print("[6/7] Saving *best* model test-set predictions")
    y_best = all_searches[best_name].predict(X_test)
    pd.DataFrame({
        "concentration_mM": data.iloc[test_idx]["concentration_mM"].values,
        "temperature_C":    data.iloc[test_idx]["temperature_C"].values,
        "frequency_Hz":     data.iloc[test_idx]["frequency_Hz"].values,
        "Z_real_true":      y_test[:,0], "Z_imag_neg_true": y_test[:,1],
        "Z_real_pred":      y_best[:,0], "Z_imag_neg_pred": y_best[:,1]
    }).to_csv(out_dir / "test_predictions_best.csv", index=False)

    # Split diagnostics
    ct_train = set(groups_train.tolist()); ct_test = set(groups_test.tolist())
    overlap = ct_train.intersection(ct_test)
    with open((plots_data_dir / "split_info" / "ct_overlap_check.txt"), "w") as fh:
        fh.write(f"unique (C,T) in train: {len(ct_train)}\n")
        fh.write(f"unique (C,T) in test:  {len(ct_test)}\n")
        fh.write(f"overlap count:         {len(overlap)}\n")

    print(f"[7/7] Done. Best model: {best_name}")
    print(f"Outputs -> {out_dir.resolve()}")

def load_model_by_label(label: str):
    out_dir = Path(CONFIG["paths"]["output_dir"]); models_dir = out_dir / "models"
    model = joblib.load(models_dir / f"{label}.joblib")
    feats = json.load(open(out_dir / "feature_columns.json"))
    return model, feats

def load_best_model():
    out_dir = Path(CONFIG["paths"]["output_dir"])
    model = joblib.load(out_dir / "best_model.joblib")
    feats = json.load(open(out_dir / "feature_columns.json"))
    return model, feats

def predict_one_with(label: str, conc_mM: float, temp_C: float, freq_Hz: float):
    model, feats = load_model_by_label(label)
    X = np.array([[conc_mM, temp_C, freq_Hz]], dtype=float)
    y = model.predict(X)
    print(f"[{label}] conc={conc_mM} mM, T={temp_C} °C, f={freq_Hz} Hz -> Z'={y[0,0]:.6g}, -Z''={y[0,1]:.6g}")
    return y[0,0], y[0,1]

def evaluate_concentration_range_all_models(min_c: float, max_c: float, tag: str = None):
    out_dir = Path(CONFIG["paths"]["output_dir"])
    model_list_path = out_dir / "models" / "model_list.txt"
    if not model_list_path.exists():
        print("No saved models found. Train first.")
        return
    with open(model_list_path) as f:
        labels = [ln.strip() for ln in f if ln.strip()]

    original_filters = CONFIG["data"]["train_filters"].copy()
    CONFIG["data"]["train_filters"] = {"conc_mM_min": None, "conc_mM_max": None, "temp_C_min": None, "temp_C_max": None}
    try:
        data_all = build_dataset(Path(CONFIG["paths"]["input_root"]))
    finally:
        CONFIG["data"]["train_filters"] = original_filters

    m = (data_all["concentration_mM"] >= float(min_c)) & (data_all["concentration_mM"] <= float(max_c))
    eval_df = data_all.loc[m].copy()
    if eval_df.empty:
        print(f"No rows found in {min_c}-{max_c} mM.")
        return

    feat_cols = CONFIG["features"]["use_features"]
    target_cols = CONFIG["targets"]["target_columns"]
    X_eval = eval_df[feat_cols].astype(float).values
    y_true = eval_df[target_cols].astype(float).values

    base = out_dir / "extrapolation_range" / (tag or f"conc{min_c:g}_{max_c:g}")
    (base / "plots").mkdir(parents=True, exist_ok=True)
    (base / "plots_data").mkdir(parents=True, exist_ok=True)
    img_dirs = _ensure_plot_dirs(base / "plots")
    csv_dirs = _ensure_data_dirs(base / "plots_data")

    all_metrics = []
    for label in labels:
        try:
            model, _ = load_model_by_label(label)
        except Exception as e:
            print(f"Skipping {label}: {e}"); continue

        y_pred = model.predict(X_eval)
        pd.DataFrame({
            "concentration_mM": eval_df["concentration_mM"].values,
            "temperature_C":    eval_df["temperature_C"].values,
            "frequency_Hz":     eval_df["frequency_Hz"].values,
            "Z_real_true":      y_true[:,0],
            "Z_imag_neg_true":  y_true[:,1],
            "Z_real_pred":      y_pred[:,0],
            "Z_imag_neg_pred":  y_pred[:,1],
        }).to_csv(base / f"predictions_{label}.csv", index=False)

        metrics = evaluate_predictions(y_true, y_pred); metrics["model"] = label
        all_metrics.append(metrics)

        for j, tname in enumerate(target_cols):
            parity_plot_and_csv(y_true[:, j], y_pred[:, j], f"{tname}",
                img_path = img_dirs["parity"] / f"{label}_parity_{tname}.png",
                csv_path = csv_dirs["parity"] / f"{label}_parity_{tname}.csv")
            residual_plot_and_csv(y_true[:, j], y_pred[:, j], f"{tname}",
                img_path = img_dirs["residuals"] / f"{label}_residual_{tname}.png",
                csv_path = csv_dirs["residuals"] / f"{label}_residual_{tname}.csv")
            error_hist_plot_and_csv(y_true[:, j], y_pred[:, j], f"{tname}",
                img_path = img_dirs["hist"] / f"{label}_errhist_{tname}.png",
                bins = 30,
                hist_csv_path = csv_dirs["hist"] / f"{label}_errhist_{tname}_bins.csv",
                fit_csv_path  = csv_dirs["hist"] / f"{label}_errhist_{tname}_fit.csv")

    if all_metrics:
        pd.DataFrame(all_metrics).sort_values(by="RMSE_mean").to_csv(base / "metrics_all_models.csv", index=False)
        print(f"Saved extrapolation results for {len(all_metrics)} models -> {base.resolve()}")

# ==============================
# ======= RUN TRAIN HERE =======
# ==============================
RUN_TRAIN = True
if RUN_TRAIN:
    train_all()
