In [None]:
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter
from scipy.io import loadmat

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

from xgboost import XGBClassifier
import matplotlib.pyplot as plt


# ===================== paths (relative) =====================

ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT / "data"
EXCEL_PATH = DATA_DIR / "pea_cluster.xlsx"
MAT_DIR = DATA_DIR / "mat"
SHEET_NAME_EXPLICIT = None


# ===================== helpers =====================
def pick_moisture_sheet(excel_path: Path, explicit: str | None = None) -> str:
    if explicit:
        return explicit
    xls = pd.ExcelFile(excel_path)
    keys = ["moisture", "moist", "water content"]
    for s in xls.sheet_names:
        low = str(s).strip().lower()
        if any(k in low for k in keys):
            return s
    raise ValueError("No moisture-related sheet name found. Set SHEET_NAME_EXPLICIT explicitly.")


def build_labels_from_sheet(excel_path: Path, sheet_name: str):
    df = pd.read_excel(excel_path, sheet_name=sheet_name)
    class_cols = [c for c in df.columns[4:] if not df[c].isna().all()]

    sample_to_class = {}
    for idx, col in enumerate(class_cols, start=1):
        ids = pd.to_numeric(df[col], errors="coerce").dropna().astype(int).tolist()
        for sid in ids:
            sample_to_class[int(sid)] = idx

    sample_ids = sorted(sample_to_class.keys())
    labels_num = [sample_to_class[sid] for sid in sample_ids]
    return sample_ids, labels_num, [str(c) for c in class_cols]


def load_sample_mean_spectrum(mat_dir: Path, sid: int, use_mask: bool = True) -> np.ndarray:
    samp_path = mat_dir / f"prep_sample{sid}.mat"
    if not samp_path.exists():
        raise FileNotFoundError(f"Missing file: {samp_path}")

    d = loadmat(samp_path)
    varname = f"prep_sample{sid}"
    if varname not in d:
        candidates = [k for k in d.keys() if not k.startswith("__")]
        if len(candidates) == 1:
            varname = candidates[0]
        else:
            raise KeyError(f"Variable '{varname}' not found in {samp_path.name}. Candidates: {candidates}")
    arr = np.asarray(d[varname])

    if arr.ndim == 3:
        H, W, B = arr.shape
        pixxbands = arr.reshape(-1, B)

        if use_mask:
            mpath = mat_dir / f"prep_mask{sid}.mat"
            if mpath.exists():
                dm = loadmat(mpath)
                mvar = f"prep_mask{sid}"
                if mvar not in dm:
                    cm = [k for k in dm.keys() if not k.startswith("__")]
                    mvar = cm[0]
                m = np.asarray(dm[mvar]).reshape(-1).astype(bool)
                if m.size == pixxbands.shape[0]:
                    pixxbands = pixxbands[m]

    elif arr.ndim == 2:
        n0, n1 = arr.shape
        pixxbands = arr.T if n0 < n1 else arr

        if use_mask:
            mpath = mat_dir / f"prep_mask{sid}.mat"
            if mpath.exists():
                dm = loadmat(mpath)
                mvar = f"prep_mask{sid}"
                if mvar not in dm:
                    cm = [k for k in dm.keys() if not k.startswith("__")]
                    mvar = cm[0]
                m = np.asarray(dm[mvar]).reshape(-1).astype(bool)
                if m.size == pixxbands.shape[0]:
                    pixxbands = pixxbands[m]
    else:
        raise ValueError(f"Unsupported array shape for sample {sid}: {arr.shape}")

    pixxbands = pixxbands[~np.isnan(pixxbands).any(axis=1)]
    if pixxbands.size == 0:
        raise ValueError(f"Empty valid pixels for sample {sid}. Check data/mask.")

    return pixxbands.mean(axis=0)


def ensure_min_splits(y: np.ndarray, max_splits: int = 5) -> int:
    cnt = Counter(y)
    min_count = min(cnt.values())
    n_splits = min(max_splits, min_count)
    if n_splits < 2:
        raise ValueError(f"Too few samples in the smallest class ({min_count}) for StratifiedKFold.")
    return n_splits


# ===================== main =====================
if __name__ == "__main__":
    if not EXCEL_PATH.exists():
        raise FileNotFoundError(f"Excel not found: {EXCEL_PATH}")
    if not MAT_DIR.exists():
        raise FileNotFoundError(f"MAT directory not found: {MAT_DIR}")

    sheet = pick_moisture_sheet(EXCEL_PATH, explicit=SHEET_NAME_EXPLICIT)
    print(f"[INFO] Sheet: {sheet}")

    sample_ids, labels_num, class_cols = build_labels_from_sheet(EXCEL_PATH, sheet)

    X_list, y_list, kept_ids = [], [], []
    for sid, lab_num in zip(sample_ids, labels_num):
        try:
            spec = load_sample_mean_spectrum(MAT_DIR, int(sid), use_mask=True)
            X_list.append(spec)
            y_list.append(int(lab_num))
            kept_ids.append(int(sid))
        except Exception as e:
            print(f"[WARN] Skip sample {sid}: {e}")

    if len(X_list) == 0:
        raise RuntimeError("No valid samples were loaded. Check MAT files and masks.")

    X = np.vstack(X_list).astype(np.float32)
    y_raw = np.asarray(y_list, dtype=int)
    kept_ids = np.asarray(kept_ids, dtype=int)
    print(f"[INFO] Samples: {len(y_raw)} | X shape: {X.shape}")

    le = LabelEncoder()
    y = le.fit_transform(y_raw)

    class_names = []
    for k in range(len(le.classes_)):
        original_num = le.classes_[k]
        class_names.append(str(class_cols[original_num - 1]))

    n_splits = ensure_min_splits(y, max_splits=5)
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    clf = Pipeline([
        ("scaler", StandardScaler(with_mean=True, with_std=True)),
        ("pca", PCA(n_components=0.99, random_state=42)),
        ("xgb", XGBClassifier(
            n_estimators=400,
            max_depth=6,
            learning_rate=0.05,
            subsample=0.85,
            colsample_bytree=0.85,
            reg_lambda=1.0,
            reg_alpha=0.0,
            objective="multi:softprob",
            eval_metric="mlogloss",
            random_state=42,
            n_jobs=-1,
        )),
    ])

    scores = cross_val_score(clf, X, y, cv=cv, scoring="accuracy")
    print(f"[CV] n_splits: {n_splits}")
    print(f"[CV] fold accuracy: {np.round(scores, 4)}")
    print(f"[CV] mean ± std: {scores.mean():.4f} ± {scores.std():.4f}")

    y_pred_cv = cross_val_predict(clf, X, y, cv=cv, method="predict")

    labels_sorted = sorted(np.unique(y))
    target_names = [class_names[i] for i in labels_sorted]

    cm_cv = confusion_matrix(y, y_pred_cv, labels=labels_sorted)
    print("\n[CV] Confusion matrix:\n", cm_cv)
    print("\n[CV] Report:\n", classification_report(y, y_pred_cv, target_names=target_names, digits=4))

    disp = ConfusionMatrixDisplay(confusion_matrix=cm_cv, display_labels=target_names)
    disp.plot(values_format="d", xticks_rotation=45)
    plt.tight_layout()

    OUT_DIR = ROOT / "outputs" / "moisture"
    OUT_DIR.mkdir(parents=True, exist_ok=True)

    plt.savefig(OUT_DIR / "confusion_matrix.png", dpi=220)
    plt.show()

    pd.DataFrame({
        "sample_id": kept_ids,
        "true_idx": y,
        "cv_pred_idx": y_pred_cv,
        "true_label": [target_names[labels_sorted.index(i)] for i in y],
        "cv_pred_label": [target_names[labels_sorted.index(i)] for i in y_pred_cv],
    }).to_csv(OUT_DIR / "moisture_cv_predictions.csv", index=False)

    pd.DataFrame({"cv_fold_acc": scores}).to_csv(OUT_DIR / "moisture_cv_scores.csv", index=False)

    pd.DataFrame(
        cm_cv,
        index=[f"True_{n}" for n in target_names],
        columns=[f"Pred_{n}" for n in target_names],
    ).to_csv(OUT_DIR / "moisture_cv_confusion_matrix.csv", index=True)

    print(f"\n[SAVE] {OUT_DIR}")
