In [1]:
import pandas as pd
import mlflow
import mlflow.xgboost
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from xgboost import XGBClassifier

In [2]:
df = pd.read_csv('data/data-cleaned.csv')

# Define X and y if not already defined
X = df.drop("output", axis=1)
y = df["output"]

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [11]:
# Set up MLflow tracking URI
mlflow.set_tracking_uri("http://mlflow:5000")
mlflow.set_experiment("heart_attack_training")

def training(max_depth, learning_rate, n_estimators):
    with mlflow.start_run():
        # Define model and hyperparameters
        model = XGBClassifier(max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators)
        model.fit(X_train, y_train)

        # Evaluate the model
        predictions = model.predict(X_test)
        accuracy = accuracy_score(y_test, predictions)

        # Log parameters and metrics
        mlflow.log_param("max_depth", max_depth)
        mlflow.log_param("learning_rate", learning_rate)
        mlflow.log_param("n_estimators", n_estimators)
        mlflow.log_metric("accuracy", accuracy)

        # Log the model
        mlflow.xgboost.log_model(model, artifact_path="models")

        # Register the model
        model_uri = f"runs:/{mlflow.active_run().info.run_id}/models"
        model_name = "heart_attack_model"
        registered_model = mlflow.register_model(model_uri, model_name)

        print(f"Model registered with version: {registered_model.version}")

# Run the training
training(max_depth=3, 
         learning_rate=0.1, 
         n_estimators=100)

Registered model 'heart_attack_model' already exists. Creating a new version of this model...
2024/11/10 13:35:02 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: heart_attack_model, version 2


Model registered with version: 2


Created version '2' of model 'heart_attack_model'.
