In [0]:
# %pip install -U -qqqq databricks-vectorsearch mlflow
# %restart_python

In [0]:
from databricks.vector_search.client import VectorSearchClient
from datetime import timedelta
import time

In [0]:
import sys, os, yaml
sys.path.append(os.path.abspath('..'))
from configs.project import ProjectConfig

with open("../configs/project.yml", "r") as file:
    data = yaml.safe_load(file)

projectConfig = ProjectConfig(**data)

In [0]:
vs_config = projectConfig.vector_search_attributes["id_1"]

for k, v in vs_config.model_dump().items():
  print(k, v)

In [0]:
print("source_table_name", vs_config.source_table_name)
print("endpoint_name", vs_config.endpoint_name)
print("index_name", vs_config.index_name)


In [0]:
vsc = VectorSearchClient(
  # disable_notice=True,
  # workspace_url = "https://e2-demo-field-eng.cloud.databricks.com/",
  # service_principal_client_id = dbutils.secrets.get("felix-flory", "SERVICE_PRINCIPAL_ID"),
  # service_principal_client_secret = dbutils.secrets.get("felix-flory", "SERVICE_PRINCIPAL_SECRET"),
  )
index = vsc.get_index(vs_config.endpoint_name, vs_config.index_name)

similarity_search: https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search

Filters: https://docs.databricks.com/aws/en/generative-ai/vector-search

In [0]:
columns = spark.table(vs_config.source_table_name).columns
filters = {
    "company": "american express",  # "apple"
    "doc_section": "Item 6: Selected Financial Data",
}

In [0]:
results = index.similarity_search(
  query_text="What was the last earnings results from American Express",
  columns=columns,
  num_results=10,
  filters = filters,
  query_type = "ANN", # "ANN" or HYBRID"
  score_threshold = 0.50
)

if results["result"]["row_count"] >0:
  display(results["result"]["data_array"])
else:
  print("No records")


# Integrate into Langchain

api docs: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.VectorSearchRetrieverTool


In [0]:
from databricks_langchain import VectorSearchRetrieverTool, ChatDatabricks

# Initialize the retriever tool.
vs_tool = VectorSearchRetrieverTool(
  index_name=vs_config.index_name,
  tool_name="docs_retriever",
  tool_description="Retrieves information about SEC filings",
  # client_args = {
  #   "disable_notice": True,
  #   "workspace_url": "https://e2-demo-field-eng.cloud.databricks.com/",
  #   "service_principal_client_id": dbutils.secrets.get("felix-flory", "SERVICE_PRINCIPAL_ID"),
  #   "service_principal_client_secret": dbutils.secrets.get("felix-flory", "SERVICE_PRINCIPAL_SECRET") 
  # }
)

# Run a query against the vector search index locally for testing
vs_tool.invoke("What was the last earnings results from American Express", disable_notice=True)

In [0]:
# Bind the retriever tool to your Langchain LLM of choice
llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct", temperature = 0.0)
llm_with_tools = llm.bind_tools([vs_tool])

# Chat with your LLM to test the tool calling functionality
llm_with_tools.invoke("What was the last earnings results from American Express")