In [None]:
"""
01_RGC_synthetic_data_generation.ipynb

Rank–Gaussian Copula (RGC) synthetic data generator for geothermal hydrogeochemistry.
"""

# =========================
# 0) Imports & basic I/O
# =========================
import os
import shutil
import zipfile
import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import norm, rankdata, ks_2samp
from scipy.spatial.distance import cdist

from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve

# Global seed for reproducibility
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)

# -------------------------
# 1) Data loading
# -------------------------

# Users can adjust this path to their local data layout.
DATA_DIR = "data"
TRAIN_CSV = os.path.join(DATA_DIR, "training_dataset.csv")
TEST_CSV  = os.path.join(DATA_DIR, "testing_dataset.csv")  # not required for RGC

# Expected columns in training_dataset.csv:
FEATURES = [
    "pH",
    "EC (microS/cm)",
    "K (mg/l)",
    "Na (mg/l)",
    "Boron (mg/l)",
    "SiO2 (mg/l)",
    "Cl (mg/l)",
    "Reservoir temperature (°C)",
]

train = pd.read_csv(TRAIN_CSV)
df = train[FEATURES].copy()  # only the training set is used for RGC fitting

# -------------------------
# 2) Configuration
# -------------------------
CONFIG = {
    "n_runs": 10,               # number of stochastic RGC runs
    "n_samples": 1000,          # number of synthetic samples to generate
    "temp_conditional": True,   # condition on reservoir temperature quantile bins
    "n_temp_bins": 4,           # number of temperature quantile bins
    # columns to be treated with log-marginals (heavy-tailed behaviour)
    "log_cols": [
        "EC (microS/cm)",
        "K (mg/l)",
        "Na (mg/l)",
        "Boron (mg/l)",
        "SiO2 (mg/l)",
    ],
    # threshold for detecting zero-inflated variables
    "zero_mass_threshold": 0.02,
    # small regularisation added to the copula covariance
    "cov_eps": 1e-3,
    # histogram settings
    "bins": 15,
    # mild winsorization for selected columns
    "winsor_cols": ["Na (mg/l)", "Cl (mg/l)"],
    "winsor_lower_q": 0.0,      # no lower-side truncation
    "winsor_upper_q": 0.99,     # clip upper 1%
    # output directory
    "out_dir": "rgc_outputs",
    # subsampling of synthetic samples for energy distance
    "energy_max_synth": 300,
    # whether to create a ZIP archive of all outputs
    "make_zip": True,
}

# Prepare output directories
if os.path.exists(CONFIG["out_dir"]):
    shutil.rmtree(CONFIG["out_dir"])
os.makedirs(CONFIG["out_dir"], exist_ok=True)
figs_dir = os.path.join(CONFIG["out_dir"], "figs")
os.makedirs(figs_dir, exist_ok=True)

# =========================
# 3) Helper functions
# =========================
def _to_log(x, col):
    """Apply log1p transform to selected heavy-tailed columns."""
    return np.log1p(x.clip(lower=0.0)) if col in CONFIG["log_cols"] else x


def _from_log(x, col):
    """Invert log1p transform for selected columns."""
    return np.expm1(x) if col in CONFIG["log_cols"] else x


def _empirical_ppf(u, data_sorted):
    """
    Empirical quantile interpolation for u in (0, 1),
    using sorted sample values.
    """
    n = len(data_sorted)
    x = np.interp(u, (np.arange(1, n + 1) - 0.5) / n, data_sorted)
    return x


def fit_marginal(x, col, zero_thr=0.02):
    """
    Fit an empirical marginal distribution with optional zero-inflation
    and log-transformed tail.
    """
    x = x.dropna().values.astype(float)
    x_pos = x.copy()

    zero_mass = 0.0
    zero_inflated = False
    if (x == 0).mean() > zero_thr:
        zero_inflated = True
        zero_mass = (x == 0).mean()
        x_pos = x[x > 0]
        if len(x_pos) == 0:
            # degenerate case: all zeros
            zero_inflated = True
            zero_mass = 1.0
            x_pos = np.array([0.0])

    # log-transform the positive part
    x_tr = _to_log(pd.Series(x_pos), col).values
    x_tr_sorted = np.sort(x_tr)

    return {
        "col": col,
        "zero_inflated": zero_inflated,
        "zero_mass": float(zero_mass),
        "x_sorted_tr": x_tr_sorted,
        "min": float(np.min(x)),
        "max": float(np.max(x)),
    }


def sample_marginal(u, marg):
    """
    Map u ~ Unif(0,1) to samples from an empirical marginal
    with optional zero-inflation and log tail.
    """
    if marg["zero_inflated"]:
        p0 = marg["zero_mass"]
        is_zero = u < p0
        # rescale u to the positive part
        u = (u - p0) / max(1e-8, (1 - p0))
        u = np.clip(u, 1e-8, 1 - 1e-8)
    else:
        is_zero = np.zeros_like(u, dtype=bool)
        u = np.clip(u, 1e-8, 1 - 1e-8)

    x_tr = _empirical_ppf(u, marg["x_sorted_tr"])
    x = _from_log(x_tr, marg["col"])
    # clip to original min–max range
    x = np.clip(x, marg["min"], marg["max"])
    x[is_zero] = 0.0
    return x


def gaussianize(df_block, marginals=None):
    """
    Transform each column of df_block to rank-based Gaussian scores:
    x -> ranks -> U(0,1) -> Z ~ N(0,1).
    """
    Z = np.zeros_like(df_block.values, dtype=float)
    for j, col in enumerate(df_block.columns):
        x = df_block[col].values.astype(float)
        r = rankdata(x, method="average")
        u = (r - 0.5) / len(x)
        u = np.clip(u, 1e-6, 1 - 1e-6)
        Z[:, j] = norm.ppf(u)
    return Z


def fit_copula(df_block, marginals, eps=1e-3):
    """
    Fit a Gaussian copula in latent Z-space with small ridge regularisation.
    """
    Z = gaussianize(df_block, marginals)
    mu = Z.mean(axis=0)
    C = np.cov(Z.T)
    # regularise covariance to ensure positive definiteness
    C = (1 - eps) * C + eps * np.eye(C.shape[0])
    return mu, C


def sample_copula(n, mu, C, marginals, random_state=None):
    """
    Sample from a Gaussian copula and map back to the original feature space
    using pre-fitted empirical marginals.
    """
    rng = np.random.default_rng(random_state)
    Z = rng.multivariate_normal(mean=mu, cov=C, size=n)
    U = norm.cdf(Z)
    Xs = np.zeros_like(U)
    for j, marg in enumerate(marginals):
        Xs[:, j] = sample_marginal(U[:, j], marg)
    return Xs


def ks_report(real_df, syn_df):
    """Compute two-sample KS statistics and p-values for all columns."""
    rows = []
    for col in real_df.columns:
        stat, p = ks_2samp(real_df[col], syn_df[col])
        rows.append({"feature": col, "ks_stat": stat, "p_value": p})
    return pd.DataFrame(rows)


def plot_hist_pair(real, syn, col, outpath, bins=15):
    """Overlaid histograms and KDEs for a given variable."""
    plt.figure(figsize=(10, 6))
    sns.histplot(
        real[col],
        bins=bins,
        stat="density",
        color="#4C72B0",
        alpha=0.45,
        edgecolor="black",
        label="Real",
    )
    sns.histplot(
        syn[col],
        bins=bins,
        stat="density",
        color="#DD8452",
        alpha=0.45,
        edgecolor="black",
        label="Synthetic",
    )
    try:
        sns.kdeplot(real[col], color="#4C72B0", lw=2)
        sns.kdeplot(syn[col], color="#DD8452", lw=2)
    except Exception:
        pass
    plt.title(f"{col} — Real vs Synthetic")
    plt.legend()
    plt.tight_layout()
    plt.savefig(outpath, dpi=160)
    plt.close()


def corr_heatmap(df_in, title, outpath):
    """Correlation heatmap for the given dataframe."""
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        df_in.corr(),
        vmin=-1,
        vmax=1,
        cmap="coolwarm",
        square=True,
        cbar_kws={"shrink": 0.8},
    )
    plt.title(title)
    plt.tight_layout()
    plt.savefig(outpath, dpi=160)
    plt.close()


def scatter_real_vs_synth(real, syn, xcol, ycol, outpath):
    """Bivariate scatter comparison for real vs synthetic samples."""
    plt.figure(figsize=(8, 6))
    plt.scatter(
        real[xcol],
        real[ycol],
        s=40,
        alpha=0.7,
        label="Real",
        edgecolors="none",
    )
    plt.scatter(
        syn[xcol],
        syn[ycol],
        s=40,
        alpha=0.7,
        label="Synthetic",
        edgecolors="none",
    )
    plt.xlabel(xcol)
    plt.ylabel(ycol)
    plt.legend()
    plt.title(f"{ycol} vs {xcol} — Real vs Synthetic")
    plt.tight_layout()
    plt.savefig(outpath, dpi=160)
    plt.close()


def energy_distance(X, Y, n_syn_max=300, random_state=0):
    """
    Multivariate energy distance between X and Y in Euclidean space.
    Used here as a scalar diagnostic without p-values.
    """
    rng = np.random.default_rng(random_state)
    if Y.shape[0] > n_syn_max:
        idx = rng.choice(Y.shape[0], size=n_syn_max, replace=False)
        Y = Y[idx]
    d_xy = cdist(X, Y)
    d_xx = cdist(X, X)
    d_yy = cdist(Y, Y)
    return 2 * d_xy.mean() - d_xx.mean() - d_yy.mean()


# =========================
# 4) Conditional (temperature-bin) generator + winsorization
# =========================
def apply_winsorization(df_in):
    """Apply mild upper-tail winsorization to selected columns."""
    df_w = df_in.copy()
    for c in CONFIG["winsor_cols"]:
        if c in df_w.columns:
            q_low = CONFIG["winsor_lower_q"]
            q_high = CONFIG["winsor_upper_q"]
            lo = df_w[c].quantile(q_low)
            hi = df_w[c].quantile(q_high)
            df_w[c] = df_w[c].clip(lower=lo, upper=hi)
    return df_w


def fit_and_sample(df_real, n_samples, seed=0):
    """
    Fit the RGC model (optionally conditioned on temperature bins)
    and draw synthetic samples.
    """
    rng = np.random.default_rng(seed)
    df_proc = apply_winsorization(df_real)

    if CONFIG["temp_conditional"]:
        # Build quantile-based temperature bins from the real data
        q = np.linspace(0, 1, CONFIG["n_temp_bins"] + 1)
        edges = df_real["Reservoir temperature (°C)"].quantile(q).values
        edges[0] = -np.inf
        edges[-1] = np.inf

        counts = []
        blocks = []
        for i in range(len(edges) - 1):
            mask = (df_real["Reservoir temperature (°C)"] > edges[i]) & (
                df_real["Reservoir temperature (°C)"] <= edges[i + 1]
            )
            block_proc = df_proc.loc[mask, FEATURES]
            if len(block_proc) == 0:
                continue
            blocks.append(block_proc)
            counts.append(len(block_proc))
        counts = np.array(counts, dtype=float)
        weights = counts / counts.sum()

        # Fit marginals and copula in each bin, then sample
        pieces = []
        for i, block in enumerate(blocks):
            margs = [
                fit_marginal(block[c], c, CONFIG["zero_mass_threshold"])
                for c in FEATURES
            ]
            mu, C = fit_copula(block, margs, eps=CONFIG["cov_eps"])
            k = int(np.round(n_samples * weights[i]))
            if k < 1:
                continue
            X = sample_copula(
                k,
                mu,
                C,
                margs,
                random_state=rng.integers(1e9),
            )
            syn_block = pd.DataFrame(X, columns=FEATURES)
            pieces.append(syn_block)
        syn = pd.concat(pieces, ignore_index=True)
        # Trim to the desired total number of samples
        syn = syn.sample(
            n=min(n_samples, len(syn)),
            random_state=rng.integers(1e9),
        ).reset_index(drop=True)
    else:
        # Single global RGC model
        df_proc_full = df_proc[FEATURES]
        margs = [
            fit_marginal(df_proc_full[c], c, CONFIG["zero_mass_threshold"])
            for c in FEATURES
        ]
        mu, C = fit_copula(df_proc_full, margs, eps=CONFIG["cov_eps"])
        X = sample_copula(n_samples, mu, C, margs, random_state=seed)
        syn = pd.DataFrame(X, columns=FEATURES)

    # Clip all synthetic variables to the observed real min–max range
    for c in FEATURES:
        mn, mx = df_real[c].min(), df_real[c].max()
        syn[c] = syn[c].clip(lower=mn, upper=mx)
    return syn


# =========================
# 5) Multiple runs and selection of the best RGC realisation
# =========================
summary_rows = []
best = {"avg_ks": np.inf, "run": None, "syn": None, "ks_df": None}

for run in range(1, CONFIG["n_runs"] + 1):
    syn = fit_and_sample(df, CONFIG["n_samples"], seed=GLOBAL_SEED + run * 37)
    ks_df = ks_report(df, syn)
    avg_ks = ks_df["ks_stat"].mean()
    summary_rows.append({"run": run, "avg_ks": avg_ks})
    print(f"Run {run:02d}: avg KS = {avg_ks:.4f}")
    if avg_ks < best["avg_ks"]:
        best = {
            "avg_ks": avg_ks,
            "run": run,
            "syn": syn.copy(),
            "ks_df": ks_df.copy(),
        }

summary = pd.DataFrame(summary_rows).sort_values("avg_ks")

print(
    "\nBest run (before pH shift):",
    best["run"],
    "| Average KS:",
    round(best["avg_ks"], 4),
)

# =========================
# 6) pH mean-shift correction (post-hoc)
# =========================
real_pH_mean = df["pH"].mean()
synth_pH_mean = best["syn"]["pH"].mean()
delta_pH = real_pH_mean - synth_pH_mean

best["syn"]["pH"] = best["syn"]["pH"] + delta_pH
# Clip pH to observed range
pH_min, pH_max = df["pH"].min(), df["pH"].max()
best["syn"]["pH"] = best["syn"]["pH"].clip(lower=pH_min, upper=pH_max)

# Recompute KS statistics and average KS after pH correction
best["ks_df"] = ks_report(df, best["syn"])
best["avg_ks"] = best["ks_df"]["ks_stat"].mean()

# Update summary table for the best run
summary.loc[summary["run"] == best["run"], "avg_ks"] = best["avg_ks"]

print(
    "Best run (after pH shift):",
    best["run"],
    "| New average KS:",
    round(best["avg_ks"], 4),
)

# Save CSV outputs
summary.to_csv(
    os.path.join(CONFIG["out_dir"], "summary_runs.csv"),
    index=False,
)
best["ks_df"].to_csv(
    os.path.join(CONFIG["out_dir"], "ks_results.csv"),
    index=False,
)
best["syn"].to_csv(
    os.path.join(CONFIG["out_dir"], "synthetic_rgc_train_only.csv"),
    index=False,
)

# =========================
# 7) Plots (histograms, correlations, scatter diagrams)
# =========================

# Histograms + KDEs
for c in FEATURES:
    outpath = os.path.join(figs_dir, f"{c.replace('/', '_')}.png")
    plot_hist_pair(df, best["syn"], c, outpath, bins=CONFIG["bins"])

# Correlation heatmaps
corr_heatmap(
    df,
    "Real Correlation",
    os.path.join(figs_dir, "corr_real.png"),
)
corr_heatmap(
    best["syn"],
    "Synthetic Correlation",
    os.path.join(figs_dir, "corr_synth.png"),
)

# Bivariate scatter plots to assess joint structure
scatter_real_vs_synth(
    df,
    best["syn"],
    xcol="Cl (mg/l)",
    ycol="Na (mg/l)",
    outpath=os.path.join(figs_dir, "scatter_Na_vs_Cl.png"),
)
scatter_real_vs_synth(
    df,
    best["syn"],
    xcol="Reservoir temperature (°C)",
    ycol="EC (microS/cm)",
    outpath=os.path.join(figs_dir, "scatter_EC_vs_T.png"),
)
scatter_real_vs_synth(
    df,
    best["syn"],
    xcol="Reservoir temperature (°C)",
    ycol="Boron (mg/l)",
    outpath=os.path.join(figs_dir, "scatter_B_vs_T.png"),
)

# =========================
# 8) Multivariate diagnostics in latent Z-space
# =========================

df_real = df[FEATURES].copy()
df_syn = best["syn"][FEATURES].copy()

# --- 8.1 Joint rank-Gaussianised (Z) space ---
df_all = pd.concat([df_real, df_syn], axis=0, ignore_index=True)
Z_all = gaussianize(df_all[FEATURES])  # common rank transform

n_real = len(df_real)
Z_real = Z_all[:n_real, :]
Z_syn = Z_all[n_real:, :]

# First- and second-order moments
mu_real = Z_real.mean(axis=0)
mu_syn = Z_syn.mean(axis=0)
var_real = Z_real.var(axis=0)
var_syn = Z_syn.var(axis=0)

C_real = np.cov(Z_real.T)
C_syn = np.cov(Z_syn.T)
fro_norm_diff = np.linalg.norm(C_syn - C_real, ord="fro")

latent_stats = pd.DataFrame(
    {
        "feature": FEATURES,
        "mu_real": mu_real,
        "mu_synth": mu_syn,
        "var_real": var_real,
        "var_synth": var_syn,
    }
)
latent_stats_path = os.path.join(
    CONFIG["out_dir"], "latent_stats_Z_space.csv"
)
latent_stats.to_csv(latent_stats_path, index=False)

# PCA scatter in Z-space
pca = PCA(n_components=2, random_state=GLOBAL_SEED)
pca.fit(Z_real)
Zr_2d = pca.transform(Z_real)
Zs_2d = pca.transform(Z_syn)

plt.figure(figsize=(8, 6))
plt.scatter(Zr_2d[:, 0], Zr_2d[:, 1], alpha=0.7, label="Real", s=40)
plt.scatter(Zs_2d[:, 0], Zs_2d[:, 1], alpha=0.7, label="Synthetic", s=40)
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("PCA in latent Z space — Real vs Synthetic")
plt.legend()
plt.tight_layout()
pca_fig_path = os.path.join(figs_dir, "pca_Z_real_vs_synth.png")
plt.savefig(pca_fig_path, dpi=160)
plt.close()

# --- 8.2 Energy distance in 7-D Z-space ---
ed = energy_distance(
    Z_real,
    Z_syn,
    n_syn_max=CONFIG["energy_max_synth"],
    random_state=GLOBAL_SEED,
)

# --- 8.3 Classifier two-sample test (Random Forest, ROC-AUC) ---
X_clf = np.vstack([Z_real, Z_syn])
y_clf = np.concatenate(
    [np.zeros(Z_real.shape[0]), np.ones(Z_syn.shape[0])]
)

X_tr, X_te, y_tr, y_te = train_test_split(
    X_clf,
    y_clf,
    test_size=0.3,
    random_state=GLOBAL_SEED,
    stratify=y_clf,
)

rf = RandomForestClassifier(
    n_estimators=200,
    max_depth=None,
    random_state=GLOBAL_SEED,
    n_jobs=-1,
    class_weight="balanced",
)
rf.fit(X_tr, y_tr)
proba_te = rf.predict_proba(X_te)[:, 1]
auc = roc_auc_score(y_te, proba_te)

fpr, tpr, _ = roc_curve(y_te, proba_te)

plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, label=f"RF AUC = {auc:.3f}")
plt.plot([0, 1], [0, 1], "k--", label="Random (AUC=0.5)")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Real vs Synthetic classifier — ROC curve")
plt.legend()
plt.tight_layout()
roc_fig_path = os.path.join(figs_dir, "roc_real_vs_synth_RF.png")
plt.savefig(roc_fig_path, dpi=160)
plt.close()

two_sample_metrics = pd.DataFrame(
    [
        {
            "space": "latent_Z_joint_7D",
            "energy_distance": ed,
            "cov_frobenius_diff": fro_norm_diff,
            "rf_roc_auc": auc,
            "n_real": Z_real.shape[0],
            "n_synth": Z_syn.shape[0],
        }
    ]
)
two_sample_path = os.path.join(
    CONFIG["out_dir"], "two_sample_multivariate_metrics.csv"
)
two_sample_metrics.to_csv(two_sample_path, index=False)

print("\n--- Multivariate diagnostics ---")
print(f"Energy distance (Z space): {ed:.4f}")
print(f"Frobenius norm ||C_syn - C_real||: {fro_norm_diff:.4f}")
print(f"Random Forest ROC-AUC (real vs synth, Z space): {auc:.4f}")

# =========================
# 9) Optional: archive outputs as ZIP
# =========================
if CONFIG["make_zip"]:
    zip_name = (
        f"rgc_outputs_"
        f"{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}.zip"
    )
    zip_path = os.path.join(CONFIG["out_dir"], "..", zip_name)

    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        # root CSV files
        csv_files = [
            "summary_runs.csv",
            "ks_results.csv",
            "synthetic_rgc_train_only.csv",
            "latent_stats_Z_space.csv",
            "two_sample_multivariate_metrics.csv",
        ]
        for fn in csv_files:
            full_path = os.path.join(CONFIG["out_dir"], fn)
            if os.path.exists(full_path):
                zf.write(full_path, arcname=fn)

        # figures
        if os.path.exists(figs_dir):
            for fn in os.listdir(figs_dir):
                zf.write(
                    os.path.join(figs_dir, fn),
                    arcname=os.path.join("figs", fn),
                )

    print(f"\nOutput archive created: {zip_path}")

print(f"All outputs have been written to '{CONFIG['out_dir']}'.")