In [0]:
%pip install -U -qqqq mlflow langchain langgraph==0.3.4 databricks-langchain databricks-agents unitycatalog-langchain[databricks] uv langgraph-checkpoint-postgres==2.0.21 psycopg[binary,pool]
dbutils.library.restartPython()

In [0]:
%load_ext autoreload
%autoreload 2
import warnings

warnings.simplefilter(action="ignore")

# Mosaic AI Agent Framework: Author and deploy a Stateful Agent using Databricks Lakebase and LangGraph
This notebook demonstrates how to build a stateful agent using the Mosaic AI Agent Framework and LangGraph, with Lakebase as the agent’s durable memory and checkpoint store. In this notebook, you will:
1. Author a Stateful Agent graph with LakeBase (the new Postgres database in Databricks) and Langgraph to manage state using thread ids in a Databricks Agent 
2. Wrap the LangGraph agent with MLflow ChatAgent to ensure compatibility with Databricks features
3. Test the agent's behavior locally
4. Register model to Unity Catalog, log and deploy the agent for use in apps and Playground

We use [PostgresSaver in Langgraph](https://api.python.langchain.com/en/latest/checkpoint/langchain_postgres.checkpoint.PostgresSaver.html) to open a connection with our Lakebase, pass it into the checkpoint and pass that into the LangGraph Agent

## Why use Lakebase?
Stateful agents need a place to persist, resume, and inspect their work. Lakebase provides a managed, UC-governed store for agent state:
- Durable, resumable state. Automatically capture threads, intermediate checkpoints, tool outputs, and node state after each graph step—so you can resume, branch, or replay any point in time.
- Queryable & observable. Because state lands in the Lakehouse, you can use SQL (or notebooks) to audit conversations and build upon other Databricks functionality like dashboards
- Governed by Unity Catalog. Apply data permissions, lineage, and auditing to AI state, just like any other table.

## What are Stateful Agents?
Unlike stateless LLM calls, a stateful agent keeps and reuses context across steps and sessions. Each new conversation is tracked with a thread ID, which represents the logical task or dialogue stream. This way, you can pick up an existing thread and continue the conversation with your Agent.

## Prerequisites
- Create a Lakebase instance, see Databricks documentation ([AWS](https://docs.databricks.com/aws/en/oltp/create/) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/oltp/create/)). 
- You can create a Lakebase instance by going to SQL Warehouses -> Lakebase Postgres -> Create database instance. You will need to retrieve values from the "Connection details" section of your Lakebase to fill out this notebook.

In [0]:
dbutils.widgets.text(name="catalog", defaultValue="catalog.schema", label="catalog")
dbutils.widgets.text(name="schema", defaultValue="agents", label="schema")
dbutils.widgets.text(name="model", defaultValue="memory_agent", label="model")
dbutils.widgets.text(
    name="DATABRICKS_CLIENT_ID", defaultValue="", label="DATABRICKS_CLIENT_ID"
)
dbutils.widgets.text(
    name="DATABRICKS_CLIENT_SECRET", defaultValue="", label="DATABRICKS_CLIENT_SECRET"
)
dbutils.widgets.text(name="secret_scope", defaultValue="dbdemos", label="secret_scope")

In [0]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
model = dbutils.widgets.get("model")
# LLM_ENDPOINT = dbutils.widgets.get("foundation_model")
assert (
    len(catalog) > 0 and len(schema) > 0 and len(model) > 0
), "Please provide a valid catalog, schema, and model name"
three_tiered_model_name = f"{catalog}.{schema}.{model}"
print(f"{three_tiered_model_name=}")

In [0]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()

DATABRICKS_HOST = w.config.host

secret_scope_name = dbutils.widgets.get("secret_scope")

# if needed create a secret scope
if secret_scope_name != "dbdemos":
    w.secrets.create_scope(scope=secret_scope_name)
else:
    print(f"Using existing secret scope: {secret_scope_name}")

In [0]:
if dbutils.widgets.get("DATABRICKS_CLIENT_ID") == "":
    print("no DATABRICKS_CLIENT_ID is provided")
else:
    w.secrets.put_secret(
        scope=secret_scope_name,
        key="DATABRICKS_CLIENT_ID",
        string_value=dbutils.widgets.get("DATABRICKS_CLIENT_ID"),
    )
if dbutils.widgets.get("DATABRICKS_CLIENT_SECRET") == "":
    print("no DATABRICKS_CLIENT_ID is provided")
else:
    w.secrets.put_secret(
        scope=secret_scope_name,
        key="DATABRICKS_CLIENT_SECRET",
        string_value=dbutils.widgets.get("DATABRICKS_CLIENT_SECRET"),
    )
w.secrets.put_secret(
    scope=secret_scope_name, key="DATABRICKS_HOST", string_value=DATABRICKS_HOST
)

In [0]:
# import os

# os.environ["DATABRICKS_CLIENT_ID"] = dbutils.secrets.get(
#     scope=secret_scope_name, key="DATABRICKS_CLIENT_ID"
# )
# os.environ["DATABRICKS_CLIENT_SECRET"] = dbutils.secrets.get(
#     scope=secret_scope_name, key="DATABRICKS_CLIENT_SECRET"
# )

# os.unsetenv("DATABRICKS_CLIENT_ID")
# os.unsetenv("DATABRICKS_CLIENT_SECRET")

## Lakebase Config
- Enable Postgres native role login
- Might need to wait a few min for pg roles to apply
- Create new catalog with PostgreSQL Database: `databricks_postgres` schema off lakebase instance for querying purposes

In [0]:
w.database.get_database_instance(name="bo-test-lakebase-3")

In [0]:
import uuid
import psycopg2
import os
import psycopg
from psycopg_pool import ConnectionPool
from langgraph.checkpoint.postgres import PostgresSaver

w = WorkspaceClient(
    host=dbutils.secrets.get(scope=secret_scope_name, key="DATABRICKS_HOST"),
    client_id=dbutils.secrets.get(scope=secret_scope_name, key="DATABRICKS_CLIENT_ID"),
    client_secret=dbutils.secrets.get(
        scope=secret_scope_name, key="DATABRICKS_CLIENT_SECRET"
    ),
)

instance_name = "bo-test-lakebase-3"

PGPASSWORD = w.database.generate_database_credential(
    request_id=str(uuid.uuid4()), instance_names=[instance_name]
)

conn_user = "bo.cheng%40databricks.com"
conn_host = (
    "instance-71597a8a-7e99-4c85-b29a-f751a73ecb85.database.cloud.databricks.com"
)
conn_db_name = "databricks_postgres"
conn_ssl_mode = "require"
conn_port = "5432"
conn_info = f"postgresql://{conn_user}:{PGPASSWORD.token}@{conn_host}:{conn_port}/{conn_db_name}?sslmode={conn_ssl_mode}"


def db_password_provider() -> str:
    """
    Ask Databricks to mint a fresh DB credential for this instance.
    Called only when the pool needs a new physical connection.
    """
    cred = w.database.generate_database_credential(
        request_id=str(uuid.uuid4()),
        instance_names=[instance_name],
    )
    return cred.token


class CustomConnection(psycopg.Connection):
    """
    A psycopg Connection subclass that injects a fresh password
    *at connection time* (only when the pool creates a new connection).
    """

    @classmethod
    def connect(cls, conninfo="", **kwargs):
        # Append the new password to kwargs
        kwargs["password"] = db_password_provider()
        # Call the superclass's connect method with updated kwargs
        return super().connect(conninfo, **kwargs)


pool = ConnectionPool(
    conninfo=f"dbname={conn_db_name} user={dbutils.secrets.get(scope=secret_scope_name, key='DATABRICKS_CLIENT_ID')} host={conn_host} port={conn_port} sslmode={conn_ssl_mode}",
    connection_class=CustomConnection,
    min_size=1,
    max_size=10,
    open=True,
    kwargs={
        "autocommit": True,
        "keepalives": 1,
        "keepalives_idle": 30,
        "keepalives_interval": 10,
        "keepalives_count": 5,
    },
)

# # Example: use the pool to initialize your checkpoint tables
with pool.connection() as conn:
    with conn.cursor() as cur:
        cur.execute("select 1")

    checkpointer = PostgresSaver(conn)
    checkpointer.setup()
    print("✅ Pool connected and checkpoint tables are ready.")

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union
import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph, MessagesState, START
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.prebuilt import create_react_agent
import os
from langchain_core.messages import HumanMessage, AIMessage
import uuid
from databricks.sdk import WorkspaceClient
import urllib.parse
from databricks_ai_bridge import ModelServingUserCredentials
from psycopg_pool import ConnectionPool

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"
# llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

# You can use UDFs in Unity Catalog as agent tools
uc_tool_names = [
    "bo_cheng_dnb_demos.agents.get_cyber_threat_info",
    "bo_cheng_dnb_demos.agents.get_user_info",
]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

# # (Optional) Use Databricks vector search indexes as tools
# # See https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html
# # for details
#
# # TODO: Add vector search indexes as tools or delete this block
# vector_search_index_tools = [
#     VectorSearchRetrieverTool(
#         index_name="bo_cheng_dnb_demos.agents.poc_customer_support_index",
#         num_results=3,
#         tool_name="customer_support_retriever",
#         tool_description="Retrieves information about customer support responses",
#         query_type="ANN",
#     )
# ]
# tools.extend(vector_search_index_tools)

#####################
## Define agent logic
#####################


class LangGraphChatAgent(ChatAgent):
    def __init__(self, config, tools):
        self.config = config
        # self.connstring = conn
        self.conn_db_name = self.config["conn_db_name"]
        self.conn_ssl_mode = self.config["conn_ssl_mode"]
        self.conn_host = self.config["conn_host"]
        self.instance_name = self.config["instance_name"]
        self.tools = tools
        self.model = ChatDatabricks(
            endpoint=self.config.get("llm_model_serving_endpoint_name"), temperature=0.1
        ).bind_tools(self.tools)
        self.system_prompt = self.config.get("llm_prompt_template")
        self.pool_min_size = int(os.getenv("DB_POOL_MIN_SIZE", "1"))
        self.pool_max_size = int(os.getenv("DB_POOL_MAX_SIZE", "10"))
        self.pool_timeout = float(os.getenv("DB_POOL_TIMEOUT", "30.0"))

    def _get_oauth_connection_string(self):
        """Get a fresh OAuth token and return connection string"""
        # self.w = WorkspaceClient(credentials_strategy=ModelServingUserCredentials())
        self.w = WorkspaceClient()
        try:
            sp = self.w.current_service_principal.me()
            sp_username = sp.application_id
        except Exception:
            user = self.w.current_user.me()
            sp_username = urllib.parse.quote_plus(
                user.user_name
            )  # we need to allow encoding for local testing of the agent to work since it will pass in username

        pg_credential = self.w.database.generate_database_credential(
            request_id=str(uuid.uuid4()), instance_names=[self.instance_name]
        )

        conn_string = f"postgresql://{sp_username}:{pg_credential.token}@{self.conn_host}:5432/{self.conn_db_name}?sslmode={self.conn_ssl_mode}"

        return conn_string

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        try:
            thread_id = custom_inputs.get("thread_id")
        except:
            thread_id = str(uuid.uuid4())

        request = {"messages": self._convert_messages_to_dict(messages)}
        checkpoint_config = {"configurable": {"thread_id": thread_id}}
        messages = []
        if self.system_prompt:
            preprocessor = RunnableLambda(
                lambda state: [{"role": "system", "content": self.system_prompt}]
                + state["messages"]
            )
        else:
            preprocessor = RunnableLambda(lambda state: state["messages"])
        self.model = self.model.bind_tools(self.tools)
        model_runnable = preprocessor | self.model

        # Get connection string to connect to lakebase postgres instance
        conn_info = self._get_oauth_connection_string()

        # Run the agent with the checkpointer
        with ConnectionPool(
            conninfo=conn_info,
            min_size=self.pool_min_size,
            max_size=self.pool_max_size,
            timeout=self.pool_timeout,
            # Configure connection settings
            kwargs={
                "autocommit": True,
                "keepalives": 1,
                "keepalives_idle": 30,
                "keepalives_interval": 10,
                "keepalives_count": 5,
            },
        ).connection() as conn:
            checkpointer = PostgresSaver(conn)

            def should_continue(state: ChatAgentState):
                messages = state["messages"]
                last_message = messages[-1]
                # If there are function calls, continue. else, end
                if last_message.get("tool_calls"):
                    return "continue"
                else:
                    return "end"

            def call_model(
                state: ChatAgentState,
                config: RunnableConfig,
            ):
                response = model_runnable.invoke(state, config)

                return {"messages": [response]}

            workflow = StateGraph(ChatAgentState)

            workflow.add_node("agent", RunnableLambda(call_model))
            workflow.add_node("tools", ChatAgentToolNode(self.tools))

            workflow.set_entry_point("agent")
            workflow.add_conditional_edges(
                "agent",
                should_continue,
                {
                    "continue": "tools",
                    "end": END,
                },
            )
            workflow.add_edge("tools", "agent")

            graph = workflow.compile(checkpointer=checkpointer)
            for event in graph.stream(
                request, checkpoint_config, stream_mode="updates"
            ):
                # print(event)
                for node_data in event.values():
                    messages.extend(
                        ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                    )
                # print(messages)
            return ChatAgentResponse(messages=messages)


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()

config = {
    "llm_model_serving_endpoint_name": "databricks-claude-3-7-sonnet",
    "llm_prompt_template": """
    You are an cybersecurity assistant.
    You are given a task and you must complete it.
    Use the following routine to support the customer.
    # Routine:
    1. Provide the get_cyber_threat_info tool the type of threat being asked about.
    2. Use the source ip address provided in step 1 as input for the get_user_info tool to retrieve user specific info.
    Use the following tools to complete the task:
    {tools}""",
    "conn_db_name": "databricks_postgres",
    "conn_ssl_mode": "require",
    "conn_host": "instance-71597a8a-7e99-4c85-b29a-f751a73ecb85.database.cloud.databricks.com",
    "instance_name": "bo-test-lakebase-3",
}

AGENT = LangGraphChatAgent(config=config, tools=tools)
mlflow.models.set_model(AGENT)

In [0]:
from agent import AGENT

AGENT.predict(
    {
        "messages": [
            {
                "role": "user",
                # "content": "My name is Bo",
                # "content": "What is my name?",
                "content": "Who committed the latest malware threat?",
            }
        ],
        "custom_inputs": {"thread_id": "4"},
    }
)

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksServingEndpoint,
    DatabricksLakebase,
    DatabricksVectorSearchIndex,
)  # we are adding DatabricksLakebase resource type
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
from agent import LLM_ENDPOINT_NAME, tools

# TODO: Manually include additional underlying resources if needed and update values for endpoint/lakebase
resources = [
    DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
    DatabricksLakebase(database_instance_name="demo-lakebase-instance"),
    # DatabricksVectorSearchIndex(
    #     index_name="catalog.schema.agents.poc_customer_support_index"
    # ),
]
for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

# System policy: resources accessed with system credentials
system_policy = SystemAuthPolicy(resources=resources)

# User policy: API scopes for OBO access
api_scopes = [
    "sql.statement-execution",
    "mcp.genie",
    "mcp.external",
    "catalog.connections",
    "mcp.vectorsearch",
    "vectorsearch.vector-search-indexes",
    "iam.current-user:read",
    "sql.warehouses",
    "dashboards.genie",
    "serving.serving-endpoints",
    "iam.access-control:read",
    "apps.apps",
    "mcp.functions",
    "vectorsearch.vector-search-endpoints",
]
user_policy = UserAuthPolicy(api_scopes=api_scopes)

input_example = {
    "messages": [{"role": "user", "content": "What is an LLM agent?"}],
    "custom_inputs": {"thread_id": "1"},
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        input_example=input_example,
        extra_pip_requirements=[
            "databricks-connect",
            "databricks-agents",
            "databricks-langchain",
            "unitycatalog-langchain[databricks]",
            "psycopg[binary,pool]",
            "langgraph-checkpoint-postgres==2.0.21",
            "langgraph==0.3.4",
            "langchain",
        ],
        # auth_policy=AuthPolicy(
        #     system_auth_policy=system_policy, user_auth_policy=user_policy
        # ),
        resources=resources,
    )

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data=input_example,
    env_manager="uv",
)

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
UC_MODEL_NAME = f"{catalog}.{schema}.{model}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri,
    name=UC_MODEL_NAME,
)

In [0]:
from databricks import agents

agents.deploy(
    UC_MODEL_NAME,
    uc_registered_model_info.version,
    environment_vars={
        "DATABRICKS_HOST": "{{secrets/dbdemos/DATABRICKS_HOST}}",
        "DATABRICKS_CLIENT_ID": "{{secrets/dbdemos/DATABRICKS_CLIENT_ID}}",
        "DATABRICKS_CLIENT_SECRET": "{{secrets/dbdemos/DATABRICKS_CLIENT_SECRET}}",
    },
    tags={"endpointSource": "playground"},
)

In [0]:
from databricks import agents

# Note that  can specify individual users or groups.
agents.set_permissions(
    model_name=UC_MODEL_NAME,
    users=["users"],
    permission_level=agents.PermissionLevel.CAN_QUERY,
)

## Next steps
After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See docs for details

In [0]:
import time
from databricks.sdk.service.serving import EndpointStateReady, EndpointStateConfigUpdate
from databricks.sdk import WorkspaceClient

endpoint_name: str = f"agents_{catalog}-{schema}-{model}"
print("\nWaiting for endpoint to deploy.  This can take 10 - 20 minutes.", end="")
w = WorkspaceClient()
while (
    w.serving_endpoints.get(endpoint_name).state.ready == EndpointStateReady.NOT_READY
    or w.serving_endpoints.get(endpoint_name).state.config_update
    == EndpointStateConfigUpdate.IN_PROGRESS
):
    print(".", end="")
    time.sleep(30)

In [0]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
endpoint_name: str = f"agents_{catalog}-{schema}-{model}"
response = w.serving_endpoints.query(
    name=endpoint_name,
    dataframe_records=[
        {
            "messages": [
                {
                    "role": "user",
                    "content": "Who was just mentioned for a cybersecurity incident?",
                }
            ],
            "custom_inputs": {"thread_id": "4"},
        }
    ],
    temperature=0.1,
)

In [0]:
response