In [0]:
from operator import itemgetter
import mlflow
import os
from typing import Any, Callable, Dict, Generator, List, Optional

from databricks.vector_search.client import VectorSearchClient

from vector_search_utils.self_querying_retriever import load_self_querying_retriever
#from supervisor_utils.decomposer import load_decomposer

from databricks_langchain import ChatDatabricks
from databricks_langchain import DatabricksVectorSearch

from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    MessagesPlaceholder,
)

from langgraph_supervisor import create_supervisor
from langchain.chat_models import init_chat_model

from langchain_core.runnables import RunnablePassthrough, RunnableBranch
from langchain_core.messages import HumanMessage, AIMessage
from langchain.retrievers.multi_query import MultiQueryRetriever

from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field

from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

In [0]:
## Enable MLflow Tracing

mlflow.set_tracking_uri("databricks")
mlflow.langchain.autolog()

# Load the chain's configuration
model_config = mlflow.models.ModelConfig(development_config="./configs/agent.yaml")

databricks_config = model_config.get("databricks_config")

doc_agent_config = model_config.get("doc_agent_config")
genie_agent_config = model_config.get("genie_agent_config")
supervisor_config = model_config.get("supervisor_agent_config")

retriever_config = model_config.get("retriever_config")
vector_search_schema = retriever_config.get("schema")

doc_retrieval_model = ChatDatabricks(
    endpoint=doc_agent_config.get("llm_config").get("llm_endpoint_name"),
    extra_params=doc_agent_config.get("llm_config").get("llm_parameters"),
)

genie_agent_model = ChatDatabricks(
    endpoint=genie_agent_config.get("llm_config").get("llm_endpoint_name"),
    extra_params=genie_agent_config.get("llm_config").get("llm_parameters"),
)

supervisor_model = ChatDatabricks(
    endpoint=supervisor_config.get("llm_config").get("llm_endpoint_name"),
    extra_params=supervisor_config.get("llm_config").get("llm_parameters"),
)

In [0]:
from databricks.sdk import WorkspaceClient
from databricks_langchain.genie import GenieAgent

from langgraph.prebuilt import create_react_agent
from supervisor_utils.handoff_tools import create_handoff_tool

sq_retriever = load_self_querying_retriever(doc_retrieval_model, databricks_config, retriever_config)

research_agent = create_react_agent(
    model=doc_retrieval_model,
    tools=[sq_retriever.as_tool()],
    prompt=doc_agent_config.get("prompt"),
    name="research_agent",
)

# Genie agent
class GenieInput(BaseModel):
    """OpenAI-format chat history for Genie."""
    messages: List[Dict[str, Any]] = Field(
        ..., description="Chat messages in {'role','content'} format"
    )

genie_agent_tool = GenieAgent(
    genie_space_id=genie_agent_config.get("genie_space_id"),
    genie_agent_name="sec_metrics_analyst",
    description=genie_agent_config.get("genie_space_description"),
    client=WorkspaceClient(),
).as_tool(
    name="sec_metrics_analyst",
    description="Ask Genie questions about public-company filings and metrics",
    args_schema=GenieInput,
)

genie_agent = create_react_agent(
    model=genie_agent_model,
    tools=[genie_agent_tool],
    prompt=genie_agent_config.get("prompt"),
    name="sec_metrics_analyst",
)

# Handoffs
assign_to_research_agent = create_handoff_tool(
    agent_name="research_agent",
    description="Assign task to the doc researcher agent.",
)

# Handoffs
assign_to_analyst_agent = create_handoff_tool(
    agent_name="sec_metrics_analyst",
    description="Assign task to the sec metrics agent.",
)

supervisor_agent = create_react_agent(
    model=supervisor_model,
    tools=[assign_to_research_agent, assign_to_analyst_agent],
    prompt=supervisor_config.get("prompt"),
    name="supervisor"
)

In [0]:
from typing import Annotated
from langchain_core.tools import tool, InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.graph import StateGraph, START, MessagesState, END
from langgraph.types import Command

# Define the multi-agent supervisor graph
supervisor = (
    StateGraph(MessagesState)
    # NOTE: `destinations` is only needed for visualization and doesn't affect runtime behavior
    .add_node(supervisor_agent, destinations=("research_agent", "sec_metrics_analyst", END))
    .add_node(research_agent)
    .add_node(genie_agent)
    .add_edge(START, "supervisor")
    # always return back to the supervisor
    .add_edge("research_agent", "supervisor")
    .add_edge("sec_metrics_analyst", "supervisor")
    .compile()
)

In [0]:
from mlflow.entities import SpanType


class SECAgent(ChatAgent):
    def __init__(self, agent):
        self.agent = agent

    def get_last_valid_message(self, messages: list):
        for message in reversed(messages):
            if message.content != "":
                return message

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        
        if type(messages[0]) == mlflow.types.agent.ChatAgentMessage:
            request = {"messages": [m.model_dump() for m in messages]}

        agent_response = self.agent.invoke(request)

        last_valid_agent_message = self.get_last_valid_message(agent_response["messages"])

        response = [
            {
                "role": "assistant",
                "id": last_valid_agent_message.id,
                "content": last_valid_agent_message.content,
            }
        ]

        return ChatAgentResponse(messages=response)

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:

        request = {"messages": [m.model_dump() for m in messages]}

        for event in self.agent.stream(request):
            if len(event["messages"][-1].content) > 0:
                yield ChatAgentChunk(
                    delta=ChatAgentMessage(
                        content=event["messages"][-1].content,
                        role="assistant",
                        id=event["messages"][-1].id,
                    )
                )


## Tell MLflow logging where to find your chain.
mlflow.langchain.autolog()
AGENT = SECAgent(supervisor)

mlflow.models.set_model(AGENT)