In [0]:
from operator import itemgetter
import mlflow
import os
from typing import Any, Callable, Dict, Generator, List, Optional

from databricks.vector_search.client import VectorSearchClient

from vector_search_utils.self_querying_retriever import load_self_querying_retriever
#from supervisor_utils.decomposer import load_decomposer

from databricks_langchain import ChatDatabricks
from databricks_langchain import DatabricksVectorSearch

from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    MessagesPlaceholder,
)

from langgraph_supervisor import create_supervisor
from langchain.chat_models import init_chat_model

from langchain_core.runnables import RunnablePassthrough, RunnableBranch
from langchain_core.messages import HumanMessage, AIMessage
from langchain.retrievers.multi_query import MultiQueryRetriever

from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field

from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

In [0]:
## Enable MLflow Tracing
mlflow.langchain.autolog()

# Load the chain's configuration
model_config = mlflow.models.ModelConfig(development_config="./configs/rag_agent.yaml")

databricks_config = model_config.get("databricks_config")
llm_config = model_config.get("llm_config")

retriever_config = model_config.get("retriever_config")
vector_search_schema = retriever_config.get("schema")

# FM for generation
model = ChatDatabricks(
    endpoint=llm_config.get("llm_endpoint_name"),
    extra_params=llm_config.get("llm_parameters"),
)

In [0]:
from langgraph.prebuilt import create_react_agent

sq_retriever = load_self_querying_retriever(model, databricks_config, retriever_config)

research_agent = create_react_agent(
    model=model,
    tools=[sq_retriever.as_tool()],
    prompt=(
        "You are a research agent.\n\n"
        "INSTRUCTIONS:\n"
        "- After you're done with your tasks, respond to the supervisor directly\n"
        "- Respond ONLY with the results of your work, do NOT include ANY other text."
    ),
    name="research_agent",
)

supervisor = create_supervisor(
    model=model,
    agents=[research_agent],
    prompt=(
        "You are a supervisor managing two agents:\n"
        "- An SEC research agent: This agent has access to SEC filings for specific companies. Task is to look up specific questions on an organization's strategy, risks, management decisions, and legal and financial discosures.\n"
        "- Assign work to one agent at a time, do not call agents in parallel.\n"
        "- Once an agent has responded to you, check if there is more work to do. Call additional agents as needed to answer the question"
    ),
    add_handoff_back_messages=True,
    output_mode="full_history",
).compile()

In [0]:
class SECAgent(ChatAgent):
    def __init__(self, agent):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}
        agent_response = self.agent.invoke(request)

        response = [{"role": "assistant", "id": agent_response["messages"][-1].id, "content": agent_response["messages"][-1].content}]

        return ChatAgentResponse(messages=response)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:

        request = {"messages": self._convert_messages_to_dict(messages)}

        for event in self.agent.stream(request):
            if len(event["messages"][-1].content)>0:
                yield ChatAgentChunk(
                            delta=ChatAgentMessage(
                                content=event["messages"][-1].content,
                                role="assistant",
                                id=event["messages"][-1].id
                            )
                    )

## Tell MLflow logging where to find your chain.
mlflow.langchain.autolog()
AGENT = SECAgent(supervisor)

mlflow.models.set_model(AGENT)