In [1]:
# ------------------------------------------------------
# IMPORTS
# ------------------------------------------------------
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score
from xgboost import XGBClassifier
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

# ------------------------------------------------------
# CONFIG
# ------------------------------------------------------
FEATURES_PATH = "/workspaces/datasciencetest_reco_plante/dataset/plantvillage/csv/clean_data_plantvillage_segmented_all_with_features.csv"
SEED = 42

xgb_configs = {
    "Baseline": {"n_estimators": 200, "learning_rate": 0.1, "max_depth": 6},
    "Deep Trees": {"n_estimators": 300, "learning_rate": 0.05, "max_depth": 10},
    "Shallow Trees": {"n_estimators": 500, "learning_rate": 0.01, "max_depth": 3}
}

# ------------------------------------------------------
# CHARGEMENT DES DONNEES
# ------------------------------------------------------
df = pd.read_csv(FEATURES_PATH)

# Labels
label_col = 'nom_maladie'
le = LabelEncoder()
y = le.fit_transform(df[label_col])
n_classes = len(le.classes_)
print("Classes détectées :", le.classes_)

# Features numériques
numeric_columns = df.select_dtypes(include=np.number).columns.tolist()
exclude_cols = ['width_img', 'height_img', 'is_black']  # à adapter selon besoin
numeric_columns = [c for c in numeric_columns if c not in exclude_cols]

# Remplir les NaN numériques par la médiane
df[numeric_columns] = df[numeric_columns].fillna(df[numeric_columns].median())

# X final
X = df[numeric_columns].values

print("Shape de X :", X.shape)
print("Shape de y :", y.shape)

# ------------------------------------------------------
# SCALER
# ------------------------------------------------------
scaler = RobustScaler()
X_scaled = scaler.fit_transform(X)

# ------------------------------------------------------
# PIPELINES
# ------------------------------------------------------
pipelines = {
    "XGBoost": X_scaled,
    "XGBoost + PCA": PCA(n_components=min(50, X_scaled.shape[1]), random_state=SEED).fit_transform(X_scaled),
    "XGBoost + LDA": LDA(n_components=min(n_classes-1, X_scaled.shape[1])).fit(X_scaled, y).transform(X_scaled)
}

# ------------------------------------------------------
# FONCTION D'EVALUATION
# ------------------------------------------------------
def evaluate_metrics(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(y_true, y_pred, average="macro")
    f1_weighted = f1_score(y_true, y_pred, average="weighted")
    precision = precision_score(y_true, y_pred, average=None)
    recall = recall_score(y_true, y_pred, average=None)
    f1_per_class = f1_score(y_true, y_pred, average=None)
    support = np.bincount(y_true)
    return acc, f1_macro, f1_weighted, precision, recall, f1_per_class, support

# ------------------------------------------------------
# BOUCLE PRINCIPALE : pipelines x configs
# ------------------------------------------------------
results = []
class_results = []

for config_name, params in xgb_configs.items():
    print(f"\n===== Configuration : {config_name} =====")
    
    for pipe_name, X_proc in pipelines.items():
        print(f"\n🚀 Pipeline: {pipe_name}")
        
        # Train/test split
        X_train, X_test, y_train, y_test = train_test_split(X_proc, y, test_size=0.2, stratify=y, random_state=SEED)
        
        # Validation croisée 5-fold
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
        f1_scores = []
        for fold, (train_idx, val_idx) in enumerate(skf.split(X_train, y_train)):
            X_tr, X_val = X_train[train_idx], X_train[val_idx]
            y_tr, y_val = y_train[train_idx], y_train[val_idx]
            
            smote = SMOTE(random_state=SEED)
            X_tr_bal, y_tr_bal = smote.fit_resample(X_tr, y_tr)
            
            model = XGBClassifier(use_label_encoder=False, eval_metric="mlogloss", random_state=SEED, n_jobs=-1, **params)
            model.fit(X_tr_bal, y_tr_bal)
            
            y_val_pred = model.predict(X_val)
            f1_fold = f1_score(y_val, y_val_pred, average="weighted")
            f1_scores.append(f1_fold)
        
        f1_mean, f1_std = np.mean(f1_scores), np.std(f1_scores)
        
        # Réentraînement sur tout le train + SMOTE
        smote = SMOTE(random_state=SEED)
        X_train_bal, y_train_bal = smote.fit_resample(X_train, y_train)
        model.fit(X_train_bal, y_train_bal)
        y_test_pred = model.predict(X_test)
        
        # Evaluation globale et par classe
        acc, f1_macro, f1_weighted, precision, recall, f1_per_class, support = evaluate_metrics(y_test, y_test_pred)
        
        # Stockage résultats globaux
        results.append({
            "Pipeline": pipe_name,
            "Config": config_name,
            "CV_F1_mean": f1_mean,
            "CV_F1_std": f1_std,
            "Test_Accuracy": acc,
            "Test_F1_macro": f1_macro,
            "Test_F1_weighted": f1_weighted,
            "Model": model,
            "X_test": X_test,
            "y_test": y_test,
            "y_pred": y_test_pred
        })
        
        # Tableau par classe
        for i, classe in enumerate(le.classes_):
            class_results.append({
                "Pipeline": pipe_name,
                "Config": config_name,
                "Classe": classe,
                "Precision": precision[i],
                "Recall": recall[i],
                "F1_score": f1_per_class[i],
                "Support": support[i]
            })
        
        # Matrice de confusion
        cm = confusion_matrix(y_test, y_test_pred)
        plt.figure(figsize=(8,6))
        sns.heatmap(cm, annot=True, fmt='d', xticklabels=le.classes_, yticklabels=le.classes_, cmap="Blues")
        plt.title(f"Matrice de confusion - {pipe_name} ({config_name})")
        plt.xlabel("Prédictions")
        plt.ylabel("Vraies classes")
        plt.show()
        
        # Graph des classes les moins bien prédites
        class_acc = cm.diagonal() / cm.sum(axis=1)
        class_acc_df = pd.DataFrame({"Classe": le.classes_, "Accuracy": class_acc}).sort_values(by="Accuracy")
        plt.figure(figsize=(8,6))
        sns.barplot(x="Accuracy", y="Classe", data=class_acc_df, palette="Reds_r")
        plt.xlim(0,1)
        plt.title(f"Classes les moins bien prédites - {pipe_name} ({config_name})")
        plt.show()

# ------------------------------------------------------
# Tableau comparatif global
# ------------------------------------------------------
results_df = pd.DataFrame(results).sort_values(by="Test_F1_weighted", ascending=False)
print("\n📊 Résultats globaux :")
print(results_df[["Pipeline","Config","CV_F1_mean","CV_F1_std","Test_Accuracy","Test_F1_macro","Test_F1_weighted"]])

# ------------------------------------------------------
# Tableau récapitulatif par classe
# ------------------------------------------------------
class_results_df = pd.DataFrame(class_results)
print("\n📊 Résultats par classe :")
print(class_results_df)

# Optionnel : exporter
results_df.to_csv("global_results.csv", index=False)
class_results_df.to_csv("class_results.csv", index=False)

# ------------------------------------------------------
# Top 20 features pour le meilleur modèle XGBoost classique
# ------------------------------------------------------
best_model_row = results_df.iloc[0]
if best_model_row["Pipeline"] == "XGBoost":
    best_model = best_model_row["Model"]
    importances = best_model.feature_importances_
    feat_df = pd.DataFrame({"Feature": numeric_columns, "Importance": importances}).sort_values(by="Importance", ascending=False)
    plt.figure(figsize=(12,8))
    sns.barplot(x="Importance", y="Feature", data=feat_df.head(20))
    plt.title(f"Top 20 features - {best_model_row['Pipeline']} ({best_model_row['Config']})")
    plt.show()


KeyboardInterrupt: 