In [None]:
import mlflow
from pyspark.sql import SparkSession

from hotel_reservation.config import ProjectConfig, Tags
from hotel_reservation.models.basic_model import BasicModel

In [None]:
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

In [None]:
config = ProjectConfig.from_yaml(config_path="../project_config.yml")
spark = SparkSession.builder.getOrCreate()
tags = Tags(**{"git_sha": "abcd12345", "branch": "week2"})

In [None]:
# Initialize model with the config path
basic_model = BasicModel(config=config, tags=tags, spark=spark)

In [None]:
basic_model.load_data()
basic_model.prepare_features()

In [None]:
basic_model.train()
basic_model.log_model()

In [None]:
run_id = mlflow.search_runs(
    experiment_names=["/Shared/house-prices-basic"], filter_string="tags.branch='week2'"
).run_id[0]

model = mlflow.sklearn.load_model(f"runs:/{run_id}/lightgbm-pipeline-model")

In [None]:
# Retrieve dataset for the current run
basic_model.retrieve_current_run_dataset()

In [None]:
# Retrieve metadata for the current run
basic_model.retrieve_current_run_metadata()

In [None]:
# Register model
basic_model.register_model()

In [None]:
test_set = spark.table(f"{config.catalog_name}.{config.schema_name}.test_set").limit(10)

X_test = test_set.drop(config.target).toPandas()

predictions_df = basic_model.load_latest_model_and_predict(X_test)