In [0]:
# Importing libraries
import os
import sys
import mlflow
import mlflow.sklearn
import joblib
import json
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from mlflow.models.signature import infer_signature

def predict(model, X_test):
    y_test_pred = model.predict(X_test)
    return y_test_pred

def evaluate(y_test, y_test_pred):
    mae = mean_absolute_error(y_test, y_test_pred)
    mse = mean_squared_error(y_test, y_test_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_test, y_test_pred)
    return mae, mse, rmse, r2

def log_model(mae, mse, rmse, r2):
    mlflow.log_metric("MAE", mae)
    mlflow.log_metric("MSE", mse)
    mlflow.log_metric("RMSE", rmse)
    mlflow.log_metric("R2", r2)

    # Register the model in MLflow 
def register_model(model, X_test):
    # Log the model with signature
    signature = infer_signature(X_test, model.predict(X_test))
    mlflow.sklearn.log_model(
        registered_model_name="mlops.public.Ice_Creame_Model",
        sk_model=model,
        signature=signature,
        artifact_path="model"
        )
    
if __name__ == "__main__":
    
 # Load model and test data
 model_path = "/Workspace/Users/mahesh2139@gmail.com/Model/model.pkl"
 test_data_path = "/Workspace/Users/mahesh2139@gmail.com/Model/test_data.pkl"
 model = joblib.load(model_path)
 X_test, y_test = joblib.load(test_data_path)

 mlflow.set_tracking_uri("databricks")
 # Set the registry URI to Unity Catalog
 mlflow.set_registry_uri("databricks-uc")
 mlflow.set_experiment("/Users/mahesh2139@gmail.com/Ice_Creame_Model")
 
 # Create input example for logging
 input_example = X_test[:5]  # first 5 rows of input features
 
 with mlflow.start_run(run_name="ice_cream_regression_model"):
    y_test_pred = predict(model, X_test)
    mae, mse, rmse, r2 = evaluate(y_test, y_test_pred)
    print(f"mae: {mae}, mse: {mse}, rmse : {rmse}, r2: {r2}")
    log_model(mae, mse, rmse, r2)
    register_model(model, X_test)
    print("Run logged and model registered successfully")
    print("Artifact location:",mlflow.get_artifact_uri())
