In [None]:
# -*- coding: utf-8 -*-
import os
import numpy as np
import pandas as pd
from collections import Counter
from scipy.io import loadmat

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

from xgboost import XGBClassifier
import matplotlib.pyplot as plt


# ===================== Paths (relative) =====================
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
REPO_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, "..")) if os.path.basename(SCRIPT_DIR) == "scripts" else SCRIPT_DIR

EXCEL_PATH = os.path.join(REPO_DIR, "data", "pea_cluster.xlsx")
MAT_DIR = os.path.join(REPO_DIR, "data", "mats")
SHEET_NAME_EXPLICIT = None  # e.g., "RVA"; if None, auto-pick by name contains "rva"


# ===================== Helpers =====================
def pick_rva_sheet(excel_path, explicit=None):
    if explicit:
        return explicit
    x = pd.ExcelFile(excel_path)
    for s in x.sheet_names:
        if "rva" in str(s).lower():
            return s
    raise ValueError("No sheet name contains 'rva'. Set SHEET_NAME_EXPLICIT.")


def build_labels_from_sheet(excel_path, sheet_name):
    """
    Sheet format:
      - Columns from the 5th onward: each column corresponds to one class (e.g., I/II/III/...),
        cells contain sample IDs.
    Returns:
      - sample_ids: List[int]
      - labels_num: List[int] (1,2,3,... mapped by column order)
      - class_cols: List[str] original class column names
    """
    df = pd.read_excel(excel_path, sheet_name=sheet_name)
    class_cols = [c for c in df.columns[4:] if not df[c].isna().all()]

    sample_to_class = {}
    for idx, col in enumerate(class_cols, start=1):
        ids = df[col].dropna()
        ids = ids.apply(lambda v: int(str(v).strip().split(".")[0])).tolist()
        for sid in ids:
            sample_to_class[sid] = idx

    sample_ids = sorted(sample_to_class.keys())
    labels_num = [sample_to_class[sid] for sid in sample_ids]
    return sample_ids, labels_num, class_cols


def load_sample_mean_spectrum(mat_dir, sid, use_mask=True):
    """
    Load prep_sample{sid}.mat (variable name usually prep_sample{sid})
    Optionally apply prep_mask{sid}.mat.
    Supports (H,W,B) / (Npix,B) / (B,Npix).
    Returns 1D mean spectrum (B,).
    """
    samp_fname = os.path.join(mat_dir, f"prep_sample{sid}.mat")
    if not os.path.exists(samp_fname):
        raise FileNotFoundError(f"Missing file: {samp_fname}")

    d = loadmat(samp_fname)
    varname = f"prep_sample{sid}"
    if varname not in d:
        candidates = [k for k in d.keys() if not k.startswith("__")]
        if len(candidates) == 1:
            varname = candidates[0]
        else:
            raise KeyError(f"Variable '{varname}' not found in {samp_fname}. Candidates: {candidates}")

    arr = np.array(d[varname])

    if arr.ndim == 3:
        H, W, B = arr.shape
        pixxbands = arr.reshape(-1, B)

        if use_mask:
            mpath = os.path.join(mat_dir, f"prep_mask{sid}.mat")
            if os.path.exists(mpath):
                dm = loadmat(mpath)
                mvar = f"prep_mask{sid}"
                if mvar not in dm:
                    cm = [k for k in dm.keys() if not k.startswith("__")]
                    mvar = cm[0]
                m = np.array(dm[mvar])

                mask_vec = m.reshape(-1).astype(bool)
                if mask_vec.size == pixxbands.shape[0]:
                    pixxbands = pixxbands[mask_vec]

    elif arr.ndim == 2:
        n0, n1 = arr.shape
        pixxbands = arr.T if n0 < n1 else arr

        if use_mask:
            mpath = os.path.join(mat_dir, f"prep_mask{sid}.mat")
            if os.path.exists(mpath):
                dm = loadmat(mpath)
                mvar = f"prep_mask{sid}"
                if mvar not in dm:
                    cm = [k for k in dm.keys() if not k.startswith("__")]
                    mvar = cm[0]
                m = np.array(dm[mvar]).reshape(-1)
                if m.size == pixxbands.shape[0]:
                    pixxbands = pixxbands[m.astype(bool)]
    else:
        raise ValueError(f"Unsupported array shape for sample {sid}: {arr.shape}")

    pixxbands = pixxbands[~np.isnan(pixxbands).any(axis=1)]
    if pixxbands.size == 0:
        raise ValueError(f"Empty valid pixels for sample {sid}. Check data/mask.")

    return pixxbands.mean(axis=0)


# ===================== Main =====================
if __name__ == "__main__":
    if not os.path.exists(EXCEL_PATH):
        raise FileNotFoundError(f"EXCEL_PATH not found: {EXCEL_PATH}")
    if not os.path.isdir(MAT_DIR):
        raise FileNotFoundError(f"MAT_DIR not found: {MAT_DIR}")

    sheet = pick_rva_sheet(EXCEL_PATH, explicit=SHEET_NAME_EXPLICIT)
    print(f"[INFO] Sheet: {sheet}")

    sample_ids, labels_num, class_cols = build_labels_from_sheet(EXCEL_PATH, sheet)

    X_list, y_num_list, kept_ids = [], [], []
    for sid, lab_num in zip(sample_ids, labels_num):
        try:
            spec = load_sample_mean_spectrum(MAT_DIR, sid, use_mask=True)
            X_list.append(spec)
            y_num_list.append(lab_num)
            kept_ids.append(sid)
        except Exception as e:
            print(f"[WARN] Skip sample {sid}: {repr(e)}")

    X = np.vstack(X_list).astype(np.float32)
    y_raw_num = np.array(y_num_list, dtype=int)
    kept_ids = np.array(kept_ids, dtype=int)
    print(f"[INFO] Samples used: {len(y_raw_num)} | X shape: {X.shape}")

    le = LabelEncoder()
    y = le.fit_transform(y_raw_num)

    class_names_used = []
    for k in range(len(le.classes_)):
        original_num = le.classes_[k]  # 1,2,3,...
        class_names_used.append(str(class_cols[original_num - 1]))

    cnt = Counter(y)
    min_class_count = min(cnt.values())
    n_splits = min(5, min_class_count)
    if n_splits < 2:
        raise ValueError(f"Min class count = {min_class_count}. Stratified CV cannot run. Merge classes or add samples.")

    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    clf = Pipeline([
        ("scaler", StandardScaler(with_mean=True, with_std=True)),
        ("pca",    PCA(n_components=0.99, random_state=42)),
        ("xgb",    XGBClassifier(
            n_estimators=400,
            max_depth=6,
            learning_rate=0.05,
            subsample=0.85,
            colsample_bytree=0.85,
            reg_lambda=1.0,
            reg_alpha=0.0,
            objective="multi:softprob",
            eval_metric="mlogloss",
            random_state=42,
            n_jobs=-1
        ))
    ])

    scores = cross_val_score(clf, X, y, cv=cv, scoring="accuracy")
    print(f"[CV] n_splits: {n_splits}")
    print(f"[CV] fold accuracy: {np.round(scores, 4)}")
    print(f"[CV] mean accuracy: {scores.mean():.4f} Â± {scores.std():.4f}")

    clf.fit(X, y)
    y_pred_train = clf.predict(X)

    labels_sorted = sorted(np.unique(y))
    target_names = [class_names_used[i] for i in labels_sorted]

    cm = confusion_matrix(y, y_pred_train, labels=labels_sorted)
    print("\n[TRAIN] Confusion matrix (rows=true, cols=pred):\n", cm)
    print("\n[TRAIN] Report:\n", classification_report(y, y_pred_train, target_names=target_names))

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
    disp.plot(values_format="d", xticks_rotation=45)
    plt.tight_layout()
    plt.show()

    y_pred_cv = cross_val_predict(clf, X, y, cv=cv, method="predict")

    def idx_to_name(idx):
        return target_names[labels_sorted.index(idx)]

    out_dir = os.path.join(REPO_DIR, "outputs", "rva")
    os.makedirs(out_dir, exist_ok=True)

    out_pred_path = os.path.join(out_dir, "rva_cv_predictions.csv")
    out_scores_path = os.path.join(out_dir, "rva_cv_scores.csv")
    out_cm_path = os.path.join(out_dir, "rva_train_confusion_matrix.csv")

    pd.DataFrame({
        "sample_id": kept_ids,
        "true_label_idx": y,
        "cv_pred_label_idx": y_pred_cv,
        "true_label": [idx_to_name(k) for k in y],
        "cv_pred_label": [idx_to_name(k) for k in y_pred_cv]
    }).to_csv(out_pred_path, index=False)

    pd.DataFrame({"cv_fold_acc": scores}).to_csv(out_scores_path, index=False)

    pd.DataFrame(
        cm,
        index=[f"True_{n}" for n in target_names],
        columns=[f"Pred_{n}" for n in target_names]
    ).to_csv(out_cm_path)

    print(f"\n[SAVE] {out_pred_path}")
    print(f"[SAVE] {out_scores_path}")
    print(f"[SAVE] {out_cm_path}")
