## Log the Agent to MLflow, validate inference, and deploy to model serving

Install the uv library to test our model deployment within the notebook using a virtual environment built from this project's requirements.txt file

In [0]:
%pip install uv
%restart_python

In [0]:
%load_ext autoreload
%autoreload 2

Load the compiled langgraph agent and view it's graph. This graph will be wrapped in an MLflow ChatAgent model for deployment to Databricks model serving. The ChatAgent ensures that all chat messages received and returned will conform to mlflow chat message formats, which Databricks model serving requires.

Note that the genie and rag agents can only return control to the supervisor. They cannot directly end the graph's execution as shown in the below image.

In [0]:
import os
import mlflow
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex
from databricks.agents import deploy
from IPython.display import Image
from swarm_agent import swarm

display(Image(swarm.get_graph().draw_mermaid_png(output_file_path="swarm_graph.png")))

Load secrets for authentication to the Databricks Genie space

In [0]:
DATABRICKS_HOST = dbutils.secrets.get(scope="mlc_credentials", key="databricks_host")
DATABRICKS_TOKEN = dbutils.secrets.get(scope="mlc_credentials", key="databricks_token")

os.environ["DATABRICKS_HOST"] = DATABRICKS_HOST
os.environ["DATABRICKS_TOKEN"] = DATABRICKS_TOKEN

messages = {"messages": [{"role": "user", "content": "What is Apache Spark?"}]}

Load configuration related to logging

In [0]:
config = mlflow.models.ModelConfig(development_config='config.yaml')

genie_llm = config.get("agents").get("genie")[0]['llm']

rag_config = config.get("agents").get("rag")[0]
rag_llm = rag_config['llm']
index_location = rag_config['index_location']

llms = list(set([genie_llm,  rag_llm]))
serving_endpoints = [DatabricksServingEndpoint(endpoint_name=llm) for llm in llms]
mlflow_config = config.get("mlflow")
experiment = mlflow_config['experiment_location']
uc_model = mlflow_config['uc_model']
input_example = mlflow_config['input_example']

mlflow.set_experiment(experiment)
mlflow.set_registry_uri("databricks-uc")

Register agent to MLflow

In [0]:
messages = {"messages": [{"role": "user", "content": "What are our top 3 forecasted raw material shortages?"}]}

from swarm_agent import swarm
for event in swarm.stream(messages, stream_mode="updates"):
  print(event)

In [0]:
from swarm_agent import AGENT
for event in AGENT.predict_stream(messages):
  print(event)

In [0]:
with mlflow.start_run(run_name="swarm"):

  model_info = mlflow.pyfunc.log_model(
                  python_model = "swarm_agent.py",
                  streamable=True,
                  model_config="config.yaml",
                  artifact_path="graph",
                  input_example=input_example,
                  # See the caveate's for code_paths: https://mlflow.org/docs/latest/ml/model/dependencies#caveats-of-code_paths-option
                  code_paths = [
                    'agents'
                    ],
                  resources = [
                    DatabricksVectorSearchIndex(index_name=index_location),
                    *serving_endpoints
                    ],
                  pip_requirements = "requirements.txt"
               )
  
  mlflow.log_artifact("swarm_graph.png")

  model_uri = model_info.model_uri

  loaded_app = mlflow.pyfunc.load_model(model_uri)
  loaded_app.predict(input_example)
  
print(model_uri)

Validate inference within a virtual environment based on the requirements.txt file

In [0]:
messages = {"messages": [{"role": "user", "content": "What are our top 3 forecasted raw material shortages?"}]}

mlflow.models.predict(
    model_uri=model_uri,
    input_data=messages,
    env_manager="uv",
)

Register agent version to Unity Catalog; this is a requirement for model serving

In [0]:
model_info = mlflow.register_model(model_uri, 
                                   name = uc_model,
                                   tags={"architecture": "swarm",
                                         "stage": "production"})

Deploy the agent to model serving

In [0]:
deployment_info = deploy(model_name=uc_model, 
                         model_version=model_info.version,
                         environment_vars = {"DATABRICKS_HOST": "{{secrets/mlc_credentials/databricks_host}}",
                                             "DATABRICKS_TOKEN": "{{secrets/mlc_credentials/databricks_token}}"})

Query the agent endpoint

In [0]:
import os
import requests
import numpy as np
import pandas as pd
import json

host = dbutils.secrets.get('mlc_credentials', 'databricks_host')
token = dbutils.secrets.get('mlc_credentials', 'databricks_token')
url = f'{host}/serving-endpoints/agents_main-default-langgraph_example_agent/invocations'

def score_model(dataset):
    headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
    data_json = json.dumps(dataset, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    return response.json()

In [0]:
input_example

In [0]:
score_model(input_example)

In [0]:

{"messages": [{"role": "user", "content": "What are our top 3 forecasted raw material shortages?"}]}
score_model(messages)