## 04. UCに登録されたモデルをモデルサービングエンドポイントにデプロイ

02でUCに登録したモデルを、リモートから利用できるようにモデルサービングエンドポイントにデプロイします。

<!-- <img src='https://sajpstorage.blob.core.windows.net/maruyama/public_share/demo_end2end/4_model_endpoint.png' width='800' /> -->
<img src='https://github.com/komae5519pv/komae_dbdemos/blob/main/e2e_ML_20250629/_data/_imgs/4_model_endpoint.png?raw=true' width='1200' />

In [0]:
%run ./00_setup

### 1. モデルサービングエンドポイントにデプロイ
MLflowデプロイSDKを使用して、モデルサービングエンドポイントにUCに登録されたモデルをデプロイします

In [0]:
import mlflow
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException
from requests.exceptions import HTTPError

# MLflowクライアントの作成
mlflow_client = mlflow.MlflowClient()

deploy_client = get_deploy_client("databricks")
model_version = mlflow_client.get_model_version_by_alias(f"{MY_CATALOG}.{MY_SCHEMA}.{MODEL_NAME}", "prod")

endpoint_name = f"komae_{MODEL_NAME}"[:63]  # 63文字以内にサニタイズ

# エンドポイント設定（served_entitiesでエイリアス指定）
config = {
    "served_entities": [{
        "entity_name": f"{MY_CATALOG}.{MY_SCHEMA}.{MODEL_NAME}",
        "entity_version": model_version.version,
        "workload_size": "Small"
    }]
}

# deploy_client.create_endpoint(name=endpoint_name, config=config)
# エラーハンドリング付（新規デプロイ or 既存あればデプロイなし）
try:
    existing_endpoint = deploy_client.get_endpoint(endpoint_name)
    print(f"既存のエンドポイント '{endpoint_name}' が見つかりました。更新は行いません。")
except HTTPError as e:
    if e.response.status_code == 404:
        endpoint = deploy_client.create_endpoint(name=endpoint_name, config=config)
        print(f"エンドポイント '{endpoint_name}' を新規作成しました。")
    else:
        raise
except MlflowException as e:
    print(f"エラーが発生しました: {e}")
    raise

モデルサービングエンドポイントが「READY」状態になるまでノートブックで監視します。<br>
MLflow Deployments SDK（またはREST API）を使ってエンドポイント状態をポーリングします。

In [0]:
import time

def wait_for_endpoint_ready(client, endpoint_name, timeout=1000, check_interval=10):
    """
    エンドポイントがREADY状態になるまで監視する
    :param deploy_client: MLflow Deployments client
    :param endpoint_name: 監視対象のエンドポイント名
    :param timeout: 最大待機時間（秒）
    :param check_interval: 状態チェック間隔（秒）
    :return: True if ready, False if timeout
    """
    start_time = time.time()
    while True:
        endpoint = deploy_client.get_endpoint(endpoint_name)
        state_ready = endpoint.get("state", {}).get("ready", "NOT_READY")
        print(f"Endpoint state: {state_ready}")
        if state_ready == "READY":
            print("Endpoint is ready to receive traffic.")
            return True
        if time.time() - start_time > timeout:
            print("Timeout waiting for endpoint to be ready.")
            return False
        time.sleep(check_interval)

# 監視実行
wait_for_endpoint_ready(deploy_client, endpoint_name)


### 2. テスト推論
モデルサービングエンドポイントにデプロイされたモデルをクエリ、正常に推論できるかテストします

In [0]:
results = spark.sql(f'''
SELECT
  * EXCEPT (churn),
  ai_query(
    'komae_churn_model',
    named_struct(
      "seniorCitizen", seniorCitizen,
      "tenure", tenure,
      "monthlyCharges", monthlyCharges,
      "totalCharges", totalCharges,
      "gender_Female", gender_Female,
      "gender_Male", gender_Male,
      "partner_No", partner_No,
      "partner_Yes", partner_Yes,
      "dependents_No", dependents_No,
      "dependents_Yes", dependents_Yes,
      "phoneService_No", phoneService_No,
      "phoneService_Yes", phoneService_Yes,
      "multipleLines_No", multipleLines_No,
      "multipleLines_Nophoneservice", multipleLines_Nophoneservice,
      "multipleLines_Yes", multipleLines_Yes,
      "internetService_DSL", internetService_DSL,
      "internetService_Fiberoptic", internetService_Fiberoptic,
      "internetService_No", internetService_No,
      "onlineSecurity_No", onlineSecurity_No,
      "onlineSecurity_Nointernetservice", onlineSecurity_Nointernetservice,
      "onlineSecurity_Yes", onlineSecurity_Yes,
      "onlineBackup_No", onlineBackup_No,
      "onlineBackup_Nointernetservice", onlineBackup_Nointernetservice,
      "onlineBackup_Yes", onlineBackup_Yes,
      "deviceProtection_No", deviceProtection_No,
      "deviceProtection_Nointernetservice", deviceProtection_Nointernetservice,
      "deviceProtection_Yes", deviceProtection_Yes,
      "techSupport_No", techSupport_No,
      "techSupport_Nointernetservice", techSupport_Nointernetservice,
      "techSupport_Yes", techSupport_Yes,
      "streamingTV_No", streamingTV_No,
      "streamingTV_Nointernetservice", streamingTV_Nointernetservice,
      "streamingTV_Yes", streamingTV_Yes,
      "streamingMovies_No", streamingMovies_No,
      "streamingMovies_Nointernetservice", streamingMovies_Nointernetservice,
      "streamingMovies_Yes", streamingMovies_Yes,
      "contract_Month-to-month", `contract_Month-to-month`,
      "contract_Oneyear", contract_Oneyear,
      "contract_Twoyear", contract_Twoyear,
      "paperlessBilling_No", paperlessBilling_No,
      "paperlessBilling_Yes", paperlessBilling_Yes,
      "paymentMethod_Banktransfer-automatic", `paymentMethod_Banktransfer-automatic`,
      "paymentMethod_Creditcard-automatic", `paymentMethod_Creditcard-automatic`,
      "paymentMethod_Electroniccheck", paymentMethod_Electroniccheck,
      "paymentMethod_Mailedcheck", paymentMethod_Mailedcheck
    ),
    'FLOAT'  -- モデルが返す型に合わせて指定
  ) AS prediction
FROM {MY_CATALOG}.{MY_SCHEMA}.churn_features
LIMIT 100
''')

print(results.count())
print(results.columns)
display(results.limit(10))

In [0]:
from pyspark.sql.functions import expr

# テーブルデータの取得
df = spark.table(f"{MY_CATALOG}.{MY_SCHEMA}.churn_features")

# ai_queryをexprで実行
results = df.withColumn(
    "prediction",
    expr(f"""
    ai_query(
        'komae_churn_model',
        named_struct(
            'seniorCitizen', seniorCitizen,
            'tenure', tenure,
            'monthlyCharges', monthlyCharges,
            'totalCharges', totalCharges,
            'gender_Female', gender_Female,
            'gender_Male', gender_Male,
            'partner_No', partner_No,
            'partner_Yes', partner_Yes,
            'dependents_No', dependents_No,
            'dependents_Yes', dependents_Yes,
            'phoneService_No', phoneService_No,
            'phoneService_Yes', phoneService_Yes,
            'multipleLines_No', multipleLines_No,
            'multipleLines_Nophoneservice', multipleLines_Nophoneservice,
            'multipleLines_Yes', multipleLines_Yes,
            'internetService_DSL', internetService_DSL,
            'internetService_Fiberoptic', internetService_Fiberoptic,
            'internetService_No', internetService_No,
            'onlineSecurity_No', onlineSecurity_No,
            'onlineSecurity_Nointernetservice', onlineSecurity_Nointernetservice,
            'onlineSecurity_Yes', onlineSecurity_Yes,
            'onlineBackup_No', onlineBackup_No,
            'onlineBackup_Nointernetservice', onlineBackup_Nointernetservice,
            'onlineBackup_Yes', onlineBackup_Yes,
            'deviceProtection_No', deviceProtection_No,
            'deviceProtection_Nointernetservice', deviceProtection_Nointernetservice,
            'deviceProtection_Yes', deviceProtection_Yes,
            'techSupport_No', techSupport_No,
            'techSupport_Nointernetservice', techSupport_Nointernetservice,
            'techSupport_Yes', techSupport_Yes,
            'streamingTV_No', streamingTV_No,
            'streamingTV_Nointernetservice', streamingTV_Nointernetservice,
            'streamingTV_Yes', streamingTV_Yes,
            'streamingMovies_No', streamingMovies_No,
            'streamingMovies_Nointernetservice', streamingMovies_Nointernetservice,
            'streamingMovies_Yes', streamingMovies_Yes,
            'contract_Month-to-month', `contract_Month-to-month`,
            'contract_Oneyear', contract_Oneyear,
            'contract_Twoyear', contract_Twoyear,
            'paperlessBilling_No', paperlessBilling_No,
            'paperlessBilling_Yes', paperlessBilling_Yes,
            'paymentMethod_Banktransfer-automatic', `paymentMethod_Banktransfer-automatic`,
            'paymentMethod_Creditcard-automatic', `paymentMethod_Creditcard-automatic`,
            'paymentMethod_Electroniccheck', paymentMethod_Electroniccheck,
            'paymentMethod_Mailedcheck', paymentMethod_Mailedcheck
        ),
        'FLOAT'
    )
    """)
)

print(results.count())
print(results.columns)
display(results.limit(10))

# オプション: 結果をテーブルとして保存
# results.write.mode("overwrite").saveAsTable(f"{MY_CATALOG}.{MY_SCHEMA}.churn_predictions")
