In [None]:
import mlflow
import mlflow.sklearn
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.model_selection import RandomizedSearchCV

def mlflow_train_and_log(estimator, param_grid, experiment_name, run_name,
                          X_train, X_test, y_train, y_test,
                          scoring='neg_mean_absolute_error', n_iter=20, cv=5):
    
    mlflow.set_experiment(experiment_name)
    
    with mlflow.start_run(run_name=run_name):
        # Search
        search = RandomizedSearchCV(
            estimator=estimator,
            param_distributions=param_grid,
            n_iter=n_iter,
            cv=cv,
            n_jobs=-1,
            scoring=scoring,
            verbose=1
        )
        search.fit(X_train, y_train)
        best_model = search.best_estimator_
        
        # Log best parameters
        mlflow.log_params(search.best_params_)
        
        # Predictions
        y_train_pred = best_model.predict(X_train)
        y_test_pred = best_model.predict(X_test)

        # Metrics
        metrics = {
            "r2_train": r2_score(y_train, y_train_pred),
            "r2_test": r2_score(y_test, y_test_pred),
            "mse_train": mean_squared_error(y_train, y_train_pred),
            "mse_test": mean_squared_error(y_test, y_test_pred),
            "mae_train": mean_absolute_error(y_train, y_train_pred),
            "mae_test": mean_absolute_error(y_test, y_test_pred)
        }
        mlflow.log_metrics(metrics)

        # Résidus
        residuals = y_test - y_test_pred
        res_df = pd.DataFrame({'y_true': y_test, 'y_pred': y_test_pred, 'residual': residuals})
        res_df.to_csv("residuals.csv", index=False)
        mlflow.log_artifact("residuals.csv")

        # Graphe diagnostic
        plt.figure(figsize=(8, 5))
        sns.histplot(residuals, bins=30, kde=True, color='skyblue')
        plt.title("Distribution des résidus")
        plt.tight_layout()
        plt.savefig("residual_plot.png")
        mlflow.log_artifact("residual_plot.png")
        plt.close()

        # Log du modèle
        mlflow.sklearn.log_model(best_model, "best_model")
        
        print("✅ Run MLflow terminé et loggué avec succès.")
