# A 03 I: MIT Hyperparameter tuning: Grid Search for Best Model + Sampling on MIT

Fine-tune the best model with the best sampling method from A 02 02.

## Content

A) MIT-BIH Arrhytmia Dataset

1. train/test split: 80%, 20% -> as defined at the beginning of the project to ensure result reproducibility, no duplicates or missing values present
2. Hyperparameter tuning using RandomizedSearch with cross validation for the mentioned baseline models and oversampling techniques
 


## 1. Imports

In [None]:
import os 
from typing import Dict, Optional
import random 

from src.utils import evaluate_model
from src.visualization import save_cv_diagnostics, save_overfit_diagnostic, save_model_diagnostics, save_roc_curve
from src.utils.model_saver import create_model_saver

# external 
import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix
)
from sklearn.model_selection import RepeatedStratifiedKFold, GridSearchCV
from scipy.stats import loguniform, randint, uniform
import numpy as np
import re
import json

# Models
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
import xgboost as xgb

# Samplers

from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
from imblearn.combine import SMOTETomek, SMOTEENN
from imblearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from src.utils.preprocessing import (
    _normalize_sampling_method_name,
    _SAMPLING_REGISTRY
)
import mlflow
from mlflow.tracking import MlflowClient

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)


## 3. Methods

In [None]:
def create_leak_free_pipeline(
    model_name: str,
    estimator,
    sampling_method: Optional[str] = "none",
    sampler_kwargs: Optional[Dict] = None,
    random_state: Optional[int] = 42,
) -> Pipeline:
    """
    Build a leak-free pipeline:
    - Using imblearn.Pipeline ensures fit/transform of SAMPLER happen within each CV fold on TRAIN only.
    """
    sampler_kwargs = dict(sampler_kwargs or {})

    # Provide a default random_state to samplers if not overridden
    if random_state is not None and "random_state" not in sampler_kwargs:
        sampler_kwargs["random_state"] = random_state

    internal_name = _normalize_sampling_method_name(sampling_method)

    steps = []

    SamplerClass = _SAMPLING_REGISTRY[internal_name]
    steps.append(("sampler", SamplerClass(**sampler_kwargs)))

    steps.append(("classifier", estimator))
    display(steps)
    return Pipeline(steps)

In [None]:
def configure_mlflow_for_minio(experiment_name, mlflow_tracking_uri, minio_endpoint_url):
    """Configure MLflow to use MinIO. Credentials loaded from .aws/credentials"""
    import os
    from configparser import ConfigParser
    from pathlib import Path
    
    # Load credentials from .aws/credentials (project or home directory)
    creds_file = Path('.aws/credentials') if Path('.aws/credentials').exists() else Path.home() / '.aws' / 'credentials'
    if creds_file.exists():
        config = ConfigParser()
        config.read(creds_file)
        if 'default' in config:
            os.environ['AWS_ACCESS_KEY_ID'] = config['default'].get('aws_access_key_id', '')
            os.environ['AWS_SECRET_ACCESS_KEY'] = config['default'].get('aws_secret_access_key', '')
    else:
        raise LookupError("No MINIO Credentials found!")
    
    os.environ['MLFLOW_S3_ENDPOINT_URL'] = minio_endpoint_url

    mlflow.set_tracking_uri(mlflow_tracking_uri)

    # Overwrites existing experiment


    # Initialize client
    client = MlflowClient(tracking_uri=mlflow_tracking_uri)

    # First, get the experiment to find its ID#
    experiment = mlflow.get_experiment_by_name(experiment_name)
    if experiment:
        experiment_id = experiment.experiment_id
        print(f"Found experiment: {experiment_name} (ID: {experiment_id})")
    else:
        print(f"Experiment '{experiment_name}' not found, creating new experiment!")
        experiment_id = client.create_experiment(experiment_name)
        
    return experiment_id 

In [None]:
def log_results_to_mlflow(
    model_name: str,
    search,
    eval_results: dict,
    summary: dict,
    sampling_method: str,
    remove_outliers: bool,
    results_path: str,
    experiment_id: str,
    n_iter: int,
    cv,
    refit_metric: str,
    random_state: int,
    X_train,
    dataset_name: str,
    dataset_source: str,
    mlflow_tracking_uri: str = "http://mlflow.home.lan",
    sampler_kwargs: Optional[Dict] = None,
):
    """
    Log experiment results to MLflow.
    
    Args:
        model_name: Name of the model/classifier
        search: Fitted RandomizedSearchCV object
        eval_results: Dictionary with train/test evaluation results
        summary: Dictionary with summary metrics
        sampling_method: Name of sampling method used
        sampler_kwargs: Dictionary with sampler parameters (e.g., k_neighbors, n_neighbors)
        remove_outliers: Whether outliers were removed
        results_path: Path to results CSV file (used to find plot files)
        experiment_id: MLflow experiment ID
        n_iter: Number of iterations in RandomizedSearchCV
        cv: Cross-validation object
        refit_metric: Metric used for refitting
        random_state: Random state used
        mlflow_tracking_uri: MLflow tracking server URI
    """
    run_name = f"{model_name}_{sampling_method}_outliers_{remove_outliers}"
    
    with mlflow.start_run(run_name=run_name, experiment_id=experiment_id):
        # Log tags
        try:
            dataset = mlflow.data.from_pandas(
                X_train, 
                source=dataset_source,
                name=f"{dataset_name} Training Dataset"
            )
            mlflow.log_input(dataset, context="training")
        except Exception as e:
            print(f"Warning: Could not log dataset: {e}")

        mlflow.set_tags({
            "dataset": "MIT-BIH",
            "phase": "baseline_models",
            "model_type": model_name,
            "sampling_method": sampling_method,
            "outlier_removal": str(remove_outliers),
        })
        
        # Log hyperparameters (best_params)
        # Convert best_params to string format for MLflow (handles nested pipeline params)
        mlflow_params = {}
        for key, value in search.best_params_.items():
            # Convert to string if needed (MLflow params must be strings)
            if isinstance(value, (list, dict)):
                mlflow_params[key] = json.dumps(value)
            elif value is None:
                mlflow_params[key] = "None"
            else:
                mlflow_params[key] = str(value)
        
        mlflow.log_params(mlflow_params)
        
        # Log experiment configuration parameters
        mlflow.log_params({
            "sampling_method": sampling_method,
            "remove_outliers": str(remove_outliers),
            "n_iter": str(n_iter),
            "cv_n_splits": str(cv.n_splits),
            "refit_metric": refit_metric,
            "random_state": str(random_state),
        })
        
        # Log sampler parameters if available
        if sampler_kwargs:
            sampler_params = {}
            for key, value in sampler_kwargs.items():
                # Handle special case where 'smote' is an object instance (for SMOTETomek, SMOTEENN)
                if key == "smote" and hasattr(value, "__class__"):
                    # Extract SMOTE parameters from the object
                    sampler_params["sampler_smote_k_neighbors"] = str(getattr(value, "k_neighbors", "N/A"))
                    sampler_params["sampler_smote_random_state"] = str(getattr(value, "random_state", "N/A"))
                else:
                    # Convert to string for MLflow
                    if isinstance(value, (list, dict)):
                        sampler_params[f"sampler_{key}"] = json.dumps(value)
                    elif value is None:
                        sampler_params[f"sampler_{key}"] = "None"
                    elif hasattr(value, "__class__"):
                        # For other object instances, log the class name
                        sampler_params[f"sampler_{key}"] = value.__class__.__name__
                    else:
                        sampler_params[f"sampler_{key}"] = str(value)
            
            mlflow.log_params(sampler_params)
        
        # Log metrics (rest of the function remains the same)
        mlflow.log_metrics({
            "best_cv_score": summary["best_cv_score"],
            "test_f1_macro": summary["test_f1_macro"],
            "train_f1_macro": summary["train_f1_macro"],
            "test_accuracy": summary["test_accuracy"],
            "train_test_diff": summary["train_test_diff"],
            "cv_mean_val_f1_macro": summary["cv_mean_val_f1_macro"],
            "cv_std_val_f1_macro": summary["cv_std_val_f1_macro"],
            "cv_mean_train_f1_macro": summary["cv_mean_train_f1_macro"],
            "cv_std_train_f1_macro": summary["cv_std_train_f1_macro"],
            "cv_diff_train_val_f1_macro": summary["cv_diff_train_val_f1_macro"],
            "cv_mean_val_bal_acc": summary["cv_mean_val_bal_acc"],
            "cv_std_val_bal_acc": summary["cv_std_val_bal_acc"],
            "mean_fit_time": summary["mean_fit_time"],
            "std_fit_time": summary["std_fit_time"],
        })
        
        # Log ROC-AUC if available
        if summary["roc_auc"] is not None:
            mlflow.log_metric("test_roc_auc", summary["roc_auc"])
        
        # Log per-class F1 scores as metrics
        for lbl in eval_results["labels"]:
            mlflow.log_metric(f"test_f1_class_{lbl}", summary[f"test_f1_class_{lbl}"])
            mlflow.log_metric(f"train_f1_class_{lbl}", summary[f"train_f1_class_{lbl}"])
        
        # Log artifacts (plots and CSV files)
        base = results_path.replace(".csv", "")
        
        # Find all generated plot files
        plot_patterns = [
            f"{base}_{model_name}_{sampling_method}_cv_tradeoff.png",
            f"{base}_{model_name}_{sampling_method}_cv_spread.png",
            f"{base}_{model_name}_{sampling_method}_cv_learning_curve.png",
            f"{base}_{model_name}_{sampling_method}_overfit_diag.png",
            f"{base}_{model_name}_{sampling_method}_roc_curve.png",
        ]
        
        # Log diagnostic plots
        for plot_path in plot_patterns:
            if os.path.exists(plot_path):
                mlflow.log_artifact(plot_path, "diagnostics")
        
        # Log model diagnostics plot (if it exists)
        model_diag_path = f"{base}_{model_name}_{sampling_method}_model_diagnostics.png"
        if os.path.exists(model_diag_path):
            mlflow.log_artifact(model_diag_path, "diagnostics")
        
        # Log CSV files
        cv_full_path = results_path.replace(".csv", '_'+model_name+'_'+sampling_method.lower()+'_outliers_'+str(remove_outliers)+"_cv_results.csv")
        if os.path.exists(cv_full_path):
            mlflow.log_artifact(cv_full_path, "data")
        if os.path.exists(results_path):
            mlflow.log_artifact(results_path, "data")
        
        # Log the trained model
        mlflow.sklearn.log_model(
            search.best_estimator_,
            name="model",
            input_example=X_train.iloc[:1].values if hasattr(X_train, 'iloc') else X_train[:1],
            registered_model_name=f"{model_name}_{sampling_method}",
        )
        
        run_id = mlflow.active_run().info.run_id
        print(f"âœ… Logged to MLflow: {run_name}")
        print(f"   Run ID: {run_id}")
        print(f"   View at: {mlflow_tracking_uri}/#/experiments/{experiment_id}/runs/{run_id}")

In [None]:
def run_grid_search(model_name, 
                        estimator,
                        params,
                        X_train,
                        y_train,
                        X_test,
                        y_test,
                        cv,
                        results_path,
                        sampling_method,
                        sampler_kwargs,
                        remove_outliers,
                        model_saver,
                        SCORING,
                        verbose,
                        refit_metric,
                        log_to_mlflow,
                        dataset_name,
                        dataset_source,
                        mlflow_experiment_id,
                        mlflow_tracking_uri="http://mlflow.home.lan"
                        ) -> Dict:
    """
    Run GridSearchCV for a specific model and sampling method.
    
    Args:
        model_name: Name of the model to train
        sampling_method: Sampling method to use
        remove_outliers: Whether to remove outliers
        model_saver: Model saver instance
        results_dir: Directory to save results
        
    Returns:
        Dictionary with results
    """
    print(f"\n{'='*80}")
    print(f"Running GridSearchCV for {model_name} ({sampling_method})")
    print(f"Outlier removal: {remove_outliers}")
    print(f"{'='*80}")

    # Create experiment name
    experiment_name = f"{sampling_method.lower()}_outliers_{remove_outliers}"

    # --- SKIP if model already exists ---
    if model_saver and model_saver.model_exists(model_name, experiment_name):
        print(f"  Skipping {model_name} ({experiment_name}) - model already saved.")
        try:
            meta = model_saver.load_metadata(model_name, experiment_name)
            if meta:
                print(f"    Existing model best_score={meta.get('best_score'):.4f}, "
                      f"params={meta.get('best_params')}")
        except Exception as e:
            print(f"  (Could not load metadata: {e})")
        return None
    # ---------------------------------------------------------------

    # Create leak-free pipeline - only applies for sampling methods
    if sampling_method != "No_Sampling":
        estimator = create_leak_free_pipeline(model_name, estimator, sampling_method, sampler_kwargs)
        # Adjust parameter names for pipeline
        params = {f'classifier__{param_name}': param_values 
                        for param_name, param_values in params.items()}
    

    # Run the search
    search = GridSearchCV(
        estimator=estimator,
        param_grid=params,
        scoring=SCORING,
        refit=refit_metric,
        cv=cv,
        n_jobs=-1,
        verbose=verbose,
        return_train_score=True,
    )
    search.fit(X_train, y_train)

    # Save model
    experiment_name = f"{sampling_method.lower()}_outliers_{remove_outliers}"

    # Evaluate best model
    eval_results = evaluate_model(search.best_estimator_, X_train, y_train, X_test, y_test)

    # Summary table (1 row per model)
    summary = {
        "model": model_name,
        "sampling_method": sampling_method,
        "remove_outliers": remove_outliers,
        "best_cv_score": round(search.best_score_, 4), # Best mean validation score from CV (based on refit_metric), higher better!
        "best_params": json.dumps(search.best_params_), 
        "train_f1_macro": round(eval_results["train"]["f1_macro"], 4), # Macro-F1 on training - how well model fits seen data across all classes
        "test_f1_macro": round(eval_results["test"]["f1_macro"], 4), # Macro-F1 on test data - balanced generalization to all classes?
        "test_accuracy": round(eval_results["test"]["accuracy"], 4), # overall proportion of correct predictions 
        "train_test_diff": round(eval_results["train"]["f1_macro"] - eval_results["test"]["f1_macro"], 4), # Gap between train and test: Over/Underfitting indicator: smaller better!
        "roc_auc": round(eval_results["test"]["roc_auc"], 4) if not np.isnan(eval_results["test"]["roc_auc"]) else None, # ROC-AUC on test data: Class separation: closer to 1 better separation!
    }

    # Log cross-fold metrics for best model

    cv_df = pd.DataFrame(search.cv_results_)
    best_idx = search.best_index_
    summary["cv_mean_train_f1_macro"] = round(cv_df["mean_train_f1_macro"][best_idx],4) # High: model fits training folds well, too hig vs validation: possible overfitting
    summary["cv_std_train_f1_macro"]  = round(cv_df["std_train_f1_macro"][best_idx],4) # should be low: stable learning across folds
    summary["cv_mean_val_f1_macro"] = round(cv_df["mean_test_f1_macro"][best_idx],4) # balanced per class performance
    summary["cv_std_val_f1_macro"] = round(cv_df["std_test_f1_macro"][best_idx],4) # should be low
    summary["cv_diff_train_val_f1_macro"] = round(cv_df["mean_train_f1_macro"][best_idx] - cv_df["mean_test_f1_macro"][best_idx],4)
    summary["cv_mean_val_bal_acc"] = round(cv_df["mean_test_bal_acc"][best_idx],4) # Higher better: class imbalance  by averaging recall per class
    summary["cv_std_val_bal_acc"] = round(cv_df["std_test_bal_acc"][best_idx],4) # should be low
    summary["mean_fit_time"] = round(cv_df["mean_fit_time"][best_idx],4) 
    summary["std_fit_time"] = round(cv_df["std_fit_time"][best_idx],4)

    for lbl, f1_val in zip(eval_results["labels"], eval_results["test"]["f1_per_class"]):
        summary[f"test_f1_class_{lbl}"] = round(float(f1_val), 4)

    for lbl, f1_val in zip(eval_results["labels"], eval_results["train"]["f1_per_class"]):
        summary[f"train_f1_class_{lbl}"] = round(float(f1_val), 4)

    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    pd.DataFrame([summary]).to_csv(results_path, mode="a", header=not os.path.exists(results_path), index=False)

    # Save full CV results for analysis
    cv_full_path = results_path.replace(".csv", '_'+model_name+'_'+experiment_name+"_cv_results.csv")
    cv_df.to_csv(cv_full_path, index=False)

    # Generate diagnostics / graphics
    save_overfit_diagnostic(cv_df, model_name, sampling_method, results_path)
    save_cv_diagnostics(cv_df, model_name, sampling_method, results_path)
    save_model_diagnostics(eval_results, model_name, sampling_method, results_path)
    save_roc_curve(search.best_estimator_, X_test, y_test, model_name, sampling_method, results_path)

    print(f"Saved unified results to {results_path}")

    # Log to MLflow if enabled
    if log_to_mlflow and mlflow_experiment_id is not None:
        log_results_to_mlflow(
            model_name=model_name,
            search=search,
            eval_results=eval_results,
            summary=summary,
            sampling_method=sampling_method,
            sampler_kwargs=sampler_kwargs,  # ADD THIS LINE
            remove_outliers=remove_outliers,
            results_path=results_path,
            experiment_id=mlflow_experiment_id,
            n_iter=0,
            cv=cv,
            refit_metric=refit_metric,
            random_state=RANDOM_STATE,
            dataset_name=dataset_name, 
            dataset_source=dataset_source,
            X_train=X_train,
            mlflow_tracking_uri=mlflow_tracking_uri,
        )

    if model_saver:
        meta = {
            "best_params": search.best_params_,
            "best_score": search.best_score_,
            "cv_results": search.cv_results_,
            "experiment": experiment_name,
            "classifier": model_name,
        }
        model_saver.save_model(model_name, search, experiment_name, meta)
    
    print(f"Saved model {model_name} ({experiment_name})!")

    return summary


## 2. Constants & Param Spaces

In [None]:

PARAM_SPACES = {
    "XGBoost": {
        "estimator": xgb.XGBClassifier(
            objective="multi:softmax",
            num_class=5,
            random_state=RANDOM_STATE,
            n_jobs=-1,
            eval_metric="mlogloss",
        ),
        "params": {
            "n_estimators": [150, 200, 250, 350, 500],
            "max_depth": [8, 9],
            "learning_rate": [0.2],
            "subsample": [0.7, 0.8],
            "colsample_bytree": [0.9],
            "reg_alpha": [0.1, 0.2],
            "reg_lambda": [0.0, 0.05],
            "min_child_weight": [5],
            "gamma": [0.0, 0.05],
        },
        "cv": RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=RANDOM_STATE),
    },
    "ANN": {
        "estimator": MLPClassifier(
            max_iter=300,
            early_stopping=True,
            random_state=RANDOM_STATE,
            n_iter_no_change=10,
            solver="adam",
        ),
        "params": {
            "hidden_layer_sizes": [(128, 64)],
            "activation": ["relu"],
            "alpha": [3e-4],
            "learning_rate_init": [0.001, 0.0015],
            "batch_size": [96, 128],
            "beta_1": [0.9, 0.91],
            "beta_2": [0.97, 0.974],
            "validation_fraction": [0.1],
        },
        "cv": RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=RANDOM_STATE),
    },
    # best: {'clf__kernel': 'rbf', 'clf__gamma': 0,5, 'clf__C': 10}
    "SVM": {
        "estimator": SVC(),
        "params": {
            "kernel": ["rbf"],
            "C": [10],
            "gamma": [0.4, 0.5, 0.6],
        },
        "cv": RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=RANDOM_STATE),
    },
}


In [None]:
RANDOM_STATE = 42
dataset_name="MIT-BIH"
EXPERIMENT_NAME = "MIT_03_01_GS_SAMPLING"
REDUCED_DATASET = True # 5% of original for testing
dirname = 'MIT_03_01_baseline_models_grid_search'
results_csv = f"reports/03_model_testing_results/{dirname}.csv"
SCORING = {'f1_macro': 'f1_macro', 'bal_acc': 'balanced_accuracy', 'f1_weighted': 'f1_weighted'}
RESULTS_PATH = f"reports/03_baseline_models/{EXPERIMENT_NAME}/A_03_01.csv"
minio_endpoint_url = "http://192.168.178.78:9500"
MLFLOW_TRACKING_URI = "http://mlflow.home.lan"#

dataset_train="data/original/mitbih_train.csv"


#import MIT data 
df_mitbih_train = pd.read_csv('data/original/mitbih_train.csv', header = None)
df_mitbih_test = pd.read_csv('data/original/mitbih_test.csv', header = None)

X_train = pd.read_csv('data/processed/mitbih/X_train.csv')
y_train = pd.read_csv('data/processed/mitbih/y_train.csv')
y_train = y_train['187']

X_train_sm = pd.read_csv('data/processed/mitbih/X_train_sm.csv')
y_train_sm = pd.read_csv('data/processed/mitbih/y_train_sm.csv')
y_train_sm = y_train_sm['187']

X_val = pd.read_csv('data/processed/mitbih/X_val.csv')
y_val = pd.read_csv('data/processed/mitbih/y_val.csv')

X_test = df_mitbih_test.drop(187, axis = 1)
y_test = df_mitbih_test[187]

print("MITBIH dataset - SMOTE")
print(f"\tTraining ORIG  size: {df_mitbih_train.shape}")
print(f"\tTraining SMOTE size: {X_train_sm.shape}, {y_train_sm.shape}")
print(f"\tTest size: {X_test.shape}, {y_test.shape}")
print(f"\tVal size: {X_val.shape}, {y_val.shape}")


if REDUCED_DATASET:
    EXPERIMENT_NAME + '_RED'

    # Subsample training set to 10 % (keeping all classes)
    X_train_small, _, y_train_small, _ = train_test_split(
        X_train, y_train,
        train_size=0.05,
        stratify=y_train,
        random_state=42
    )

    # Subsample test set to 10 % as well
    X_test_small, _, y_test_small, _ = train_test_split(
        X_test, y_test,
        train_size=0.05,
        stratify=y_test,
        random_state=42
    )

    print("Reduced MIT-BIH dataset")
    print(f"\tTraining size: {X_train_small.shape}, {y_train_small.shape}")
    print(f"\tTest size: {X_test_small.shape}, {y_test_small.shape}")

    # Assign back for your pipeline
    X_train, y_train = X_train_small, y_train_small
    X_test,  y_test  = X_test_small,  y_test_small


experiment_id = configure_mlflow_for_minio(EXPERIMENT_NAME, MLFLOW_TRACKING_URI, minio_endpoint_url)

model_saver = create_model_saver(f"src/models/{dirname}")

sampling_methods = {
    'SMOTE': {"random_state": RANDOM_STATE, "k_neighbors": 5}, 
}

best_models = ["XGBoost", "SVM", "KNN"]


In [None]:
for model_name, param_dict in PARAM_SPACES.items():
    for sampler_name, sampler_kwargs in sampling_methods.items():
        if sampler_name == "SMOTE":
            X_train_ = X_train_sm
            y_train_ = y_train_sm
        elif sampler_name == "No_Sampling":
            X_train_ = X_train
            y_train_ = y_train
        elif sampler_name != "No_Sampling":
            raise ValueError(f"Unsupported sampling method: {sampler_name}")
            
        run_grid_search(model_name, 
                        estimator=param_dict["estimator"],
                        params=param_dict["params"],
                        X_train=X_train_,  # Use the loop variable
                        y_train=y_train_,  # Use the loop variable
                        X_test=X_test,
                        y_test=y_test,
                        cv=param_dict["cv"],
                        results_path=RESULTS_PATH,
                        sampling_method=sampler_name,
                        sampler_kwargs=sampler_kwargs,
                        remove_outliers=False,
                        model_saver=model_saver,
                        SCORING=SCORING,
                        verbose=3,
                        refit_metric="f1_macro",
                        log_to_mlflow=True,
                        dataset_name=dataset_name,
                        dataset_source=dataset_train,
                        mlflow_experiment_id=experiment_id
                        )