In [None]:
# -*- coding: utf-8 -*-
"""
Patch-level training with sample-level soft voting (3 classes)

- Input:
  Patch CSV with columns:
    ['patch_id', 'sample_id', <spectral features ...>]

  Excel file with class columns:
    I / II / III -> sample_id lists

- Preprocessing:
  SNV (+ optional Savitzky-Golay first derivative)
  StandardScaler -> PCA

- Models:
  Logistic Regression (ElasticNet)
  Random Forest
  Extra Trees
  XGBoost
  Optional LightGBM
  Soft Voting Ensemble

- Evaluation:
  GroupKFold (by sample_id, no leakage)

- Output:
  Patch-level and sample-level Accuracy / Macro-F1
  Confusion matrices and CSV reports
"""

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from scipy.signal import savgol_filter
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, VotingClassifier
from xgboost import XGBClassifier

# Optional LightGBM
try:
    from lightgbm import LGBMClassifier
    HAS_LGBM = True
except Exception:
    HAS_LGBM = False


# ===================== Paths (relative) =====================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

PATCH_CSV = os.path.join(BASE_DIR, "data", "pea_patch_dataset.csv")
EXCEL_3CLS = os.path.join(BASE_DIR, "data", "pea_cluster.xlsx")
OUT_DIR = os.path.join(BASE_DIR, "results_patch_sample_voting")

os.makedirs(OUT_DIR, exist_ok=True)


# ===================== Global parameters =====================
RANDOM_STATE = 42
PCA_DIM = 60
USE_SG = True
SG_WINDOW = 21
SG_POLY = 3
SG_DERIV = 1
N_SPLITS = 5

np.random.seed(RANDOM_STATE)


# ===================== Utilities =====================
def snv(X):
    mu = X.mean(axis=1, keepdims=True)
    sd = X.std(axis=1, keepdims=True) + 1e-12
    return (X - mu) / sd


def sg_derivative(X, window=21, poly=3, deriv=1):
    if window % 2 == 0:
        window += 1
    window = min(window, X.shape[1] - (1 - X.shape[1] % 2))
    if window < 5:
        window = 5

    out = np.empty_like(X)
    for i in range(X.shape[0]):
        out[i] = savgol_filter(
            X[i],
            window_length=window,
            polyorder=poly,
            deriv=deriv
        )
    return out


def load_labels_from_excel(excel_path):
    df = pd.read_excel(excel_path, engine="openpyxl")
    df.columns = [str(c).strip().upper() for c in df.columns]

    cols = [c for c in ["I", "II", "III"] if any(col.startswith(c) for col in df.columns)]
    assert len(cols) == 3, "Expected 3 class columns (I / II / III)"

    id_to_label = {}
    for lab, cname in enumerate(cols, start=1):
        col = next(c for c in df.columns if c.startswith(cname))
        ids = pd.to_numeric(df[col], errors="coerce").dropna().astype(int).tolist()
        for sid in ids:
            id_to_label[sid] = lab

    return id_to_label


def aggregate_sample_probs(df_probs):
    true_by_sample = (
        df_probs.groupby("sample_id")["y_true"]
        .agg(lambda x: np.bincount(x).argmax())
    )

    prob_cols = [c for c in df_probs.columns if c.startswith("p")]
    mean_probs = df_probs.groupby("sample_id")[prob_cols].mean()
    pred_by_sample = mean_probs.values.argmax(axis=1)

    return true_by_sample.values, pred_by_sample


# ===================== Load data =====================
print("[INFO] Loading patch-level CSV...")
df = pd.read_csv(PATCH_CSV)

non_spec = ["patch_id", "sample_id"]
spec_cols = [c for c in df.columns if c not in non_spec]
df[spec_cols] = df[spec_cols].astype(float)

print(f"[INFO] Patches: {len(df)}, Bands: {len(spec_cols)}, Samples: {df['sample_id'].nunique()}")

label_map = load_labels_from_excel(EXCEL_3CLS)

df = df[df["sample_id"].astype(int).isin(label_map.keys())].copy()
df["y"] = df["sample_id"].astype(int).map(label_map)

X_raw = df[spec_cols].values
y_raw = df["y"].values
groups = df["sample_id"].values

print(f"[INFO] After filtering: patches={len(df)}, class distribution={np.bincount(y_raw)[1:]}")

le = LabelEncoder()
y_enc = le.fit_transform(y_raw)


# ===================== Preprocessing =====================
print("[INFO] Preprocessing...")
X = snv(X_raw)

if USE_SG:
    X = sg_derivative(X, SG_WINDOW, SG_POLY, SG_DERIV)

X_std = StandardScaler().fit_transform(X)

pca = PCA(n_components=min(PCA_DIM, X_std.shape[1]), random_state=RANDOM_STATE)
X_pca = pca.fit_transform(X_std)

print(f"[INFO] PCA shape: {X_pca.shape}")


# ===================== Models =====================
estimators = []

estimators.append((
    "logreg",
    LogisticRegression(
        max_iter=5000,
        solver="saga",
        penalty="elasticnet",
        l1_ratio=0.5,
        C=2.0,
        class_weight="balanced",
        n_jobs=-1
    )
))

estimators.append((
    "rf",
    RandomForestClassifier(
        n_estimators=800,
        class_weight="balanced_subsample",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

estimators.append((
    "et",
    ExtraTreesClassifier(
        n_estimators=1000,
        class_weight="balanced_subsample",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

estimators.append((
    "xgb",
    XGBClassifier(
        n_estimators=800,
        max_depth=8,
        learning_rate=0.05,
        subsample=0.85,
        colsample_bytree=0.85,
        objective="multi:softprob",
        eval_metric="mlogloss",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

if HAS_LGBM:
    estimators.append((
        "lgbm",
        LGBMClassifier(
            objective="multiclass",
            n_estimators=1000,
            learning_rate=0.05,
            class_weight="balanced",
            random_state=RANDOM_STATE
        )
    ))

voter = VotingClassifier(
    estimators=estimators,
    voting="soft",
    n_jobs=-1
)

models = estimators + [("voting_soft", voter)]


# ===================== GroupKFold evaluation =====================
gkf = GroupKFold(n_splits=N_SPLITS)

summary = []
best_name = None
best_scores = (-1, -1)
best_true = None
best_pred = None

for name, clf in models:
    print(f"\n[MODEL] {name}")
    patch_true, patch_pred = [], []
    sample_true, sample_pred = [], []

    for fold, (tr, te) in enumerate(gkf.split(X_pca, y_enc, groups), 1):
        clf.fit(X_pca[tr], y_enc[tr])

        proba = clf.predict_proba(X_pca[te])
        yhat = proba.argmax(axis=1)

        patch_true.append(y_enc[te])
        patch_pred.append(yhat)

        df_fold = pd.DataFrame(
            proba, columns=[f"p{i}" for i in range(proba.shape[1])]
        )
        df_fold["sample_id"] = groups[te]
        df_fold["y_true"] = y_enc[te]

        yt_s, yp_s = aggregate_sample_probs(df_fold)
        sample_true.append(yt_s)
        sample_pred.append(yp_s)

        acc = accuracy_score(yt_s, yp_s)
        f1m = f1_score(yt_s, yp_s, average="macro")
        print(f"  Fold {fold}: sample-ACC={acc:.4f}, Macro-F1={f1m:.4f}")

    ypt = np.concatenate(patch_true)
    ypp = np.concatenate(patch_pred)
    acc_p = accuracy_score(ypt, ypp)
    f1_p = f1_score(ypt, ypp, average="macro")

    yst = np.concatenate(sample_true)
    ysp = np.concatenate(sample_pred)
    acc_s = accuracy_score(yst, ysp)
    f1_s = f1_score(yst, ysp, average="macro")

    summary.append({
        "model": name,
        "patch_acc": acc_p,
        "patch_macro_f1": f1_p,
        "sample_acc": acc_s,
        "sample_macro_f1": f1_s
    })

    if (f1_s, acc_s) > best_scores:
        best_scores = (f1_s, acc_s)
        best_name = name
        best_true = yst
        best_pred = ysp


# ===================== Save results =====================
summary_df = pd.DataFrame(summary).sort_values(
    ["sample_macro_f1", "sample_acc"], ascending=False
)
summary_df.to_csv(
    os.path.join(OUT_DIR, "model_comparison.csv"),
    index=False
)

y_true_lab = le.inverse_transform(best_true)
y_pred_lab = le.inverse_transform(best_pred)

report_df = pd.DataFrame(
    classification_report(y_true_lab, y_pred_lab, digits=4, output_dict=True)
).T

cm = confusion_matrix(y_true_lab, y_pred_lab, labels=le.classes_)
cm_df = pd.DataFrame(
    cm,
    index=[f"True_{c}" for c in le.classes_],
    columns=[f"Pred_{c}" for c in le.classes_]
)

report_df.to_csv(os.path.join(OUT_DIR, f"{best_name}_classification_report.csv"))
cm_df.to_csv(os.path.join(OUT_DIR, f"{best_name}_confusion_matrix.csv"))

print(f"\n[BEST MODEL] {best_name}")
print(f"Sample ACC={best_scores[1]:.4f}, Macro-F1={best_scores[0]:.4f}")
print(f"Results saved to: {OUT_DIR}")
