# Information Retriever: Building Blocks

## Env set up

In [0]:
!python3 --version

In [0]:
%pip install -U -qqq langgraph langgraph-prebuilt langgraph-sdk langgraph-checkpoint-sqlite langsmith langchain-community langchain-core langchain-openai notebook langchain-tavily wikipedia trustcall langgraph-cli[inmem] transformers

%pip install -U -qqq databricks-agents mlflow-skinny[databricks] databricks-langchain

%pip install -U -qqq databricks-sdk
dbutils.library.restartPython()

In [0]:
## Variables containing the pointers to catalog, schema, lakebase

catalog_ = 'users'
schema_ = 'gabriele_albini'
lakebase_instance_ = 'shared-instance-size-4'

## Example data

In [0]:
## Creating a dataframe with examples of clients data
columns = ["customer_id", "name", "surname", "email", "address"]
examples_ = [
    (1, "Lena", "Schmidt", "lena.schmidt@example.com",
     "Alexanderplatz 5, 10178 Berlin, Germany"),
    (2, "Jonas", "Müller", "jonas.mueller@example.com",
     "221B Baker Street, NW1 6XE London, United Kingdom"),
    (3, "Sofia", "Rossi", "sofia.rossi@example.com",
     "Via Milano 11, 20126 Milano, Italy"),
    (4, "Marek", "Nowak", "marek.nowak@example.com",
     "ul. Marszałkowska 10, 00-590 Warsaw, Poland"),
    (5, "Hannah", "Dubois", "hannah.dubois@example.com",
     "12 Rue de la Paix, 75002 Paris, France"),
    (6, "David", "Nguyen", "david.nguyen@example.com",
     "Damrak 45, 1012 LL Amsterdam, Netherlands"),
    (7, "Amira", "Hassan", "amira.hassan@example.com",
     "Gran Vía 28, 28013 Madrid, Spain"),
    (8, "Felix", "Kovac", "felix.kovac@example.com",
     "Ringstrasse 3, 1010 Vienna, Austria"),
    (9, "Carla", "Lindberg", "carla.lindberg@example.com",
     "Sveavägen 15, 111 57 Stockholm, Sweden"),
    (10, "Tobias", "Nielsen", "tobias.nielsen@example.com",
     "Rådhuspladsen 1, 1550 Copenhagen, Denmark")
]

df_customer_example = spark.createDataFrame(examples_, columns)
display(df_customer_example)

In [0]:
## Creating a dataframe with examples of order data
order_columns = ["Order_ID", "date", "cart", "amount", "customer_id"]
examples_ = [
    ("ID-45892171", "2025-11-02",
     ["Office Chair", "Standing Desk", "Desk Lamp"],
     1299.90, 1),

    ("ID-45892172", "2025-11-03",
     ["Mechanical Keyboard", "Wireless Mouse", "Mouse Pad"],
     189.50, 2),

    ("ID-45892173", "2025-11-05",
     ["A4 Paper (500 sheets)", "Stapler", "Staples Pack", "Highlighters Set"],
     64.30, 3),

    ("ID-45892174", "2025-11-06",
     ["Whiteboard", "Whiteboard Markers Set", "Eraser"],
     215.00, 4),

    ("ID-45892175", "2025-11-08",
     ["Desk Organizer", "Notebooks Pack", "Ballpoint Pens (20x)"],
     72.40, 5),

    ("ID-45892176", "2025-11-09",
     ["Monitor 27\"", "HDMI Cable", "Monitor Arm"],
     459.99, 6),

    ("ID-45892177", "2025-11-10",
     ["Office Chair", "Footrest"],
     389.00, 7),

    ("ID-45892178", "2025-11-11",
     ["Laser Printer", "Printer Paper (1000 sheets)"],
     329.50, 8),

    ("ID-45892179", "2025-11-12",
     ["Inkjet Printer Cartridges Set", "Label Printer"],
     248.75, 9),

    ("ID-45892180", "2025-11-13",
     ["Conference Speakerphone", "Webcam HD"],
     312.20, 10),

    ("ID-45892181", "2025-11-14",
     ["Filing Cabinet", "Hanging Folders (25x)"],
     410.00, 1),

    ("ID-45892182", "2025-11-15",
     ["Desk Lamp", "LED Bulbs (4x)", "Cable Management Kit"],
     98.60, 2),

    ("ID-45892183", "2025-11-16",
     ["Mouse Pad", "Ergonomic Wrist Rest"],
     42.30, 3),

    ("ID-45892184", "2025-11-18",
     ["A4 Paper (500 sheets)", "Sticky Notes Pack", "Markers Set"],
     53.10, 4),

    ("ID-45892185", "2025-11-19",
     ["Flipchart", "Flipchart Paper (3x)", "Markers Set"],
     187.40, 5),

    ("ID-45892186", "2025-11-20",
     ["Network Switch 8-port", "Ethernet Cables (5x)"],
     224.80, 6),

    ("ID-45892187", "2025-11-21",
     ["External SSD 1TB", "USB Hub"],
     189.99, 7),

    ("ID-45892188", "2025-11-22",
     ["Office Headset", "Laptop Stand"],
     154.60, 8),

    ("ID-45892189", "2025-11-23",
     ["Desk Organizer", "Pen Holder", "Document Tray"],
     61.25, 9),

    ("ID-45892190", "2025-11-24",
     ["Monitor 24\"", "DisplayPort Cable"],
     269.90, 10),

    ("ID-45892191", "2025-11-25",
     ["Whiteboard Markers Set", "Cleaning Spray", "Microfiber Cloths"],
     39.80, 1),

    ("ID-45892192", "2025-11-26",
     ["Notebooks Pack", "Ballpoint Pens (50x)", "Highlighters Set"],
     88.15, 2),

    ("ID-45892193", "2025-11-27",
     ["Office Chair", "Seat Cushion"],
     345.00, 3),

    ("ID-45892194", "2025-11-28",
     ["Paper Shredder", "Trash Bags (Office)"],
     219.50, 4),

    ("ID-45892195", "2025-11-29",
     ["Laminator", "Laminating Pouches (100x)"],
     132.40, 5)
]

df_order_example = spark.createDataFrame(examples_, order_columns)
display(df_order_example)

In [0]:
## Create some retrieval requests
columns = ["id", "sender_email", "email_body"]
examples_ = [
  (1, "sofia.rossi@example.com", "Dear sir or madam, I would like to get my residence address updated to: v. bella 1, Milano (MI), Italia"),

  (2, "lena.schmidt@example.com", "Hello, I'd like to get an updated delivery date for shipment ID-45892171, given that I didn't receive it yet. Thank you"),

  (3, "roberta.verdi@example.com", "Understood, I will wait a couple of days thanks"),

  (4, "roberta.verdi@example.com", "Can you provide an update on the refund related to order ID-99999999?")
]
df_example = spark.createDataFrame(examples_, columns)
display(df_example)

## Persist Data
Execute this section only when the data needs to be recreated.

Sync delta tables to Lakebase ([doc](https://docs.databricks.com/aws/en/oltp/instances/sync-data/sync-table?language=Python+SDK))

In [0]:
regenerate_data = False

In [0]:
## Generate delta tables
if regenerate_data:
  spark.sql("CREATE CATALOG IF NOT EXISTS "+catalog_)
  spark.sql("USE CATALOG "+catalog_)
  spark.sql("CREATE SCHEMA IF NOT EXISTS "+schema_)
  spark.sql("USE SCHEMA "+schema_)
  spark.sql("DROP TABLE IF EXISTS classificator_agent_customers")
  spark.sql("DROP TABLE IF EXISTS classificator_agent_orders")
  df_customer_example.write.mode("overwrite").saveAsTable("classificator_agent_customers")
  spark.sql(
  """
    ALTER TABLE classificator_agent_customers ALTER COLUMN customer_id SET NOT NULL
  """)
  spark.sql(
  """
    ALTER TABLE classificator_agent_customers ADD CONSTRAINT customers_pk PRIMARY KEY (customer_id)
  """)
  df_order_example.write.mode("overwrite").saveAsTable("classificator_agent_orders")
  spark.sql(
  """
    ALTER TABLE classificator_agent_orders ALTER COLUMN Order_ID SET NOT NULL
  """)
  spark.sql(
  """
    ALTER TABLE classificator_agent_orders ADD CONSTRAINT orders_pk PRIMARY KEY (Order_ID)
  """)

else:
  spark.sql("USE CATALOG "+catalog_)
  spark.sql("USE SCHEMA "+schema_)

In [0]:
## Sync Delta tables with Lakebase Instance: Customers

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.database import DatabaseInstance
from databricks.sdk.service.database import SyncedDatabaseTable, SyncedTableSpec, NewPipelineSpec, SyncedTableSchedulingPolicy
w = WorkspaceClient()

if regenerate_data:    

    # Create a synced table in a standard UC catalog
    synced_table = w.database.create_synced_database_table(
        SyncedDatabaseTable(
            name=catalog_+"."+schema_+".classificator_agent_customers_synced",  # Full three-part name
            database_instance_name=lakebase_instance_,  # Required for standard catalogs
            logical_database_name="databricks_postgres",  # Required for standard catalogs
            spec=SyncedTableSpec(
                source_table_full_name=catalog_+"."+schema_+".classificator_agent_customers",
                primary_key_columns=["customer_id"],
                scheduling_policy=SyncedTableSchedulingPolicy.SNAPSHOT,
                create_database_objects_if_missing=True,  # Create database/schema if needed
                new_pipeline_spec=NewPipelineSpec(
                    storage_catalog=catalog_,
                    storage_schema=schema_
                )
            ),
        )
    )
    print(f"Created synced table: {synced_table.name}")

    synced_table_name = catalog_+"."+schema_+".classificator_agent_customers_synced"
    status = w.database.get_synced_database_table(name=synced_table_name)
    print(f"Synced table status: {status.data_synchronization_status.detailed_state}")
    print(f"Status message: {status.data_synchronization_status.message}")

# Check the status of a synced table
synced_table_name = catalog_+"."+schema_+".classificator_agent_customers_synced"
status = w.database.get_synced_database_table(name=synced_table_name)
print(status)

In [0]:
## Sync Delta tables with Lakebase Instance: Orders

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.database import DatabaseInstance
from databricks.sdk.service.database import SyncedDatabaseTable, SyncedTableSpec, NewPipelineSpec, SyncedTableSchedulingPolicy
w = WorkspaceClient()

if regenerate_data:    

    # Create a synced table in a standard UC catalog
    synced_table = w.database.create_synced_database_table(
        SyncedDatabaseTable(
            name=catalog_+"."+schema_+".classificator_agent_orders_synced",  # Full three-part name
            database_instance_name=lakebase_instance_,  # Required for standard catalogs
            logical_database_name="databricks_postgres",  # Required for standard catalogs
            spec=SyncedTableSpec(
                source_table_full_name=catalog_+"."+schema_+".classificator_agent_orders",
                primary_key_columns=["Order_ID"],
                scheduling_policy=SyncedTableSchedulingPolicy.SNAPSHOT,
                create_database_objects_if_missing=True,  # Create database/schema if needed
                new_pipeline_spec=NewPipelineSpec(
                    storage_catalog=catalog_,
                    storage_schema=schema_
                )
            ),
        )
    )
    print(f"Created synced table: {synced_table.name}")

    synced_table_name = catalog_+"."+schema_+".classificator_agent_orders_synced"
    status = w.database.get_synced_database_table(name=synced_table_name)
    print(f"Synced table status: {status.data_synchronization_status.detailed_state}")
    print(f"Status message: {status.data_synchronization_status.message}")

# Check the status of a synced table
synced_table_name = catalog_+"."+schema_+".classificator_agent_orders_synced"
status = w.database.get_synced_database_table(name=synced_table_name)
print(status)

## Create the Retriever functions

Create UC functions that can be used as retriever tools by the agent ([doc](https://docs.databricks.com/aws/en/generative-ai/agent-framework/structured-retrieval-tools), Databricks Demos [example in Notebook 5.1](https://notebooks.databricks.com/demos/lakehouse-iot-platform/index.html#))

In [0]:
spark.sql("USE CATALOG "+catalog_)
spark.sql("USE SCHEMA "+schema_)

In [0]:
%sql
DROP FUNCTION IF EXISTS classificator_agent_customer_retriever;
DROP FUNCTION IF EXISTS classificator_agent_order_retriever;

CREATE OR REPLACE FUNCTION classificator_agent_customer_retriever(customer_email STRING)
RETURNS STRUCT<customer_id BIGINT, name STRING, surname STRING, email STRING, address STRING>
LANGUAGE SQL
COMMENT 'Returns customer details based on the customer email address'
RETURN (
  SELECT struct(customer_id, name, surname, email, address)
  FROM classificator_agent_customers_synced
  WHERE email = classificator_agent_customer_retriever.customer_email
  LIMIT 1
);

CREATE OR REPLACE FUNCTION classificator_agent_order_retriever(customer_order_id STRING)
RETURNS STRUCT<Order_ID STRING, date STRING, cart STRING, amount DOUBLE, customer_id BIGINT>
LANGUAGE SQL
COMMENT 'Returns order details based on the order ID'
RETURN (
  SELECT struct(Order_ID, date, cart, amount, customer_id)
  FROM classificator_agent_orders_synced
  WHERE Order_ID = classificator_agent_order_retriever.customer_order_id
  LIMIT 1
);

In [0]:
%sql
SELECT classificator_agent_customer_retriever('jonas.mueller@example.com') AS customer_details;

In [0]:
%sql
SELECT classificator_agent_order_retriever('ID-45892178') AS Order_Details;

## Create a ReAct tool calling retriever agent

In [0]:
## LLM Set up: Variable containing the LLM Endpoint to use
llm_endpoint = 'databricks-gpt-5-1'
llm_endpoint = 'databricks-claude-sonnet-4-5'

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph State
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages, AnyMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

## Set up the agent state to concatenate messages in memory
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    customer_context: str
    order_context: str

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Base LLM
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from databricks_langchain import ChatDatabricks
model = ChatDatabricks(endpoint = llm_endpoint, temperature=0) 

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# UC functions as Tools
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from databricks_langchain import UCFunctionToolkit

toolkit = UCFunctionToolkit(function_names=[
      f"{catalog_}.{schema_}.classificator_agent_customer_retriever",
      f"{catalog_}.{schema_}.classificator_agent_order_retriever",
])
tools = toolkit.tools

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Create Tool Binding
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
model_with_tools = model.bind_tools(tools)

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Reasoning Node (Assistant)
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langgraph.graph import MessagesState
from langchain_core.messages import HumanMessage, SystemMessage

# System message
sys_msg = SystemMessage(content="""
    You are a customer support assistant, tasked with providing more context to incoming customer communications, such as emmails.\n
    Use the available tools to perform information retrieval.\n
    If you don't have the required information to use a tool, DO NOT use it and assume that content cannot be retrieved.\n
    If a tool you've invoked returns null or an error, also assume that content cannot be retrieved.\n
    Once you've retrieved all the available information, provide a summary of the information and STOP.""")

# Reasoning Node
def assistant(state: AgentState):
  result = model_with_tools.invoke([sys_msg] + state["messages"])
  return {"messages": [result]}

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# ReAct Graph
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition # Check if the model's output is a tool call
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display

# Graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "assistant")
react_retriever = builder.compile()  # No memory needed

# Display the Graph
display(Image(react_retriever.get_graph().draw_mermaid_png()))

In [0]:
## Testing
from pyspark.sql.functions import *
from langchain_core.messages import HumanMessage
import mlflow

# Extract test case from the dataframe with examples
id_ = 4
config_ = {"configurable": {"thread_id": str(id_)}}
example_row = df_example.filter(col("ID") == id_).first()
message_ = f"We received from: {example_row.sender_email} the following email: {example_row.email_body}"

# Invoke the Agent
mlflow.langchain.autolog()
with mlflow.start_run(run_name="React Agent Retriever Test"):
  request = [
    HumanMessage(content = message_)
  ]
  messages = react_retriever.invoke({"messages": request}, config_)

  for m in messages['messages']:
      m.pretty_print()

## v2: Create a ReAct tool calling retriever agent with Output schema

Adding a respond node to the above schema, to format the reply in a desired Output.

In [0]:
## LLM Set up: Variable containing the LLM Endpoint to use
llm_endpoint = 'databricks-gpt-5-1'
llm_endpoint = 'databricks-claude-sonnet-4-5'

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Output Schema
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from pydantic import BaseModel, Field
from typing import Literal

class RetrievedInfoOutput(BaseModel):
  Customer_Email: str = Field(description = "Email of the customer who sent the message.")
  Customer_Id: str = Field(description = "Id of the customer, related to the email. If not found, return NULL.")
  Order_Id: str = Field(description = "Id of the order found in the system, if mentioned in the customer's email. If the customer didn't mention any order or the Id was not found, return NULL.")
  Customer_Details: str = Field(description = "Retrieved customer information. If not found, return NULL.")
  Order_Details: str = Field(description = "Retrieved order information. If not found, return NULL.")

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph State
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages, AnyMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

## Set up the agent state to concatenate messages in memory
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    final_response: RetrievedInfoOutput | None

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Base LLM
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from databricks_langchain import ChatDatabricks
model = ChatDatabricks(endpoint = llm_endpoint, temperature=0) 

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# UC functions as Tools
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from databricks_langchain import UCFunctionToolkit

toolkit = UCFunctionToolkit(function_names=[
      f"{catalog_}.{schema_}.classificator_agent_customer_retriever",
      f"{catalog_}.{schema_}.classificator_agent_order_retriever",
])
tools = toolkit.tools

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Create Tool Binding
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
model_with_tools = model.bind_tools(tools)

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Create Respond Model (LLM + Output schema)
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
model_with_output_schema = model.with_structured_output(RetrievedInfoOutput)

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Reasoning Node (Assistant)
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langgraph.graph import MessagesState
from langchain_core.messages import HumanMessage, SystemMessage

# System message
sys_msg = SystemMessage(content="""
    You are a customer support assistant, tasked with providing more context to incoming customer communications, such as emmails.\n
    Use the available tools to perform information retrieval.\n
    If you don't have the required information to use a tool, DO NOT use it and assume that content cannot be retrieved.\n
    If a tool you've invoked returns null or an error, also assume that content cannot be retrieved.\n
    Once you've retrieved all the available information, proceed.""")

# Reasoning Node
def assistant(state: AgentState) -> AgentState:
  result = model_with_tools.invoke([sys_msg] + state["messages"])
  return {"messages": state["messages"] + [result]}

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Respond Node: enforce schema on final answer
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

def respond_node(state: AgentState) -> AgentState:
    # Using the last tool message as input to structured output model
    last_msg = state["messages"][-1]
    result = model_with_output_schema.invoke(
        [HumanMessage(content=last_msg.content)]
    )
    return {"final_response": result}

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# ReAct Graph
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langgraph.graph import START, StateGraph, END
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display

# Graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_node("respond", respond_node)

# Define the conditional edge: if there is a tool call, route to tools, else to the respond node

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
    {"tools": "tools", "__end__": "respond"} # customize mapping which routes to end by default
)
builder.add_edge("tools", "assistant") # ReAct: return tools output to assistant!
builder.add_edge("respond", END)
react_retriever_with_output = builder.compile() # No memory needed

# Display the Graph
display(Image(react_retriever_with_output.get_graph().draw_mermaid_png()))

In [0]:
## Testing
from pyspark.sql.functions import *
from langchain_core.messages import HumanMessage
import mlflow

# Extract test case from the dataframe with examples
id_ = 2
config_ = {"configurable": {"thread_id": str(id_)}}
example_row = df_example.filter(col("ID") == id_).first()
message_ = f"We received from: {example_row.sender_email} the following email: {example_row.email_body}"

# Invoke the Agent
mlflow.langchain.autolog()
with mlflow.start_run(run_name="React Agent Retriever with Output schema Test"):
  request = [
    HumanMessage(content = message_)
  ]
  messages = react_retriever_with_output.invoke({"messages": request}, config_)

  for m in messages['messages']:
      m.pretty_print()