In [None]:
# -*- coding: utf-8 -*-
"""
Pea hyperspectral classification
Patch-level training + sample-level soft voting
Supports KMeans labels with 3 or 4 classes
"""

import os
import re
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from scipy.signal import savgol_filter
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, VotingClassifier
from xgboost import XGBClassifier

try:
    from lightgbm import LGBMClassifier
    HAS_LGBM = True
except Exception:
    HAS_LGBM = False


# =========================
# Configuration
# =========================
N_CLASSES = 3  # set to 3 or 4

PATCH_CSV = os.path.join("data", "pea_patch_dataset.csv")
LABEL_XLSX = os.path.join("data", "kmeans_labels.xlsx")

OUT_DIR = os.path.join("results", f"kmeans_{N_CLASSES}class_patch_voting")
os.makedirs(OUT_DIR, exist_ok=True)

RANDOM_STATE = 42
PCA_DIM = 60

USE_SG = True
SG_WINDOW = 21
SG_POLY = 3
SG_DERIV = 1

N_SPLITS = 5
np.random.seed(RANDOM_STATE)


# =========================
# Utility functions
# =========================
def snv(X):
    mu = X.mean(axis=1, keepdims=True)
    sd = X.std(axis=1, keepdims=True) + 1e-12
    return (X - mu) / sd


def sg_derivative(X, window=21, poly=3, deriv=1):
    if window % 2 == 0:
        window += 1
    window = min(window, X.shape[1] - (1 - X.shape[1] % 2))
    if window < 5:
        window = 5

    out = np.empty_like(X)
    for i in range(X.shape[0]):
        out[i] = savgol_filter(
            X[i],
            window_length=window,
            polyorder=poly,
            deriv=deriv
        )
    return out


def load_labels(excel_path, n_classes):
    df = pd.read_excel(excel_path, engine="openpyxl")
    cols = [c.lower().strip() for c in df.columns]

    label_col = f"kmeans_k{n_classes}_roman"
    roman_map = {"i": 1, "ii": 2, "iii": 3, "iv": 4}

    if "sample_id" not in cols:
        raise ValueError("sample_id column not found")

    if label_col not in cols:
        raise ValueError(f"{label_col} column not found")

    sid_col = df.columns[cols.index("sample_id")]
    lab_col = df.columns[cols.index(label_col)]

    tmp = df[[sid_col, lab_col]].dropna().copy()
    tmp[sid_col] = tmp[sid_col].astype(int)
    tmp[lab_col] = tmp[lab_col].astype(str).str.lower().str.strip()
    tmp["y"] = tmp[lab_col].map(roman_map)
    tmp = tmp.dropna(subset=["y"])

    if tmp["y"].max() > n_classes:
        raise ValueError("label exceeds class number")

    return dict(zip(tmp[sid_col], tmp["y"].astype(int)))


def aggregate_sample_probs(df_probs):
    prob_cols = [c for c in df_probs.columns if re.fullmatch(r"p\d+", c)]
    mean_probs = df_probs.groupby("sample_id")[prob_cols].mean()
    y_pred = mean_probs.values.argmax(axis=1)

    y_true = (
        df_probs.groupby("sample_id")["y_true"]
        .agg(lambda x: np.bincount(x).argmax())
        .values
    )
    return y_true, y_pred


# =========================
# Load data
# =========================
df_patch = pd.read_csv(PATCH_CSV)

spec_cols = [c for c in df_patch.columns if c not in ["patch_id", "sample_id"]]
df_patch[spec_cols] = df_patch[spec_cols].apply(pd.to_numeric, errors="coerce")

id2y = load_labels(LABEL_XLSX, N_CLASSES)

df = df_patch[df_patch["sample_id"].astype(int).isin(id2y.keys())].copy()
df["y"] = df["sample_id"].astype(int).map(id2y)

X_raw = df[spec_cols].values
y_raw = df["y"].values
groups = df["sample_id"].values

le = LabelEncoder()
y_enc = le.fit_transform(y_raw)


# =========================
# Preprocessing
# =========================
X = snv(X_raw)

if USE_SG:
    X = sg_derivative(X, SG_WINDOW, SG_POLY, SG_DERIV)

X_std = StandardScaler().fit_transform(X)

pca_dim = min(PCA_DIM, X_std.shape[1], X_std.shape[0] - 1)
X_pca = PCA(n_components=pca_dim, random_state=RANDOM_STATE).fit_transform(X_std)


# =========================
# Models
# =========================
models = []

models.append((
    "logreg",
    LogisticRegression(
        max_iter=6000,
        solver="saga",
        penalty="elasticnet",
        l1_ratio=0.5,
        C=2.0,
        class_weight="balanced",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

models.append((
    "rf",
    RandomForestClassifier(
        n_estimators=900,
        class_weight="balanced_subsample",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

models.append((
    "et",
    ExtraTreesClassifier(
        n_estimators=1200,
        class_weight="balanced_subsample",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

models.append((
    "xgb",
    XGBClassifier(
        n_estimators=900,
        max_depth=8,
        learning_rate=0.05,
        subsample=0.85,
        colsample_bytree=0.85,
        objective="multi:softprob",
        eval_metric="mlogloss",
        n_jobs=-1,
        random_state=RANDOM_STATE
    )
))

if HAS_LGBM:
    models.append((
        "lgbm",
        LGBMClassifier(
            objective="multiclass",
            n_estimators=1200,
            num_leaves=63,
            learning_rate=0.05,
            subsample=0.85,
            colsample_bytree=0.85,
            class_weight="balanced",
            random_state=RANDOM_STATE
        )
    ))

voter = VotingClassifier(estimators=models, voting="soft", n_jobs=-1)
models.append(("voting_soft", voter))


# =========================
# Cross-validation
# =========================
gkf = GroupKFold(n_splits=N_SPLITS)

summary = []
best_score = (-1, -1)
best_name = None
best_true = None
best_pred = None
oof_rows = []

for name, clf in models:
    sample_true_all = []
    sample_pred_all = []

    for fold, (tr, te) in enumerate(gkf.split(X_pca, y_enc, groups), 1):
        clf.fit(X_pca[tr], y_enc[tr])
        proba = clf.predict_proba(X_pca[te])
        y_hat = np.argmax(proba, axis=1)

        df_fold = pd.DataFrame(proba, columns=[f"p{i}" for i in range(proba.shape[1])])
        df_fold["sample_id"] = groups[te]
        df_fold["y_true"] = y_enc[te]
        df_fold["y_pred"] = y_hat
        df_fold["model"] = name
        df_fold["fold"] = fold
        oof_rows.append(df_fold)

        y_true_s, y_pred_s = aggregate_sample_probs(df_fold)
        sample_true_all.append(y_true_s)
        sample_pred_all.append(y_pred_s)

    y_true = np.concatenate(sample_true_all)
    y_pred = np.concatenate(sample_pred_all)

    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro")

    summary.append({
        "model": name,
        "sample_acc": acc,
        "sample_f1": f1m
    })

    if (f1m, acc) > best_score:
        best_score = (f1m, acc)
        best_name = name
        best_true = y_true
        best_pred = y_pred


# =========================
# Save results
# =========================
pd.DataFrame(summary).sort_values(
    ["sample_f1", "sample_acc"], ascending=False
).to_csv(
    os.path.join(OUT_DIR, "model_comparison.csv"),
    index=False
)

oof_df = pd.concat(oof_rows, ignore_index=True)
oof_df.to_csv(
    os.path.join(OUT_DIR, "oof_patch_predictions.csv"),
    index=False
)

report = classification_report(
    le.inverse_transform(best_true),
    le.inverse_transform(best_pred),
    output_dict=True,
    digits=4
)
pd.DataFrame(report).T.to_csv(
    os.path.join(OUT_DIR, f"{best_name}_classification_report.csv")
)

cm = confusion_matrix(
    le.inverse_transform(best_true),
    le.inverse_transform(best_pred),
    labels=le.classes_
)
pd.DataFrame(cm).to_csv(
    os.path.join(OUT_DIR, f"{best_name}_confusion_matrix.csv"),
    index=False
)

print("Best model:", best_name)
print("Accuracy:", round(best_score[1], 4))
print("Macro F1:", round(best_score[0], 4))
print("Results saved to:", OUT_DIR)
