In [0]:
%pip install -q -r ./../requirements.txt
dbutils.library.restartPython()

In [0]:
from operator import itemgetter
import mlflow
import os
from typing import Any, Callable, Dict, Generator, List, Optional

from databricks.vector_search.client import VectorSearchClient

from vector_search_utils.self_querying_retriever import load_self_querying_retriever
#from supervisor_utils.decomposer import load_decomposer

from databricks_langchain import ChatDatabricks
from databricks_langchain import DatabricksVectorSearch

from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    MessagesPlaceholder,
)

from langgraph_supervisor import create_supervisor
from langchain.chat_models import init_chat_model

from langchain_core.runnables import RunnablePassthrough, RunnableBranch
from langchain_core.messages import HumanMessage, AIMessage
from langchain.retrievers.multi_query import MultiQueryRetriever

from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field

from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

In [0]:
## Enable MLflow Tracing

mlflow.set_tracking_uri("databricks")
mlflow.langchain.autolog()

# Load the chain's configuration
model_config = mlflow.models.ModelConfig(development_config="./../configs/agent.yaml")
databricks_config = model_config.get("databricks_config")
doc_agent_config = model_config.get("doc_agent_config")

catalog = databricks_config["catalog"]
schema = databricks_config["schema"]

retriever_config = model_config.get("retriever_config")
vector_search_schema = retriever_config.get("schema")

doc_retrieval_model = ChatDatabricks(
    endpoint=doc_agent_config.get("llm_config").get("llm_endpoint_name"),
    extra_params=doc_agent_config.get("llm_config").get("llm_parameters"),
    model_kwargs={"parallel_tool_calls": False}
)

In [0]:
print(retriever_config["vector_search_endpoint"])
print(retriever_config["vector_search_index"])

In [0]:
client = VectorSearchClient()
index = client.get_index(
  endpoint_name=retriever_config["vector_search_endpoint"],
  index_name=f"{catalog}.{schema}.sec_doc_chunks_index_v1"
)

results = index.similarity_search(
    query_text="American Express Revenue 2022", #query_embeddings
    columns=["chunk_id", "document_type", "path", "resolved_company", "doc_content", "year"],
    num_results=10,
    query_type="HYBRID",
    filters={"resolved_company": ["AMERICANEXPRESS"], "year": "2022"},
    score_threshold=""
    )

for r in results['result']['data_array']:
  print(r)


In [0]:
from databricks_langchain import VectorSearchRetrieverTool, ChatDatabricks

# Initialize the retriever tool.
vs_tool = VectorSearchRetrieverTool(
  index_name=f"{catalog}.{schema}.sec_doc_chunks_index_v1",
  tool_name="databricks_docs_retriever",
  tool_description="Retrieves information about SEC filings for different companies",
  query_type="ANN", # can also by HYBRID
  num_results="10"
)

# Run a query against the vector search index locally for testing
vs_tool.invoke("What was American Express's revenue in 2022?")

In [0]:
example_input = {
        "messages": [
            {
                "role": "user",
                "content": "What was American Express's revenue in 2022?",
            }
        ]
    }

In [0]:
from langgraph.prebuilt import create_react_agent

# Bind the retriever tool to your Langchain LLM of choice
llm = ChatDatabricks(endpoint="databricks-claude-3-7-sonnet")

retriever_agent = create_react_agent(
    model=llm,
    tools=[vs_tool],
    prompt="You are an SEC Docs analyst. Use the retrieved SEC documents to answer the following question",
    name="sec_metrics_analyst",
)

response = retriever_agent.invoke(example_input)