# How to create a custom checkpointer using MongoDB

When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions.

This example shows how to use `MongoDB` as the backend for persisting checkpoint state.

NOTE: this is just an example implementation. You can implement your own checkpointer using a different database or modify this one as long as it conforms to the `BaseCheckpointSaver` interface.

## Checkpointer implementation

In [None]:
%%capture --no-stderr
%pip install -U langgraph pymongo

In [1]:
import pickle
from contextlib import AbstractContextManager
from types import TracebackType
from typing import Any, Dict, Iterator, Optional

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.checkpoint.base import (
    BaseCheckpointSaver,
    Checkpoint,
    CheckpointMetadata,
    CheckpointTuple,
    SerializerProtocol,
)
from langgraph.serde.jsonplus import JsonPlusSerializer
from pymongo import MongoClient


class JsonPlusSerializerCompat(JsonPlusSerializer):
    """A serializer that supports loading pickled checkpoints for backwards compatibility.

    This serializer extends the JsonPlusSerializer and adds support for loading pickled
    checkpoints. If the input data starts with b"\x80" and ends with b".", it is treated
    as a pickled checkpoint and loaded using pickle.loads(). Otherwise, the default
    JsonPlusSerializer behavior is used.

    Examples:
        >>> import pickle
        >>> from langgraph.checkpoint.sqlite import JsonPlusSerializerCompat
        >>>
        >>> serializer = JsonPlusSerializerCompat()
        >>> pickled_data = pickle.dumps({"key": "value"})
        >>> loaded_data = serializer.loads(pickled_data)
        >>> print(loaded_data)  # Output: {"key": "value"}
        >>>
        >>> json_data = '{"key": "value"}'.encode("utf-8")
        >>> loaded_data = serializer.loads(json_data)
        >>> print(loaded_data)  # Output: {"key": "value"}
    """

    def loads(self, data: bytes) -> Any:
        if data.startswith(b"\x80") and data.endswith(b"."):
            return pickle.loads(data)
        return super().loads(data)


class MongoDBSaver(AbstractContextManager, BaseCheckpointSaver):
    """A checkpoint saver that stores checkpoints in a MongoDB database.

    Args:
        client (pymongo.MongoClient): The MongoDB client.
        db_name (str): The name of the database to use.
        collection_name (str): The name of the collection to use.
        serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat.

    Examples:

        >>> from pymongo import MongoClient
        >>> from langgraph.checkpoint.mongodb import MongoDBSaver
        >>> from langgraph.graph import StateGraph
        >>>
        >>> builder = StateGraph(int)
        >>> builder.add_node("add_one", lambda x: x + 1)
        >>> builder.set_entry_point("add_one")
        >>> builder.set_finish_point("add_one")
        >>> client = MongoClient("mongodb://localhost:27017/")
        >>> memory = MongoDBSaver(client, "checkpoints", "checkpoints")
        >>> graph = builder.compile(checkpointer=memory)
        >>> config = {"configurable": {"thread_id": "1"}}
        >>> graph.get_state(config)
        >>> result = graph.invoke(3, config)
        >>> graph.get_state(config)
        StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '2024-05-04T06:32:42.235444+00:00'}}, parent_config=None)
    """

    serde = JsonPlusSerializerCompat()

    client: MongoClient
    db_name: str
    collection_name: str

    def __init__(
        self,
        client: MongoClient,
        db_name: str,
        collection_name: str,
        *,
        serde: Optional[SerializerProtocol] = None,
    ) -> None:
        super().__init__(serde=serde)
        self.client = client
        self.db_name = db_name
        self.collection_name = collection_name
        self.collection = client[db_name][collection_name]

    def __enter__(self) -> Self:
        return self

    def __exit__(
        self,
        __exc_type: Optional[type[BaseException]],
        __exc_value: Optional[BaseException],
        __traceback: Optional[TracebackType],
    ) -> Optional[bool]:
        return True

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """Get a checkpoint tuple from the database.

        This method retrieves a checkpoint tuple from the MongoDB database based on the
        provided config. If the config contains a "thread_ts" key, the checkpoint with
        the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
        for the given thread ID is retrieved.

        Args:
            config (RunnableConfig): The config to use for retrieving the checkpoint.

        Returns:
            Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
        """
        if config["configurable"].get("thread_ts"):
            query = {
                "thread_id": config["configurable"]["thread_id"],
                "thread_ts": config["configurable"]["thread_ts"],
            }
        else:
            query = {"thread_id": config["configurable"]["thread_id"]}
        result = self.collection.find(query).sort("thread_ts", -1).limit(1)
        for doc in result:
            return CheckpointTuple(
                config,
                self.serde.loads(doc["checkpoint"]),
                self.serde.loads(doc["metadata"]),
                (
                    {
                        "configurable": {
                            "thread_id": doc["thread_id"],
                            "thread_ts": doc["parent_ts"],
                        }
                    }
                    if doc.get("parent_ts")
                    else None
                ),
            )

    def list(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[Dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> Iterator[CheckpointTuple]:
        """List checkpoints from the database.

        This method retrieves a list of checkpoint tuples from the MongoDB database based
        on the provided config. The checkpoints are ordered by timestamp in descending order.

        Args:
            config (RunnableConfig): The config to use for listing the checkpoints.
            before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None.
            limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.

        Yields:
            Iterator[CheckpointTuple]: An iterator of checkpoint tuples.
        """
        query = {}
        if config is not None:
            query["thread_id"] = config["configurable"]["thread_id"]
        if filter:
            for key, value in filter.items():
                query[f"metadata.{key}"] = value
        if before is not None:
            query["thread_ts"] = {"$lt": before["configurable"]["thread_ts"]}
        result = self.collection.find(query).sort("thread_ts", -1).limit(limit)
        for doc in result:
            yield CheckpointTuple(
                {
                    "configurable": {
                        "thread_id": doc["thread_id"],
                        "thread_ts": doc["thread_ts"],
                    }
                },
                self.serde.loads(doc["checkpoint"]),
                self.serde.loads(doc["metadata"]),
                (
                    {
                        "configurable": {
                            "thread_id": doc["thread_id"],
                            "thread_ts": doc["parent_ts"],
                        }
                    }
                    if doc.get("parent_ts")
                    else None
                ),
            )

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        """Save a checkpoint to the database.

        This method saves a checkpoint to the MongoDB database. The checkpoint is associated
        with the provided config and its parent config (if any).

        Args:
            config (RunnableConfig): The config to associate with the checkpoint.
            checkpoint (Checkpoint): The checkpoint to save.
            metadata (Optional[dict[str, Any]]): Additional metadata to save with the checkpoint. Defaults to None.

        Returns:
            RunnableConfig: The updated config containing the saved checkpoint's timestamp.
        """
        doc = {
            "thread_id": config["configurable"]["thread_id"],
            "thread_ts": checkpoint["id"],
            "checkpoint": self.serde.dumps(checkpoint),
            "metadata": self.serde.dumps(metadata),
        }
        if config["configurable"].get("thread_ts"):
            doc["parent_ts"] = config["configurable"]["thread_ts"]
        self.collection.insert_one(doc)
        return {
            "configurable": {
                "thread_id": config["configurable"]["thread_id"],
                "thread_ts": checkpoint["id"],
            }
        }

## MongoDB connection

In [4]:
MONGO_URI = "mongodb://localhost:27017/"

## Basic example using graph

In [3]:
from langgraph.graph import StateGraph, START, END

checkpointer = MongoDBSaver(
    MongoClient(MONGO_URI), "checkpoints_db", "checkpoints_collection"
)
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.add_edge(START, "add_one")
builder.add_edge("add_one", END)
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "123"}}
graph.get_state(config)
result = graph.invoke(3, config)
graph.get_state(config)

StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '123'}}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, created_at='2024-07-09T15:56:06.885848+00:00', parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff'}})

In [4]:
result

4

In [5]:
checkpointer.get(config)

{'v': 1,
 'ts': '2024-07-09T15:56:06.885848+00:00',
 'id': '1ef3e0bc-09d1-6a75-8001-8f750e9a0782',
 'channel_values': {'__root__': 4, 'add_one': 'add_one'},
 'channel_versions': {'__start__': 2,
  '__root__': 3,
  'start:add_one': 3,
  'add_one': 3},
 'versions_seen': {'__start__': {'__start__': 1},
  'add_one': {'start:add_one': 2}},
 'pending_sends': []}

In [12]:
list = checkpointer.list(config, limit=3)
for item in list:
    print(item)

CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09d1-6a75-8001-8f750e9a0782'}}, checkpoint={'v': 1, 'ts': '2024-07-09T15:56:06.885848+00:00', 'id': '1ef3e0bc-09d1-6a75-8001-8f750e9a0782', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff'}})
CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff'}}, checkpoint={'v': 1, 'ts': '2024-07-09T15:56:06.878338+00:00', 'id': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff', 'channel_values': {'__root__': 3, 'start:add_one': '__start__'}, 'channel_versions': {'__start__': 2, '__root__': 2, 'start:add_on

In [13]:
checkpointer.get_tuple(config)

CheckpointTuple(config={'configurable': {'thread_id': '123'}}, checkpoint={'v': 1, 'ts': '2024-07-09T13:22:19.610402+00:00', 'id': '1ef3df64-4ba9-6b58-8001-ab084cc01a30', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3df64-4ba2-660c-8000-569999697ff3'}})

## Setup environment

In [14]:
import getpass
import os


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")

## Setup model and tools for the graph

In [None]:
%pip install langchain_openai

In [15]:
from typing import Literal
from langchain_core.runnables import ConfigurableField
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent


@tool
def get_weather(city: Literal["nyc", "sf"]):
    """Use this to get weather information."""
    if city == "nyc":
        return "It might be cloudy in nyc"
    elif city == "sf":
        return "It's always sunny in sf"
    else:
        raise AssertionError("Unknown city")


tools = [get_weather]
model = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

In [16]:
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
res = graph.invoke({"messages": [("human", "what's the weather in sf")]}, config)

In [17]:
res

{'messages': [HumanMessage(content="what's the weather in sf", id='a624d383-13c6-499c-8f03-31ed11fa0cfb'),
  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wapm4s91KQUQqE9y1L53QmmE', 'function': {'arguments': '{"city":"sf"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 58, 'total_tokens': 72}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-614bc54a-ad37-4f17-9047-b80752bdf66e-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_wapm4s91KQUQqE9y1L53QmmE'}]),
  ToolMessage(content="It's always sunny in sf", name='get_weather', id='e58dd97d-b50d-4b0a-9492-7155106c975a', tool_call_id='call_wapm4s91KQUQqE9y1L53QmmE'),
  AIMessage(content='The weather in San Francisco is always sunny!', response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 86, 'total_tokens': 96}, 'model_name':

In [18]:
checkpointer.get_tuple(config)

CheckpointTuple(config={'configurable': {'thread_id': '1'}}, checkpoint={'v': 1, 'ts': '2024-07-09T13:22:49.794047+00:00', 'id': '1ef3df65-6b84-63fd-8003-888bcef289e3', 'channel_values': {'messages': [HumanMessage(content="what's the weather in sf", id='a624d383-13c6-499c-8f03-31ed11fa0cfb'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wapm4s91KQUQqE9y1L53QmmE', 'function': {'arguments': '{"city":"sf"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 58, 'total_tokens': 72}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-614bc54a-ad37-4f17-9047-b80752bdf66e-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_wapm4s91KQUQqE9y1L53QmmE'}]), ToolMessage(content="It's always sunny in sf", name='get_weather', id='e58dd97d-b50d-4b0a-9492-7155106c975a', tool_call_id='call_wapm4s91KQUQqE9y1L53QmmE'), A

### Checkpoints saved in MongoDB

In [19]:
client = MongoClient(MONGO_URI)
database = client["checkpoints_db"]
collection = database["checkpoints_collection"]

for doc in collection.find():
    print(doc)

# The checkpoints from both the examples have been saved in the database.

{'_id': ObjectId('668d398bb975d3e766de42ce'), 'thread_id': '123', 'thread_ts': '1ef3df64-4b98-68b4-bfff-592f97570cf6', 'checkpoint': b'{"v": 1, "ts": "2024-07-09T13:22:19.603371+00:00", "id": "1ef3df64-4b98-68b4-bfff-592f97570cf6", "channel_values": {"__start__": 3}, "channel_versions": {"__start__": 1}, "versions_seen": {}, "pending_sends": []}', 'metadata': b'{"source": "input", "step": -1, "writes": 3}'}
{'_id': ObjectId('668d398bb975d3e766de42cf'), 'thread_id': '123', 'thread_ts': '1ef3df64-4ba2-660c-8000-569999697ff3', 'checkpoint': b'{"v": 1, "ts": "2024-07-09T13:22:19.607399+00:00", "id": "1ef3df64-4ba2-660c-8000-569999697ff3", "channel_values": {"__root__": 3, "start:add_one": "__start__"}, "channel_versions": {"__start__": 2, "__root__": 2, "start:add_one": 2}, "versions_seen": {"__start__": {"__start__": 1}, "add_one": {}}, "pending_sends": []}', 'metadata': b'{"source": "loop", "step": 0, "writes": null}', 'parent_ts': '1ef3df64-4b98-68b4-bfff-592f97570cf6'}
{'_id': ObjectId

## Asynchronous implementation

In [None]:
# Async package for MongoDB
%pip install motor

In [2]:
import pickle
from contextlib import AbstractContextManager
from types import TracebackType
from typing import Any, Dict, Optional, AsyncIterator

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.checkpoint.base import (
    BaseCheckpointSaver,
    Checkpoint,
    CheckpointMetadata,
    CheckpointTuple,
    SerializerProtocol,
)
from langgraph.serde.jsonplus import JsonPlusSerializer
from motor.motor_asyncio import AsyncIOMotorClient


class JsonPlusSerializerCompat(JsonPlusSerializer):
    """A serializer that supports loading pickled checkpoints for backwards compatibility.

    This serializer extends the JsonPlusSerializer and adds support for loading pickled
    checkpoints. If the input data starts with b"\x80" and ends with b".", it is treated
    as a pickled checkpoint and loaded using pickle.loads(). Otherwise, the default
    JsonPlusSerializer behavior is used.

    Examples:
        >>> import pickle
        >>> from langgraph.checkpoint.sqlite import JsonPlusSerializerCompat
        >>>
        >>> serializer = JsonPlusSerializerCompat()
        >>> pickled_data = pickle.dumps({"key": "value"})
        >>> loaded_data = serializer.loads(pickled_data)
        >>> print(loaded_data)  # Output: {"key": "value"}
        >>>
        >>> json_data = '{"key": "value"}'.encode("utf-8")
        >>> loaded_data = serializer.loads(json_data)
        >>> print(loaded_data)  # Output: {"key": "value"}
    """

    def loads(self, data: bytes) -> Any:
        if data.startswith(b"\x80") and data.endswith(b"."):
            return pickle.loads(data)
        return super().loads(data)


class MongoDBSaver(AbstractContextManager, BaseCheckpointSaver):
    """A checkpoint saver that stores checkpoints in a MongoDB database.

    Args:
        client (AsyncIOMotorClient): The Async MongoDB client.
        db_name (str): The name of the database to use.
        collection_name (str): The name of the collection to use.
        serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat.

    Examples:

        >>> from motor.motor_asyncio import AsyncIOMotorClient
        >>> from langgraph.checkpoint.mongodb import MongoDBSaver
        >>> from langgraph.graph import StateGraph
        >>>
        >>> builder = StateGraph(int)
        >>> builder.add_node("add_one", lambda x: x + 1)
        >>> builder.set_entry_point("add_one")
        >>> builder.set_finish_point("add_one")
        >>> client = AsyncIOMotorClient("mongodb://localhost:27017/")
        >>> memory = MongoDBSaver(client, "checkpoints", "checkpoints")
        >>> graph = builder.compile(checkpointer=memory)
        >>> config = {"configurable": {"thread_id": "1"}}
        >>> result = graph.ainvoke(3, config)
    """

    serde = JsonPlusSerializerCompat()

    client: AsyncIOMotorClient
    db_name: str
    collection_name: str

    def __init__(
        self,
        client: AsyncIOMotorClient,
        db_name: str,
        collection_name: str,
        *,
        serde: Optional[SerializerProtocol] = None,
    ) -> None:
        super().__init__(serde=serde)
        self.client = client
        self.db_name = db_name
        self.collection_name = collection_name
        self.collection = client[db_name][collection_name]

    def __enter__(self) -> Self:
        return self

    def __exit__(
        self,
        __exc_type: Optional[type[BaseException]],
        __exc_value: Optional[BaseException],
        __traceback: Optional[TracebackType],
    ) -> Optional[bool]:
        return True

    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """Get a checkpoint tuple from the database.

        This method retrieves a checkpoint tuple from the MongoDB database based on the
        provided config. If the config contains a "thread_ts" key, the checkpoint with
        the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
        for the given thread ID is retrieved.

        Args:
            config (RunnableConfig): The config to use for retrieving the checkpoint.

        Returns:
            Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
        """
        if config["configurable"].get("thread_ts"):
            query = {
                "thread_id": config["configurable"]["thread_id"],
                "thread_ts": config["configurable"]["thread_ts"],
            }
        else:
            query = {"thread_id": config["configurable"]["thread_id"]}
        result = self.collection.find(query).sort("thread_ts", -1).limit(1)
        async for doc in result:
            return CheckpointTuple(
                config,
                self.serde.loads(doc["checkpoint"]),
                self.serde.loads(doc["metadata"]),
                (
                    {
                        "configurable": {
                            "thread_id": doc["thread_id"],
                            "thread_ts": doc["parent_ts"],
                        }
                    }
                    if doc.get("parent_ts")
                    else None
                ),
            )

    async def alist(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[Dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> AsyncIterator[CheckpointTuple]:
        """List checkpoints from the database.

        This method retrieves a list of checkpoint tuples from the MongoDB database based
        on the provided config. The checkpoints are ordered by timestamp in descending order.

        Args:
            config (RunnableConfig): The config to use for listing the checkpoints.
            before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None.
            limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.

        Yields:
            AsyncIterator[CheckpointTuple]: An Async iterator of checkpoint tuples.
        """
        query = {}
        if config is not None:
            query["thread_id"] = config["configurable"]["thread_id"]
        if filter:
            for key, value in filter.items():
                query[f"metadata.{key}"] = value
        if before is not None:
            query["thread_ts"] = {"$lt": before["configurable"]["thread_ts"]}
        result = self.collection.find(query).sort("thread_ts", -1).limit(limit)
        if limit is not None:
            result = result.limit(limit)
        async for doc in result:
            yield CheckpointTuple(
                {
                    "configurable": {
                        "thread_id": doc["thread_id"],
                        "thread_ts": doc["thread_ts"],
                    }
                },
                self.serde.loads(doc["checkpoint"]),
                self.serde.loads(doc["metadata"]),
                (
                    {
                        "configurable": {
                            "thread_id": doc["thread_id"],
                            "thread_ts": doc["parent_ts"],
                        }
                    }
                    if doc.get("parent_ts")
                    else None
                ),
            )

    async def aput(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        """Save a checkpoint to the database.

        This method saves a checkpoint to the MongoDB database. The checkpoint is associated
        with the provided config and its parent config (if any).

        Args:
            config (RunnableConfig): The config to associate with the checkpoint.
            checkpoint (Checkpoint): The checkpoint to save.
            metadata (Optional[dict[str, Any]]): Additional metadata to save with the checkpoint. Defaults to None.

        Returns:
            RunnableConfig: The updated config containing the saved checkpoint's timestamp.
        """
        doc = {
            "thread_id": config["configurable"]["thread_id"],
            "thread_ts": checkpoint["id"],
            "checkpoint": self.serde.dumps(checkpoint),
            "metadata": self.serde.dumps(metadata),
        }
        if config["configurable"].get("thread_ts"):
            doc["parent_ts"] = config["configurable"]["thread_ts"]
        await self.collection.insert_one(doc)
        return {
            "configurable": {
                "thread_id": config["configurable"]["thread_id"],
                "thread_ts": checkpoint["id"],
            }
        }

## Example with basic graph

In [5]:
from langgraph.graph import StateGraph, START

checkpointer = MongoDBSaver(
    AsyncIOMotorClient(MONGO_URI), "checkpoints_db", "checkpoints_collection"
)
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.add_edge(START, "add_one")
builder.add_edge("add_one", END)
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "123"}}
res = await graph.ainvoke(3, config)

In [6]:
res

4

In [9]:
await checkpointer.aget(config)

{'v': 1,
 'ts': '2024-07-10T11:34:28.485660+00:00',
 'id': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb',
 'channel_values': {'__root__': 4, 'add_one': 'add_one'},
 'channel_versions': {'__start__': 5,
  '__root__': 6,
  'start:add_one': 6,
  'add_one': 6},
 'versions_seen': {'__start__': {'__start__': 4},
  'add_one': {'start:add_one': 5}},
 'pending_sends': []}

In [10]:
await checkpointer.aget_tuple(config)

CheckpointTuple(config={'configurable': {'thread_id': '123'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.485660+00:00', 'id': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 5, '__root__': 6, 'start:add_one': 6, 'add_one': 6}, 'versions_seen': {'__start__': {'__start__': 4}, 'add_one': {'start:add_one': 5}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 4, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}})

In [12]:
list = checkpointer.alist(config, limit=3)
async for item in list:
    print(item)

CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.485660+00:00', 'id': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 5, '__root__': 6, 'start:add_one': 6, 'add_one': 6}, 'versions_seen': {'__start__': {'__start__': 4}, 'add_one': {'start:add_one': 5}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 4, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}})
CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.477660+00:00', 'id': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1', 'channel_values': {'__root__': 3, 'start:add_one': '__start__'}, 'channel_versions': {'__start__': 5, '__root__': 5, 'start:add_on

## Checkpoints saved in MongoDB

In [14]:
from pymongo import MongoClient

client = MongoClient(MONGO_URI)
database = client["checkpoints_db"]
collection = database["checkpoints_collection"]

for doc in collection.find():
    print(doc)

{'_id': ObjectId('668e57930f55bbe62f358531'), 'thread_id': '123', 'thread_ts': '1ef3ea0c-18a5-67a6-bfff-0d85b77e4a09', 'checkpoint': b'{"v": 1, "ts": "2024-07-10T09:42:43.453328+00:00", "id": "1ef3ea0c-18a5-67a6-bfff-0d85b77e4a09", "channel_values": {"__start__": 3}, "channel_versions": {"__start__": 1}, "versions_seen": {}, "pending_sends": []}', 'metadata': b'{"source": "input", "step": -1, "writes": 3}'}
{'_id': ObjectId('668e57930f55bbe62f358532'), 'thread_id': '123', 'thread_ts': '1ef3ea0c-18a7-6ea3-8000-9a52ba553d0c', 'checkpoint': b'{"v": 1, "ts": "2024-07-10T09:42:43.454326+00:00", "id": "1ef3ea0c-18a7-6ea3-8000-9a52ba553d0c", "channel_values": {"__root__": 3, "start:add_one": "__start__"}, "channel_versions": {"__start__": 2, "__root__": 2, "start:add_one": 2}, "versions_seen": {"__start__": {"__start__": 1}, "add_one": {}}, "pending_sends": []}', 'metadata': b'{"source": "loop", "step": 0, "writes": null}', 'parent_ts': '1ef3ea0c-18a5-67a6-bfff-0d85b77e4a09'}
{'_id': ObjectId