In [None]:
# -*- coding: utf-8 -*-
"""
Predict protein content based on patch-level hyperspectral data,
with 5-fold grouped cross-validation performed at the sample level.

- Input: pea_patch_dataset.csv (patch_id, sample_id, quality attributes, spectra)
- Features: spectral columns ending with 'nm'
- Target: protein content (automatically detected, English only)
- Preprocessing: row-wise SNV + standardization
- Model: PLSRegression
  - Outer CV: GroupKFold (5 folds, grouped by sample_id)
  - Inner CV: GroupKFold (3 folds) for selecting n_components
- Evaluation: patch-level predictions aggregated to sample-level
              (median by sample_id), reporting R² and RMSE
"""

import re
import numpy as np
import pandas as pd
from pathlib import Path

from sklearn.model_selection import GroupKFold
from sklearn.cross_decomposition import PLSRegression
from sklearn.preprocessing import StandardScaler, FunctionTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import r2_score, root_mean_squared_error

# ---------------- Paths (repo-relative) ----------------
HERE = Path(__file__).resolve().parent
PATCH_CSV = HERE / "pea_patch_dataset.csv"

RANDOM_STATE = 42

# ---------------- Helper functions ----------------
def snv(X):
    """Standard Normal Variate (SNV): row-wise centering and scaling."""
    X = np.asarray(X, dtype=float)
    mu = X.mean(axis=1, keepdims=True)
    sd = X.std(axis=1, keepdims=True)
    sd[sd == 0] = 1.0
    return (X - mu) / sd

def root_mse(y_true, y_pred):
    """Root mean squared error."""
    return root_mean_squared_error(y_true, y_pred)

def find_protein_col(columns):
    """
    Automatically identify the protein content column (English only).

    Examples:
    - protein
    - protein content
    - crude protein
    """
    pats = [
        r'^protein(\s*content)?\s*(\(\s*%\s*\)|%)?$',
        r'^crude\s*protein(\s*content)?\s*(\(\s*%\s*\)|%)?$'
    ]
    cols = [str(c).strip() for c in columns]

    for c in cols:
        cl = c.lower()
        for p in pats:
            if re.match(p, cl):
                return c

    for c in cols:
        if "protein" in c.lower():
            return c

    raise KeyError("Protein content column not found. Please check column names.")

def sample_level_metrics(y_true_patch, y_pred_patch, sample_ids_patch):
    """Aggregate patch-level predictions to sample level (median) and compute metrics."""
    df = pd.DataFrame({
        "sample_id": sample_ids_patch,
        "y_true":    y_true_patch,
        "y_pred":    y_pred_patch
    })
    agg = df.groupby("sample_id").agg({"y_true": "median", "y_pred": "median"})
    r2 = r2_score(agg["y_true"], agg["y_pred"])
    rmse = root_mse(agg["y_true"], agg["y_pred"])
    return agg["y_true"].values, agg["y_pred"].values, r2, rmse

# ---------------- Load patch-level data ----------------
df = pd.read_csv(PATCH_CSV)

# Spectral columns (ending with 'nm'), sorted by wavelength
spec_cols = [c for c in df.columns if str(c).strip().endswith("nm")]

def _w2f(c):
    try:
        return float(str(c).replace("nm", ""))
    except:
        return np.inf

spec_cols = sorted(spec_cols, key=_w2f)

# Target column: protein
protein_col = find_protein_col(df.columns)
df = df.loc[df[protein_col].notna()].copy()

X_all = df[spec_cols].to_numpy(dtype=float)
y_all = df[protein_col].astype(float).to_numpy()
groups_all = df["sample_id"].astype(str).to_numpy()

print(
    f"Number of patches: {len(df)} | "
    f"Number of samples: {df['sample_id'].nunique()} | "
    f"Number of bands: {len(spec_cols)}"
)
print(f"Protein target column: {protein_col}")

# ---------------- Outer 5-fold GroupKFold ----------------
outer_cv = GroupKFold(n_splits=5)
fold_metrics = []
cv_preds_sample, cv_truth_sample = [], []

fold_id = 1
for tr_idx, te_idx in outer_cv.split(X_all, y_all, groups=groups_all):
    X_tr, X_te = X_all[tr_idx], X_all[te_idx]
    y_tr, y_te = y_all[tr_idx], y_all[te_idx]
    g_tr, g_te = groups_all[tr_idx], groups_all[te_idx]

    # ---- Inner CV: select optimal n_components ----
    max_comp = int(min(20, X_tr.shape[1], len(np.unique(g_tr)) - 1, X_tr.shape[0] - 1))
    if max_comp < 1:
        raise RuntimeError("Insufficient training data for PLS fitting.")

    inner_cv = GroupKFold(n_splits=3)

    def build_pls(a):
        return Pipeline([
            ("snv",    FunctionTransformer(snv, validate=False)),
            ("scaler", StandardScaler(with_mean=True, with_std=True)),
            ("pls",    PLSRegression(n_components=a))
        ])

    best_a, best_score = None, -np.inf
    for a in range(1, max_comp + 1):
        mdl = build_pls(a)
        scores = []
        for tri, vai in inner_cv.split(X_tr, y_tr, groups=g_tr):
            mdl.fit(X_tr[tri], y_tr[tri])
            pred_va = mdl.predict(X_tr[vai]).ravel()
            _, _, r2_va, _ = sample_level_metrics(y_tr[vai], pred_va, g_tr[vai])
            scores.append(r2_va)

        mean_score = float(np.mean(scores))
        if mean_score > best_score:
            best_score, best_a = mean_score, a

    # ---- Final training and evaluation ----
    final_model = build_pls(best_a)
    final_model.fit(X_tr, y_tr)
    y_pred_te_patch = final_model.predict(X_te).ravel()

    y_true_s, y_pred_s, r2_s, rmse_s = sample_level_metrics(y_te, y_pred_te_patch, g_te)
    fold_metrics.append((r2_s, rmse_s, best_a))

    cv_preds_sample.extend(list(y_pred_s))
    cv_truth_sample.extend(list(y_true_s))

    print(
        f"[Fold {fold_id}] best n_components = {best_a:2d} | "
        f"sample-level R²={r2_s:.4f} RMSE={rmse_s:.4f}"
    )
    fold_id += 1

# ---------------- Summary ----------------
r2_folds   = np.array([m[0] for m in fold_metrics])
rmse_folds = np.array([m[1] for m in fold_metrics])
best_as    = np.array([m[2] for m in fold_metrics])

print("\n===== 5-fold CV (sample-level) results =====")
print(f"R²  : mean={r2_folds.mean():.4f} ± {r2_folds.std():.4f}")
print(f"RMSE: mean={rmse_folds.mean():.4f} ± {rmse_folds.std():.4f}")
print(f"Selected n_components per fold: {best_as.tolist()}")

r2_overall   = r2_score(cv_truth_sample, cv_preds_sample)
rmse_overall = root_mse(cv_truth_sample, cv_preds_sample)
print(f"\nOverall CV on samples -> R²={r2_overall:.4f} RMSE={rmse_overall:.4f}")
