# Mosaic AI Agent Framework: Author and deploy a Stateful Agent with Long-term memory using Databricks Lakebase as a Store
This notebook demonstrates how to build a stateful agent that stores and retrieves user preference using the Mosaic AI Agent Framework with Lakebase as the agentâ€™s memory store

In this notebook, you will:
1. Author a Long-term memory Agent graph with Lakebase which stores and recalls users' preferences (via semantic search in store)
2. Wrap the LangGraph agent with `ResponsesAgent` interface to ensure compatibility with Databricks features
3. Test the agent's behavior locally
4. Register model to Unity Catalog, log and deploy the agent for use in apps and Playground

## Prerequisites
- Have a Lakebase instance ready and running, see Databricks documentation ([AWS](https://docs.databricks.com/aws/en/oltp/create/) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/oltp/create/)). 
- You can create a Lakebase instance by going to SQL Warehouses -> Lakebase Postgres -> Create database instance. You will need to retrieve values from the "Connection details" section of your Lakebase to fill out this notebook.
- Complete all the "TODO"s throughout this notebook

### Install dependencies

In [0]:
%pip install -U -qqqq uv databricks-agents mlflow-skinny[databricks] databricks-langchain[memory]
dbutils.library.restartPython()

## First time setup only: Set up store tables for your Lakebase instance

In [0]:
from databricks.sdk import WorkspaceClient
from databricks_langchain import DatabricksStore

# TODO: Fill in your Lakebase config values
LAKEBASE_INSTANCE_NAME = "lakebase-name"

store = DatabricksStore(instance_name=LAKEBASE_INSTANCE_NAME)
store.setup()

# Define the agent in code

## Write agent code to file agent.py
Define the agent code in a single cell below. This lets you write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.

## Wrap the LangGraph agent using the ResponsesAgent interface
For compatibility with Databricks AI features, the `LangGraphResponsesAgent` class implements the `ResponsesAgent` interface to wrap the LangGraph agent.

Databricks recommends using `ResponsesAgent` as it simplifies authoring multi-turn conversational agents using an open source standard. See MLflow's [ResponsesAgent documentation](https://www.mlflow.org/docs/latest/llms/responses-agent-intro/).

In [0]:
%%writefile agent.py
import json
import logging
import os
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    DatabricksStore,
    UCFunctionToolkit,
)
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import tool
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
    output_to_responses_items_stream,
    to_chat_completions_input,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))


############################################
# Define your LLM endpoint and system prompt
############################################
# TODO: Replace with your model serving endpoint
LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"

# TODO: Update with your system prompt
SYSTEM_PROMPT = "You are a helpful assistant. Use the available tools to answer questions."

# TODO: Fill in values for your lakebase instance for agent to use here
LAKEBASE_INSTANCE_NAME = "lakebase-name"

# TODO: Update with your desired embedding configuration values for semantic memory search
# Example Model Serving endpoint for text embeddings https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models#gte-large-en
EMBEDDING_ENDPOINT = "databricks-gte-large-en"  
EMBEDDING_DIMS = 1024

###############################################################################
## Define tools for your agent,enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html
###############################################################################

tools = []

# Example UC tools; add your own as needed
UC_TOOL_NAMES: list[str] = []
if UC_TOOL_NAMES:
    uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
    tools.extend(uc_toolkit.tools)

# Use Databricks vector search indexes as tools
# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge
# List to store vector search tool instances for unstructured retrieval.
VECTOR_SEARCH_TOOLS = []

# To add vector search retriever tools,
# use VectorSearchRetrieverTool and create_tool_info,
# then append the result to TOOL_INFOS.
# Example:
# VECTOR_SEARCH_TOOLS.append(
#     VectorSearchRetrieverTool(
#         index_name="",
#         # filters="..."
#     )
# )

tools.extend(VECTOR_SEARCH_TOOLS)

#####################
## Define agent logic
#####################


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    custom_inputs: Optional[dict[str, Any]]
    custom_outputs: Optional[dict[str, Any]]
    user_id: Optional[str]


class LangGraphResponsesAgent(ResponsesAgent):
    """Stateless agent using ResponsesAgent with user-based long-term memory.

    Features:
    - Connection pooling with credential rotation via DatabricksStore
    - User-based long-term memory persistence (memories stored under "users".user_id) in "store" table
    - Tool support with UC functions
    - Automatic connection management - borrows connections per operation for scalability
    """

    def __init__(self):
        self.lakebase_instance_name = LAKEBASE_INSTANCE_NAME
        self.system_prompt = SYSTEM_PROMPT
        self.model = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

        self._store = None
        self._memory_tools = None

    @property
    def store(self):
        """Lazy initialization of DatabricksStore with semantic search support."""
        if self._store is None:
            logger.info(f"Initializing DatabricksStore with instance: {self.lakebase_instance_name} and embedding endpoint {EMBEDDING_ENDPOINT} with dims {EMBEDDING_DIMS}")
            self._store = DatabricksStore(
                instance_name=self.lakebase_instance_name,
                embedding_endpoint=EMBEDDING_ENDPOINT,
                embedding_dims=EMBEDDING_DIMS,
            )
            self._store.setup()
        return self._store

    @property
    def memory_tools(self):
        """Lazy initialization of memory tools."""
        if self._memory_tools is None:
            logger.info("Creating memory tools")
            self._memory_tools = self._create_memory_tools()
        return self._memory_tools

    @property
    def model_with_all_tools(self):
        all_tools = tools + self.memory_tools
        return self.model.bind_tools(all_tools) if all_tools else self.model

    def _create_memory_tools(self):
        """Create tools for reading and writing long-term memory."""

        @tool
        def get_user_memory(query: str, config: RunnableConfig) -> str:
            """Search for relevant information about the user from long-term memory using semantic search via vector embeddings.

            Use this tool to retrieve previously saved information about the user,
            such as their preferences, facts they've shared, or other personal details.

            Args:
            """
            user_id = config.get("configurable", {}).get("user_id")
            if not user_id:
                return "Memory not available - no user_id provided."

            namespace = ("user_memories", user_id.replace(".", "-"))

            results = self.store.search(namespace, query=query, limit=5)

            if not results:
                return "No memories found for this user."

            memory_items = []
            for item in results:
                memory_items.append(f"- [{item.key}]: {json.dumps(item.value)}")

            return f"Found {len(results)} relevant memories (ranked by semantic similarity):\n" + "\n".join(memory_items)

        @tool
        def save_user_memory(memory_key: str, memory_data_json: str, config: RunnableConfig) -> str:
            """Save information about the user to long-term memory with vector embeddings.

            Use this tool to remember important information the user shares about themselves,
            such as preferences, facts, or other personal details.

            Args:
                memory_key: A descriptive key for this memory (e.g., "preferences", "favorite_color", "background_info")
                memory_data_json: JSON string with the information to remember.
                    Example: '{"favorite_color": "purple"}'
            """
            user_id = config.get("configurable", {}).get("user_id")
            if not user_id:
                return "Cannot save memory - no user_id provided."

            namespace = ("user_memories", user_id.replace(".", "-"))

            try:
                memory_data = json.loads(memory_data_json)
                # Validate that memory_data is a dictionary (not a list or other type)
                if not isinstance(memory_data, dict):
                    return f"Failed to save memory: memory_data must be a JSON object (dictionary), not {type(memory_data).__name__}. Example: '{{\"key\": \"value\"}}'"
                self.store.put(namespace, memory_key, memory_data)
                return f"Successfully saved memory with key '{memory_key}' for user."
            except json.JSONDecodeError as e:
                return f"Failed to save memory: Invalid JSON format - {str(e)}"

        @tool
        def delete_user_memory(memory_key: str, config: RunnableConfig) -> str:
            """Delete a specific memory from the user's long-term memory.

            Use this tool when the user asks you to forget something or remove
            a piece of information from their memory.

            Args:
                memory_key: The key of the memory to delete (e.g., "preferences", "likes", "background_info")
            """
            user_id = config.get("configurable", {}).get("user_id")
            if not user_id:
                return "Cannot delete memory - no user_id provided."

            namespace = ("user_memories", user_id.replace(".", "-"))

            self.store.delete(namespace, memory_key)
            return f"Successfully deleted memory with key '{memory_key}' for user."

        return [get_user_memory, save_user_memory, delete_user_memory]

    def _create_graph(self):
        """Create the LangGraph workflow"""
        def should_continue(state: AgentState):
            messages = state["messages"]
            last_message = messages[-1]
            if isinstance(last_message, AIMessage) and last_message.tool_calls:
                return "continue"
            return "end"

        model_with_tools = self.model_with_all_tools

        if self.system_prompt:
            preprocessor = RunnableLambda(
                lambda state: [{"role": "system", "content": self.system_prompt}] + state["messages"]
            )
        else:
            preprocessor = RunnableLambda(lambda state: state["messages"])

        model_runnable = preprocessor | model_with_tools

        def call_model(state: AgentState, config: RunnableConfig):
            response = model_runnable.invoke(state, config)
            return {"messages": [response]}

        workflow = StateGraph(AgentState)
        workflow.add_node("agent", RunnableLambda(call_model))

        active_tools = (tools + self.memory_tools)

        if active_tools:
            workflow.add_node("tools", ToolNode(active_tools))
            workflow.add_conditional_edges(
                "agent",
                should_continue,
                {"continue": "tools", "end": END}
            )
            workflow.add_edge("tools", "agent")
        else:
            workflow.add_edge("agent", END)

        workflow.set_entry_point("agent")

        return workflow.compile()

    def _get_user_id(self, request: ResponsesAgentRequest) -> Optional[str]:
        """
        Use user_id from chat context if available, return None if not provided
        """
        # User id from chat context as user id to store memories
        # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.types.html#mlflow.types.agent.ChatContext
        if request.context and getattr(request.context, "user_id", None):
            return request.context.user_id
        return None

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        """Non-streaming prediction"""
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs)

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Streaming prediction"""
        user_id = self._get_user_id(request)

        # If there is no user_id, we cannot retrieve memories
        if not user_id:
            logger.error(
                "Cannot store or retrieve memories without a user_id."
            )

        ci = dict(request.custom_inputs or {})
        if user_id:
            ci["user_id"] = user_id
        request.custom_inputs = ci

        cc_msgs = to_chat_completions_input([i.model_dump() for i in request.input])

        run_config = {"configurable": {}}
        if user_id:
            run_config["configurable"]["user_id"] = user_id

        graph = self._create_graph()

        state_input = {"messages": cc_msgs}
        if user_id:
            state_input["user_id"] = user_id

        # Stream the graph execution
        for event in graph.stream(
            state_input,
            run_config,
            stream_mode=["updates", "messages"]
        ):
            if event[0] == "updates":
                for node_data in event[1].values():
                    if len(node_data.get("messages", [])) > 0:
                        yield from output_to_responses_items_stream(node_data["messages"])
            # Stream message chunks for real-time text generation
            elif event[0] == "messages":
                try:
                    chunk = event[1][0]
                    if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
                        yield ResponsesAgentStreamEvent(
                            **self.create_text_delta(delta=content, item_id=chunk.id),
                        )
                except Exception as e:
                    logger.error(f"Error streaming chunk: {e}")

# ----- Export model -----
mlflow.langchain.autolog()
AGENT = LangGraphResponsesAgent()
mlflow.models.set_model(AGENT)

# Test the Agent locally

In [0]:
dbutils.library.restartPython()

In [0]:
# example using user_id from ChatContext as input user_id
# https://mlflow.org/docs/latest/api_reference/python_api/mlflow.types.html#mlflow.types.agent.ChatContext
from agent import AGENT
import mlflow
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ChatContext
)

req = ResponsesAgentRequest(
    input=[{"role": "user", "content": "Please remember I use Databricks and I am a python developer who likes pistachios and has a dog named Fluffy"}],
    context=ChatContext(
        conversation_id="abc",
        user_id="email@databricks.com"
    ),
)
result = AGENT.predict(req)

print(result.model_dump(exclude_none=True))

In [0]:
# Recall memory example

req = ResponsesAgentRequest(
    input=[{"role": "user", "content": "What data platform do I use?"}],
    context=ChatContext(
        conversation_id="abc",
        user_id="email@databricks.com"
    ),
)
result = AGENT.predict(req)

print(result.model_dump(exclude_none=True))

# Log the agent as an MLflow model
Log the agent as code from the agent.py file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

## Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model()`.

**TODO:** 
- Add lakebase as a resource type
- If your Unity Catalog tool queries a [vector search index](https://docs.databricks.com/docs%20link) or leverages [external functions](https://docs.databricks.com/docs%20link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#resources)).

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import tools, LLM_ENDPOINT_NAME, LAKEBASE_INSTANCE_NAME
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint, DatabricksLakebase
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
from pkg_resources import get_distribution

resources = [DatabricksServingEndpoint(LLM_ENDPOINT_NAME), DatabricksLakebase(database_instance_name=LAKEBASE_INSTANCE_NAME)]

for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "input": [
        {
            "role": "user",
            "content": "What is an LLM agent?"
        }
    ],
}

logged_agent_info = mlflow.pyfunc.log_model(
    name="agent",
    python_model="agent.py",
    input_example=input_example,
    resources=resources,
    pip_requirements=[
        "mlflow==3.6.0",
        f"databricks-langchain[memory]=={get_distribution('databricks-langchain[memory]').version}",
    ]
)

# Evaluate the agent with Agent Evaluation
Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics. See Databricks documentation ([AWS](https://docs.databricks.com/(https://docs.databricks.com/aws/generative-ai/agent-evaluation) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-evaluation/)).

To evaluate your tool calls, add custom metrics. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/agent-evaluation/custom-metrics.html#evaluating-tool-calls) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-evaluation/custom-metrics#evaluating-tool-calls)).

In [0]:
import mlflow
from mlflow.genai.scorers import RelevanceToQuery, RetrievalGroundedness, RetrievalRelevance, Safety

eval_dataset = [
    {
        "inputs": {"input": [{"role": "user", "content": "Calculate the 15th Fibonacci number"}]},
        "expected_response": "The 15th Fibonacci number is 610.",
    }
]

eval_results = mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=lambda input: AGENT.predict({"input": input}),
    scorers=[RelevanceToQuery(), Safety()],  # add more scorers here if they're applicable
)

# Review the evaluation results in the MLfLow UI (see console output)

# Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the mlflow.models.predict() API.

In [0]:
mlflow.models.predict(
    model_uri=logged_agent_info.model_uri,
    input_data={"input": [{"role": "user", "content": "I am working on stateful agents"}]},
    env_manager="uv",
)

# Register the model to Unity Catalog
Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "catalog"
schema = "schema"
model_name = "long-term-memory-agent"

UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

Deploy the agent

In [0]:
from databricks import agents
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})

# Next steps
It will take around 15 minutes for you to finish deploying your agent. After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. 

Now, with your stateful agent, you can pick up past threads and continue the conversation.

You can query your Lakebase instance to see a record of your user memories. Here is a basic query to see items in the store:
```
select *
from public.store
order by updated_at desc
limit 50;
```