# Unity Catalog Checkpointer Example

This notebook demonstrates how to use the `UnityCatalogCheckpointSaver` to persist LangGraph state in Databricks Unity Catalog.

In [None]:
# Create widgets for configuration
dbutils.widgets.text("catalog", "", "Catalog Name")
dbutils.widgets.text("schema", "", "Schema Name")
dbutils.widgets.text("warehouse_id", "", "Warehouse ID")

In [None]:
# Load environment variables from .env file
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

In [None]:
# Add the src directory to Python path for custom code imports
import sys
import os
sys.path.append(os.path.abspath('../src'))

In [None]:
# Enable nested async support for Jupyter notebooks
import nest_asyncio
nest_asyncio.apply()
print("✓ Nested asyncio support enabled")

## Setup and Configuration


In [None]:
import os
from databricks.sdk import WorkspaceClient
from langgraph_unity_catalog_checkpoint import UnityCatalogCheckpointSaver

# Initialize Databricks WorkspaceClient
workspace_client: WorkspaceClient = WorkspaceClient()

# Configuration for Unity Catalog - prefer environment variables over widgets
catalog: str = os.getenv("UC_CATALOG") or dbutils.widgets.get("catalog")
schema: str = os.getenv("UC_SCHEMA") or dbutils.widgets.get("schema")
warehouse_id: str | None = os.getenv("DATABRICKS_WAREHOUSE_ID") or dbutils.widgets.get("warehouse_id") or None

print(f"Using catalog: {catalog}")
print(f"Using schema: {schema}")
print(f"Using warehouse_id: {warehouse_id}")

In [None]:
import os
from typing import Annotated
from databricks.sdk import WorkspaceClient
from langchain_core.messages import HumanMessage, BaseMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict
from databricks_langchain import ChatDatabricks

# Initialize Databricks WorkspaceClient
workspace_client: WorkspaceClient = WorkspaceClient()

# Initialize ChatDatabricks with Llama model
llm: ChatDatabricks = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")

In [None]:
# Enable MLflow autologging for LangChain
import mlflow
mlflow.langchain.autolog()
print("✓ MLflow LangChain autologging enabled")

## Create Catalog and Schema


In [None]:
# Create catalog if it doesn't exist using Catalog API
try:
    workspace_client.catalogs.create(name=catalog, comment="Unity Catalog for LangGraph persistence")
    print(f"✓ Created catalog '{catalog}'")
except Exception as e:
    if "already exists" in str(e).lower():
        print(f"✓ Catalog '{catalog}' already exists")
    else:
        print(f"Warning: Could not create catalog: {e}")

# Create schema if it doesn't exist using Schema API
try:
    workspace_client.schemas.create(
        name=schema,
        catalog_name=catalog,
        comment="Schema for LangGraph checkpoints and stores"
    )
    print(f"✓ Created schema '{catalog}.{schema}'")
except Exception as e:
    if "already exists" in str(e).lower():
        print(f"✓ Schema '{catalog}.{schema}' already exists")
    else:
        print(f"Warning: Could not create schema: {e}")

## Define the Graph State and Nodes


In [None]:
# Define the state schema
class State(TypedDict):
    """State for the agent graph."""
    messages: Annotated[list[BaseMessage], add_messages]

# Define the chatbot node
def chatbot(state: State) -> dict[str, list[BaseMessage]]:
    """Chatbot node that uses Databricks LLM."""
    response: BaseMessage = llm.invoke(state["messages"])
    return {"messages": [response]}


## Create the Checkpointer


In [None]:
# Create the checkpointer with default table names
# Default tables follow PostgreSQL/LangGraph conventions:
# - checkpoints_table: "checkpoints"
# - checkpoint_blobs_table: "checkpoint_blobs"
# - writes_table: "checkpoint_writes"
checkpointer = UnityCatalogCheckpointSaver(
    workspace_client=workspace_client,
    catalog=catalog,
    schema=schema,
    warehouse_id=warehouse_id,
)

print(f"✓ Checkpointer created")
print(f"  Checkpoints table: {checkpointer.full_checkpoints_table}")
print(f"  Writes table: {checkpointer.full_writes_table}")

## Build and Compile the Graph

In [None]:
# Build the graph
graph_builder = StateGraph(State)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)

# Compile with checkpointer for persistence
graph = graph_builder.compile(checkpointer=checkpointer)

print("✓ Graph compiled with checkpoint persistence")


## Run Conversation - First Interaction


In [None]:
# Configuration for the conversation thread
config = {"configurable": {"thread_id": "conversation_1"}}

# First interaction
print("First interaction:")
result = graph.invoke(
    {"messages": [HumanMessage(content="Hello, my name is Alice")]},
    config=config
)
print(f"Response: {result['messages'][-1].content}")

## Second Interaction - State is Preserved


In [None]:
# Second interaction - graph remembers context from the same thread
print("Second interaction:")
result = graph.invoke(
    {"messages": [HumanMessage(content="What's my name?")]},
    config=config
)
print(f"Response: {result['messages'][-1].content}")

## View Conversation History


In [None]:
# View the conversation history
print("Conversation history:")
state = await graph.aget_state(config)
for msg in state.values["messages"]:
    msg_type = "Human" if isinstance(msg, HumanMessage) else "AI"
    print(f"  {msg_type}: {msg.content}")


## List All Checkpoints for This Thread


In [None]:
# List all checkpoints for this thread
print("\nCheckpoints for this thread:")
for i, checkpoint_tuple in enumerate(checkpointer.list(config), 1):
    checkpoint_id = checkpoint_tuple.config["configurable"]["checkpoint_id"]
    metadata = checkpoint_tuple.metadata
    print(f"\nCheckpoint {i}:")
    print(f"  ID: {checkpoint_id}")
    print(f"  Metadata: {metadata}")