In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
%pip install databricks-sdk --upgrade

In [0]:
%pip install --quiet mlflow==2.19 databricks-feature-engineering==0.8.0

In [0]:
%restart_python

In [0]:
dbutils.widgets.text("catalog_use", "datascience_dev", label="Catalog to Use")
dbutils.widgets.text("schema_use", "main", label="Schema to Use")

In [0]:
catalog_use = dbutils.widgets.get("catalog_use")
schema_use = dbutils.widgets.get("schema_use")
spark.sql(f"USE {catalog_use}.{schema_use}")

In [0]:
%sql
select current_catalog(), current_schema();

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedEntityInput, ServingModelWorkloadType, AiGatewayConfig, AiGatewayInferenceTableConfig, AiGatewayUsageTrackingConfig
from databricks.sdk.service.catalog import OnlineTable, OnlineTableSpec, OnlineTableSpecTriggeredSchedulingPolicy

w = WorkspaceClient()

In [0]:
catalog_use = dbutils.widgets.get("catalog_use")
schema_use = dbutils.widgets.get("schema_use")
spark.sql(f"USE {catalog_use}.{schema_use}")

In [0]:
%sql
select current_catalog(), current_schema();

In [0]:
dbutils.widgets.text(
    "short_model_name",
    f"advanced_mlops_churn_model",
    label="Short Model Name",
)

# Feature table to store the computed features.
dbutils.widgets.text(
    "advanced_churn_feature_table",
    f"{catalog_use}.{schema_use}.advanced_churn_feature_table",
    label="Feature Table",
)

In [0]:
model_name = dbutils.widgets.get("short_model_name")
full_model_name = f"{catalog_use}.{schema_use}.{model_name}"
feature_table = dbutils.widgets.get("advanced_churn_feature_table")

In [0]:
print(f""" 
  model_name = {model_name}
  full_model_name = {full_model_name}
  feature_table = {feature_table}
""")

In [0]:
import mlflow

In [0]:
client = mlflow.MlflowClient()

In [0]:
# Define the online table spec
spec = OnlineTableSpec(
    primary_key_columns=["customer_id", "transaction_ts"],  # Replace with your primary key(s)
    source_table_full_name=feature_table,  # Your Delta feature table
    run_triggered=OnlineTableSpecTriggeredSchedulingPolicy.from_dict({'triggered': 'true'})
)

In [0]:
spark.sql(f"ALTER TABLE {feature_table} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")

In [0]:
# Create the online table
online_table = OnlineTable(
    name=f"{feature_table}_online",  # Name for the online table
    spec=spec
)

# Check if the online table exists
if not spark.catalog.tableExists(f"{feature_table}_online"):
  w.online_tables.create_and_wait(table=online_table)

In [0]:
endpoints = w.serving_endpoints.list()
endpoints = [w.serving_endpoints.get(endpoint.name) for endpoint in endpoints if endpoint.name == model_name]
if len(endpoints) > 0:
  pass
else:
  w.serving_endpoints.create(
    name = model_name
    ,config = EndpointCoreConfigInput(
        name = model_name
        ,served_entities = [
            ServedEntityInput(
                entity_name = full_model_name
                ,entity_version = client.get_model_version_by_alias(full_model_name, "champion").version
                ,environment_vars = {}
                ,scale_to_zero_enabled = True
                ,workload_size = "Small"
                ,workload_type = ServingModelWorkloadType("CPU")
            )
        ]
    )
    ,ai_gateway = AiGatewayConfig(
        inference_table_config = AiGatewayInferenceTableConfig(
            catalog_name=catalog_use
            ,schema_name=schema_use
            ,table_name_prefix = None
            ,enabled=True
        )
        ,usage_tracking_config = AiGatewayUsageTrackingConfig(
            enabled=True
        )
    )
  ) 