In [3]:
import os
import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import StratifiedKFold, train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.preprocessing import MinMaxScaler


# Load and Preprocess Data for Multi-Class
def load_and_preprocess(directory, label=None):
    data_frames = []
    for file in os.listdir(directory):
        if file.endswith(".csv"):
            df = pd.read_csv(os.path.join(directory, file))
            if label is not None:
                df["Label"] = label  # Assign class labels
            data_frames.append(df)
    df = pd.concat(data_frames, ignore_index=True)
    
    # Encode categorical data
    for col in df.select_dtypes(include=['object']).columns:
        df[col] = LabelEncoder().fit_transform(df[col])
    
    # Handle missing values
    df.fillna(df.median(), inplace=True)
    
    return df

# Load Data
dataset_path_healthy = "../split_fif/healthy_csv"
dataset_path_mdd = "../split_fif/mdd_csv"
dataset_path_other = "../split_fif/other_csv"  # New dataset folder for multi-class

df_healthy = load_and_preprocess(dataset_path_healthy, label=0)  # 0 for Healthy
df_mdd = load_and_preprocess(dataset_path_mdd, label=1)  # 1 for MDD
df_other = load_and_preprocess(dataset_path_other, label=2)  # 2 for Other (you can update labels as necessary)

# Combine all datasets into one
df = pd.concat([df_healthy, df_mdd, df_other], axis=0).reset_index(drop=True)

# Split Features and Labels
X = df.drop(columns=['Label']).values
y = df['Label'].values

# Standardize Data
scaler = MinMaxScaler(feature_range=(0, 100))  # Min-Max scaling
X_scaled = scaler.fit_transform(X)

# Split into Train, Validation & Test
X_train, X_temp, y_train, y_temp = train_test_split(X_scaled, y, test_size=0.2, random_state=42, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.3, random_state=42, stratify=y_temp)

X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

joblib.dump(scaler, "multi_class_best_xgb_scaler.pkl")

# Define XGBoost Model for Multi-Class Classification
xgb_model = xgb.XGBClassifier(
    objective="multi:softmax",  # Change to multi-class classification
    tree_method="hist",
    device="cuda",  # Use GPU if available
    eval_metric="mlogloss",  # Logarithmic loss for multi-class classification
    learning_rate=0.005,
    max_depth=25,
    gamma=0.2,
    subsample=0.95,
    colsample_bytree=0.97,
    min_child_weight=1,
    reg_alpha=0.8,
    reg_lambda=3.0,
    n_estimators=6000,
    verbosity=1,
    num_class=3  # Number of classes (update this based on your dataset)
)

# Perform Cross-Validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=100)
cv_scores = cross_val_score(xgb_model, X_train, y_train, cv=cv, scoring="accuracy")

# Train Model
xgb_model.fit(X_train, y_train)
joblib.dump(xgb_model, "multi_class_best_xgb_model.pkl")

# Evaluate on Validation Set
val_preds = xgb_model.predict(X_val)
val_probs = xgb_model.predict_proba(X_val)  # Multi-class probabilities
accuracy = accuracy_score(y_val, val_preds)
roc_auc = roc_auc_score(y_val, val_probs, multi_class="ovr")  # One-vs-rest ROC AUC for multi-class
report = classification_report(y_val, val_preds)

# Print Results
print(f"Validation Accuracy: {accuracy:.4f}")
print(f"ROC AUC Score (One-vs-Rest): {roc_auc:.4f}")
print("\nClassification Report:\n", report)
print(f"Cross-Validation Scores: {cv_scores}")


Validation Accuracy: 0.8108
ROC AUC Score (One-vs-Rest): 0.9341

Classification Report:
               precision    recall  f1-score   support

           0       0.73      0.67      0.70        12
           1       0.78      0.82      0.80        17
           2       1.00      1.00      1.00         8

    accuracy                           0.81        37
   macro avg       0.84      0.83      0.83        37
weighted avg       0.81      0.81      0.81        37

Cross-Validation Scores: [0.92857143 0.76190476 0.80952381 0.87804878 0.87804878]
