<a href="https://colab.research.google.com/github/bits05368/mlops_group53/blob/main/src/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install mlflow

In [None]:
import mlflow
import mlflow.sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from model import load_data, evaluate_model

def main(data_path="data/iris.csv", experiment_name="Iris-Experiment"):
    X, y = load_data(data_path)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    mlflow.set_experiment(experiment_name)

    results = []  # store (model_name, accuracy, run_id)

    # 1 Logistic Regression
    with mlflow.start_run(run_name="LogisticRegression") as run:
        params = {"max_iter": 200}
        mlflow.log_params(params)
        model = LogisticRegression(**params)
        acc = evaluate_model(model, X_train, X_test, y_train, y_test)
        mlflow.log_metric("accuracy", acc)
        mlflow.sklearn.log_model(model, artifact_path="model")
        results.append(("LogisticRegression", acc, run.info.run_id))

    # 2 Random Forest
    with mlflow.start_run(run_name="RandomForest") as run:
        params = {"n_estimators": 100, "random_state": 42}
        mlflow.log_params(params)
        model = RandomForestClassifier(**params)
        acc = evaluate_model(model, X_train, X_test, y_train, y_test)
        mlflow.log_metric("accuracy", acc)
        mlflow.sklearn.log_model(model, artifact_path="model")
        results.append(("RandomForest", acc, run.info.run_id))

    # Pick best model
    best_model_name, best_acc, best_run_id = max(results, key=lambda x: x[1])
    print(f" Best model: {best_model_name} (accuracy={best_acc:.4f})")

    # Register best model in MLflow registry
    # Requires MLFLOW_TRACKING_URI to point to a tracking server with registry support
    model_uri = f"runs:/{best_run_id}/model"
    mlflow.register_model(model_uri, "IrisClassifier")
    print(f" Registered {best_model_name} in MLflow Model Registry.")

if __name__ == "__main__":
    main()