In [0]:
!pip install uv --quiet
!uv sync --active --quiet
dbutils.library.restartPython()

In [0]:
# %run ../tools/metric_calculator.py

In [0]:
%run ../tools/database_searcher.py

In [0]:
%run ../tools/web_news_searcher.py

In [0]:
%run ../agent_config/callback_handler.py

In [0]:
import os
import mlflow
import toml
import pyspark.sql.functions as F
from databricks_langchain import ChatDatabricks
import datetime
from typing import Any, Generator, Optional, Sequence, Union
from databricks_langchain import ChatDatabricks
from langchain_core.tools import BaseTool, tool
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.language_models import LanguageModelLike
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import ChatAgentMessage, ChatAgentResponse, ChatAgentChunk, ChatContext
from mlflow.types.llm import ChatCompletionResponse, ChatChoice, ChatMessage, ChatCompletionChunk, ChatChunkChoice, ChatChoiceDelta
from langchain_core.tools import StructuredTool

In [0]:
env_vars = toml.load("../../conf/env_vars.toml")

In [0]:
callbacks = [LoggingHandler()]

LLM_ENDPOINT_NAME = env_vars["LLM_ENDPOINT_NAME"]
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, callbacks=callbacks)

# system_prompt = "Você é um analista de dados em saúde que consulta dados do SUS (Sistema Único de Saúde do Brasil) sobre síndrome respiratória grave e retorna um relatório diário sobre a situação da doença trazendo métricas relevantes e as respectivas explicações que ajudem a explicar o cenário atual."

In [0]:
system_prompt = """
"Você é um analista de dados em saúde que consulta dados do SUS (Sistema Único de Saúde do Brasil) sobre síndrome respiratória grave e retorna um relatório diário sobre a situação da doença trazendo métricas relevantes e as respectivas explicações que ajudem a explicar o cenário atual."

Para resolver a tarefa você tem acesso às ferramentas:
Você pode calcular algumas das métricas através das tools:
get_month_mortality_rate()

Caso não houver uma ferramenta específica para calcular uma métrica, você usar a ferramenta database_searcher para buscar informações no banco de dados de SRAG. Você pode fazer a pergunta diretamente ao agente database_searcher, conferir as suas repostas e logs, e se necessário fazer queries de READ em SQL na tabela srag_features.

Utilize a tool search_query para fazer uma busca na internet sobre o cenário atual da Sindrome Respiratória Aguda Grave a fim de contextualizar o resultado das métricas encontradas (se houverem).

Em seguida sumarize os resultados e forneça a resposta final da análise efetuada."
"""

In [0]:
# metric_calculator = SRAGMetrics()
# metric_calculator_tool = StructuredTool.from_function(func=metric_calculator.run)

In [0]:
# @tool
# def get_month_mortality_rate(month: Optional[str | int] = None, year: Optional[str | int] = None) -> str:
#     """
#     Call to returns the mortality rate(%) for the given month of a year or for the current month.
#     """
#     today = datetime.date.today()
#     if month is None:
#         month = today.month
#     if year is None:
#         year = today.year
#     else:
#         month = int(month)
#         year = int(year)
#     feature_store_table_name = f'{env_vars["CATALOG"]}.{env_vars["SCHEMA"]}.srag_features'
#     feature_store_table = spark.read.table(feature_store_table_name)
#     mortality_rate_percent = feature_store_table.filter(
#         (F.year(F.col("DT_EVOLUCA")).cast("int") == year) &
#         (F.month(F.col("DT_EVOLUCA")).cast("int") == month)
#         ).select(
#         F.last("mortalidade_mensal_perc")
#     ).collect()[0][0]
#     return f"The mortality rate (%) for {month} of {year} is {mortality_rate_percent}."

In [0]:
# import getpass
# import os
# from dotenv import load_dotenv
# from langchain_core.tools import BaseTool, tool
# from langchain_tavily import TavilySearch
# from tavily import TavilyClient

# @tool
# def search_news(search_query:str):
#     """Use this tool to search for news articles related to the current topic. The input to this tool should be a search query and the output will be a list of news articles."""
#     load_dotenv()
#     tavily_api_key = os.environ.get("TAVILY_API_KEY")
#     tavily_client = TavilyClient(api_key=tavily_api_key)
    
#     resp = tavily_client.search(
#         query=search_query,
#         max_results=3,
#         search_depth="basic",
#         time_range="month",
#         country="brazil",
#         )
#     context = []
#     context.append({
#         "news": [
#             { "title": result["title"], "title": result["content"], "url": result["url"]} for result in resp["results"]
#         ]
#     })
#     return context

In [0]:
database_searcher = SparkSQLQueryTool()
web_search_query = TavilyTool()


In [0]:
tools = []
# tools.append(metric_calculator_tool)
tools.append(database_searcher)
tools.append(web_search_query)

# Define the agent logic.

In [0]:
from langgraph.graph.message import add_messages
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    convert_to_openai_messages,
)
from pydantic import BaseModel, create_model
from typing import Annotated, TypedDict

In [0]:
#  The state for the agent workflow, including the conversation and any custom data
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    custom_inputs: Optional[dict[str, Any]]
    custom_outputs: Optional[dict[str, Any]]

# Define the LangGraph agent that can call tools
def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    system_prompt: Optional[str] = None,
) -> CompiledStateGraph:
    # Bind tools to the model
    model = model.bind_tools(tools)  

    # Function to check if agent should continue or finish based on last message
    def routing_logic(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If function (tool) calls are present, continue; otherwise, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    # Preprocess: optionally prepend a system prompt to the conversation history
    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])

    model_runnable = preprocessor | model  # Chain the preprocessor and the model

    # The function to invoke the model within the workflow
    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)
        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)  # Create the agent
    workflow.add_node("agent", RunnableLambda(call_model))  # Agent node (LLM)
    workflow.add_node("tools", ChatAgentToolNode(tools))            # Tools node

    workflow.set_entry_point("agent")  # Start at agent node
    workflow.add_conditional_edges(
        "agent",
        routing_logic,
        {
            "continue": "tools",  # If the model requests a tool call, move to tools node
            "end": END,           # Otherwise, end the workflow
        },
    )
    workflow.add_edge("tools", "agent")  # After tools are called, return to agent node

    # Compile and return the tool-calling agent workflow
    return workflow.compile()

In [0]:
class LangGraphAgent(ChatAgent):
    def __init__(
        self,
        agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: Sequence[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,

    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])

                )
        return ChatAgentResponse(messages=messages)

In [0]:

mlflow.langchain.autolog()
# Create the agent graph with an LLM, tool set, and system prompt (if given)
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphAgent(agent)
mlflow.models.set_model(AGENT)

In [0]:
from IPython.display import Image, display

try:
    display(Image(agent.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [0]:
AGENT.predict({"messages": [{"role": "user", "content": "Qual a taxa de mortalidade por SRAG esse mês no Brazil?"}]})

In [0]:
AGENT.predict({"messages": [{"role": "user", "content": "O que é SRAG?"}]})