In [0]:
df = spark.read.format("delta").load("/mnt/datamount/delta_table")

In [0]:
!pip install mlflow

In [0]:
with open('tokenfile', 'w') as f:
    f.write(dbutils.secrets.get(scope="creds", key="dbtoken"))
!databricks configure --host https://adb-6724577987585661.1.azuredatabricks.net/ --token-file tokenfile

In [0]:
# Import necessary libraries
import mlflow
import mlflow.sklearn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# Load the data
data = df.toPandas()

# Define feature columns and target
features = ["Pclass", "Age", "SibSp", "Parch", "Fare"]
target = "Survived"

# Split the data into training and testing sets
X = data[features]
y = data[target]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

mlflow.set_experiment("/Users/bhaveshkak26122000@gmail.com/my_experiment")

# Start an MLflow run
with mlflow.start_run():
    # Train a machine learning model (Random Forest in this example)
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)

    # Make predictions on the test set
    y_pred = model.predict(X_test)

    # Calculate and log accuracy
    accuracy = accuracy_score(y_test, y_pred)
    mlflow.log_metric("accuracy", accuracy)

    # Log the model
    mlflow.sklearn.log_model(model, "model")

    # Save the feature columns for reference
    mlflow.log_param("features", features)

    # Register the model in MLflow
    model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"
    registered_model_name = "titanic_model"  # Replace with your desired model name

    # You can specify model tags and description if needed
    model_tags = {"key1": "value1", "key2": "value2"}
    model_description = "Description of the registered model"

    # Register the model with MLflow
    registered_model = mlflow.register_model(model_uri, registered_model_name, tags=model_tags)

    # Print the registered model information
    print(f"Registered model: {registered_model.name} (Version {registered_model.version})")
