## Training Models

### Reading Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd

dataset_path = 'path'
df = pd.read_csv(dataset_path)
df

### Gradient Boosting Mode

In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.5.0-py3-none-any.whl.metadata (17 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.10.1-py3-none-any.whl.metadata (11 kB)
Downloading optuna-4.5.0-py3-none-any.whl (400 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/400.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m400.9/400.9 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.10.1-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, optuna
Successfully installed colorlog-6.10.1 optuna-4.5.0


In [None]:
!pip install optuna-integration[xgboost]

Collecting optuna-integration[xgboost]
  Downloading optuna_integration-4.5.0-py3-none-any.whl.metadata (12 kB)
Downloading optuna_integration-4.5.0-py3-none-any.whl (99 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/99.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.1/99.1 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: optuna-integration
Successfully installed optuna-integration-4.5.0


#### XGBoost

In [None]:
# ===============================================================
# XGBoost (GPU / A100) Multiclass — rock-solid (no category errors)
# - Colab-ready; saves under RESULTS/XGBoost_GPU/<timestamp>
# - Hybrid categorical encoding:
#     * ≤ MAX_ONEHOT_CARD uniques -> one-hot
#     * > MAX_ONEHOT_CARD -> hashing (K bins)  ✅ no unseen-category errors
# - No pandas StringDtype; pure numeric matrices
# - Datetime expansion, numeric downcast, stratified 70/15/15
# - Optuna (TPE) + EarlyStopping; device='cuda' (XGBoost ≥ 3.1)
# - Full metrics & artifacts; inference helper (labels 1..10)
# ===============================================================

import os, re, json, warnings, datetime, hashlib
from pathlib import Path
import numpy as np
import pandas as pd
import pandas.api.types as pdt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

import xgboost as xgb
from xgboost import callback as xgb_cb
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
from optuna.integration import XGBoostPruningCallback

import joblib
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# ------------------------ COLAB DRIVE SETUP ------------------------
drive_folder = 'path'
try:
    from google.colab import drive as _colab_drive
    if not Path("/content/drive").exists() or not any(Path("/content/drive").iterdir()):
        _colab_drive.mount("/content/drive")
except Exception:
    pass

# Subfolder for this run (timestamped)
RUN_NAME = f"XGBoost_GPU_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
OUTPUT_DIR = Path(drive_folder) / "XGBoost_GPU" / RUN_NAME
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)

print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
np.random.seed(SEED)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# >>>>>>>>>>>> SET THIS TO YOUR FILE <<<<<<<<<<<<
DATA_CSV = "path"

TARGET_ASCII   = "target_risk_class"
TARGET_PERSIAN = "risk_score"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# GPU training controls
N_TRIALS            = 30
N_ESTIMATORS_LIMIT  = 3000
EARLY_STOP_ROUNDS   = 300
NTHREAD             = os.cpu_count()
TIMEOUT_SEC         = None

# Encoding thresholds
MAX_ONEHOT_CARD     = 128   # ≤ this -> one-hot
HASH_BINS           = 32    # > MAX_ONEHOT_CARD -> hashing to K bins

REFIT_ON_TRAINVAL   = True

# Tokens (for normalizing text before encoding)
MISSING_TOKEN = "Missing"

# ------------------------ UTILITIES ------------------------
def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)   # Persian ZWNJ/RTL marks
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]):
            df[c] = pd.to_numeric(df[c], downcast="float")
        else:
            df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try:
                    df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception:
                    pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols:
        df.drop(columns=dt_cols, inplace=True)
    return df

def resolve_target_name(df):
    for k in [TARGET_ASCII, TARGET_PERSIAN, "risk_class", "label", "target", "class", "y"]:
        if k in df.columns: return k
    norm = {re.sub(r"[_\-\s]+"," ", str(c)).strip().lower(): c for c in df.columns}
    for k in ["target_risk_class", "risk_score", "risk class", "label", "target", "class", "y"]:
        kk = re.sub(r"[_\-\s]+"," ", k).strip().lower()
        if kk in norm: return norm[kk]
    return None

# Stable 32-bit hash for strings
def stable_hash32(val):
    if pd.isna(val):
        val = MISSING_TOKEN
    if not isinstance(val, str):
        val = str(val)
    h = hashlib.sha1(val.encode("utf-8")).hexdigest()[:8]
    return int(h, 16)

# ------------------------ PREPROCESSOR (one-hot + hashing) ------------------------
class TabularPreprocessor:
    """
    Learns from TRAIN only:
      - numeric columns + medians
      - low-card categorical columns and their categories (for fixed one-hot columns)
      - high-card categorical columns and number of hash bins
      - final feature_names_ (fixed order)
    Transform returns a purely numeric float32 DataFrame aligned to feature_names_.
    """
    def __init__(self, max_onehot=MAX_ONEHOT_CARD, hash_bins=HASH_BINS):
        self.max_onehot = int(max_onehot)
        self.hash_bins  = int(hash_bins)
        self.num_cols_  = []
        self.cat_low_   = []            # columns one-hot encoded
        self.cat_low_categories_ = {}   # col -> sorted categories list (incl MISSING)
        self.cat_high_  = []            # columns hashed
        self.num_median_ = {}           # col -> median
        self.feature_names_ = []
        self.fitted_ = False

    def _prep_base(self, df):
        d = df.copy()
        # normalize text-ish columns
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        # booleans -> int8
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)

        # Decide numeric vs categorical by dtype after base prep
        # Anything not numeric -> treat as categorical (object or category)
        num_cols = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        cat_cols = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]

        # Low vs high cardinality split (TRAIN only)
        self.cat_low_  = []
        self.cat_high_ = []
        for c in cat_cols:
            s = d[c].astype(object).where(pd.notna(d[c]), MISSING_TOKEN)
            k = int(pd.Series(s).nunique(dropna=False))
            if k <= self.max_onehot:
                self.cat_low_.append(c)
                # freeze categories for one-hot (sorted for stability)
                cats = sorted(pd.Series(s).unique().tolist())
                if MISSING_TOKEN not in cats:
                    cats.append(MISSING_TOKEN)
                self.cat_low_categories_[c] = cats
            else:
                self.cat_high_.append(c)

        self.num_cols_ = num_cols

        # Numeric medians
        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())

        # Build feature_names_ by simulating transform on TRAIN
        feats = []

        # 1) numeric (kept as-is)
        feats.extend(self.num_cols_)

        # 2) one-hot columns -> fixed dummy names
        for c in self.cat_low_:
            cats = self.cat_low_categories_[c]
            feats.extend([f"{c}__oh__{v}" for v in cats])

        # 3) hashed columns -> fixed bin names
        for c in self.cat_high_:
            feats.extend([f"{c}__h{b}" for b in range(self.hash_bins)])

        self.feature_names_ = feats
        self.fitted_ = True
        return self

    def _encode_onehot(self, s, col):
        # Ensure values are strings (no StringDtype) and map unknown -> MISSING_TOKEN
        s = pd.Series(s, copy=False).astype(object).where(pd.notna(s), MISSING_TOKEN)
        cats = set(self.cat_low_categories_[col])
        s = s.where(s.isin(cats), MISSING_TOKEN)
        # get dummies, then reindex to fixed columns
        dummies = pd.get_dummies(s, prefix=f"{col}__oh", prefix_sep="__", dummy_na=False)
        # Column names are like f"{col}__oh__value"
        target_cols = [f"{col}__oh__{v}" for v in self.cat_low_categories_[col]]
        dummies = dummies.reindex(columns=target_cols, fill_value=0)
        return dummies

    def _encode_hash(self, s, col):
        # Stable hash into HASH_BINS; returns a DataFrame with K indicator columns
        s = pd.Series(s, copy=False).astype(object).where(pd.notna(s), MISSING_TOKEN)
        idx = s.map(lambda v: stable_hash32(v) % self.hash_bins)
        mat = np.zeros((len(s), self.hash_bins), dtype=np.float32)
        rows = np.arange(len(s))
        mat[rows, idx.values] = 1.0
        cols = [f"{col}__h{b}" for b in range(self.hash_bins)]
        return pd.DataFrame(mat, columns=cols, index=s.index)

    def transform(self, X):
        assert self.fitted_, "Call fit() first."
        d = self._prep_base(X)

        # Start empty matrix and fill blocks to avoid reindex per step
        out = pd.DataFrame(index=d.index)

        # 1) numeric
        for c in self.num_cols_:
            if c not in d.columns:
                out[c] = float(self.num_median_[c])
            else:
                col = pd.to_numeric(d[c], errors="coerce")
                col = col.fillna(self.num_median_[c]).astype(np.float32)
                out[c] = col

        # 2) one-hot categoricals
        for c in self.cat_low_:
            s = d[c] if c in d.columns else pd.Series([MISSING_TOKEN]*len(d), index=d.index)
            block = self._encode_onehot(s, c)
            out = pd.concat([out, block], axis=1)

        # 3) hashed categoricals
        for c in self.cat_high_:
            s = d[c] if c in d.columns else pd.Series([MISSING_TOKEN]*len(d), index=d.index)
            block = self._encode_hash(s, c)
            out = pd.concat([out, block], axis=1)

        # Final align & type
        out = out.reindex(columns=self.feature_names_, fill_value=0).astype(np.float32, copy=False)
        return out

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
tgt = resolve_target_name(df)
if tgt is None:
    raise KeyError("Could not find target column. Expected 'target_risk_class' or 'risk_score'.")

y_1_10 = pd.to_numeric(df[tgt], errors="coerce").astype("Int64")
y_1_10 = y_1_10.where((y_1_10>=1) & (y_1_10<=10))
mask = y_1_10.notna()
df = df.loc[mask].copy()
y_1_10 = y_1_10.loc[mask].astype("int16")
y = (y_1_10 - 1).astype("int16")  # 0..9

X = df.drop(columns=[tgt])

# ------------------------ SPLIT (70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

# Save indices
pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv",  index=False)

print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS ------------------------
pp = TabularPreprocessor(max_onehot=MAX_ONEHOT_CARD, hash_bins=HASH_BINS).fit(X_train)
Xtr = pp.transform(X_train)
Xva = pp.transform(X_val)
Xte = pp.transform(X_test)

# class-balanced weights
classes_present = np.unique(y_train)
class_weights = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_train)
cw_map = {int(c): float(w) for c, w in zip(classes_present, class_weights)}
w_train = y_train.map(cw_map).astype("float32")
w_val   = y_val.map(cw_map).astype("float32")

# ------------------------ DMATRIX (pure numeric) ------------------------
def to_dmatrix_num(X, y=None, w=None):
    Xc = X.astype(np.float32, copy=False)
    return xgb.DMatrix(
        Xc,
        label=(y if y is not None else None),
        weight=(w if w is not None else None),
        nthread=NTHREAD
    )

# ------------------------ OPTUNA (xgb.train) ------------------------
def suggest_xgb_params(trial):
    params = {
        "objective": "multi:softprob",
        "num_class": NUM_CLASSES,
        "eval_metric": "mlogloss",
        "eta": trial.suggest_float("eta", 0.02, 0.2, log=True),
        "max_depth": trial.suggest_int("max_depth", 4, 14),
        "min_child_weight": trial.suggest_float("min_child_weight", 1e-2, 20.0, log=True),
        "subsample": trial.suggest_float("subsample", 0.6, 1.0),
        "colsample_bytree": trial.suggest_float("colsample_bytree", 0.6, 1.0),
        "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.6, 1.0),
        "gamma": trial.suggest_float("gamma", 0.0, 5.0),
        "lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
        "alpha": trial.suggest_float("alpha", 1e-8, 10.0, log=True),
        "max_bin": trial.suggest_int("max_bin", 64, 512),
        # XGBoost >= 3.1 GPU
        "tree_method": "hist",
        "device": "cuda",
        "nthread": NTHREAD,
        "random_state": SEED,
    }
    if trial.suggest_categorical("use_dart", [False, True]):
        params.update({
            "booster": "dart",
            "rate_drop": trial.suggest_float("rate_drop", 0.0, 0.3),
            "skip_drop": trial.suggest_float("skip_drop", 0.0, 0.3),
            "sample_type": trial.suggest_categorical("sample_type", ["uniform", "weighted"]),
            "normalize_type": trial.suggest_categorical("normalize_type", ["tree", "forest"]),
        })
    else:
        params["booster"] = "gbtree"
    return params

def train_booster(params, X_train, y_train, w_train, X_val, y_val, w_val, pruning_cb=None):
    dtrain = to_dmatrix_num(X_train, y_train, w_train)
    dval   = to_dmatrix_num(X_val,   y_val,   w_val)
    evals = [(dval, "validation_0")]
    callbacks = [xgb_cb.EarlyStopping(rounds=EARLY_STOP_ROUNDS, save_best=True)]
    if pruning_cb is not None:
        callbacks.append(pruning_cb)
    booster = xgb.train(
        params=params,
        dtrain=dtrain,
        num_boost_round=N_ESTIMATORS_LIMIT,
        evals=evals,
        callbacks=callbacks,
    )
    return booster

def objective(trial):
    params = suggest_xgb_params(trial)
    pruning_cb = XGBoostPruningCallback(trial, "validation_0-mlogloss")
    booster = train_booster(params, Xtr, y_train, w_train, Xva, y_val, w_val, pruning_cb=pruning_cb)
    dval = to_dmatrix_num(Xva, y_val)
    try:
        br = getattr(booster, "best_iteration", None)
        proba = booster.predict(dval, iteration_range=(0, br)) if br is not None else booster.predict(dval)
    except TypeError:
        bntl = getattr(booster, "best_ntree_limit", 0)
        proba = booster.predict(dval, ntree_limit=bntl if bntl > 0 else 0)
    if proba.ndim == 1:
        proba = np.column_stack([1 - proba, proba])
    return log_loss(y_val, proba, labels=list(range(NUM_CLASSES)))

study = optuna.create_study(direction="minimize",
                            sampler=TPESampler(seed=SEED),
                            pruner=MedianPruner(n_warmup_steps=10))
study.optimize(objective, n_trials=N_TRIALS, timeout=TIMEOUT_SEC, show_progress_bar=False)

best_params = study.best_params
best_params.update({
    "objective": "multi:softprob",
    "num_class": NUM_CLASSES,
    "eval_metric": "mlogloss",
    "tree_method": "hist",
    "device": "cuda",
    "nthread": NTHREAD, "random_state": SEED,
})
with open(OUTPUT_DIR / "best_params.json", "w", encoding="utf-8") as f:
    json.dump(best_params, f, ensure_ascii=False, indent=2)
print("Best params:", json.dumps(best_params, indent=2, ensure_ascii=False))

# ------------------------ FINAL TRAIN ------------------------
if REFIT_ON_TRAINVAL:
    X_refit = pd.concat([X_train, X_val], axis=0)
    y_refit = pd.concat([y_train,  y_val], axis=0)
    w_refit = pd.concat([w_train,  w_val], axis=0)
    X_refit_enc = pp.transform(X_refit)
    X_eval_enc  = pp.transform(X_val)
    X_train_use, y_train_use, w_train_use = X_refit_enc, y_refit, w_refit
    X_eval_use,  y_eval_use,  w_eval_use  = X_eval_enc,  y_val,   w_val
else:
    X_train_use, y_train_use, w_train_use = Xtr, y_train, w_train
    X_eval_use,  y_eval_use,  w_eval_use  = Xva, y_val,   w_val

final_booster = train_booster(best_params, X_train_use, y_train_use, w_train_use, X_eval_use, y_eval_use, w_eval_use, pruning_cb=None)
best_iter = getattr(final_booster, "best_iteration", None)
print("Best iteration:", best_iter)

# ------------------------ EVALUATION ------------------------
train_feature_names = list(X_train_use.columns)

def predict_proba_booster(booster, Xs):
    dm = to_dmatrix_num(Xs)
    try:
        br = getattr(booster, "best_iteration", None)
        return booster.predict(dm, iteration_range=(0, br)) if br is not None else booster.predict(dm)
    except TypeError:
        bntl = getattr(booster, "best_ntree_limit", 0)
        return booster.predict(dm, ntree_limit=bntl if bntl > 0 else 0)

def eval_split(name, Xs_raw, ys_zero):
    Xs = pp.transform(Xs_raw)
    proba = predict_proba_booster(final_booster, Xs)
    if proba.ndim == 1:
        proba = np.column_stack([1 - proba, proba])
    pred0 = np.argmax(proba, axis=1)
    metrics = {
        "split": name,
        "n_samples": int(len(ys_zero)),
        "accuracy": float(accuracy_score(ys_zero, pred0)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(ys_zero, pred0, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(ys_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = np.nan
    try:
        y_bin = pd.get_dummies(pd.Categorical(ys_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = np.nan

    ys_one  = ys_zero + 1
    pred_one = pred0 + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

metrics_train, yhat_train_1_10, proba_train = eval_split("train", X_train, y_train)
metrics_val,   yhat_val_1_10,   proba_val   = eval_split("val",   X_val,   y_val)
metrics_test,  yhat_test_1_10,  proba_test  = eval_split("test",  X_test,  y_test)

metrics_df = pd.DataFrame([metrics_train, metrics_val, metrics_test])
metrics_df.to_csv(OUTPUT_DIR / "metrics_xgb_gpu.csv", index=False)
print(metrics_df)

# ------------------------ FEATURE IMPORTANCE ------------------------
score_gain = final_booster.get_score(importance_type='gain')
imp = pd.DataFrame({"feature": list(score_gain.keys()), "gain": list(score_gain.values())}).sort_values("gain", ascending=False)
imp.to_csv(OUTPUT_DIR / "feature_importance_gain.csv", index=False)

plt.figure(figsize=(10, max(4, min(16, len(imp.head(40)) * 0.25))))
topk = imp.head(40).iloc[::-1]
plt.barh(topk["feature"], topk["gain"])
plt.title("XGBoost (GPU) Feature Importance (gain) - Top 40")
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "feature_importance_gain_top40.png", dpi=150)
plt.close()

# ------------------------ SAVE ARTIFACTS ------------------------
final_booster.save_model(str(OUTPUT_DIR / "xgb_gpu_multiclass.json"))
joblib.dump(pp, OUTPUT_DIR / "preprocessor.pkl")
with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": resolve_target_name(pd.read_csv(DATA_CSV, nrows=1)),
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),
        "seed": SEED,
        "best_iteration": best_iter,
        "n_trials": N_TRIALS,
        "refit_on_trainval": REFIT_ON_TRAINVAL,
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV,
        "train_feature_names": train_feature_names,
        "max_onehot_card": MAX_ONEHOT_CARD,
        "hash_bins": HASH_BINS
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to Google Drive at: {OUTPUT_DIR.resolve()}")

# ------------------------ INFERENCE HELPER ------------------------
def predict_target_risk_class(df_new: pd.DataFrame,
                              model_path=OUTPUT_DIR / "xgb_gpu_multiclass.json",
                              preproc_path=OUTPUT_DIR / "preprocessor.pkl",
                              meta_path=OUTPUT_DIR / "training_meta.json") -> pd.Series:
    """Predict on new risks (returns labels in 1..10). df_new is raw frame BEFORE preprocessing."""
    booster = xgb.Booster()
    booster.load_model(str(model_path))
    preproc = joblib.load(preproc_path)
    Xn = preproc.transform(df_new)
    dm = xgb.DMatrix(Xn.astype(np.float32, copy=False), nthread=NTHREAD)
    proba = booster.predict(dm)
    if proba.ndim == 1:
        proba = np.column_stack([1 - proba, proba])
    pred1 = np.argmax(proba, axis=1) + 1
    return pd.Series(pred1, index=df_new.index, name="pred_target_risk_class")

#### LightGBM (it was trained with CPU and used in VSCODE)

In [None]:
# ===============================================================
# Full, edited, single-cell pipeline (final):
# - Labels 1..10 are shifted to 0..9 for LightGBM, then +1 on outputs
# - Robust preprocessing (string/object -> categorical, datetime expansion, downcast)
# - Stratified 70/15/15 split + saved indices
# - Subset-based Optuna tuning (memory-light) with pruning (fixed valid name)
# - Final LightGBM multiclass training + metrics & artifacts
# ===============================================================

import os, re, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

import lightgbm as lgb
import joblib
import matplotlib.pyplot as plt
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler

warnings.filterwarnings("ignore")

# ------------------------ configuration ------------------------
SEED = 42
np.random.seed(SEED)

# >>>>>>>>>>>> SET YOUR FILE HERE (prefer model_ready_ascii.csv) <<<<<<<<<<<<
DATA_CSV = "path"

OUTPUT_DIR = Path("path"); OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)

TARGET_ASCII = "target_risk_class"
TARGET_PERSIAN = "risk score"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15  # of the whole dataset (applied after test split)

# ---- Resource-friendly tuning/training settings ----
SUBSET_FRAC = 0.20        # fraction of train+val used for Optuna tuning
N_TRIALS    = 15          # tuning trials (increase on stronger machine)
N_ESTIMATORS = 2500       # trees upper bound
EARLY_STOP_ROUNDS = 300
NUM_THREADS = 4           # reduce to 2 on laptop
TIMEOUT_SEC = None        # e.g. 1800 for 30-min guard

# ---- Eval wiring (fixes your valid-name mismatch) ----
VALID_NAME  = "val"
METRIC_NAME = "multi_logloss"

# ------------------------ helpers ------------------------
def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)  # remove ZWNJ/RTL/LTR marks
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def resolve_target_name(df):
    candidates = [TARGET_ASCII, TARGET_PERSIAN, "risk_class", "label", "target", "class", "y"]
    norm = {re.sub(r"[_\-\s]+"," ", str(c)).strip().lower(): c for c in df.columns}
    for cand in candidates:
        k = re.sub(r"[_\-\s]+"," ", cand).strip().lower()
        if k in norm: return norm[k]
    return None

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64", "int64"]).columns:
        if pd.api.types.is_float_dtype(df[c]):
            df[c] = pd.to_numeric(df[c], downcast="float")
        else:
            df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    """Lightweight parse for obvious date strings; skip heavy conversions."""
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == "object":
            s = df[c].astype("string")
            if s.str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try:
                    df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception:
                    pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pd.api.types.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]  = s.dt.year.astype("Int16")
        df[f"{c}__month"] = s.dt.month.astype("Int8")
        df[f"{c}__day"]   = s.dt.day.astype("Int8")
        df[f"{c}__dow"]   = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]  = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"]= s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]  = s.dt.is_month_end.astype("Int8")
    df.drop(columns=dt_cols, inplace=True)
    return df

def pd_cat_fix(series, allowed):
    """Coerce series to categorical with fixed categories; unseen -> 'Missing'."""
    s = series.astype("string").fillna("Missing")
    s = s.where(s.isin(allowed), "Missing")
    return pd.Categorical(s, categories=allowed, ordered=False)

# ------------------------ Preprocessor ------------------------
class TabularPreprocessor:
    """
    Fits on train, then applies consistent transforms to val/test/inference:
    - normalize whitespace
    - downcast numeric
    - parse + expand datetimes, drop originals
    - convert non-numeric to categorical with a 'Missing' bucket
    - store feature order, cat columns & categories, numeric columns
    """
    def __init__(self):
        self.feature_names_ = None
        self.cat_cols_ = []
        self.cat_categories_ = {}
        self.num_cols_ = []
        self.fitted_ = False

    def _prep_base(self, df):
        d = df.copy()
        for c in d.columns:
            if d[c].dtype == "object":
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pd.api.types.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X: pd.DataFrame):
        d = self._prep_base(X)
        self.num_cols_  = [c for c in d.columns if pd.api.types.is_numeric_dtype(d[c])]
        self.cat_cols_  = [c for c in d.columns if c not in self.num_cols_]

        for c in self.cat_cols_:
            s = d[c].astype("string").fillna("Missing")
            categories = pd.Index(pd.unique(s.dropna()))
            if "Missing" not in categories:
                categories = categories.insert(len(categories), "Missing")
            self.cat_categories_[c] = categories.tolist()
            d[c] = pd_cat_fix(s, self.cat_categories_[c])

        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X: pd.DataFrame):
        assert self.fitted_, "Preprocessor not fitted."
        d = self._prep_base(X)
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()
        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce")
            if d[c].isna().any():
                d[c] = d[c].fillna(d[c].median())
        for c in self.cat_cols_:
            allowed = self.cat_categories_[c]
            d[c] = pd_cat_fix(d[c], allowed)
        return d

# ------------------------ data load & target ------------------------
df_raw = pd.read_csv(DATA_CSV, low_memory=False)
tgt = resolve_target_name(df_raw)
if tgt is None:
    raise KeyError("Could not find target column. Expected 'target_risk_class' or 'risk score'.")

# sanitize target: keep only labels in 1..10, drop NaN labels
y_all_1_10 = pd.to_numeric(df_raw[tgt], errors="coerce").astype("Int64")
y_all_1_10 = y_all_1_10.where((y_all_1_10 >= 1) & (y_all_1_10 <= 10))
mask_labeled = y_all_1_10.notna()
df_raw = df_raw.loc[mask_labeled].copy()
y_all_1_10 = y_all_1_10.loc[mask_labeled].astype("int16")

# ---- shift to 0..9 for LightGBM ----
y_all = (y_all_1_10 - 1).astype("int16")

X_all = df_raw.drop(columns=[tgt])

# ------------------------ split (stratified 70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X_all, y_all))
X_trainval, X_test = X_all.iloc[trainval_idx], X_all.iloc[test_idx]
y_trainval, y_test = y_all.iloc[trainval_idx], y_all.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

# save split indices (relative to labeled subset)
pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv",  index=False)

print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ SUBSET-BASED tuning set ------------------------
# pick a stratified subset from TRAIN+VAL pool to reduce memory during Optuna
sss_sub = StratifiedShuffleSplit(n_splits=1, train_size=SUBSET_FRAC, random_state=SEED)
subset_idx, _ = next(sss_sub.split(X_trainval, y_trainval))
X_train_sub = X_trainval.iloc[subset_idx]
y_train_sub = y_trainval.iloc[subset_idx]
print(f"Tuning subset: {X_train_sub.shape} from train+val {X_trainval.shape}")

# ------------------------ preprocessing ------------------------
# Preprocessor for tuning subset
pp_tune = TabularPreprocessor().fit(X_train_sub)
Xtr_sub = pp_tune.transform(X_train_sub)   # tuning train
Xva_sub = pp_tune.transform(X_val)         # use full VAL as holdout during tuning

# Preprocessor for FINAL model (fit on full TRAIN)
pp = TabularPreprocessor().fit(X_train)
Xtr = pp.transform(X_train)
Xva = pp.transform(X_val)
Xte = pp.transform(X_test)

# categorical indices
cat_idx_sub = [Xtr_sub.columns.get_loc(c) for c in Xtr_sub.columns
               if pd.api.types.is_categorical_dtype(Xtr_sub[c])]
cat_idx = [Xtr.columns.get_loc(c) for c in Xtr.columns
           if pd.api.types.is_categorical_dtype(Xtr[c])]

# class weights -> sample weights (zero-index labels)
classes_present = np.unique(y_train)
class_weights = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_train)
class_weight_map = {int(c): float(w) for c, w in zip(classes_present, class_weights)}
w_train_sub = y_train_sub.map(class_weight_map).astype("float32")
w_train = y_train.map(class_weight_map).astype("float32")

# LightGBM Dataset objects for TUNING
dtrain_sub = lgb.Dataset(
    Xtr_sub, label=y_train_sub.values, weight=w_train_sub.values,
    categorical_feature=cat_idx_sub, free_raw_data=False
)
dval_sub = lgb.Dataset(
    Xva_sub, label=y_val.values,
    categorical_feature=cat_idx_sub, reference=dtrain_sub, free_raw_data=False
)
dtrain_sub.save_binary = True
dval_sub.save_binary   = True

# ------------------------ Optuna tuning (on subset) ------------------------
def objective(trial: optuna.Trial):
    params = {
        "objective": "multiclass",
        "num_class": NUM_CLASSES,
        "metric": METRIC_NAME,
        "verbosity": -1,
        "boosting": trial.suggest_categorical("boosting", ["gbdt", "dart"]),
        "learning_rate": trial.suggest_float("learning_rate", 0.02, 0.15, log=True),
        "num_leaves": trial.suggest_int("num_leaves", 31, 127, step=8),
        "max_depth": trial.suggest_int("max_depth", -1, 12),
        "min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 300, 1000),
        "feature_fraction": trial.suggest_float("feature_fraction", 0.6, 0.9),
        "bagging_fraction": trial.suggest_float("bagging_fraction", 0.6, 0.9),
        "bagging_freq": trial.suggest_int("bagging_freq", 1, 5),
        "lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 5.0, log=True),
        "lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
        "min_gain_to_split": trial.suggest_float("min_gain_to_split", 0.0, 2.0),
        "extra_trees": trial.suggest_categorical("extra_trees", [False, True]),
        "max_bin": trial.suggest_int("max_bin", 63, 127),
        "seed": SEED,
        "force_row_wise": True,
        "deterministic": True,
        "num_threads": NUM_THREADS,
    }
    if params["boosting"] == "dart":
        params["drop_rate"] = trial.suggest_float("drop_rate", 0.05, 0.2)

    callbacks = [
        lgb.early_stopping(stopping_rounds=EARLY_STOP_ROUNDS, verbose=False),
        lgb.log_evaluation(period=0),
        # IMPORTANT: valid_name matches valid_names passed to lgb.train
        optuna.integration.LightGBMPruningCallback(trial, METRIC_NAME, valid_name=VALID_NAME),
    ]

    model = lgb.train(
        params,
        dtrain_sub,
        num_boost_round=N_ESTIMATORS,
        valid_sets=[dval_sub],
        valid_names=[VALID_NAME],
        callbacks=callbacks,
    )

    # Return best score from the same VALID_NAME/METRIC_NAME
    try:
        return float(model.best_score[VALID_NAME][METRIC_NAME])
    except KeyError:
        ev = model.evals_result_
        return float(ev[VALID_NAME][METRIC_NAME][-1])

study = optuna.create_study(
    direction="minimize",
    sampler=TPESampler(seed=SEED),
    pruner=MedianPruner(n_warmup_steps=10),
)
study.optimize(objective, n_trials=N_TRIALS, timeout=TIMEOUT_SEC, show_progress_bar=False)

best_params = study.best_params
best_params.update({"num_threads": NUM_THREADS, "force_row_wise": True, "deterministic": True})
with open(OUTPUT_DIR / "best_params.json", "w", encoding="utf-8") as f:
    json.dump(best_params, f, ensure_ascii=False, indent=2)
print("Best params:", json.dumps(best_params, indent=2, ensure_ascii=False))
print("Best val multi_logloss:", study.best_value)

# ------------------------ Final training on FULL TRAIN (validate on VAL) ------------------------
# build final datasets from FULL train/val (preprocessed with pp)
dtrain = lgb.Dataset(
    Xtr, label=y_train.values, weight=w_train.values,
    categorical_feature=cat_idx, free_raw_data=False
)
dval = lgb.Dataset(
    Xva, label=y_val.values,
    categorical_feature=cat_idx, reference=dtrain, free_raw_data=False
)
dtrain.save_binary = True
dval.save_binary   = True

final_params = {
    "objective": "multiclass",
    "num_class": NUM_CLASSES,
    "metric": METRIC_NAME,
    "verbosity": -1,
    "seed": SEED,
    **best_params,
}

callbacks = [
    lgb.early_stopping(stopping_rounds=EARLY_STOP_ROUNDS, verbose=True),
    lgb.log_evaluation(period=100),
]

final_model = lgb.train(
    final_params,
    dtrain,
    num_boost_round=N_ESTIMATORS,
    valid_sets=[dval],
    valid_names=[VALID_NAME],
    callbacks=callbacks,
)
best_iter = final_model.best_iteration
print("Best iteration:", best_iter)

# ------------------------ Evaluation ------------------------
def eval_split(name, Xs: pd.DataFrame, ys_zero: pd.Series):
    """ys_zero expected in 0..9; we also save 1..10 artifacts for readability."""
    d = pp.transform(Xs)
    proba = final_model.predict(d, num_iteration=best_iter)            # (n, 10) for classes 0..9
    pred_zero  = np.argmax(proba, axis=1)                              # 0..9
    # Metrics computed in zero-index space (matches proba columns)
    metrics = {
        "split": name,
        "n_samples": int(len(ys_zero)),
        "accuracy": float(accuracy_score(ys_zero, pred_zero)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(ys_zero, pred_zero, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(ys_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = np.nan
    try:
        y_bin = pd.get_dummies(pd.Categorical(ys_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = np.nan

    # Save human-readable reports in 1..10 label space
    ys_one  = ys_zero + 1
    pred_one = pred_zero + 1
    report = classification_report(
        ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0, output_dict=False
    )
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

metrics_train, yhat_train_1_10, proba_train = eval_split("train", X_train, y_train)
metrics_val,   yhat_val_1_10,   proba_val   = eval_split("val",   X_val,   y_val)
metrics_test,  yhat_test_1_10,  proba_test  = eval_split("test",  X_test,  y_test)

metrics_df = pd.DataFrame([metrics_train, metrics_val, metrics_test])
metrics_df.to_csv(OUTPUT_DIR / "metrics_lightgbm_optuna.csv", index=False)
print(metrics_df)

# ------------------------ Feature importance ------------------------
imp = pd.DataFrame({
    "feature": final_model.feature_name(),
    "gain": final_model.feature_importance(importance_type="gain")
}).sort_values("gain", ascending=False)
imp.to_csv(OUTPUT_DIR / "feature_importance_gain.csv", index=False)

plt.figure(figsize=(10, max(4, min(16, len(imp.head(40)) * 0.25))))
topk = imp.head(40).iloc[::-1]
plt.barh(topk["feature"], topk["gain"])
plt.title("LightGBM Feature Importance (gain) - Top 40")
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "feature_importance_gain_top40.png", dpi=150)
plt.close()

# ------------------------ Save artifacts ------------------------
joblib.dump(final_model, OUTPUT_DIR / "lightgbm_booster.pkl")
joblib.dump(pp,           OUTPUT_DIR / "preprocessor.pkl")
with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": tgt,
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),  # trained on 0..9
        "best_iteration": int(best_iter),
        "seed": SEED,
        "num_threads": NUM_THREADS,
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV
    }, f, ensure_ascii=False, indent=2)

print(f"\nArtifacts saved in: {OUTPUT_DIR.resolve()}")

# ------------------------ Inference helper ------------------------
def predict_target_risk_class(df_new: pd.DataFrame,
                              model_path=OUTPUT_DIR / "lightgbm_booster.pkl",
                              preproc_path=OUTPUT_DIR / "preprocessor.pkl") -> pd.Series:
    """Predict on new risks. df_new should have the same raw columns as training BEFORE preprocessing.
       Returns labels in 1..10."""
    booster = joblib.load(model_path)
    preproc = joblib.load(preproc_path)
    d = preproc.transform(df_new)
    proba = booster.predict(d, num_iteration=getattr(booster, "best_iteration", None))
    pred = np.argmax(proba, axis=1) + 1  # map back to 1..10
    return pd.Series(pred, index=df_new.index, name="pred_target_risk_class")

#### CatBoost

In [None]:
!pip install catboost

Collecting catboost
  Downloading catboost-1.2.8-cp312-cp312-manylinux2014_x86_64.whl.metadata (1.2 kB)
Downloading catboost-1.2.8-cp312-cp312-manylinux2014_x86_64.whl (99.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.2/99.2 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: catboost
Successfully installed catboost-1.2.8


In [None]:
# ===============================================================
# CatBoost (GPU / A100) Multiclass Pipeline (1..10 -> 0..9 labels)
# - Colab-ready: saves under RESULTS/CatBoost_GPU/<timestamp>
# - Robust preprocessing (datetime expansion, numeric downcast, strings->object)
# - CatBoost native categorical hashing (handles unseen categories)
# - Stratified 70/15/15 split + saved indices
# - Optuna tuning (GPU-safe search space)
# - Early stopping via od_type='Iter'
# - Metrics: accuracy, macro/weighted P/R/F1, log loss, ROC-AUC OvR macro
# - Artifacts: model.cbm, preprocessor.pkl, best_params.json, metrics.csv, reports, confusions, importances
# - Inference helper returns labels in 1..10
# ===============================================================

import os, re, json, warnings, datetime
from pathlib import Path
import numpy as np
import pandas as pd
import pandas.api.types as pdt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

from catboost import CatBoostClassifier, Pool
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

import joblib
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# ------------------------ COLAB DRIVE SETUP ------------------------
drive_folder = 'path'
try:
    from google.colab import drive as _colab_drive
    if not Path("/content/drive").exists() or not any(Path("/content/drive").iterdir()):
        _colab_drive.mount("/content/drive")
except Exception:
    pass

# Subfolder for this run (timestamped)
RUN_NAME = f"CatBoost_GPU_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
OUTPUT_DIR = Path(drive_folder) / "CatBoost_GPU" / RUN_NAME
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)

print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
np.random.seed(SEED)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"   # A100 in Colab

# >>>>>>>>>>>> SET THIS TO YOUR FILE <<<<<<<<<<<<
DATA_CSV = "path"

TARGET_ASCII   = "target_risk_class"
TARGET_PERSIAN = "path"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# Tuning / training controls
N_TRIALS            = 50          # increase if you want deeper search
N_ESTIMATORS_LIMIT  = 5000        # upper cap; early stopping stops earlier
EARLY_STOP_ROUNDS   = 400
NTHREAD             = os.cpu_count()
TIMEOUT_SEC         = None
REFIT_ON_TRAINVAL   = True

MISSING_TOKEN = "Missing"

# ------------------------ UTILITIES ------------------------
def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)   # Persian ZWNJ/RTL/LTR marks
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]):
            df[c] = pd.to_numeric(df[c], downcast="float")
        else:
            df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try:
                    df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception:
                    pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols:
        df.drop(columns=dt_cols, inplace=True)
    return df

def resolve_target_name(df):
    for k in [TARGET_ASCII, TARGET_PERSIAN, "risk_class", "label", "target", "class", "y"]:
        if k in df.columns: return k
    norm = {re.sub(r"[_\-\s]+"," ", str(c)).strip().lower(): c for c in df.columns}
    for k in ["target_risk_class", "طبقه خطر", "risk class", "label", "target", "class", "y"]:
        kk = re.sub(r"[_\-\s]+"," ", k).strip().lower()
        if kk in norm: return norm[kk]
    return None

# ------------------------ PREPROCESSOR (CatBoost-friendly) ------------------------
class TabularPreprocessor:
    """
    - Expands datetimes, normalizes whitespace
    - Numerics -> float32 with median impute
    - Categoricals -> object strings with Missing filled (CatBoost hashes them)
    - Keeps fixed feature order; exposes cat feature indices for CatBoost Pool
    """
    def __init__(self):
        self.num_cols_  = []
        self.cat_cols_  = []
        self.num_median_ = {}
        self.feature_names_ = []
        self.fitted_ = False

    def _prep_base(self, df):
        d = df.copy()
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)
        self.num_cols_ = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        self.cat_cols_ = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]
        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())
        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_
        d = self._prep_base(X)
        # ensure all training features exist
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()
        # numerics -> float32 + median
        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce").astype("float32")
            if d[c].isna().any():
                d[c] = d[c].fillna(self.num_median_[c])
        # categoricals -> object strings; fill Missing
        for c in self.cat_cols_:
            s = d[c].astype(object)
            d[c] = s.where(pd.notna(s), MISSING_TOKEN).astype(object)
        return d

    def cat_feature_indices(self):
        return [self.feature_names_.index(c) for c in self.cat_cols_]

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
tgt = resolve_target_name(df)
if tgt is None:
    raise KeyError("Could not find target column. Expected 'target_risk_class' or 'risk sore'.")

y_1_10 = pd.to_numeric(df[tgt], errors="coerce").astype("Int64")
y_1_10 = y_1_10.where((y_1_10>=1) & (y_1_10<=10))
mask = y_1_10.notna()
df = df.loc[mask].copy()
y_1_10 = y_1_10.loc[mask].astype("int16")
y = (y_1_10 - 1).astype("int16")  # 0..9

X = df.drop(columns=[tgt])

# ------------------------ SPLIT (70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

# Save indices
pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv",  index=False)

print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS ------------------------
pp = TabularPreprocessor().fit(X_train)
Xtr = pp.transform(X_train)
Xva = pp.transform(X_val)
Xte = pp.transform(X_test)
cat_idx = pp.cat_feature_indices()

# class-balanced weights (length = NUM_CLASSES for 0..9)
classes_present = np.unique(y_train)
class_weights_arr = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_train)
cw_map = {int(c): float(w) for c, w in zip(classes_present, class_weights_arr)}
class_weights = [cw_map.get(k, 1.0) for k in range(NUM_CLASSES)]

# Pools (CatBoost requires Pool with cat_features indices)
train_pool = Pool(Xtr, label=y_train, cat_features=cat_idx)
val_pool   = Pool(Xva, label=y_val,   cat_features=cat_idx)
test_pool  = Pool(Xte, label=y_test,  cat_features=cat_idx)

# ------------------------ OPTUNA TUNING ------------------------
def suggest_cat_params(trial):
    # GPU-safe bootstrap options for multiclass
    bootstrap_type = trial.suggest_categorical("bootstrap_type", ["Bayesian", "Bernoulli"])

    if bootstrap_type == "Bayesian":
        bagging_temperature = trial.suggest_float("bagging_temperature", 0.0, 2.0)
        subsample = 1.0
    else:  # Bernoulli
        bagging_temperature = None
        subsample = trial.suggest_float("subsample", 0.6, 1.0)

    params = dict(
        loss_function="MultiClass",
        classes_count=NUM_CLASSES,
        eval_metric="MultiClass",

        learning_rate=trial.suggest_float("learning_rate", 0.02, 0.2, log=True),
        depth=trial.suggest_int("depth", 4, 10),
        l2_leaf_reg=trial.suggest_float("l2_leaf_reg", 1e-3, 100.0, log=True),
        random_strength=trial.suggest_float("random_strength", 0.0, 2.0),

        # GPU
        task_type="GPU",
        devices="0",
        gpu_ram_part=0.95,

        # Sampling
        bootstrap_type=bootstrap_type,
        one_hot_max_size=128,

        # Optional speedup for large data
        border_count=128,

        random_seed=SEED,
        thread_count=NTHREAD,
        verbose=False,
        allow_writing_files=False,

        iterations=N_ESTIMATORS_LIMIT,
        od_type="Iter",
        od_wait=EARLY_STOP_ROUNDS,
        use_best_model=True,
        class_weights=class_weights,
    )

    if bootstrap_type == "Bernoulli":
        params["subsample"] = subsample
    else:  # Bayesian
        params["bagging_temperature"] = bagging_temperature

    # NOTE: Do NOT set `rsm` on GPU multiclass (unsupported).
    # NOTE: Do NOT use `MVS` on GPU multiclass (unsupported).
    return params

def objective(trial):
    params = suggest_cat_params(trial)
    model = CatBoostClassifier(**params)
    model.fit(train_pool, eval_set=val_pool, verbose=False)
    proba = model.predict_proba(val_pool)
    return log_loss(y_val, proba, labels=list(range(NUM_CLASSES)))

study = optuna.create_study(direction="minimize",
                            sampler=TPESampler(seed=SEED),
                            pruner=MedianPruner(n_warmup_steps=10))
study.optimize(objective, n_trials=N_TRIALS, timeout=TIMEOUT_SEC, show_progress_bar=False)

best = study.best_params.copy()

# Safety guard in case an old artifact had MVS or other incompatible options
if best.get("bootstrap_type") not in ("Bayesian", "Bernoulli"):
    best["bootstrap_type"] = "Bernoulli"
    if "subsample" not in best:
        best["subsample"] = 0.8

# Normalize final params
best_params = dict(
    loss_function="MultiClass",
    classes_count=NUM_CLASSES,
    eval_metric="MultiClass",

    task_type="GPU",
    devices="0",
    gpu_ram_part=0.95,
    one_hot_max_size=128,
    border_count=128,

    random_seed=SEED,
    thread_count=NTHREAD,
    verbose=100,
    allow_writing_files=False,

    iterations=N_ESTIMATORS_LIMIT,
    od_type="Iter",
    od_wait=EARLY_STOP_ROUNDS,
    use_best_model=True,
    class_weights=class_weights,

    # tuned core params
    learning_rate=best.get("learning_rate"),
    depth=best.get("depth"),
    l2_leaf_reg=best.get("l2_leaf_reg"),
    random_strength=best.get("random_strength"),
    bootstrap_type=best.get("bootstrap_type"),
)

if best_params["bootstrap_type"] == "Bernoulli":
    if "subsample" in best: best_params["subsample"] = best["subsample"]
else:  # Bayesian
    if "bagging_temperature" in best:
        best_params["bagging_temperature"] = best["bagging_temperature"]

with open(OUTPUT_DIR / "best_params.json", "w", encoding="utf-8") as f:
    json.dump(best_params, f, ensure_ascii=False, indent=2)
print("Best params:", json.dumps(best_params, indent=2, ensure_ascii=False))

# ------------------------ FINAL TRAIN ------------------------
if REFIT_ON_TRAINVAL:
    X_refit = pd.concat([X_train, X_val], axis=0)
    y_refit = pd.concat([y_train,  y_val], axis=0)
    X_refit_enc = pp.transform(X_refit)
    train_pool_final = Pool(X_refit_enc, label=y_refit, cat_features=cat_idx)
    eval_pool_final  = Pool(pp.transform(X_val), label=y_val, cat_features=cat_idx)
else:
    train_pool_final = train_pool
    eval_pool_final  = val_pool

final_model = CatBoostClassifier(**best_params)
final_model.fit(train_pool_final, eval_set=eval_pool_final)  # verbose handled by params
best_iter = getattr(final_model, "best_iteration_", None)
print("Best iteration:", best_iter)

# ------------------------ EVALUATION ------------------------
def eval_split(name, Xs_raw, ys_zero):
    Xs = pp.transform(Xs_raw)
    pool = Pool(Xs, label=ys_zero, cat_features=cat_idx)
    proba = final_model.predict_proba(pool)
    pred0 = np.argmax(proba, axis=1)

    metrics = {
        "split": name,
        "n_samples": int(len(ys_zero)),
        "accuracy": float(accuracy_score(ys_zero, pred0)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(ys_zero, pred0, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)

    try:
        metrics["log_loss"] = float(log_loss(ys_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = float("nan")

    try:
        y_bin = pd.get_dummies(pd.Categorical(ys_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = float("nan")

    # Human-friendly reports in 1..10 space
    ys_one  = ys_zero + 1
    pred_one = pred0 + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

metrics_train, _, _ = eval_split("train", X_train, y_train)
metrics_val,   _, _ = eval_split("val",   X_val,   y_val)
metrics_test,  _, _ = eval_split("test",  X_test,  y_test)

metrics_df = pd.DataFrame([metrics_train, metrics_val, metrics_test])
metrics_df.to_csv(OUTPUT_DIR / "metrics_catboost_gpu.csv", index=False)
print(metrics_df)

# ------------------------ FEATURE IMPORTANCE ------------------------
imp_vals = final_model.get_feature_importance(train_pool_final, type="PredictionValuesChange")
imp = pd.DataFrame({"feature": pp.feature_names_, "importance": imp_vals}).sort_values("importance", ascending=False)
imp.to_csv(OUTPUT_DIR / "feature_importance_prediction_values_change.csv", index=False)

plt.figure(figsize=(10, max(4, min(16, len(imp.head(40)) * 0.25))))
topk = imp.head(40).iloc[::-1]
plt.barh(topk["feature"], topk["importance"])
plt.title("CatBoost (GPU) Feature Importance (PredictionValuesChange) - Top 40")
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "feature_importance_top40.png", dpi=150)
plt.close()

# ------------------------ SAVE ARTIFACTS ------------------------
final_model.save_model(str(OUTPUT_DIR / "catboost_multiclass.cbm"))
joblib.dump(pp, OUTPUT_DIR / "preprocessor.pkl")
with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": tgt,
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),
        "seed": SEED,
        "best_iteration": int(best_iter) if best_iter is not None else None,
        "n_trials": N_TRIALS,
        "refit_on_trainval": REFIT_ON_TRAINVAL,
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV,
        "feature_names": pp.feature_names_,
        "cat_features": pp.cat_cols_,
        "num_features": pp.num_cols_
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to Google Drive at: {OUTPUT_DIR.resolve()}")

# ------------------------ INFERENCE HELPER ------------------------
def predict_target_risk_class(df_new: pd.DataFrame,
                              model_path=OUTPUT_DIR / "catboost_multiclass.cbm",
                              preproc_path=OUTPUT_DIR / "preprocessor.pkl",
                              meta_path=OUTPUT_DIR / "training_meta.json") -> pd.Series:
    """Predict on new risks (returns labels in 1..10). df_new is raw frame BEFORE preprocessing."""
    model = CatBoostClassifier()
    model.load_model(str(model_path))
    preproc = joblib.load(preproc_path)
    with open(meta_path, "r", encoding="utf-8") as f:
        meta = json.load(f)
    d = preproc.transform(df_new)
    cat_idx = [preproc.feature_names_.index(c) for c in preproc.cat_cols_]
    pool = Pool(d, cat_features=cat_idx)
    proba = model.predict_proba(pool)
    pred1 = np.argmax(proba, axis=1) + 1
    return pd.Series(pred1, index=df_new.index, name="pred_target_risk_class")

### Transformers Based Model

#### tabM

##### CPU/TPU ready

In [None]:
# ===============================================================
# TabM (official) – Low-RAM, FIXED for hyphenated columns & unique categories
# ===============================================================

import os, re, json, warnings, datetime, math, time, random, subprocess, sys, gc
from pathlib import Path
import numpy as np
import pandas as pd
import pandas.api.types as pdt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

import joblib
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore")

# ------------------------ COLAB DRIVE SETUP ------------------------
drive_folder = 'path'
try:
    from google.colab import drive as _colab_drive
    if not Path("/content/drive").exists() or not any(Path("/content/drive").iterdir()):
        _colab_drive.mount("/content/drive")
except Exception:
    pass

# ------------------------ Install official packages ------------------------
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
                "tabm>=0.0.3", "rtdl_num_embeddings>=0.0.12"], check=True)
from tabm import TabM
from rtdl_num_embeddings import LinearReLUEmbeddings

# ------------------------ RUN/OUTPUT DIR ------------------------
RUN_NAME = f"TabM_LowRAM_Fixed_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
OUTPUT_DIR = Path(drive_folder) / "TabM_Package" / RUN_NAME
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)
print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.set_num_threads(max(1, (os.cpu_count() or 4)//2))

# >>>>>>>>>>>> SET THIS TO YOUR FILE <<<<<<<<<<<<
DATA_CSV = "path"

TARGET_NAME = "target_risk_class"   # you said this is the label
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# ------------------------ LOW-RAM PROFILE ------------------------
TABM_D_BLOCK   = 256
TABM_K         = 8
TABM_N_BLOCKS  = 1
TABM_DROPOUT   = 0.10

BATCH_SIZE       = 256
GRAD_ACC_STEPS   = 2
MAX_EPOCHS       = 60
BASE_LR          = 2e-3
WEIGHT_DECAY     = 3e-4
PATIENCE         = 10
MIN_DELTA        = 1e-4
WARMUP_EPOCHS    = 3

MISSING_TOKEN    = "Missing"
MAX_CAT_CARD     = 500
USE_NUM_EMBEDDINGS = False
STORE_NUM_AS_FP16  = True
STORE_CAT_AS_INT32 = True
NUM_WORKERS      = 0

# ------------------------ helpers ------------------------
def canon_col(name: str) -> str:
    s = re.sub(r"[^0-9A-Za-z_]+", "_", str(name))
    s = re.sub(r"_+", "_", s).strip("_")
    return s

def canon_cols_inplace(df: pd.DataFrame) -> pd.DataFrame:
    df.columns = [canon_col(c) for c in df.columns]
    return df

def _unique_in_order(seq):
    seen = set()
    out = []
    for x in seq:
        x = str(x)
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out

def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]):
            df[c] = pd.to_numeric(df[c], downcast="float")
        else:
            df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try:
                    df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception:
                    pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols:
        df.drop(columns=dt_cols, inplace=True)
    return df

def pd_cat_fix(series, allowed):
    # ensure unique, ordered categories and include MISSING_TOKEN once
    allowed = _unique_in_order(list(allowed) + [MISSING_TOKEN])
    s = series.astype("string").fillna(MISSING_TOKEN)
    s = s.where(s.isin(allowed), MISSING_TOKEN)
    return pd.Categorical(s, categories=pd.Index(allowed), ordered=False)

# ------------------------ Preprocessor (canon names + cap cats) ------------------------
class TabularPreprocessor:
    def __init__(self):
        self.num_cols_ = []
        self.cat_cols_ = []
        self.num_median_ = {}
        self.cat_categories_ = {}
        self.feature_names_ = []
        self.fitted_ = False

    def _prep_base(self, X):
        d = X.copy()
        canon_cols_inplace(d)  # <<< canonicalize EVERY time (train/val/test/inference)
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)
        self.num_cols_ = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        self.cat_cols_ = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]

        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())

        for c in self.cat_cols_:
            s = d[c].astype("string").fillna(MISSING_TOKEN)
            vc = s.value_counts(dropna=False)
            if MAX_CAT_CARD and len(vc) > (MAX_CAT_CARD - 1):
                top = vc.index.astype("string").tolist()[:MAX_CAT_CARD - 1]
                cats = _unique_in_order(top + [MISSING_TOKEN])
            else:
                cats = _unique_in_order(pd.unique(s).astype("string").tolist() + [MISSING_TOKEN])
            # final safety: ensure uniqueness
            if len(cats) != len(set(cats)):
                cats = _unique_in_order(cats)
            self.cat_categories_[c] = cats

        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_
        d = self._prep_base(X)
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()

        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce").astype("float32")
            if d[c].isna().any():
                d[c] = d[c].fillna(self.num_median_[c])
        for c in self.cat_cols_:
            d[c] = pd_cat_fix(d[c], self.cat_categories_[c])
        return d

# ------------------------ Torch encoder (compact arrays) ------------------------
class TorchTabEncoder:
    def __init__(self, num_cols, cat_cols, cat_categories):
        self.num_cols = list(num_cols)
        self.cat_cols = list(cat_cols)
        self.cat_categories = {c: list(cats) for c, cats in cat_categories.items()}
        self.num_mean_ = None
        self.num_std_  = None
        self.cat_cardinalities_ = {c: len(self.cat_categories[c]) for c in self.cat_cols}

    def fit(self, df_proc):
        if self.num_cols:
            arr = df_proc[self.num_cols].astype("float32").values
            self.num_mean_ = arr.mean(axis=0).astype("float32")
            std = arr.std(axis=0).astype("float32")
            self.num_std_  = np.where(std < 1e-6, 1.0, std).astype("float32")
        else:
            self.num_mean_ = np.array([], dtype="float32")
            self.num_std_  = np.array([], dtype="float32")
        return self

    def transform(self, df_proc):
        if self.num_cols:
            Xn = df_proc[self.num_cols].astype("float32").values
            Xn = (Xn - self.num_mean_) / self.num_std_
            if STORE_NUM_AS_FP16:
                Xn = Xn.astype("float16")
        else:
            Xn = np.zeros((len(df_proc), 0), dtype="float16" if STORE_NUM_AS_FP16 else "float32")

        Xc_list = []
        for c in self.cat_cols:
            codes = df_proc[c].cat.codes.to_numpy(copy=False)
            fix = self.cat_categories[c].index(MISSING_TOKEN)
            codes = np.where(codes < 0, fix, codes)
            codes = codes.astype("int32" if STORE_CAT_AS_INT32 else "int64")
            Xc_list.append(codes)
        Xc = np.stack(Xc_list, axis=1) if Xc_list else np.zeros((len(df_proc), 0), dtype="int32" if STORE_CAT_AS_INT32 else "int64")
        return Xn, Xc

    def save_meta(self, path_json):
        meta = {
            "num_cols": self.num_cols,
            "cat_cols": self.cat_cols,
            "cat_categories": self.cat_categories,
            "num_mean": self.num_mean_.tolist(),
            "num_std": self.num_std_.tolist(),
            "cat_cardinalities": self.cat_cardinalities_,
        }
        with open(path_json, "w", encoding="utf-8") as f:
            json.dump(meta, f, ensure_ascii=False, indent=2)

    @staticmethod
    def load_meta(path_json):
        with open(path_json, "r", encoding="utf-8") as f:
            meta = json.load(f)
        enc = TorchTabEncoder(meta["num_cols"], meta["cat_cols"], meta["cat_categories"])
        enc.num_mean_ = np.array(meta["num_mean"], dtype="float32")
        enc.num_std_  = np.array(meta["num_std"], dtype="float32")
        enc.cat_cardinalities_ = {k:int(v) for k,v in meta["cat_cardinalities"].items()}
        return enc

# ------------------------ Dataset ------------------------
class TabDataset(Dataset):
    def __init__(self, Xn, Xc, y=None):
        self.Xn = Xn
        self.Xc = Xc
        self.y  = None if y is None else y.astype("int64")
    def __len__(self): return len(self.Xn)
    def __getitem__(self, i):
        if self.y is None:
            return self.Xn[i], self.Xc[i]
        return self.Xn[i], self.Xc[i], self.y[i]

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
canon_cols_inplace(df)                  # <<< canonicalize once at read time too
if TARGET_NAME not in df.columns:
    raise KeyError(f"Expected label column '{TARGET_NAME}' after canonicalization. "
                   f"Got columns like: {list(df.columns)[:20]}")

y1 = pd.to_numeric(df[TARGET_NAME], errors="coerce").astype("Int64")
y1 = y1.where((y1>=1) & (y1<=10))
mask = y1.notna()
df = df.loc[mask].copy()
y1 = y1.loc[mask].astype("int16")
y  = (y1 - 1).astype("int16")
X  = df.drop(columns=[TARGET_NAME])
del df; gc.collect()

# ------------------------ SPLIT ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv",  index=False)
print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS & ENCODE ------------------------
pp = TabularPreprocessor().fit(X_train)
Xtr_df = pp.transform(X_train); Xva_df = pp.transform(X_val); Xte_df = pp.transform(X_test)

enc = TorchTabEncoder(pp.num_cols_, pp.cat_cols_, pp.cat_categories_).fit(Xtr_df)
Xtr_num, Xtr_cat = enc.transform(Xtr_df)
Xva_num, Xva_cat = enc.transform(Xva_df)
Xte_num, Xte_cat = enc.transform(Xte_df)

# free raw frames early
del X_train, X_val, X_test, X_trainval, Xtr_df, Xva_df, Xte_df, X, y, y1; gc.collect()

y_tr = y_train.values.astype("int64")
y_va = y_val.values.astype("int64")
y_te = y_test.values.astype("int64")

ds_tr = TabDataset(Xtr_num, Xtr_cat, y_tr)
ds_va = TabDataset(Xva_num, Xva_cat, y_va)
ds_te = TabDataset(Xte_num, Xte_cat, y_te)

dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
dl_te = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# ------------------------ CLASS WEIGHTS ------------------------
classes_present = np.unique(y_tr)
cw = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_tr)
cw_map = {int(c): float(w) for c, w in zip(classes_present, cw)}
class_weights = np.ones(NUM_CLASSES, dtype="float32")
for c, w in cw_map.items(): class_weights[c] = w
class_weights = class_weights / class_weights.mean()
class_weights_t = torch.tensor(class_weights, dtype=torch.float32)

# ------------------------ MODEL ------------------------
device = torch.device("cpu")
n_num = Xtr_num.shape[1]
cat_cards = [enc.cat_cardinalities_[c] for c in enc.cat_cols] if enc.cat_cols else None
num_emb = LinearReLUEmbeddings(n_num) if (USE_NUM_EMBEDDINGS and n_num > 0) else None

model = TabM.make(
    n_num_features=n_num,
    cat_cardinalities=cat_cards,
    num_embeddings=num_emb,
    d_out=NUM_CLASSES,
    k=TABM_K,
    n_blocks=TABM_N_BLOCKS,
    d_block=TABM_D_BLOCK,
    dropout=TABM_DROPOUT,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
def cosine_factor(epoch, max_epochs=MAX_EPOCHS, warmup=WARMUP_EPOCHS):
    if epoch < warmup: return (epoch + 1) / max(1, warmup)
    t = (epoch - warmup) / max(1, max_epochs - warmup)
    return 0.5 * (1.0 + math.cos(math.pi * t))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: cosine_factor(e))

best_val = float("inf"); best_epoch = -1; pat = 0
history = {"epoch": [], "train_loss": [], "val_loss": [], "lr": []}
t0 = time.time()

# ------------------------ TRAIN ------------------------
for epoch in range(1, MAX_EPOCHS + 1):
    model.train()
    total, n = 0.0, 0
    optimizer.zero_grad(set_to_none=True)
    for step, (xnum_np, xcat_np, yb) in enumerate(dl_tr, 1):
        xnum = torch.as_tensor(xnum_np, device=device).float()
        xcat = torch.as_tensor(xcat_np, device=device).long()
        yb   = yb.to(device)

        y_pred = model(xnum, xcat) if (xcat.shape[1] > 0 or n_num > 0) else model(xnum)
        B, K, C = y_pred.shape

        loss = 0.0
        for k in range(K):
            loss = loss + F.cross_entropy(y_pred[:, k, :], yb, weight=class_weights_t.to(device))
        loss = loss / K

        (loss / GRAD_ACC_STEPS).backward()
        if step % GRAD_ACC_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        total += loss.item() * B
        n += B
        del xnum, xcat

    train_loss = total / max(1, n)

    model.eval()
    vtotal, vn = 0.0, 0
    with torch.no_grad():
        for xnum_np, xcat_np, yb in dl_va:
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            yb   = yb.to(device)
            y_pred = model(xnum, xcat) if (xcat.shape[1] > 0 or n_num > 0) else model(xnum)
            B, K, C = y_pred.shape
            vloss = 0.0
            for k in range(K):
                vloss = vloss + F.cross_entropy(y_pred[:, k, :], yb, weight=class_weights_t.to(device))
            vloss = vloss / K
            vtotal += vloss.item() * B
            vn += B
            del xnum, xcat
    val_loss = vtotal / max(1, vn)
    scheduler.step()

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["lr"].append(optimizer.param_groups[0]["lr"])
    print(f"Epoch {epoch:03d} | train {train_loss:.4f} | val {val_loss:.4f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if val_loss + MIN_DELTA < best_val:
        best_val = val_loss; best_epoch = epoch; pat = 0
        torch.save({"state_dict": model.state_dict(),
                    "tabm_config": {
                        "n_num": n_num, "cat_cardinalities": cat_cards,
                        "d_out": NUM_CLASSES, "k": TABM_K,
                        "n_blocks": TABM_N_BLOCKS, "d_block": TABM_D_BLOCK, "dropout": TABM_DROPOUT,
                        "use_num_embeddings": bool(USE_NUM_EMBEDDINGS and n_num > 0),
                        "num_embedding_type": "LinearReLUEmbeddings" if (USE_NUM_EMBEDDINGS and n_num > 0) else None
                    }},
                   OUTPUT_DIR / "tabm_model.pt")
    else:
        pat += 1
        if pat >= PATIENCE:
            print(f"Early stopping at epoch {epoch} (best @ {best_epoch} | val {best_val:.4f})")
            break

elapsed = time.time() - t0
print(f"Training time: {elapsed/60:.1f} min; best epoch: {best_epoch}")

# ------------------------ Evaluation ------------------------
def predict_proba_dl(model, dl):
    model.eval()
    probs = []
    with torch.no_grad():
        for batch in dl:
            if len(batch) == 3:
                xnum_np, xcat_np, _ = batch
            else:
                xnum_np, xcat_np = batch
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            y_pred = model(xnum, xcat) if (xcat.shape[1] > 0 or n_num > 0) else model(xnum)
            p = F.softmax(y_pred, dim=-1).mean(dim=1).cpu().numpy()
            probs.append(p)
            del xnum, xcat
    return np.vstack(probs)

ckpt = torch.load(OUTPUT_DIR / "tabm_model.pt", map_location=device)
model.load_state_dict(ckpt["state_dict"])

dl_tr_eval = DataLoader(ds_tr, batch_size=1024, shuffle=False, num_workers=0)
dl_va_eval = DataLoader(ds_va, batch_size=1024, shuffle=False, num_workers=0)
dl_te_eval = DataLoader(ds_te, batch_size=1024, shuffle=False, num_workers=0)

def eval_split(name, dl, y_zero):
    proba = predict_proba_dl(model, dl)
    pred0 = np.argmax(proba, axis=1)
    metrics = {
        "split": name,
        "n_samples": int(len(y_zero)),
        "accuracy": float(accuracy_score(y_zero, pred0)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(y_zero, pred0, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(y_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = float("nan")
    try:
        y_bin = pd.get_dummies(pd.Categorical(y_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = float("nan")

    ys_one  = y_zero + 1
    pred_one = pred0 + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

m_train, _, _ = eval_split("train", dl_tr_eval, y_tr)
m_val,   _, _ = eval_split("val",   dl_va_eval, y_va)
m_test,  _, _ = eval_split("test",  dl_te_eval, y_te)

pd.DataFrame([m_train, m_val, m_test]).to_csv(OUTPUT_DIR / "metrics_tabm_lowram_fixed.csv", index=False)

# ------------------------ Curves ------------------------
plt.figure(figsize=(7,4))
plt.plot(history["epoch"], history["train_loss"], label="train")
plt.plot(history["epoch"], history["val_loss"],   label="val")
plt.xlabel("epoch"); plt.ylabel("CE loss"); plt.title("TabM (Low-RAM, Fixed) Loss"); plt.legend(); plt.tight_layout()
plt.savefig(OUTPUT_DIR / "loss_curves.png", dpi=150); plt.close()

# ------------------------ Save artifacts ------------------------
joblib.dump(pp, OUTPUT_DIR / "preprocessor.pkl")
enc.save_meta(OUTPUT_DIR / "encoder_meta.json")
with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": TARGET_NAME,
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),
        "seed": SEED,
        "best_epoch": int(best_epoch),
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV,
        "tabm_cfg": ckpt["tabm_config"],
        "train_time_min": round(elapsed/60, 2),
        "feature_names": pp.feature_names_,
        "cat_features": pp.cat_cols_,
        "num_features": pp.num_cols_,
        "column_names_canonicalized": True
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to Google Drive at: {OUTPUT_DIR.resolve()}")

# ------------------------ Inference helper ------------------------
def predict_target_risk_class(df_new: pd.DataFrame,
                              model_path=OUTPUT_DIR / "tabm_model.pt",
                              preproc_path=OUTPUT_DIR / "preprocessor.pkl",
                              encoder_meta_path=OUTPUT_DIR / "encoder_meta.json",
                              batch_size=4096) -> pd.Series:
    device = torch.device("cpu")
    pp = joblib.load(preproc_path)
    enc = TorchTabEncoder.load_meta(encoder_meta_path)
    ckpt = torch.load(model_path, map_location=device)
    cfg = ckpt["tabm_config"]

    # Ensure incoming columns match training convention
    df_new = df_new.copy()
    canon_cols_inplace(df_new)

    n_num = len(enc.num_cols)
    cat_cards = [enc.cat_cardinalities_[c] for c in enc.cat_cols] if enc.cat_cols else None
    num_emb = LinearReLUEmbeddings(n_num) if (cfg.get("use_num_embeddings") and n_num > 0) else None
    model = TabM.make(
        n_num_features=n_num, cat_cardinalities=cat_cards, num_embeddings=num_emb,
        d_out=cfg["d_out"], k=cfg["k"], n_blocks=cfg["n_blocks"], d_block=cfg["d_block"], dropout=cfg["dropout"]
    ).to(device)
    model.load_state_dict(ckpt["state_dict"]); model.eval()

    dproc = pp.transform(df_new)
    Xn, Xc = enc.transform(dproc)
    ds = TabDataset(Xn, Xc, y=None)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)

    preds = []
    with torch.no_grad():
        for xnum_np, xcat_np in dl:
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            y_pred = model(xnum, xcat) if (xcat.shape[1] > 0 or n_num > 0) else model(xnum)
            p = F.softmax(y_pred, dim=-1).mean(dim=1)
            preds.append(torch.argmax(p, dim=-1).cpu().numpy())
    yhat0 = np.concatenate(preds, axis=0).astype("int16")
    return pd.Series(yhat0 + 1, index=dproc.index, name="pred_target_risk_class")

##### GPU ready

In [None]:
%pip install -q -U pip setuptools wheel
%pip install -q --no-cache-dir tabm rtdl-num-embeddings

In [None]:
import tabm, rtdl_num_embeddings
print("tabm:", tabm.__version__)
print("rtdl-num-embeddings:", rtdl_num_embeddings.__version__)

tabm: 0.0.3
rtdl-num-embeddings: 0.0.12


In [None]:
# ===============================================================
# TabM (PyTorch • CUDA) Multiclass Training Pipeline — Colab GPU
# - Colab-ready; saves under RESULTS/TabM_GPU/<timestamp>
# - Robust preprocessing: datetime expansion, float32 downcast, stable ordinal encoding (unseen->'Missing')
# - Stratified 70/15/15 split + saved indices
# - Model: TabM (parameter-efficient ensemble MLP) with numeric embeddings
# - GPU: CUDA mixed precision (fp16) + GradScaler; gradient accumulation
# - Metrics: accuracy, macro/weighted P/R/F1, log loss + reports + confusions
# - Artifacts: tabm_model.pt, preprocessor.pkl, config.json, metrics.csv
# - Inference helper returns labels in 1..10 (averaging ensemble probabilities)
# ===============================================================

# ------------------------ IMPORTS ------------------------
import numpy as np
import pandas as pd
import pandas.api.types as pdt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, log_loss, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import joblib
import os, subprocess, sys, json, warnings, re, shutil
import datetime
from pathlib import Path

# TabM
from tabm import TabM
from rtdl_num_embeddings import PiecewiseLinearEmbeddings, LinearReLUEmbeddings

# ------------------------ DRIVE SETUP ------------------------
drive_folder = 'path'
try:
    from google.colab import drive as _colab_drive
    if not Path("/content/drive").exists() or not any(Path("/content/drive").iterdir()):
        _colab_drive.mount("/content/drive")
except Exception:
    pass

RUN_NAME = f"TabM_GPU_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
OUTPUT_DIR = Path(drive_folder) / "TabM_GPU" / RUN_NAME
(OUTPUT_DIR / "splits").mkdir(parents=True, exist_ok=True)
print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# >>>>>>>>>>>> SET THIS TO YOUR FILE <<<<<<<<<<<<
DATA_CSV = "path"

TARGET_ASCII   = "target_risk_class"
TARGET_PERSIAN = "risk score"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# Training controls (tuned for ~15GB VRAM, e.g., T4)
MAX_EPOCHS      = 80
PATIENCE        = 8
BATCH_SIZE      = 1024          # reduce if you still see OOM; try 512
ACCUM_STEPS     = 2             # effective batch = BATCH_SIZE * ACCUM_STEPS
LEARNING_RATE   = 2e-3
WEIGHT_DECAY    = 1e-4
NUM_WORKERS     = 2
MIXED_PRECISION = True          # CUDA autocast(fp16) + GradScaler
K_ENSEMBLE      = 16            # TabM's k (8..16 fits better on T4)
EMBEDDING_TYPE  = "piecewise"   # numeric embedding type

MISSING_TOKEN = "Missing"

# ------------------------ GPU / DEVICE ------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ------------------------ UTILITIES ------------------------
def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]):
            df[c] = pd.to_numeric(df[c], downcast="float")
        else:
            df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try:
                    df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception:
                    pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols:
        df.drop(columns=dt_cols, inplace=True)
    return df

def resolve_target_name(df):
    for k in [TARGET_ASCII, TARGET_PERSIAN, "risk_class", "label", "target", "class", "y"]:
        if k in df.columns: return k
    norm = {re.sub(r"[_\-\s]+"," ", str(c)).strip().lower(): c for c in df.columns}
    for k in ["target_risk_class", "risk score", "risk class", "label", "target", "class", "y"]:
        kk = re.sub(r"[_\-\s]+"," ", k).strip().lower()
        if kk in norm: return norm[kk]
    return None

# ------------------------ PREPROCESSOR ------------------------
class TabularPreprocessor:
    """
    - Expands datetimes, normalizes whitespace
    - numerics as float32, median impute
    - categoricals -> stable ordinal codes per column; unseen -> 'Missing'
    """
    def __init__(self):
        self.num_cols_, self.cat_cols_ = [], []
        self.num_median_, self.cat_maps_, self.cat_cardinalities_ = {}, {}, {}
        self.feature_names_, self.fitted_ = [], False

    def _prep_base(self, df):
        d = df.copy()
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)
        self.num_cols_ = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        self.cat_cols_ = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]
        # stats
        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())
        # build categorical maps (ensure MISSING_TOKEN present)
        for c in self.cat_cols_:
            s = d[c].astype("string").fillna(MISSING_TOKEN)
            cats = pd.Index(pd.unique(pd.concat([pd.Series([MISSING_TOKEN]), s])))
            cmap = {v: i for i, v in enumerate(cats)}
            self.cat_maps_[c] = cmap
            self.cat_cardinalities_[c] = len(cmap)
        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_
        d = self._prep_base(X)
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()
        # numerics
        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce").astype("float32")
            if d[c].isna().any():
                d[c] = d[c].fillna(self.num_median_[c])
        # categoricals
        for c in self.cat_cols_:
            s = d[c].astype("string").fillna(MISSING_TOKEN)
            cmap = self.cat_maps_[c]
            d[c] = s.map(cmap).fillna(cmap[MISSING_TOKEN]).astype("int32")
        return d

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
tgt = resolve_target_name(df)
if tgt is None:
    raise KeyError("Could not find target column. Expected 'target_risk_class' or 'risk score'.")

y_1_10 = pd.to_numeric(df[tgt], errors="coerce").astype("Int64")
y_1_10 = y_1_10.where((y_1_10>=1) & (y_1_10<=10))
mask = y_1_10.notna()
df = df.loc[mask].copy()
y_1_10 = y_1_10.loc[mask].astype("int16")
y = (y_1_10 - 1).astype("int64")  # 0..9
X = df.drop(columns=[tgt])

# ------------------------ SPLIT (70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test  = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

pd.Series(X_train.index, name="index").to_csv(OUTPUT_DIR/"splits/train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(OUTPUT_DIR/"splits/val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(OUTPUT_DIR/"splits/test_indices.csv",  index=False)

print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS ------------------------
pp = TabularPreprocessor().fit(X_train)
Xtr = pp.transform(X_train)
Xva = pp.transform(X_val)
Xte = pp.transform(X_test)

num_cols = pp.num_cols_
cat_cols = pp.cat_cols_
cat_cardinalities = [pp.cat_cardinalities_[c] for c in cat_cols]
print(f"Numerics: {len(num_cols)} | Categoricals: {len(cat_cols)}")

# ------------------------ DATASETS ------------------------
class TabDataset(data.Dataset):
    def __init__(self, X_enc: pd.DataFrame, y_arr: np.ndarray | None, num_cols, cat_cols):
        self.X = X_enc
        self.y = None if y_arr is None else np.asarray(y_arr, dtype=np.int64)
        self.num_cols = num_cols
        self.cat_cols = cat_cols
    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        row = self.X.iloc[idx]
        x_num = torch.tensor(row[self.num_cols].values, dtype=torch.float32) if self.num_cols else torch.empty( (0,), dtype=torch.float32 )
        x_cat = torch.tensor(row[self.cat_cols].values, dtype=torch.long)     if self.cat_cols else torch.empty( (0,), dtype=torch.long )
        if self.y is None:
            return x_num, x_cat
        return x_num, x_cat, torch.tensor(self.y[idx], dtype=torch.long)

ds_tr = TabDataset(Xtr, y_train.values, num_cols, cat_cols)
ds_va = TabDataset(Xva, y_val.values,   num_cols, cat_cols)
ds_te = TabDataset(Xte, y_test.values,  num_cols, cat_cols)

dl_kwargs = dict(
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type == "cuda"),
    persistent_workers=(NUM_WORKERS > 0),
    drop_last=False
)
dl_tr = data.DataLoader(ds_tr, shuffle=True,  **dl_kwargs)
dl_va = data.DataLoader(ds_va, shuffle=False, **dl_kwargs)
dl_te = data.DataLoader(ds_te, shuffle=False, **dl_kwargs)

# ------------------------ MODEL ------------------------
d_out = NUM_CLASSES

# Numeric embeddings (VRAM-friendly defaults)
num_embeddings = None
if len(num_cols) > 0:
    d_emb = 16
    num_emb_inf = None
    if len(num_cols) > 0:
        try:
            num_emb_inf = PiecewiseLinearEmbeddings(len(num_cols), d_emb, version="B")
        except TypeError:
            try:
                num_emb_inf = PiecewiseLinearEmbeddings(len(num_cols), d_emb)
            except Exception:
                num_emb_inf = LinearReLUEmbeddings(len(num_cols), d_emb)

model = TabM.make(
    n_num_features=len(num_cols),
    num_embeddings=num_embeddings,
    cat_cardinalities=cat_cardinalities if len(cat_cols)>0 else None,
    d_out=NUM_CLASSES,
    k=K_ENSEMBLE,                 # ensemble size (reduce if OOM)
    arch_type="batch-ensemble",
)
model.to(device)

# Optimizer & schedule
opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=MAX_EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(MIXED_PRECISION and device.type=="cuda"))

# ------------------------ TRAINING UTILS ------------------------
def _forward_logits(x_num, x_cat):
    if x_num is not None and x_cat is not None and x_cat.numel()>0:
        return model(x_num, x_cat)        # (B, k, C)
    elif x_num is not None:
        return model(x_num)               # (B, k, C)
    else:
        return model(None, x_cat)         # (B, k, C)

def _step_batch(batch, train: bool):
    if len(num_cols) > 0 and len(cat_cols) > 0:
        x_num, x_cat, yb = batch
    elif len(num_cols) > 0:
        x_num, yb = batch[0], batch[-1]
        x_cat = None
    elif len(cat_cols) > 0:
        x_cat, yb = batch[0], batch[-1]
        x_num = None
    else:
        raise ValueError("No features!")

    if x_num is not None: x_num = x_num.to(device, non_blocking=True)
    if x_cat is not None and x_cat.numel()>0: x_cat = x_cat.to(device, non_blocking=True)
    yb = yb.to(device, non_blocking=True)

    autocast_ctx = torch.cuda.amp.autocast(enabled=(MIXED_PRECISION and device.type=="cuda"), dtype=torch.float16)
    with autocast_ctx:
        logits_k = _forward_logits(x_num, x_cat)   # (B, k, C)
        B, K, C = logits_k.shape
        loss = F.cross_entropy(
            logits_k.reshape(B*K, C),
            yb.unsqueeze(1).repeat(1, K).reshape(B*K),
            reduction="mean"
        )
    if train:
        if scaler.is_enabled():
            scaler.scale(loss).backward()
        else:
            loss.backward()
    return loss.item(), logits_k.detach(), yb.detach()

@torch.no_grad()
def evaluate(dataloader):
    model.eval()
    all_probs, all_y = [], []
    total_loss = 0.0
    for batch in dataloader:
        # forward (no grad)
        if len(num_cols) > 0 and len(cat_cols) > 0:
            x_num, x_cat, yb = batch
        elif len(num_cols) > 0:
            x_num, yb = batch[0], batch[-1]
            x_cat = None
        else:
            x_cat, yb = batch[0], batch[-1]
            x_num = None

        if x_num is not None: x_num = x_num.to(device, non_blocking=True)
        if x_cat is not None and x_cat.numel()>0: x_cat = x_cat.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(MIXED_PRECISION and device.type=="cuda"), dtype=torch.float16):
            logits_k = _forward_logits(x_num, x_cat)
            B,K,C = logits_k.shape
            l = F.cross_entropy(logits_k.reshape(B*K, C),
                                yb.unsqueeze(1).repeat(1, K).reshape(B*K),
                                reduction="mean")

        total_loss += l.item() * yb.size(0)
        probs_k = F.softmax(logits_k.float(), dim=-1)   # (B, k, C)
        probs   = probs_k.mean(dim=1)                   # (B, C)
        all_probs.append(probs.cpu().numpy())
        all_y.append(yb.cpu().numpy())

    probs = np.concatenate(all_probs, axis=0)
    ys    = np.concatenate(all_y,    axis=0)
    preds = probs.argmax(axis=1)
    metrics = {
        "n_samples": int(len(ys)),
        "accuracy": float(accuracy_score(ys, preds)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(ys, preds, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(ys, probs, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = float("nan")
    return metrics, probs, preds, ys, total_loss / max(1, len(ys))

# ------------------------ TRAIN LOOP ------------------------
best_val = float("inf")
best_state = None
epochs_no_improve = 0

for epoch in range(1, MAX_EPOCHS+1):
    model.train()
    epoch_loss = 0.0
    n_seen = 0
    opt.zero_grad(set_to_none=True)
    for it, batch in enumerate(dl_tr, start=1):
        l, _, yb = _step_batch(batch, train=True)
        bs = yb.size(0)
        n_seen += bs
        epoch_loss += l * bs

        # gradient accumulation
        if scaler.is_enabled():
            scaler.step(opt) if (it % ACCUM_STEPS == 0) else None
            scaler.update()  if (it % ACCUM_STEPS == 0) else None
            if (it % ACCUM_STEPS == 0):
                opt.zero_grad(set_to_none=True)
        else:
            if (it % ACCUM_STEPS == 0):
                opt.step()
                opt.zero_grad(set_to_none=True)

    # Handle last partial accumulation step
    if (it % ACCUM_STEPS) != 0:
        if scaler.is_enabled():
            scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
        else:
            opt.step(); opt.zero_grad(set_to_none=True)

    train_loss = epoch_loss / max(1, n_seen)
    val_metrics, _, _, _, val_loss = evaluate(dl_va)
    scheduler.step()

    print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_metrics['accuracy']:.4f}")

    # Early stopping on val_loss
    if val_loss < best_val - 1e-4:
        best_val = val_loss
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

# Load best
if best_state is not None:
    model.load_state_dict(best_state)

# ------------------------ EVALUATION & SAVE ------------------------
def eval_split(name, dl, X_raw, y0):
    metrics, probs, preds, ys, _ = evaluate(dl)
    ys_one   = ys + 1
    pred_one = preds + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics

metrics_train = eval_split("train", dl_tr, X_train, y_train)
metrics_val   = eval_split("val",   dl_va, X_val,   y_val)
metrics_test  = eval_split("test",  dl_te, X_test,  y_test)

metrics_df = pd.DataFrame([ {**{"split":"train"}, **metrics_train},
                             {**{"split":"val"},   **metrics_val},
                             {**{"split":"test"},  **metrics_test} ])
metrics_df.to_csv(OUTPUT_DIR / "metrics.csv", index=False)
print(metrics_df)

# Save artifacts
torch.save(model.state_dict(), OUTPUT_DIR / "tabm_model.pt")
joblib.dump(pp, OUTPUT_DIR / "preprocessor.pkl")
with open(OUTPUT_DIR / "config.json", "w", encoding="utf-8") as f:
    json.dump({
        "timestamp": RUN_NAME,
        "seed": SEED,
        "data_csv": DATA_CSV,
        "num_classes": NUM_CLASSES,
        "k_ensemble": K_ENSEMBLE,
        "batch_size": BATCH_SIZE,
        "accum_steps": ACCUM_STEPS,
        "lr": LEARNING_RATE,
        "weight_decay": WEIGHT_DECAY,
        "max_epochs": MAX_EPOCHS,
        "patience": PATIENCE,
        "mixed_precision": MIXED_PRECISION,
        "numeric_cols": num_cols,
        "categorical_cols": cat_cols,
        "cat_cardinalities": cat_cardinalities
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to: {OUTPUT_DIR.resolve()}")

# ------------------------ INFERENCE HELPER ------------------------
def predict_target_risk_class_tabm(df_new: pd.DataFrame,
                                   model_path=OUTPUT_DIR / "tabm_model.pt",
                                   preproc_path=OUTPUT_DIR / "preprocessor.pkl",
                                   batch_size=BATCH_SIZE):
    """Predict on raw frame BEFORE preprocessing; returns labels 1..10.
       Averages ensemble probabilities across k (recommended)."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mdl = TabM.make(
        n_num_features=len(num_cols),
        num_embeddings=PiecewiseLinearEmbeddings(n_features=len(num_cols)) if len(num_cols)>0 else None,
        cat_cardinalities=cat_cardinalities if len(cat_cols)>0 else None,
        d_out=NUM_CLASSES,
        k=K_ENSEMBLE,
        arch_type="batch-ensemble",
    )
    mdl.load_state_dict(torch.load(model_path, map_location="cpu"))
    mdl.to(device)
    mdl.eval()
    preproc = joblib.load(preproc_path)
    d = preproc.transform(df_new)
    ds = TabDataset(d, None, num_cols, cat_cols)
    dl = data.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=(device.type=="cuda"))
    preds_all = []
    with torch.no_grad():
        for (xb_num, xb_cat) in dl:
            x_num = xb_num.to(device, non_blocking=True) if len(num_cols)>0 else None
            x_cat = xb_cat.to(device, non_blocking=True) if len(cat_cols)>0 else None
            with torch.cuda.amp.autocast(enabled=(device.type=="cuda"), dtype=torch.float16):
                logits_k = mdl(x_num, x_cat) if (x_num is not None and x_cat is not None and x_cat.numel()>0) \
                          else (mdl(x_num) if x_num is not None else mdl(None, x_cat))
                probs = F.softmax(logits_k.float(), dim=-1).mean(dim=1)  # average across k
            preds_all.append(probs.argmax(dim=1).cpu())
    pred0 = torch.cat(preds_all, dim=0).numpy()
    pred1 = pred0 + 1
    return pd.Series(pred1, index=df_new.index, name="pred_target_risk_class")

#### Tab-Transformer (CPU/VSCODE)

In [None]:
# ===============================================================
# Tab-Transformer (CPU, PyTorch) – Multiclass Pipeline (FAST v2)
# - Local paths; timestamped outputs under ./RESULTS/tabtransformer_outputs
# - Canonicalized column names; label column = target_risk_class (values 1..10)
# - Preprocess: datetime expansion, numeric downcast+median, capped & unique categoricals
# - CPU-optimized: tiny encoder, mean pooling, capped batches/epoch, no dataloader workers
# - Training: AdamW, cosine LR, early stopping (val CE)
# - Metrics: accuracy, macro/weighted P/R/F1, log loss, ROC-AUC OvR macro
# - Artifacts: tabtransformer_model.pt, preprocessor.pkl, encoder_meta.json, metrics.csv,
#   loss_curves.png, reports + confusions, training_meta.json
# - Inference helper returns labels in 1..10
# ===============================================================

import os, re, json, warnings, datetime, math, time, random, gc
from pathlib import Path
import numpy as np
import pandas as pd
import pandas.api.types as pdt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

import joblib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore")

# ------------------------ PATHS ------------------------
DATA_CSV = "path"
OUTPUT_DIR = Path("path"); OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)
RUN_NAME = f"TabTransformer_CPU_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
# On CPU, fewer threads sometimes yields better throughput
torch.set_num_threads(max(1, (os.cpu_count() or 4)//2))

TARGET_NAME = "target_risk_class"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# ---- Low-RAM & training profile (FAST) ----
BATCH_SIZE       = 4096        # larger batch since model is small
GRAD_ACC_STEPS   = 1
MAX_EPOCHS       = 10
BASE_LR          = 2e-3
WEIGHT_DECAY     = 3e-4
PATIENCE         = 2
MIN_DELTA        = 1e-4
WARMUP_EPOCHS    = 2
NUM_WORKERS      = 0           # Windows/CPU: avoid multiprocessing overhead
MAX_BATCHES_PER_EPOCH = 100    # caps compute per epoch (~100*4096 ≈ 410k samples)

# ---- Evaluation toggles ----
EVAL_TRAIN = False             # set True if you really want train metrics

# ---- Preprocessing ----
MISSING_TOKEN      = "Missing"
MAX_CAT_CARD       = 200       # cap cardinality to shrink embeddings
STORE_NUM_AS_FP16  = True
STORE_CAT_AS_INT32 = True

# ---- Tab-Transformer hyperparams (tiny-but-strong) ----
D_TOKEN       = 32             # must be divisible by N_HEADS
N_HEADS       = 4
N_LAYERS      = 1
D_FF          = 128
ATTN_DROPOUT  = 0.10
FFN_DROPOUT   = 0.10
RESID_DROPOUT = 0.10
POOLING       = "mean"        # "concat" | "mean" (mean keeps 32 dims)

# ---- Numeric path ----
USE_NUMERIC  = True
D_NUM_PROJ   = 64
HEAD_HIDDEN  = 128
HEAD_DEPTH   = 1
HEAD_DROPOUT = 0.10
ACTIVATION   = "silu"         # "relu" | "gelu" | "silu"

# ---- sanity ----
assert D_TOKEN % N_HEADS == 0, "D_TOKEN must be divisible by N_HEADS"

# ------------------------ helpers ------------------------
def canon_col(name: str) -> str:
    s = re.sub(r"[^0-9A-Za-z_]+", "_", str(name))
    s = re.sub(r"_+", "_", s).strip("_")
    return s

def canon_cols_inplace(df: pd.DataFrame) -> pd.DataFrame:
    df.columns = [canon_col(c) for c in df.columns]
    return df

def _unique_in_order(seq):
    seen = set(); out = []
    for x in seq:
        x = str(x)
        if x not in seen:
            seen.add(x); out.append(x)
    return out

def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]): df[c] = pd.to_numeric(df[c], downcast="float")
        else:                         df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try: df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception: pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols: df.drop(columns=dt_cols, inplace=True)
    return df

def pd_cat_fix(series, allowed):
    allowed = _unique_in_order(list(allowed) + [MISSING_TOKEN])
    s = series.astype("string").fillna(MISSING_TOKEN)
    s = s.where(s.isin(allowed), MISSING_TOKEN)
    return pd.Categorical(s, categories=pd.Index(allowed), ordered=False)

def _act(name):
    return {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[name]()

# ------------------------ Preprocessor ------------------------
class TabularPreprocessor:
    def __init__(self):
        self.num_cols_ = []; self.cat_cols_ = []
        self.num_median_ = {}; self.cat_categories_ = {}
        self.feature_names_ = []; self.fitted_ = False

    def _prep_base(self, X):
        d = X.copy()
        canon_cols_inplace(d)
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)
        self.num_cols_ = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        self.cat_cols_ = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]
        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())
        for c in self.cat_cols_:
            s = d[c].astype("string").fillna(MISSING_TOKEN)
            vc = s.value_counts(dropna=False)
            if MAX_CAT_CARD and len(vc) > (MAX_CAT_CARD - 1):
                top = vc.index.astype("string").tolist()[:MAX_CAT_CARD - 1]
                cats = _unique_in_order(top + [MISSING_TOKEN])
            else:
                cats = _unique_in_order(pd.unique(s).astype("string").tolist() + [MISSING_TOKEN])
            if len(cats) != len(set(cats)):
                cats = _unique_in_order(cats)
            self.cat_categories_[c] = cats
        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_
        d = self._prep_base(X)
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()
        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce").astype("float32")
            if d[c].isna().any():
                d[c] = d[c].fillna(self.num_median_[c])
        for c in self.cat_cols_:
            d[c] = pd_cat_fix(d[c], self.cat_categories_[c])
        return d

# ------------------------ Torch encoder ------------------------
class TorchTabEncoder:
    def __init__(self, num_cols, cat_cols, cat_categories):
        self.num_cols = list(num_cols)
        self.cat_cols = list(cat_cols)
        self.cat_categories = {c: list(cats) for c, cats in cat_categories.items()}
        self.num_mean_ = None; self.num_std_  = None
        self.cat_cardinalities_ = {c: len(self.cat_categories[c]) for c in self.cat_cols}

    def fit(self, df_proc):
        if self.num_cols:
            arr = df_proc[self.num_cols].astype("float32").values
            self.num_mean_ = arr.mean(axis=0).astype("float32")
            std = arr.std(axis=0).astype("float32")
            self.num_std_  = np.where(std < 1e-6, 1.0, std).astype("float32")
        else:
            self.num_mean_ = np.array([], dtype="float32")
            self.num_std_  = np.array([], dtype="float32")
        return self

    def transform(self, df_proc):
        if self.num_cols:
            Xn = df_proc[self.num_cols].astype("float32").values
            Xn = (Xn - self.num_mean_) / self.num_std_
            if STORE_NUM_AS_FP16: Xn = Xn.astype("float16")
        else:
            Xn = np.zeros((len(df_proc), 0), dtype="float16" if STORE_NUM_AS_FP16 else "float32")

        Xc_list = []
        for c in self.cat_cols:
            codes = df_proc[c].cat.codes.to_numpy(copy=False)
            fix = self.cat_categories[c].index(MISSING_TOKEN)
            codes = np.where(codes < 0, fix, codes)
            codes = codes.astype("int32" if STORE_CAT_AS_INT32 else "int64")
            Xc_list.append(codes)
        Xc = np.stack(Xc_list, axis=1) if Xc_list else np.zeros((len(df_proc),0), dtype="int32" if STORE_CAT_AS_INT32 else "int64")
        return Xn, Xc

    def save_meta(self, path_json):
        meta = {
            "num_cols": self.num_cols,
            "cat_cols": self.cat_cols,
            "cat_categories": self.cat_categories,
            "num_mean": self.num_mean_.tolist(),
            "num_std": self.num_std_.tolist(),
            "cat_cardinalities": self.cat_cardinalities_,
        }
        with open(path_json, "w", encoding="utf-8") as f:
            json.dump(meta, f, ensure_ascii=False, indent=2)

    @staticmethod
    def load_meta(path_json):
        with open(path_json, "r", encoding="utf-8") as f:
            meta = json.load(f)
        enc = TorchTabEncoder(meta["num_cols"], meta["cat_cols"], meta["cat_categories"])
        enc.num_mean_ = np.array(meta["num_mean"], dtype="float32")
        enc.num_std_  = np.array(meta["num_std"], dtype="float32")
        enc.cat_cardinalities_ = {k:int(v) for k,v in meta["cat_cardinalities"].items()}
        return enc

# ------------------------ Dataset ------------------------
class TabDataset(Dataset):
    def __init__(self, Xn, Xc, y=None):
        self.Xn = Xn; self.Xc = Xc
        self.y  = None if y is None else y.astype("int64")
    def __len__(self): return len(self.Xn)
    def __getitem__(self, i):
        if self.y is None: return self.Xn[i], self.Xc[i]
        return self.Xn[i], self.Xc[i], self.y[i]

# ------------------------ Tab-Transformer model ------------------------
class CatTokenEmbeddings(nn.Module):
    def __init__(self, cat_cardinalities, d_token):
        super().__init__()
        self.n_cat = len(cat_cardinalities)
        self.d_token = d_token
        self.value_embs = nn.ModuleList([nn.Embedding(c, d_token) for c in cat_cardinalities])
        self.col_emb = nn.Parameter(torch.zeros(self.n_cat, d_token))
        nn.init.trunc_normal_(self.col_emb, std=0.02)
    def forward(self, x_cat):
        if x_cat.ndim == 1: x_cat = x_cat.unsqueeze(1)
        B, n = x_cat.shape
        if n == 0:
            return x_cat.new_zeros((B, 0, self.d_token)).float()  # safe dtype
        tokens = [emb(x_cat[:, j]) for j, emb in enumerate(self.value_embs)]
        tok = torch.stack(tokens, dim=1)          # (B, n_cat, d_token)
        tok = tok + self.col_emb.unsqueeze(0)     # add column embeddings
        return tok

class TransformerEncoder(nn.Module):
    def __init__(self, d_token, n_heads, d_ff, n_layers, attn_dropout, ffn_dropout, resid_dropout):
        super().__init__()
        if n_layers == 0:
            self.encoder = None
        else:
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_token, nhead=n_heads, dim_feedforward=d_ff,
                dropout=resid_dropout, activation="gelu", batch_first=True, norm_first=True
            )
            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
    def forward(self, x_tokens):
        if x_tokens.size(1) == 0 or self.encoder is None:
            return x_tokens
        return self.encoder(x_tokens)

class NumericProjector(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=256, p=0.10, act="silu"):
        super().__init__()
        if in_dim == 0:
            self.net = None
        else:
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden), nn.LayerNorm(hidden), _act(act), nn.Dropout(p),
                nn.Linear(hidden, out_dim), nn.LayerNorm(out_dim), _act(act),
            )
    def forward(self, x):
        if self.net is None:
            return x.new_zeros((x.size(0), 0))
        return self.net(x)

class HeadMLP(nn.Module):
    def __init__(self, in_dim, hidden, depth, out_dim, p=0.10, act="silu"):
        super().__init__()
        if in_dim == 0:
            raise ValueError("No input features: fused_in == 0 (both numeric and categorical are empty).")
        layers = []
        d = in_dim
        for _ in range(max(0, depth-1)):
            layers += [nn.Linear(d, hidden), nn.LayerNorm(hidden), _act(act), nn.Dropout(p)]
            d = hidden
        layers += [nn.Linear(d, out_dim)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

class TabTransformer(nn.Module):
    def __init__(self, n_num, cat_cardinalities, d_token, n_heads, n_layers,
                 d_ff, attn_dropout, ffn_dropout, resid_dropout,
                 pooling, d_num_proj, head_hidden, head_depth, d_out, act="silu"):
        super().__init__()
        self.n_num = n_num
        self.pooling = pooling
        self.d_token = d_token

        self.cat_emb = CatTokenEmbeddings(cat_cardinalities, d_token) if cat_cardinalities else None
        self.encoder  = TransformerEncoder(d_token, n_heads, d_ff, n_layers, attn_dropout, ffn_dropout, resid_dropout)
        self.num_proj = NumericProjector(n_num, d_num_proj, hidden=max(128, d_num_proj*2), p=HEAD_DROPOUT, act=act) if (USE_NUMERIC and n_num>0) else NumericProjector(0,0)

        cat_out_dim = 0
        if cat_cardinalities:
            cat_out_dim = (len(cat_cardinalities) * d_token) if pooling == "concat" else d_token
        fused_in = cat_out_dim + (d_num_proj if (USE_NUMERIC and n_num>0) else 0)
        print(f"[TabTransformer] cat_out_dim={cat_out_dim}  num_proj={(d_num_proj if (USE_NUMERIC and n_num>0) else 0)}  fused_in={fused_in}")
        if fused_in == 0:
            raise ValueError("No input features: both numeric and categorical are empty after preprocessing.")
        self.head = HeadMLP(fused_in, head_hidden, head_depth, d_out, p=HEAD_DROPOUT, act=act)

    def forward(self, x_num, x_cat):
        if self.cat_emb is not None and x_cat is not None and x_cat.size(1) > 0:
            tok = self.cat_emb(x_cat)
            enc = self.encoder(tok)
            cat_repr = enc.mean(dim=1) if self.pooling == "mean" else enc.flatten(1)
        else:
            cat_repr = x_num.new_zeros((x_num.size(0), 0))
        if self.n_num > 0 and x_num is not None and x_num.size(1) > 0:
            num_repr = self.num_proj(x_num)
        else:
            num_repr = x_num.new_zeros((x_num.size(0), 0))
        z = torch.cat([cat_repr, num_repr], dim=1)
        return self.head(z)

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
canon_cols_inplace(df)
if TARGET_NAME not in df.columns:
    raise KeyError(f"Expected label column '{TARGET_NAME}' after canonicalization; got first columns {list(df.columns)[:20]}")
y1 = pd.to_numeric(df[TARGET_NAME], errors="coerce").astype("Int64")
y1 = y1.where((y1>=1) & (y1<=10))
mask = y1.notna()
df = df.loc[mask].copy()
y1 = y1.loc[mask].astype("int16")
y  = (y1 - 1).astype("int16")
X  = df.drop(columns=[TARGET_NAME])
del df; gc.collect()

# ------------------------ SPLIT (70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv",  index=False)
print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS & ENCODE ------------------------
pp = TabularPreprocessor().fit(X_train)
Xtr_df = pp.transform(X_train); Xva_df = pp.transform(X_val); Xte_df = pp.transform(X_test)

enc = TorchTabEncoder(pp.num_cols_, pp.cat_cols_, pp.cat_categories_).fit(Xtr_df)
Xtr_num, Xtr_cat = enc.transform(Xtr_df)
Xva_num, Xva_cat = enc.transform(Xva_df)
Xte_num, Xte_cat = enc.transform(Xte_df)

# free raw frames
del X_train, X_val, X_test, X_trainval, Xtr_df, Xva_df, Xte_df, X, y, y1; gc.collect()

y_tr = y_train.values.astype("int64")
y_va = y_val.values.astype("int64")
y_te = y_test.values.astype("int64")

ds_tr = TabDataset(Xtr_num, Xtr_cat, y_tr)
ds_va = TabDataset(Xva_num, Xva_cat, y_va)
ds_te = TabDataset(Xte_num, Xte_cat, y_te)

# Windows/CPU: set num_workers=0
_dl_workers = NUM_WORKERS

dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True,  num_workers=_dl_workers)
dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=_dl_workers)
dl_te = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, num_workers=_dl_workers)

# ------------------------ CLASS WEIGHTS ------------------------
classes_present = np.unique(y_tr)
cw = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_tr)
cw_map = {int(c): float(w) for c, w in zip(classes_present, cw)}
class_weights = np.ones(NUM_CLASSES, dtype="float32")
for c, w in cw_map.items(): class_weights[c] = w
class_weights = class_weights / class_weights.mean()
class_weights_t = torch.tensor(class_weights, dtype=torch.float32)

# ------------------------ MODEL / OPTIM ------------------------
device = torch.device("cpu")
n_num = Xtr_num.shape[1]
cat_cards = [enc.cat_cardinalities_[c] for c in enc.cat_cols] if enc.cat_cols else []

model = TabTransformer(
    n_num=n_num, cat_cardinalities=cat_cards,
    d_token=D_TOKEN, n_heads=N_HEADS, n_layers=N_LAYERS,
    d_ff=D_FF, attn_dropout=ATTN_DROPOUT, ffn_dropout=FFN_DROPOUT, resid_dropout=RESID_DROPOUT,
    pooling=POOLING, d_num_proj=D_NUM_PROJ if (USE_NUMERIC and n_num>0) else 0,
    head_hidden=HEAD_HIDDEN, head_depth=HEAD_DEPTH, d_out=NUM_CLASSES, act=ACTIVATION
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Tab-Transformer params: {total_params/1e6:.2f}M | cats: {len(cat_cards)} | nums: {n_num}")

optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)

def cosine_factor(epoch, max_epochs=MAX_EPOCHS, warmup=WARMUP_EPOCHS):
    if epoch < warmup: return (epoch + 1) / max(1, warmup)
    t = (epoch - warmup) / max(1, max_epochs - warmup)
    return 0.5 * (1.0 + math.cos(math.pi * t))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: cosine_factor(e))

# ------------------------ TRAIN (early stopping + capped batches) ------------------------
best_val = float("inf"); best_epoch = -1; pat = 0
history = {"epoch": [], "train_loss": [], "val_loss": [], "lr": []}
t0 = time.time()

for epoch in range(1, MAX_EPOCHS + 1):
    model.train()
    total, n = 0.0, 0
    optimizer.zero_grad(set_to_none=True)

    for step, (xnum_np, xcat_np, yb) in enumerate(dl_tr, 1):
        xnum = torch.as_tensor(xnum_np, device=device).float()
        xcat = torch.as_tensor(xcat_np, device=device).long()
        yb   = yb.to(device)

        logits = model(xnum, xcat)
        loss = F.cross_entropy(logits, yb, weight=class_weights_t.to(device))

        (loss / GRAD_ACC_STEPS).backward()
        if step % GRAD_ACC_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        total += loss.item() * yb.size(0); n += yb.size(0)
        del xnum, xcat

        if step >= MAX_BATCHES_PER_EPOCH:
            break

    train_loss = total / max(1, n)

    # ---- validation
    model.eval()
    vtotal, vn = 0.0, 0
    with torch.no_grad():
        for xnum_np, xcat_np, yb in dl_va:
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            yb   = yb.to(device)
            logits = model(xnum, xcat)
            vloss = F.cross_entropy(logits, yb, weight=class_weights_t.to(device))
            vtotal += vloss.item() * yb.size(0); vn += yb.size(0)
            del xnum, xcat
    val_loss = vtotal / max(1, vn)
    scheduler.step()

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["lr"].append(optimizer.param_groups[0]["lr"])
    print(f"Epoch {epoch:03d} | train {train_loss:.4f} | val {val_loss:.4f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if val_loss + MIN_DELTA < best_val:
        best_val = val_loss; best_epoch = epoch; pat = 0
        torch.save({"state_dict": model.state_dict(),
                    "tt_config": {
                        "n_num": n_num, "cat_cardinalities": cat_cards,
                        "d_token": D_TOKEN, "n_heads": N_HEADS, "n_layers": N_LAYERS,
                        "d_ff": D_FF, "attn_dropout": ATTN_DROPOUT, "ffn_dropout": FFN_DROPOUT, "resid_dropout": RESID_DROPOUT,
                        "pooling": POOLING, "d_num_proj": (D_NUM_PROJ if (USE_NUMERIC and n_num>0) else 0),
                        "head_hidden": HEAD_HIDDEN, "head_depth": HEAD_DEPTH,
                        "d_out": NUM_CLASSES, "activation": ACTIVATION
                    }},
                   OUTPUT_DIR / "tabtransformer_model.pt")
    else:
        pat += 1
        if pat >= PATIENCE:
            print(f"Early stopping at epoch {epoch} (best @ {best_epoch} | val {best_val:.4f})")
            break

elapsed = time.time() - t0
print(f"Training time: {elapsed/60:.1f} min; best epoch: {best_epoch}")

# ------------------------ Evaluation ------------------------

def predict_proba_dl(model, dl):
    model.eval()
    probs = []
    with torch.no_grad():
        for batch in dl:
            if len(batch) == 3: xnum_np, xcat_np, _ = batch
            else:               xnum_np, xcat_np = batch
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            logits = model(xnum, xcat)
            p = F.softmax(logits, dim=-1).cpu().numpy()
            probs.append(p); del xnum, xcat
    return np.vstack(probs)

ckpt = torch.load(OUTPUT_DIR / "tabtransformer_model.pt", map_location=device)
model.load_state_dict(ckpt["state_dict"])

# leaner eval settings
_eval_bs = 4096

dl_va_eval = DataLoader(ds_va, batch_size=_eval_bs, shuffle=False, num_workers=0)
dl_te_eval = DataLoader(ds_te, batch_size=_eval_bs, shuffle=False, num_workers=0)
if EVAL_TRAIN:
    dl_tr_eval = DataLoader(ds_tr, batch_size=_eval_bs, shuffle=False, num_workers=0)


def eval_split(name, dl, y_zero):
    proba = predict_proba_dl(model, dl)
    pred0 = np.argmax(proba, axis=1)
    metrics = {
        "split": name,
        "n_samples": int(len(y_zero)),
        "accuracy": float(accuracy_score(y_zero, pred0)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(y_zero, pred0, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(y_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = float("nan")
    try:
        y_bin = pd.get_dummies(pd.Categorical(y_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = float("nan")

    ys_one  = y_zero + 1
    pred_one = pred0 + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

metrics_rows = []
if EVAL_TRAIN:
    m_train, _, _ = eval_split("train", dl_tr_eval, y_tr); metrics_rows.append(m_train)
m_val,   _, _ = eval_split("val",   dl_va_eval, y_va);    metrics_rows.append(m_val)
m_test,  _, _ = eval_split("test",  dl_te_eval, y_te);   metrics_rows.append(m_test)

pd.DataFrame(metrics_rows).to_csv(OUTPUT_DIR / "metrics_tabtransformer.csv", index=False)
print(pd.DataFrame(metrics_rows))

# ------------------------ Curves ------------------------
plt.figure(figsize=(7,4))
plt.plot(history["epoch"], history["train_loss"], label="train")
plt.plot(history["epoch"], history["val_loss"],   label="val")
plt.xlabel("epoch"); plt.ylabel("CE loss"); plt.title("Tab-Transformer Loss"); plt.legend(); plt.tight_layout()
plt.savefig(OUTPUT_DIR / "loss_curves.png", dpi=150); plt.close()

# ------------------------ Save artifacts ------------------------
joblib.dump(pp, OUTPUT_DIR / "preprocessor.pkl")
enc.save_meta(OUTPUT_DIR / "encoder_meta.json")
with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": TARGET_NAME,
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),
        "seed": SEED,
        "best_epoch": int(best_epoch),
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV,
        "tt_config": ckpt["tt_config"],
        "train_time_min": round(elapsed/60, 2),
        "feature_names": pp.feature_names_,
        "cat_features": pp.cat_cols_,
        "num_features": pp.num_cols_,
        "column_names_canonicalized": True
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to: {OUTPUT_DIR.resolve()}")

# ------------------------ Inference helper ------------------------

def predict_target_risk_class(df_new: pd.DataFrame,
                              model_path=OUTPUT_DIR / "tabtransformer_model.pt",
                              preproc_path=OUTPUT_DIR / "preprocessor.pkl",
                              encoder_meta_path=OUTPUT_DIR / "encoder_meta.json",
                              batch_size=4096) -> pd.Series:
    """Predict on new raw rows (returns labels in 1..10)."""
    device = torch.device("cpu")
    pp = joblib.load(preproc_path)
    enc = TorchTabEncoder.load_meta(encoder_meta_path)
    ckpt = torch.load(model_path, map_location=device)
    cfg = ckpt["tt_config"]

    df_new = df_new.copy(); canon_cols_inplace(df_new)

    n_num = len(enc.num_cols)
    cat_cards = [enc.cat_cardinalities_[c] for c in enc.cat_cols] if enc.cat_cols else []
    model = TabTransformer(
        n_num=n_num, cat_cardinalities=cat_cards,
        d_token=cfg["d_token"], n_heads=cfg["n_heads"], n_layers=cfg["n_layers"],
        d_ff=cfg["d_ff"], attn_dropout=cfg["attn_dropout"], ffn_dropout=cfg["ffn_dropout"],
        resid_dropout=cfg["resid_dropout"], pooling=cfg["pooling"],
        d_num_proj=cfg["d_num_proj"], head_hidden=cfg["head_hidden"], head_depth=cfg["head_depth"],
        d_out=cfg["d_out"], act=cfg.get("activation","silu")
    ).to(device)
    model.load_state_dict(ckpt["state_dict"]); model.eval()

    dproc = pp.transform(df_new)
    Xn, Xc = enc.transform(dproc)
    ds = TabDataset(Xn, Xc, y=None)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)

    preds = []
    with torch.no_grad():
        for xnum_np, xcat_np in dl:
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            logits = model(xnum, xcat)
            preds.append(torch.argmax(F.softmax(logits, dim=-1), dim=-1).cpu().numpy())
    yhat0 = np.concatenate(preds, axis=0).astype("int16")
    return pd.Series(yhat0 + 1, index=dproc.index, name="pred_target_risk_class")

#### FT-Transformer (CPU/VSCODE)

In [None]:
# ===============================================================
# FT-Transformer (CPU, PyTorch, no external rtdl) – Multiclass Pipeline (FAST v2)
# - Outputs under ./RESULTS/fttransformer_outputs/<timestamp>
# - Canonicalized column names; label column = target_risk_class (values 1..10)
# - Preprocess: datetime expansion, numeric downcast+median; capped & unique categoricals
# - CPU-optimized: tiny encoder, capped batches/epoch, no dataloader workers
# - Training: AdamW, cosine LR, early stopping (val CE)
# - Metrics: accuracy, macro/weighted P/R/F1, log loss, ROC-AUC OvR macro
# - Artifacts: fttransformer_model.pt, preprocessor.pkl, encoder_meta.json, metrics.csv,
#   loss_curves.png, reports + confusions, training_meta.json
# - Inference helper returns labels in 1..10
# ===============================================================

import os, re, json, warnings, datetime, math, time, random, gc
from pathlib import Path
import numpy as np
import pandas as pd
import pandas.api.types as pdt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

import joblib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore")

# ------------------------ PATHS ------------------------
DATA_CSV = "path"
RUN_NAME = f"FTTransformer_CPU_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
OUTPUT_DIR = Path("path") / RUN_NAME
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)
print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
# On CPU, fewer threads can be faster
torch.set_num_threads(max(1, (os.cpu_count() or 4)//2))

TARGET_NAME = "target_risk_class"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# ---- Training profile (FAST) ----
BATCH_SIZE       = 4096        # larger batch for tiny model
GRAD_ACC_STEPS   = 1
MAX_EPOCHS       = 10
BASE_LR          = 2e-3
WEIGHT_DECAY     = 3e-4
PATIENCE         = 2
MIN_DELTA        = 1e-4
WARMUP_EPOCHS    = 2
NUM_WORKERS      = 0           # Windows/CPU: avoid multiprocessing overhead
MAX_BATCHES_PER_EPOCH = 100    # caps work per epoch (~100*4096 ≈ 410k samples)

# ---- Evaluation toggles ----
EVAL_TRAIN = False             # set True if you need train metrics too

# ---- Preprocessing ----
MISSING_TOKEN      = "Missing"
MAX_CAT_CARD       = 200       # cap cardinality to shrink embeddings
STORE_NUM_AS_FP16  = True
STORE_CAT_AS_INT32 = True

# ---- FT-style model hyperparams (tiny-but-strong) ----
D_TOKEN            = 32
N_BLOCKS           = 1
ATTN_N_HEADS       = 4
FFN_D_HIDDEN       = 128
ATTN_DROPOUT       = 0.10
FFN_DROPOUT        = 0.10
RESID_DROPOUT      = 0.10
USE_CLS_TOKEN      = False     # mean pooling keeps sequence length small

assert D_TOKEN % ATTN_N_HEADS == 0, "D_TOKEN must be divisible by ATTN_N_HEADS"

# ------------------------ helpers ------------------------
def canon_col(name: str) -> str:
    s = re.sub(r"[^0-9A-Za-z_]+", "_", str(name))
    s = re.sub(r"_+", "_", s).strip("_")
    return s

def canon_cols_inplace(df: pd.DataFrame) -> pd.DataFrame:
    df.columns = [canon_col(c) for c in df.columns]
    return df

def _unique_in_order(seq):
    seen = set(); out = []
    for x in seq:
        xs = str(x)
        if xs not in seen:
            seen.add(xs); out.append(xs)
    return out

def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]): df[c] = pd.to_numeric(df[c], downcast="float")
        else:                         df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try: df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception: pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols: df.drop(columns=dt_cols, inplace=True)
    return df

def pd_cat_fix(series, allowed):
    allowed = _unique_in_order(list(allowed) + [MISSING_TOKEN])
    s = series.astype("string").fillna(MISSING_TOKEN)
    s = s.where(s.isin(allowed), MISSING_TOKEN)
    return pd.Categorical(s, categories=pd.Index(allowed), ordered=False)

# ------------------------ Preprocessor ------------------------
class TabularPreprocessor:
    def __init__(self):
        self.num_cols_ = []; self.cat_cols_ = []
        self.num_median_ = {}; self.cat_categories_ = {}
        self.feature_names_ = []; self.fitted_ = False

    def _prep_base(self, X):
        d = X.copy()
        canon_cols_inplace(d)
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)
        self.num_cols_ = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        self.cat_cols_ = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]
        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())
        for c in self.cat_cols_:
            s = d[c].astype("string").fillna(MISSING_TOKEN)
            vc = s.value_counts(dropna=False)
            if MAX_CAT_CARD and len(vc) > (MAX_CAT_CARD - 1):
                top = vc.index.astype("string").tolist()[:MAX_CAT_CARD - 1]
                cats = _unique_in_order(top + [MISSING_TOKEN])
            else:
                cats = _unique_in_order(pd.unique(s).astype("string").tolist() + [MISSING_TOKEN])
            if len(cats) != len(set(cats)):
                cats = _unique_in_order(cats)
            self.cat_categories_[c] = cats
        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_
        d = self._prep_base(X)
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()
        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce").astype("float32")
            if d[c].isna().any():
                d[c] = d[c].fillna(self.num_median_[c])
        for c in self.cat_cols_:
            d[c] = pd_cat_fix(d[c], self.cat_categories_[c])
        return d

# ------------------------ Encoder (standardize nums + codes for cats) ------------------------
class TorchTabEncoder:
    def __init__(self, num_cols, cat_cols, cat_categories):
        self.num_cols = list(num_cols)
        self.cat_cols = list(cat_cols)
        self.cat_categories = {c: list(cats) for c, cats in cat_categories.items()}
        self.num_mean_ = None; self.num_std_ = None
        self.cat_cardinalities_ = {c: len(self.cat_categories[c]) for c in self.cat_cols}

    def fit(self, df_proc):
        if self.num_cols:
            arr = df_proc[self.num_cols].astype("float32").values
            self.num_mean_ = arr.mean(axis=0).astype("float32")
            std = arr.std(axis=0).astype("float32")
            self.num_std_  = np.where(std < 1e-6, 1.0, std).astype("float32")
        else:
            self.num_mean_ = np.array([], dtype="float32")
            self.num_std_  = np.array([], dtype="float32")
        return self

    def transform(self, df_proc):
        if self.num_cols:
            Xn = df_proc[self.num_cols].astype("float32").values
            Xn = (Xn - self.num_mean_) / self.num_std_
            if STORE_NUM_AS_FP16: Xn = Xn.astype("float16")
        else:
            Xn = np.zeros((len(df_proc), 0), dtype="float16" if STORE_NUM_AS_FP16 else "float32")

        Xc_list = []
        for c in self.cat_cols:
            codes = df_proc[c].cat.codes.to_numpy(copy=False)
            fix = self.cat_categories[c].index(MISSING_TOKEN)
            codes = np.where(codes < 0, fix, codes)
            codes = codes.astype("int32" if STORE_CAT_AS_INT32 else "int64")
            Xc_list.append(codes)
        Xc = np.stack(Xc_list, axis=1) if Xc_list else np.zeros((len(df_proc),0), dtype="int32" if STORE_CAT_AS_INT32 else "int64")
        return Xn, Xc

    def save_meta(self, path_json):
        meta = {
            "num_cols": self.num_cols,
            "cat_cols": self.cat_cols,
            "cat_categories": self.cat_categories,
            "num_mean": self.num_mean_.tolist(),
            "num_std": self.num_std_.tolist(),
            "cat_cardinalities": self.cat_cardinalities_,
        }
        with open(path_json, "w", encoding="utf-8") as f:
            json.dump(meta, f, ensure_ascii=False, indent=2)

    @staticmethod
    def load_meta(path_json):
        with open(path_json, "r", encoding="utf-8") as f:
            meta = json.load(f)
        enc = TorchTabEncoder(meta["num_cols"], meta["cat_cols"], meta["cat_categories"])
        enc.num_mean_ = np.array(meta["num_mean"], dtype="float32")
        enc.num_std_  = np.array(meta["num_std"], dtype="float32")
        enc.cat_cardinalities_ = {k:int(v) for k,v in meta["cat_cardinalities"].items()}
        return enc

# ------------------------ Dataset ------------------------
class TabDataset(Dataset):
    def __init__(self, Xn, Xc, y=None):
        self.Xn = Xn; self.Xc = Xc
        self.y  = None if y is None else y.astype("int64")
    def __len__(self): return len(self.Xn)
    def __getitem__(self, i):
        if self.y is None: return self.Xn[i], self.Xc[i]
        return self.Xn[i], self.Xc[i], self.y[i]

# ------------------------ FT-style Model (no rtdl) ------------------------
class NumericFeatureTokenizer(nn.Module):
    """
    Tokenizes each numeric feature as: token_j = x_j * W_j + b_j, where W_j, b_j in R^{d_token}.
    """
    def __init__(self, n_num: int, d_token: int):
        super().__init__()
        self.n_num = n_num
        self.d_token = d_token
        if n_num > 0:
            self.weight = nn.Parameter(torch.empty(n_num, d_token))
            self.bias   = nn.Parameter(torch.zeros(n_num, d_token))
            nn.init.trunc_normal_(self.weight, std=0.02)
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias",   None)

    def forward(self, x_num: torch.Tensor):  # (B, n_num)
        if self.n_num == 0:
            return x_num.new_zeros((x_num.size(0), 0, self.d_token))
        return x_num.unsqueeze(-1) * self.weight.unsqueeze(0) + self.bias.unsqueeze(0)

class CategoricalFeatureTokenizer(nn.Module):
    def __init__(self, cat_cardinalities, d_token):
        super().__init__()
        self.n_cat = len(cat_cardinalities)
        self.d_token = d_token
        self.embs = nn.ModuleList([nn.Embedding(card, d_token) for card in cat_cardinalities])
        for emb in self.embs:
            nn.init.trunc_normal_(emb.weight, std=0.02)

    def forward(self, x_cat: torch.Tensor):  # (B, n_cat)
        if self.n_cat == 0:
            return x_cat.new_zeros((x_cat.size(0), 0, self.d_token), dtype=torch.float32)
        toks = [emb(x_cat[:, j]) for j, emb in enumerate(self.embs)]
        return torch.stack(toks, dim=1)   # (B, n_cat, d)

class FTTransformerNoRTDL(nn.Module):
    def __init__(self,
                 n_num: int,
                 cat_cardinalities: list,
                 d_token: int,
                 n_blocks: int,
                 n_heads: int,
                 ffn_d_hidden: int,
                 attn_dropout: float,
                 ffn_dropout: float,
                 resid_dropout: float,
                 use_cls_token: bool,
                 d_out: int):
        super().__init__()
        self.n_num = n_num
        self.use_cls = use_cls_token
        self.d_token = d_token

        self.num_tok = NumericFeatureTokenizer(n_num, d_token)
        self.cat_tok = CategoricalFeatureTokenizer(cat_cardinalities, d_token) if cat_cardinalities else None

        self.cls = nn.Parameter(torch.zeros(1, 1, d_token)) if use_cls_token else None
        if self.cls is not None:
            nn.init.trunc_normal_(self.cls, std=0.02)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=n_heads,
            dim_feedforward=ffn_d_hidden,
            dropout=resid_dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_blocks)
        self.dropout = nn.Dropout(resid_dropout)

        self.head = nn.Sequential(
            nn.LayerNorm(d_token),
            nn.Linear(d_token, d_out)
        )

    def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor):
        num_tokens = self.num_tok(x_num) if x_num is not None and x_num.size(1) > 0 else None
        cat_tokens = self.cat_tok(x_cat) if (self.cat_tok is not None and x_cat is not None and x_cat.size(1) > 0) else None
        if num_tokens is None and cat_tokens is None:
            raise ValueError("Both numeric and categorical tokens are empty.")

        tokens = cat_tokens if num_tokens is None else num_tokens if cat_tokens is None else torch.cat([num_tokens, cat_tokens], dim=1)

        if self.use_cls:
            B = tokens.size(0)
            cls_tok = self.cls.expand(B, -1, -1)
            tokens = torch.cat([cls_tok, tokens], dim=1)

        z = self.encoder(tokens)
        pooled = z[:, 0, :] if self.use_cls else z.mean(dim=1)
        logits = self.head(self.dropout(pooled))
        return logits

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
canon_cols_inplace(df)
if TARGET_NAME not in df.columns:
    raise KeyError(f"Expected label column '{TARGET_NAME}' after canonicalization; got first columns {list(df.columns)[:20]}")
y1 = pd.to_numeric(df[TARGET_NAME], errors="coerce").astype("Int64")
y1 = y1.where((y1>=1) & (y1<=10))
mask = y1.notna()
df = df.loc[mask].copy()
y1 = y1.loc[mask].astype("int16")
y  = (y1 - 1).astype("int16")
X  = df.drop(columns=[TARGET_NAME])
del df; gc.collect()

# ------------------------ SPLIT (70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv",   index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv",  index=False)
print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS & ENCODE ------------------------
pp = TabularPreprocessor().fit(X_train)
Xtr_df = pp.transform(X_train); Xva_df = pp.transform(X_val); Xte_df = pp.transform(X_test)

enc = TorchTabEncoder(pp.num_cols_, pp.cat_cols_, pp.cat_categories_).fit(Xtr_df)
Xtr_num, Xtr_cat = enc.transform(Xtr_df)
Xva_num, Xva_cat = enc.transform(Xva_df)
Xte_num, Xte_cat = enc.transform(Xte_df)

# free raw frames
del X_train, X_val, X_test, X_trainval, Xtr_df, Xva_df, Xte_df, X, y, y1; gc.collect()

y_tr = y_train.values.astype("int64")
y_va = y_val.values.astype("int64")
y_te = y_test.values.astype("int64")

ds_tr = TabDataset(Xtr_num, Xtr_cat, y_tr)
ds_va = TabDataset(Xva_num, Xva_cat, y_va)
ds_te = TabDataset(Xte_num, Xte_cat, y_te)

# Windows/CPU: num_workers=0
_dl_workers = NUM_WORKERS

dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True,  num_workers=_dl_workers)
dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=_dl_workers)
dl_te = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, num_workers=_dl_workers)

# ------------------------ CLASS WEIGHTS ------------------------
classes_present = np.unique(y_tr)
cw = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_tr)
cw_map = {int(c): float(w) for c, w in zip(classes_present, cw)}
class_weights = np.ones(NUM_CLASSES, dtype="float32")
for c, w in cw_map.items(): class_weights[c] = w
class_weights = class_weights / class_weights.mean()
class_weights_t = torch.tensor(class_weights, dtype=torch.float32)

# ------------------------ MODEL / OPTIM ------------------------
device = torch.device("cpu")
n_num = Xtr_num.shape[1]
cat_cards = [enc.cat_cardinalities_[c] for c in enc.cat_cols] if enc.cat_cols else []

if (n_num == 0) and (len(cat_cards) == 0):
    raise ValueError("No input features: both numeric and categorical are empty after preprocessing.")

model = FTTransformerNoRTDL(
    n_num=n_num,
    cat_cardinalities=cat_cards,
    d_token=D_TOKEN,
    n_blocks=N_BLOCKS,
    n_heads=ATTN_N_HEADS,
    ffn_d_hidden=FFN_D_HIDDEN,
    attn_dropout=ATTN_DROPOUT,
    ffn_dropout=FFN_DROPOUT,
    resid_dropout=RESID_DROPOUT,
    use_cls_token=USE_CLS_TOKEN,
    d_out=NUM_CLASSES
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"FT-style params: {total_params/1e6:.2f}M | cats: {len(cat_cards)} | nums: {n_num}")

optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)

def cosine_factor(epoch, max_epochs=MAX_EPOCHS, warmup=WARMUP_EPOCHS):
    if epoch < warmup: return (epoch + 1) / max(1, warmup)
    t = (epoch - warmup) / max(1, max_epochs - warmup)
    return 0.5 * (1.0 + math.cos(math.pi * t))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: cosine_factor(e))

# ------------------------ TRAIN (early stopping + capped batches) ------------------------
best_val = float("inf"); best_epoch = -1; pat = 0
history = {"epoch": [], "train_loss": [], "val_loss": [], "lr": []}
t0 = time.time()

for epoch in range(1, MAX_EPOCHS + 1):
    model.train()
    total, n = 0.0, 0
    optimizer.zero_grad(set_to_none=True)

    for step, (xnum_np, xcat_np, yb) in enumerate(dl_tr, 1):
        xnum = torch.as_tensor(xnum_np, device=device).float()
        xcat = torch.as_tensor(xcat_np, device=device).long()
        yb   = yb.to(device)

        logits = model(xnum if n_num>0 else None, xcat if len(cat_cards)>0 else None)
        loss = F.cross_entropy(logits, yb, weight=class_weights_t.to(device))

        (loss / GRAD_ACC_STEPS).backward()
        if step % GRAD_ACC_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        total += loss.item() * yb.size(0); n += yb.size(0)
        del xnum, xcat

        if step >= MAX_BATCHES_PER_EPOCH:
            break

    train_loss = total / max(1, n)

    # ---- validation
    model.eval()
    vtotal, vn = 0.0, 0
    with torch.no_grad():
        for xnum_np, xcat_np, yb in dl_va:
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            yb   = yb.to(device)
            logits = model(xnum if n_num>0 else None, xcat if len(cat_cards)>0 else None)
            vloss = F.cross_entropy(logits, yb, weight=class_weights_t.to(device))
            vtotal += vloss.item() * yb.size(0); vn += yb.size(0)
            del xnum, xcat
    val_loss = vtotal / max(1, vn)
    scheduler.step()

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["lr"].append(optimizer.param_groups[0]["lr"])
    print(f"Epoch {epoch:03d} | train {train_loss:.4f} | val {val_loss:.4f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if val_loss + MIN_DELTA < best_val:
        best_val = val_loss; best_epoch = epoch; pat = 0
        torch.save({
            "state_dict": model.state_dict(),
            "ftt_config": {
                "n_num": n_num, "cat_cardinalities": cat_cards,
                "d_token": D_TOKEN, "n_blocks": N_BLOCKS, "attention_n_heads": ATTN_N_HEADS,
                "ffn_d_hidden": FFN_D_HIDDEN, "attention_dropout": ATTN_DROPOUT,
                "ffn_dropout": FFN_DROPOUT, "residual_dropout": RESID_DROPOUT,
                "use_cls_token": USE_CLS_TOKEN,
                "d_out": NUM_CLASSES
            }
        }, OUTPUT_DIR / "fttransformer_model.pt")
    else:
        pat += 1
        if pat >= PATIENCE:
            print(f"Early stopping at epoch {epoch} (best @ {best_epoch} | val {best_val:.4f})")
            break

elapsed = time.time() - t0
print(f"Training time: {elapsed/60:.1f} min; best epoch: {best_epoch}")

# ------------------------ Evaluation ------------------------
def predict_proba_dl(model, dl):
    model.eval()
    probs = []
    with torch.no_grad():
        for batch in dl:
            if len(batch) == 3: xnum_np, xcat_np, _ = batch
            else:               xnum_np, xcat_np = batch
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            logits = model(xnum if n_num>0 else None, xcat if len(cat_cards)>0 else None)
            p = F.softmax(logits, dim=-1).cpu().numpy()
            probs.append(p); del xnum, xcat
    return np.vstack(probs)

ckpt = torch.load(OUTPUT_DIR / "fttransformer_model.pt", map_location=device)
model.load_state_dict(ckpt["state_dict"])

_eval_bs = 4096
dl_va_eval = DataLoader(ds_va, batch_size=_eval_bs, shuffle=False, num_workers=0)
dl_te_eval = DataLoader(ds_te, batch_size=_eval_bs, shuffle=False, num_workers=0)
if EVAL_TRAIN:
    dl_tr_eval = DataLoader(ds_tr, batch_size=_eval_bs, shuffle=False, num_workers=0)

def eval_split(name, dl, y_zero):
    proba = predict_proba_dl(model, dl)
    pred0 = np.argmax(proba, axis=1)
    metrics = {
        "split": name,
        "n_samples": int(len(y_zero)),
        "accuracy": float(accuracy_score(y_zero, pred0)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(y_zero, pred0, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(y_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = float("nan")
    try:
        y_bin = pd.get_dummies(pd.Categorical(y_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = float("nan")

    ys_one  = y_zero + 1
    pred_one = pred0 + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

metrics_rows = []
if EVAL_TRAIN:
    m_train, _, _ = eval_split("train", dl_tr_eval, y_tr); metrics_rows.append(m_train)
m_val,   _, _ = eval_split("val",   dl_va_eval, y_va);    metrics_rows.append(m_val)
m_test,  _, _ = eval_split("test",  dl_te_eval, y_te);   metrics_rows.append(m_test)

pd.DataFrame(metrics_rows).to_csv(OUTPUT_DIR / "metrics_fttransformer.csv", index=False)
print(pd.DataFrame(metrics_rows))

# ------------------------ Curves ------------------------
plt.figure(figsize=(7,4))
plt.plot(history["epoch"], history["train_loss"], label="train")
plt.plot(history["epoch"], history["val_loss"],   label="val")
plt.xlabel("epoch"); plt.ylabel("CE loss"); plt.title("FT-Transformer (no rtdl) Loss"); plt.legend(); plt.tight_layout()
plt.savefig(OUTPUT_DIR / "loss_curves.png", dpi=150); plt.close()

# ------------------------ Save artifacts ------------------------
joblib.dump(pp, OUTPUT_DIR / "preprocessor.pkl")
enc.save_meta(OUTPUT_DIR / "encoder_meta.json")
with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": TARGET_NAME,
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),
        "seed": SEED,
        "best_epoch": int(best_epoch),
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV,
        "ftt_config": {
            "n_num": n_num, "cat_cardinalities": cat_cards,
            "d_token": D_TOKEN, "n_blocks": N_BLOCKS, "attention_n_heads": ATTN_N_HEADS,
            "ffn_d_hidden": FFN_D_HIDDEN, "attention_dropout": ATTN_DROPOUT,
            "ffn_dropout": FFN_DROPOUT, "residual_dropout": RESID_DROPOUT,
            "use_cls_token": USE_CLS_TOKEN, "d_out": NUM_CLASSES
        },
        "train_time_min": round(elapsed/60, 2),
        "feature_names": pp.feature_names_,
        "cat_features": pp.cat_cols_,
        "num_features": pp.num_cols_,
        "column_names_canonicalized": True
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to: {OUTPUT_DIR.resolve()}")

# ------------------------ Inference helper ------------------------
def predict_target_risk_class(
    df_new: pd.DataFrame,
    model_path=OUTPUT_DIR / "fttransformer_model.pt",
    preproc_path=OUTPUT_DIR / "preprocessor.pkl",
    encoder_meta_path=OUTPUT_DIR / "encoder_meta.json",
    batch_size=4096
) -> pd.Series:
    """Predict on new raw rows (returns labels in 1..10)."""
    device = torch.device("cpu")
    pp_inf = joblib.load(preproc_path)
    enc_inf = TorchTabEncoder.load_meta(encoder_meta_path)

    cfg = json.load(open(OUTPUT_DIR / "training_meta.json", "r", encoding="utf-8"))["ftt_config"]
    n_num = len(enc_inf.num_cols)
    cat_cards = [enc_inf.cat_cardinalities_[c] for c in enc_inf.cat_cols] if enc_inf.cat_cols else []
    model = FTTransformerNoRTDL(
        n_num=n_num,
        cat_cardinalities=cat_cards,
        d_token=cfg["d_token"],
        n_blocks=cfg["n_blocks"],
        n_heads=cfg["attention_n_heads"],
        ffn_d_hidden=cfg["ffn_d_hidden"],
        attn_dropout=cfg["attention_dropout"],
        ffn_dropout=cfg["ffn_dropout"],
        resid_dropout=cfg["residual_dropout"],
        use_cls_token=cfg.get("use_cls_token", False),
        d_out=cfg["d_out"]
    ).to(device)
    ckpt = torch.load(model_path, map_location=device)
    model.load_state_dict(ckpt["state_dict"]); model.eval()

    # Preprocess -> encode
    df_new = df_new.copy(); canon_cols_inplace(df_new)
    dproc = pp_inf.transform(df_new)
    Xn, Xc = enc_inf.transform(dproc)
    ds = TabDataset(Xn, Xc, y=None)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)

    preds = []
    with torch.no_grad():
        for xnum_np, xcat_np in dl:
            xnum = torch.as_tensor(xnum_np, device=device).float()
            xcat = torch.as_tensor(xcat_np, device=device).long()
            logits = model(xnum if n_num>0 else None, xcat if len(cat_cards)>0 else None)
            preds.append(torch.argmax(F.softmax(logits, dim=-1), dim=-1).cpu().numpy())
    yhat0 = np.concatenate(preds, axis=0).astype("int16")
    return pd.Series(yhat0 + 1, index=dproc.index, name="pred_target_risk_class")

### Deep Learning (MLP) Based Models

#### Tabnet

In [None]:
# ===============================================================
# TabNet (CPU) – Multiclass Pipeline (FAST, robust)
# - Mirrors your other pipelines: same preprocessing + artifacts
# - Preprocess: datetime expansion; numeric downcast+median; capped & unique categoricals
# - Encoding: standardized numerics + integer codes for categoricals
# - Model: TabNet with categorical embeddings, early stopping on val
# - Metrics: accuracy, macro/weighted P/R/F1, log loss, ROC-AUC OvR macro
# - Artifacts: tabnet_model.zip, preprocessor.pkl, encoder_meta.json, metrics.csv,
#              loss_curves.png, classification reports + confusion matrices, training_meta.json
# - Inference helper returns labels in 1..10
# ===============================================================

import os, re, json, warnings, datetime, time, random, gc, math
from pathlib import Path
import numpy as np
import pandas as pd
import pandas.api.types as pdt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    log_loss, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight

import joblib
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

# ------------------------ TabNet ------------------------
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
import torch.nn as nn

# ------------------------ PATHS ------------------------
DATA_CSV = "path"
RUN_NAME = f"TabNet_CPU_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
OUTPUT_DIR = Path("path") / RUN_NAME
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_DIR  = OUTPUT_DIR / "splits"; SPLIT_DIR.mkdir(exist_ok=True)
print(f"Saving all outputs to: {OUTPUT_DIR.resolve()}")

# ------------------------ CONFIG ------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED)
torch.set_num_threads(max(1, (os.cpu_count() or 4)//2))

TARGET_NAME = "target_risk_class"
NUM_CLASSES = 10

TEST_SIZE = 0.15
VAL_SIZE  = 0.15

# ---- Preprocessing ----
MISSING_TOKEN      = "Missing"
MAX_CAT_CARD       = 500     # cap very high-card categorical columns
STORE_CAT_AS_INT32 = True    # for memory

# ---- TabNet hyperparams (CPU-friendly) ----
TN_N_D            = 32
TN_N_A            = 32
TN_N_STEPS        = 3
TN_GAMMA          = 1.5
TN_N_INDEPENDENT  = 1
TN_N_SHARED       = 1
TN_MOMENTUM       = 0.02
TN_VBS            = 4096               # virtual batch size
TN_BATCH_SIZE     = 65536              # real batch size (CPU-friendly large batch)
TN_MAX_EPOCHS     = 60
TN_PATIENCE       = 10
TN_LR             = 3e-2
TN_WEIGHT_DECAY   = 1e-5
TN_SPARSEMAP      = True               # sparsemax mask (stable on tabular)

# ------------------------ helpers ------------------------
def canon_col(name: str) -> str:
    s = re.sub(r"[^0-9A-Za-z_]+", "_", str(name))
    s = re.sub(r"_+", "_", s).strip("_")
    return s

def canon_cols_inplace(df: pd.DataFrame) -> pd.DataFrame:
    df.columns = [canon_col(c) for c in df.columns]
    return df

def _unique_in_order(seq):
    seen = set(); out = []
    for x in seq:
        xs = str(x)
        if xs not in seen:
            seen.add(xs); out.append(xs)
    return out

def normalize_ws(x):
    if isinstance(x, str):
        x = re.sub(r"[\u200c\u200f\u200e]", "", x)
        x = x.replace("\u00a0", " ")
        x = re.sub(r"\s+", " ", x).strip()
    return x

def downcast_numeric(df):
    df = df.copy()
    for c in df.select_dtypes(include=["float64","int64","int32","float32"]).columns:
        if pdt.is_float_dtype(df[c]): df[c] = pd.to_numeric(df[c], downcast="float")
        else:                         df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def parse_possible_datetimes(df):
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            s = df[c].astype(object)
            if s.astype(str).str.contains(r"\d{4}[-/]\d{1,2}[-/]\d{1,2}", regex=True, na=False).mean() > 0.2:
                try: df[c] = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
                except Exception: pass
    return df

def expand_datetimes(df):
    df = df.copy()
    dt_cols = [c for c in df.columns if pdt.is_datetime64_any_dtype(df[c])]
    for c in dt_cols:
        s = df[c]
        df[f"{c}__year"]   = s.dt.year.astype("Int16")
        df[f"{c}__month"]  = s.dt.month.astype("Int8")
        df[f"{c}__day"]    = s.dt.day.astype("Int8")
        df[f"{c}__dow"]    = s.dt.dayofweek.astype("Int8")
        df[f"{c}__hour"]   = s.dt.hour.fillna(0).astype("Int8")
        df[f"{c}__mstart"] = s.dt.is_month_start.astype("Int8")
        df[f"{c}__mend"]   = s.dt.is_month_end.astype("Int8")
    if dt_cols: df.drop(columns=dt_cols, inplace=True)
    return df

def pd_cat_fix(series, allowed):
    allowed = _unique_in_order(list(allowed) + [MISSING_TOKEN])
    s = series.astype("string").fillna(MISSING_TOKEN)
    s = s.where(s.isin(allowed), MISSING_TOKEN)
    return pd.Categorical(s, categories=pd.Index(allowed), ordered=False)

# ------------------------ Preprocessor ------------------------
class TabularPreprocessor:
    def __init__(self):
        self.num_cols_ = []; self.cat_cols_ = []
        self.num_median_ = {}; self.cat_categories_ = {}
        self.feature_names_ = []; self.fitted_ = False

    def _prep_base(self, X):
        d = X.copy()
        canon_cols_inplace(d)
        for c in d.columns:
            if d[c].dtype == object:
                d[c] = d[c].map(normalize_ws)
        d = parse_possible_datetimes(d)
        d = expand_datetimes(d)
        for c in d.columns:
            if pdt.is_bool_dtype(d[c]):
                d[c] = d[c].astype("int8")
        d = downcast_numeric(d)
        return d

    def fit(self, X):
        d = self._prep_base(X)
        self.num_cols_ = [c for c in d.columns if pdt.is_numeric_dtype(d[c])]
        self.cat_cols_ = [c for c in d.columns if not pdt.is_numeric_dtype(d[c])]
        for c in self.num_cols_:
            self.num_median_[c] = float(pd.to_numeric(d[c], errors="coerce").median())
        for c in self.cat_cols_:
            s = d[c].astype("string").fillna(MISSING_TOKEN)
            vc = s.value_counts(dropna=False)
            if MAX_CAT_CARD and len(vc) > (MAX_CAT_CARD - 1):
                top = vc.index.astype("string").tolist()[:MAX_CAT_CARD - 1]
                cats = _unique_in_order(top + [MISSING_TOKEN])
            else:
                cats = _unique_in_order(pd.unique(s).astype("string").tolist() + [MISSING_TOKEN])
            if len(cats) != len(set(cats)):
                cats = _unique_in_order(cats)
            self.cat_categories_[c] = cats
        self.feature_names_ = self.num_cols_ + self.cat_cols_
        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_
        d = self._prep_base(X)
        for c in self.feature_names_:
            if c not in d.columns:
                d[c] = np.nan
        d = d[self.feature_names_].copy()
        for c in self.num_cols_:
            d[c] = pd.to_numeric(d[c], errors="coerce").astype("float32")
            if d[c].isna().any():
                d[c] = d[c].fillna(self.num_median_[c])
        for c in self.cat_cols_:
            d[c] = pd_cat_fix(d[c], self.cat_categories_[c])
        return d

# ------------------------ Encoder for TabNet ------------------------
class TabNetEncoder:
    """
    - Z-score numerics
    - Integer codes for categoricals (kept as ints)
    - Returns:
        X_all (float32)  : numerics then categorical codes (cast to float32 for TabNet)
        cat_idxs         : positions of categorical columns in X_all
        cat_dims         : cardinalities
    """
    def __init__(self, num_cols, cat_cols, cat_categories):
        self.num_cols = list(num_cols)
        self.cat_cols = list(cat_cols)
        self.cat_categories = {c: list(cats) for c, cats in cat_categories.items()}
        self.num_mean_ = None; self.num_std_ = None
        self.cat_cardinalities_ = {c: len(self.cat_categories[c]) for c in self.cat_cols}

    def fit(self, df_proc):
        if self.num_cols:
            arr = df_proc[self.num_cols].astype("float32").values
            self.num_mean_ = arr.mean(axis=0).astype("float32")
            std = arr.std(axis=0).astype("float32")
            self.num_std_  = np.where(std < 1e-6, 1.0, std).astype("float32")
        else:
            self.num_mean_ = np.array([], dtype="float32")
            self.num_std_  = np.array([], dtype="float32")
        return self

    def transform(self, df_proc):
        N = len(df_proc)
        # numerics -> z-score
        if self.num_cols:
            Xn = df_proc[self.num_cols].astype("float32").values
            Xn = (Xn - self.num_mean_) / self.num_std_
        else:
            Xn = np.zeros((N, 0), dtype="float32")

        # categorical -> integer codes (kept as int, but cast to float32 for TabNet input)
        Xc_list = []
        for c in self.cat_cols:
            codes = df_proc[c].cat.codes.to_numpy(copy=False)
            fix = self.cat_categories[c].index(MISSING_TOKEN)
            codes = np.where(codes < 0, fix, codes).astype("int32" if STORE_CAT_AS_INT32 else "int64")
            Xc_list.append(codes)
        Xc = np.stack(Xc_list, axis=1) if Xc_list else np.zeros((N, 0), dtype="int32" if STORE_CAT_AS_INT32 else "int64")

        # concatenate (TabNet expects float32; it will embed categorical columns by indices)
        X_all = Xn if Xc.shape[1] == 0 else np.hstack([Xn, Xc.astype("float32")]).astype("float32")

        # categorical positions are after numerics
        cat_idxs = list(range(Xn.shape[1], Xn.shape[1] + Xc.shape[1]))
        cat_dims = [self.cat_cardinalities_[c] for c in self.cat_cols]
        return X_all, cat_idxs, cat_dims

    def save_meta(self, path_json):
        meta = {
            "num_cols": self.num_cols,
            "cat_cols": self.cat_cols,
            "cat_categories": self.cat_categories,
            "num_mean": self.num_mean_.tolist(),
            "num_std": self.num_std_.tolist(),
            "cat_cardinalities": self.cat_cardinalities_,
        }
        with open(path_json, "w", encoding="utf-8") as f:
            json.dump(meta, f, ensure_ascii=False, indent=2)

    @staticmethod
    def load_meta(path_json):
        with open(path_json, "r", encoding="utf-8") as f:
            meta = json.load(f)
        enc = TabNetEncoder(meta["num_cols"], meta["cat_cols"], meta["cat_categories"])
        enc.num_mean_ = np.array(meta["num_mean"], dtype="float32")
        enc.num_std_  = np.array(meta["num_std"], dtype="float32")
        enc.cat_cardinalities_ = {k:int(v) for k,v in meta["cat_cardinalities"].items()}
        return enc

# ------------------------ LOAD & TARGET ------------------------
df = pd.read_csv(DATA_CSV, low_memory=False)
canon_cols_inplace(df)
if TARGET_NAME not in df.columns:
    raise KeyError(f"Expected label column '{TARGET_NAME}' after canonicalization; got first columns {list(df.columns)[:20]}")
y1 = pd.to_numeric(df[TARGET_NAME], errors="coerce").astype("Int64")
y1 = y1.where((y1>=1) & (y1<=10))
mask = y1.notna()
df = df.loc[mask].copy()
y1 = y1.loc[mask].astype("int16")
y  = (y1 - 1).astype("int16")  # 0..K-1
X  = df.drop(columns=[TARGET_NAME])
del df; gc.collect()

# ------------------------ SPLIT (70/15/15) ------------------------
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
trainval_idx, test_idx = next(sss1.split(X, y))
X_trainval, X_test = X.iloc[trainval_idx], X.iloc[test_idx]
y_trainval, y_test = y.iloc[trainval_idx], y.iloc[test_idx]

val_rel = VAL_SIZE / (1.0 - TEST_SIZE)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_rel, random_state=SEED)
train_idx, val_idx = next(sss2.split(X_trainval, y_trainval))
X_train, X_val = X_trainval.iloc[train_idx], X_trainval.iloc[val_idx]
y_train, y_val = y_trainval.iloc[train_idx], y_trainval.iloc[val_idx]

pd.Series(X_train.index, name="index").to_csv(SPLIT_DIR/"train_indices.csv", index=False)
pd.Series(X_val.index,   name="index").to_csv(SPLIT_DIR/"val_indices.csv", index=False)
pd.Series(X_test.index,  name="index").to_csv(SPLIT_DIR/"test_indices.csv", index=False)
print(f"Shapes -> train: {X_train.shape}, val: {X_val.shape}, test: {X_test.shape}")

# ------------------------ PREPROCESS & ENCODE ------------------------
pp = TabularPreprocessor().fit(X_train)
Xtr_df = pp.transform(X_train); Xva_df = pp.transform(X_val); Xte_df = pp.transform(X_test)

enc = TabNetEncoder(pp.num_cols_, pp.cat_cols_, pp.cat_categories_).fit(Xtr_df)
X_tr, cat_idxs, cat_dims = enc.transform(Xtr_df)
X_va, _, _               = enc.transform(Xva_df)
X_te, _, _               = enc.transform(Xte_df)

n_num = len(pp.num_cols_)
n_cat = len(pp.cat_cols_)
print(f"[TabNet] nums={n_num} cats={n_cat}  total_in={X_tr.shape[1]}  cat_embs={len(cat_dims)}")

# free raw frames
del X_train, X_val, X_test, X_trainval, Xtr_df, Xva_df, Xte_df, X, y, y1; gc.collect()

y_tr = y_train.values.astype("int64")
y_va = y_val.values.astype("int64")
y_te = y_test.values.astype("int64")

# ------------------------ CLASS WEIGHTS / SAMPLE WEIGHTS ------------------------
classes_present = np.unique(y_tr)
cw = compute_class_weight(class_weight="balanced", classes=classes_present, y=y_tr)
cw_map = {int(c): float(w) for c, w in zip(classes_present, cw)}
class_weights = np.ones(NUM_CLASSES, dtype="float32")
for c, w in cw_map.items(): class_weights[c] = w
class_weights = class_weights / class_weights.mean()

sw_tr = class_weights[y_tr]
sw_va = class_weights[y_va]

# loss with class weights
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float32))

# ---- cat embedding sizes (TabNet best practice: min(50, (dim+1)//2))
def _emb_size(card): return int(min(32, (card + 1) // 2))
cat_emb_dim = [_emb_size(c) for c in cat_dims]

# ------------------------ MODEL / TRAIN ------------------------
mask_type = "sparsemax" if TN_SPARSEMAP else "entmax"
model = TabNetClassifier(
    n_d=TN_N_D, n_a=TN_N_A, n_steps=TN_N_STEPS, gamma=TN_GAMMA,
    n_independent=TN_N_INDEPENDENT, n_shared=TN_N_SHARED,
    cat_idxs=cat_idxs, cat_dims=cat_dims, cat_emb_dim=cat_emb_dim,
    optimizer_fn=torch.optim.AdamW,
    optimizer_params=dict(lr=TN_LR, weight_decay=TN_WEIGHT_DECAY),
    mask_type=mask_type,
    momentum=TN_MOMENTUM,
    verbose=10,  # TabNet internal verbosity
    scheduler_params={"step_size": 50, "gamma": 0.9},  # mild LR decay
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    seed=SEED,
)

t0 = time.time()
model.fit(
    X_train=X_tr, y_train=y_tr,
    eval_set=[(X_va, y_va)],
    eval_name=["val"],
    eval_metric=["logloss"],         # monitor neg log-likelihood
    max_epochs=TN_MAX_EPOCHS,
    patience=TN_PATIENCE,
    batch_size=TN_BATCH_SIZE,
    virtual_batch_size=TN_VBS,
    num_workers=max(1, (os.cpu_count() or 4)//2),
    loss_fn=loss_fn,
    weights=sw_tr,                   # class-balanced training
)
elapsed = time.time() - t0
best_epoch = getattr(model, "best_epoch", None)
print(f"Training time: {elapsed/60:.1f} min; best epoch: {best_epoch}")

# ------------------------ Evaluation ------------------------
def predict_proba_tabnet(m, X):
    return m.predict_proba(X)

def eval_split(name, Xmat, y_zero):
    proba = predict_proba_tabnet(model, Xmat)
    pred0 = np.argmax(proba, axis=1)

    metrics = {
        "split": name,
        "n_samples": int(len(y_zero)),
        "accuracy": float(accuracy_score(y_zero, pred0)),
    }
    for avg in ["macro", "weighted"]:
        p, r, f1, _ = precision_recall_fscore_support(y_zero, pred0, average=avg, zero_division=0)
        metrics[f"precision_{avg}"] = float(p)
        metrics[f"recall_{avg}"]    = float(r)
        metrics[f"f1_{avg}"]        = float(f1)
    try:
        metrics["log_loss"] = float(log_loss(y_zero, proba, labels=list(range(NUM_CLASSES))))
    except Exception:
        metrics["log_loss"] = float("nan")
    try:
        y_bin = pd.get_dummies(pd.Categorical(y_zero, categories=list(range(NUM_CLASSES))))
        metrics["roc_auc_ovr_macro"] = float(roc_auc_score(y_bin.values, proba, average="macro", multi_class="ovr"))
    except Exception:
        metrics["roc_auc_ovr_macro"] = float("nan")

    ys_one  = y_zero + 1
    pred_one = pred0 + 1
    report = classification_report(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)), zero_division=0)
    with open(OUTPUT_DIR / f"classification_report_{name}.txt", "w", encoding="utf-8") as f:
        f.write(report)
    cm = confusion_matrix(ys_one, pred_one, labels=list(range(1, NUM_CLASSES+1)))
    pd.DataFrame(cm, index=range(1, NUM_CLASSES+1), columns=range(1, NUM_CLASSES+1))\
      .to_csv(OUTPUT_DIR / f"confusion_matrix_{name}.csv")
    return metrics, pred_one, proba

m_val,   _, _ = eval_split("val",  X_va, y_va)
m_test,  _, _ = eval_split("test", X_te, y_te)
pd.DataFrame([m_val, m_test]).to_csv(OUTPUT_DIR / "metrics_tabnet.csv", index=False)
print(pd.DataFrame([m_val, m_test]))

# ------------------------ Curves ------------------------
try:
    hist = model.history
    train_loss = hist["loss"]
    val_loss   = hist["val_logloss"]
    iters = np.arange(1, len(train_loss)+1)
    plt.figure(figsize=(7,4))
    plt.plot(iters, train_loss, label="train")
    plt.plot(iters, val_loss[:len(iters)], label="val")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.title("TabNet Loss"); plt.legend(); plt.tight_layout()
    plt.savefig(OUTPUT_DIR / "loss_curves.png", dpi=150); plt.close()
except Exception:
    pass

# ------------------------ Save artifacts ------------------------
model.save_model(str(OUTPUT_DIR / "tabnet_model"))
joblib.dump(pp,  OUTPUT_DIR / "preprocessor.pkl")
enc.save_meta(OUTPUT_DIR / "encoder_meta.json")

with open(OUTPUT_DIR / "training_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "target_name": TARGET_NAME,
        "num_classes": NUM_CLASSES,
        "label_order_zero_indexed": list(range(NUM_CLASSES)),
        "seed": SEED,
        "best_epoch": int(best_epoch) if best_epoch is not None else None,
        "splits": {"train": "splits/train_indices.csv", "val": "splits/val_indices.csv", "test": "splits/test_indices.csv"},
        "data_csv": DATA_CSV,
        "tabnet_config": {
            "n_d": TN_N_D, "n_a": TN_N_A, "n_steps": TN_N_STEPS, "gamma": TN_GAMMA,
            "n_independent": TN_N_INDEPENDENT, "n_shared": TN_N_SHARED,
            "batch_size": TN_BATCH_SIZE, "virtual_batch_size": TN_VBS,
            "lr": TN_LR, "weight_decay": TN_WEIGHT_DECAY,
            "mask_type": ("sparsemax" if TN_SPARSEMAP else "entmax"),
            "cat_idxs": cat_idxs, "cat_dims": cat_dims, "cat_emb_dim": cat_emb_dim
        },
        "train_time_min": round(elapsed/60, 2),
        "feature_names": (pp.num_cols_ + pp.cat_cols_),
        "cat_features": pp.cat_cols_,
        "num_features": pp.num_cols_,
        "column_names_canonicalized": True
    }, f, ensure_ascii=False, indent=2)

print(f"\n✅ All artifacts saved to: {OUTPUT_DIR.resolve()} (model: tabnet_model.zip)")

# ------------------------ Inference helper ------------------------
def predict_target_risk_class_tabnet(
    df_new: pd.DataFrame,
    model_prefix=OUTPUT_DIR / "tabnet_model",          # TabNet uses prefix without .zip
    preproc_path=OUTPUT_DIR / "preprocessor.pkl",
    encoder_meta_path=OUTPUT_DIR / "encoder_meta.json",
    batch_size=200000
) -> pd.Series:
    """
    Predict on new raw rows (returns labels in 1..10).
    Uses the same preprocessing and encoding as training.
    """
    # reload model
    tabnet = TabNetClassifier()
    tabnet.load_model(str(model_prefix))  # loads from prefix (will look for .zip)

    pp_inf  = joblib.load(preproc_path)
    enc_inf = TabNetEncoder.load_meta(encoder_meta_path)

    df_new = df_new.copy(); canon_cols_inplace(df_new)
    dproc = pp_inf.transform(df_new)
    X_all, _, _ = enc_inf.transform(dproc)

    # batch predict for huge inputs
    N = X_all.shape[0]
    preds = np.empty(N, dtype="int16")
    for s in range(0, N, batch_size):
        e = min(N, s+batch_size)
        proba = tabnet.predict_proba(X_all[s:e])
        preds[s:e] = np.argmax(proba, axis=1).astype("int16")
    return pd.Series(preds + 1, index=dproc.index, name="pred_target_risk_class")