# Build and Deploy Supply Chain Agent
This notebook demonstrates how to:
- Author a supply chain agent that uses data access, data analysis and optimization tools.
- Manually test the agent's output.
- Evaluate the agent with Mosaic AI Agent Evaluation.
- Log and deploy the agent.

## Cluster Configuration
This notebook was tested on the following Databricks cluster configuration:
- **Databricks Runtime Version:** 16.4 LTS ML (includes Apache Spark 3.5.2, Scala 2.12)
- **Single Node** 
    - Azure: Standard_DS4_v2 (28 GB Memory, 8 Cores)
    - AWS: m5d.2xlarge (32 GB Memory, 8 Cores)

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
%pip install -r ../requirements.txt --quiet
dbutils.library.restartPython()

## Define the agent in code
Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command.

For more examples of tools to add to your agent, see [docs](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/agent-tool).

In [0]:
%%writefile supply_chain_agent.py
from typing import Any, Optional, Sequence, Union
import os
import json
import mlflow
import pandas as pd
from databricks_langchain import ChatDatabricks
from databricks_langchain import (
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from databricks.sdk import WorkspaceClient
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool, tool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.models import ModelConfig
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

import random, string, math
from itertools import product
from collections import defaultdict
import pyomo.environ as pyo
import scripts.utils as utils


mlflow.langchain.autolog()

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT = "gpt-5"

config = {
    "endpoint_name": LLM_ENDPOINT,
    "catalog": "supply_chain_stress_test",
    "database": "data",
    "volume": "operational",
    "temperature": 0.01,
    "max_tokens": 10000,
    "system_prompt": """
    "You are a helpful assistant that answers questions about a supply chain network. Questions outside this topic are considered irrelevant. You use a set of tools to provide answers, and if needed, you ask the user follow up questions to clarify their request. You may need to execute multiple tools in sequence to build up the final answer. When you receive a request, first plan the steps carefully and then execute.

    When interpreting the output of the optimization tool, make use of the following definiiotns of the parameters, decision variables and metric.

    - Below are the definitions of the important metrics:
    Metrics               | What it represents
    TTR                   | Stands for time to recover. Recovery time for a node or group after a disruption.               |
    TTS                   | Stands for time to survive. TTS indicates how long the network can meet demand with no loss.    |

    - Below are the definitions of the model parameters:
    Parameters            | What it represents                                                                                      | 
    tier1 / tier2 / tier3 | Lists of node IDs in each tier.                                                                         |
    edges                 | Directed links `(source, target)` showing which node supplies which.                                    |
    material_type         | List of all material types.                                                                             |
    supplier_material_type| Material type each supplier produces and supplies.                                                      |
    f                     | Profit margin for each Tier 1 node's finished product.                                                  |
    s                     | On-hand inventory units at every node.                                                                  |
    d                     | Demand per time unit for Tier 1 products.                                                               |
    c                     | Production capacity per time unit at each node.                                                         |
    r                     | Number of material types (k) required to make one unit of node j.                                       |
    N_minus               | For each node j (Tier 1 or 2), the set of material types it requires.                                   |
    N_plus                | For each supplier i (Tier 2 or 3), the set of downstream nodes j it feeds.                              |
    P                     | For each (j, material_part) pair, a list of upstream suppliers i that provides it (multi-sourcing view).|
    
    - Below are the definitions of decision variables:
    Decision Variables    | What it represents                                                                                 | 
    l                     | Production volume lost of the product of the node.                                                 |
    u                     | Total production volume of the node during the ttr.                                                |
    y                     | Allocation of upstream node to downsteam node during the ttr.                                      |
    
    Report the profit loss during the recovery period. When giving recommendations, compare the optimized network with and without the disruption and base them on differences in the decision variables. Include detailed action plans, a summary of the best actions for this scenario, and precise numbers whenever possible. Finally, the users of this tool are buisness analysts, so keep the language simple and avoid technical terms.
    """,
}

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/agent-tool
###############################################################################
@tool
def data_access_tool(
    catalog: str = config["catalog"],
    database: str = config["database"],
    volume: str = config["volume"],
    ):
    """
    Accesses supply chain dataset stored in the specified Unity Catalog Volume and returns the data. 

    Parameters:
    catalog (str): Catalog name for Unity Catalog.
    database (str): Database name in Unity Catalog.
    volume (str): Volume name in Unity Catalog.
    
    Returns:
    str: Dataset results.
    """    
    # Get the operational data from Unity Catalog Volume
    w = WorkspaceClient(host=os.getenv("HOST"), token=os.getenv("TOKEN"))
    path = f'/Volumes/{catalog}/{database}/{volume}/dataset_small.json'
    resp = w.files.download(path)
    with resp.contents as fh:
        dataset = json.load(fh)

    return dataset

@tool
def optimization_tool(
    disrupted: list[str],
    ttr: float,
    ) -> str:
    """
    Runs optimization algorithms for a given disrupted scenario and returns the results as a string. This function first solves the optimization problem without considering the disruption, and then solves the problem with the disruption. The function returns a string containing the results of both optimizations.

    Parameters:
    disrupted (list[str]): List of disrupted nodes in the scenario.
    ttr (float): Time to recover (TTR) for the disruption scenario.
    
    Returns:
    str: A string representation of the optimization results.
    """    
    # Get the operational data from Unity Catalog Volume
    dataset = data_access_tool.func()

    # Solve the TTR model without distruption
    df_normal = utils.build_and_solve_ttr(dataset, [], ttr, True)
    model = df_normal["model"].values[0]
    records_normal = []
    for v in model.component_data_objects(ctype=pyo.Var, active=True):
        idx  = v.index()
        record  = {
            "var_name"  : v.parent_component().name,
            "index"     : idx,
            "value"     : pyo.value(v),
        }
        records_normal.append(record)
    
    # Solve the TTR model with distruption
    df_distrupted  = utils.build_and_solve_ttr(dataset, disrupted, ttr, True)
    model = df_distrupted["model"].values[0]
    records_distrupted = []
    for v in model.component_data_objects(ctype=pyo.Var, active=True):
        idx  = v.index()
        record  = {
            "var_name"  : v.parent_component().name,
            "index"     : idx,
            "value"     : pyo.value(v),
        }
        records_distrupted.append(record)

    df_distrupted = df_distrupted.drop(["model"], axis=1)
    df_distrupted["optimized_network_without_disruption"] = str(records_normal)
    df_distrupted["optimized_network_with_disruption"] = str(records_distrupted)

    # Solve the TTS model with distruption
    df_tts = utils.build_and_solve_tts(dataset, disrupted, False)
    df_distrupted["tts"] = df_tts["tts"].values[0]
    
    # Convert the result to string
    row_str = ",".join(f"{k}={v}" for k, v in df_distrupted.iloc[0].astype(str).items())

    return row_str


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    agent_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    def routing_logic(state: ChatAgentState):
        last_message = state["messages"][-1]
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if agent_prompt:
        system_message = {"role": "system", "content": agent_prompt}
        preprocessor = RunnableLambda(
            lambda state: [system_message] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)
    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        routing_logic,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class SupplyChainAgent(ChatAgent):
    def __init__(self, config, tools):
        # Load config
        # When this agent is deployed to Model Serving, the configuration loaded here is replaced with the config passed to mlflow.pyfunc.log_model(model_config=...)
        self.config = ModelConfig(development_config=config)
        self.tools = tools
        self.agent = self._build_agent_from_config()

    def _build_agent_from_config(self):
        llm = ChatDatabricks(
            endpoint=self.config.get("endpoint_name"),
            #temperature=self.config.get("temperature"),
            max_tokens=self.config.get("max_tokens"),
        )
        agent = create_tool_calling_agent(
            llm,
            tools=self.tools,
            agent_prompt=self.config.get("system_prompt"),
        )
        return agent

    @mlflow.trace(name="SupplyChainAgent", span_type=mlflow.entities.SpanType.AGENT)
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        # ChatAgent has a built-in helper method to help convert framework-specific messages, like langchain BaseMessage to a python dictionary
        request = {"messages": self._convert_messages_to_dict(messages)}

        output = self.agent.invoke(request)
        # Here 'output' is already a ChatAgentResponse, but to make the ChatAgent signature explicit for this demonstration we are returning a new instance
        return ChatAgentResponse(**output)
    
tools = [data_access_tool, optimization_tool]

AGENT = SupplyChainAgent(config, tools)
mlflow.models.set_model(AGENT)

## Test the agent

Interact with the agent to test its output.

In [0]:
dbutils.library.restartPython()

Define enviroment variables for the agent to use to authenticate itself.

In [0]:
import os
import mlflow
from dbruntime.databricks_repl_context import get_context

user = spark.sql('select current_user() as user').collect()[0]['user'] # User email address

# TODO: set WORKSPACE_URL manually if it cannot be inferred from the current notebook
WORKSPACE_URL = None
if WORKSPACE_URL is None:
  workspace_url_hostname = get_context().browserHostName
  assert workspace_url_hostname is not None, "Unable to look up current workspace URL. This can happen if running against serverless compute. Manually set WORKSPACE_URL yourself above, or run this notebook against classic compute"
  WORKSPACE_URL = f"https://{workspace_url_hostname}"

# TODO: set secret_scope_name and secret_key_name to access your PAT
secret_scope = "ryuta"
secret_key = "token"

os.environ["HOST"] = WORKSPACE_URL
os.environ["TOKEN"] = dbutils.secrets.get(scope=secret_scope, key=secret_key)

In [0]:
from supply_chain_agent import AGENT

AGENT.predict({"messages": [{"role": "user", "content": "List all downstream sites for the raw material supplied by T3_10, and include any related information about these sites. Visualize the results clearly."}]})

In [0]:
AGENT.predict({"messages": [{"role": "user", "content": "Tell me what happens if T2_8 is disrupted and requires 9 weeks to recover. What should I do?"}]})

### Log the `agent` as an MLflow model
Log the agent as code from the `supply_chain_agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [0]:
import os
import mlflow
from supply_chain_agent import LLM_ENDPOINT, config, tools
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT)]
for tool in tools:
    if isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

code_path = os.getcwd().replace("agent", "scripts")

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        python_model="supply_chain_agent.py",
        name="agent",
        model_config=config,
        resources=resources,
        pip_requirements=[
            "mlflow",
            "langchain",
            "langgraph==0.3.4",
            "databricks-langchain",
            "unitycatalog-langchain[databricks]",
            "pydantic",
            "-r ../requirements.txt",
        ],
        code_paths=[code_path],
        input_example={
            "messages": [{"role": "user", "content": "Tell me what happens if T2_8 is disrupted and requires 9 to recover to its normal state and what to do."}]
        },
    )

## Evaluate the agent with [Agent Evaluation](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor)

You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics.

Evaluate your agent with one of our [predefined LLM scorers](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/predefined-judge-scorers), or try adding [custom metrics](https://learn.microsoft.com/azure/databricks/mlflow3/genai/eval-monitor/custom-scorers).

In [0]:
'''
import mlflow
from mlflow.genai.scorers import RelevanceToQuery, Safety, RetrievalRelevance, RetrievalGroundedness

eval_dataset = [
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "Why should I use databricks for forecasting?"
                }
            ]
        },
        "expected_response": None
    }
]

eval_results = mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=lambda input: AGENT.predict({"input": input}),
    scorers=[RelevanceToQuery(), Safety()], # add more scorers here if they're applicable
)

# Review the evaluation results in the MLfLow UI (see console output)
'''

## Register the agent to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
import mlflow
from databricks import agents

# Connect to the Unity catalog model registry
mlflow.set_registry_uri("databricks-uc")

catalog = "supply_chain_stress_test"    # Change here
schema = "agents"                       # Change here
agent_name = "supply_chain_agent"       # Change here

# Make sure that the catalog, the schema and the volume exist
_ = spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}") 
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}") 

MODEL_NAME = f"{catalog}.{schema}.{agent_name}"

# Register to Unity catalog
uc_registered_model_info = mlflow.register_model(model_uri=model_info.model_uri, name=MODEL_NAME)

## Deploy the agent

In [0]:
# Deploy to enable the review app and create an API endpoint
deployment_info = agents.deploy(
    MODEL_NAME, 
    uc_registered_model_info.version,
    environment_vars={
        "HOST": f"{WORKSPACE_URL}",
        "TOKEN": f"{{{{secrets/{secret_scope}/{secret_key}}}}}",
    },
)

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://learn.microsoft.com/azure/databricks/generative-ai/deploy-agent) for details