From 1dfc8d755b47f99e57e6148bcd115c1790faf672 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 17 Nov 2025 10:43:47 -0500 Subject: [PATCH 1/3] INTPYTHON-725 - Remove deprecated AsyncMongoDBSaver from langgraph-checkpoint-mongodb --- libs/langgraph-checkpoint-mongodb/README.md | 10 +- .../langgraph/checkpoint/mongodb/__init__.py | 3 +- .../langgraph/checkpoint/mongodb/aio.py | 561 ------------------ .../langgraph/checkpoint/mongodb/saver.py | 24 +- .../integration_tests/test_highlevel_graph.py | 29 +- .../tests/unit_tests/test_async.py | 73 +-- .../tests/unit_tests/test_delete_thread.py | 45 +- 7 files changed, 62 insertions(+), 683 deletions(-) delete mode 100644 libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py diff --git a/libs/langgraph-checkpoint-mongodb/README.md b/libs/langgraph-checkpoint-mongodb/README.md index fb270115..eb25eee1 100644 --- a/libs/langgraph-checkpoint-mongodb/README.md +++ b/libs/langgraph-checkpoint-mongodb/README.md @@ -53,9 +53,15 @@ with MongoDBSaver.from_conn_string(MONGODB_URI, DB_NAME) as checkpointer: ### Async ```python -from langgraph.checkpoint.pymongo import AsyncMongoDBSaver +from langgraph.checkpoint.mongodb import MongoDBSaver + +write_config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} +read_config = {"configurable": {"thread_id": "1"}} -async with AsyncMongoDBSaver.from_conn_string(MONGODB_URI) as checkpointer: +MONGODB_URI = "mongodb://localhost:27017" +DB_NAME = "checkpoint_example" + +with MongoDBSaver.from_conn_string(MONGODB_URI, DB_NAME) as checkpointer: checkpoint = { "v": 1, "ts": "2024-07-31T20:14:19.804150+00:00", diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/__init__.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/__init__.py index eb7c22d2..aa6f4701 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/__init__.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/__init__.py @@ -1,4 +1,3 @@ -from .aio import AsyncMongoDBSaver from .saver import MongoDBSaver -__all__ = ["MongoDBSaver", "AsyncMongoDBSaver"] +__all__ = ["MongoDBSaver"] diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py deleted file mode 100644 index 4fd12d1e..00000000 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py +++ /dev/null @@ -1,561 +0,0 @@ -from __future__ import annotations - -import asyncio -import warnings -from collections.abc import AsyncIterator, Iterator, Sequence -from contextlib import asynccontextmanager -from datetime import datetime -from typing import Any, Optional, cast - -from langchain_core.runnables import RunnableConfig -from langgraph.checkpoint.base import ( - WRITES_IDX_MAP, - BaseCheckpointSaver, - ChannelVersions, - Checkpoint, - CheckpointMetadata, - CheckpointTuple, - get_checkpoint_id, -) -from pymongo import UpdateOne -from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.mongo_client import AsyncMongoClient - -from .utils import ( - DRIVER_METADATA, - _append_client_metadata, - dumps_metadata, - loads_metadata, -) - -__all__ = ["AsyncMongoDBSaver"] - - -class AsyncMongoDBSaver(BaseCheckpointSaver): - """A checkpoint saver that stores checkpoints in a MongoDB database asynchronously. - - The synchronous MongoDBSaver has extended documentation, but - Asynchronous usage is shown below. - - Examples: - >>> import asyncio - >>> from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver - >>> from langgraph.graph import StateGraph - - >>> async def main() -> None: - >>> builder = StateGraph(int) - >>> builder.add_node("add_one", lambda x: x + 1) - >>> builder.set_entry_point("add_one") - >>> builder.set_finish_point("add_one") - >>> async with AsyncMongoDBSaver.from_conn_string("mongodb://localhost:27017") as memory: - >>> graph = builder.compile(checkpointer=memory) - >>> config = {"configurable": {"thread_id": "1"}} - >>> input = 3 - >>> output = await graph.ainvoke(input, config) - >>> print(f"{input=}, {output=}") - - >>> if __name__ == "__main__": - >>> asyncio.run(main()) - input=3, output=4 - """ - - client: AsyncMongoClient - db: AsyncDatabase - - def __init__( - self, - client: AsyncMongoClient, - db_name: str = "checkpointing_db", - checkpoint_collection_name: str = "checkpoints_aio", - writes_collection_name: str = "checkpoint_writes_aio", - ttl: Optional[int] = None, - **kwargs: Any, - ) -> None: - warnings.warn( - f"{self.__class__.__name__} is deprecated and will be removed in 0.3.0 release. " - "Please use the async methods of MongoDBSaver instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__() - self.client = client - self.db = self.client[db_name] - self.checkpoint_collection = self.db[checkpoint_collection_name] - self.writes_collection = self.db[writes_collection_name] - self._setup_future: asyncio.Future | None = None - self.loop = asyncio.get_running_loop() - self.ttl = ttl - - # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions - _append_client_metadata(self.client) - - async def _setup(self) -> None: - """Create indexes if not present.""" - if self._setup_future is not None: - return await self._setup_future - self._setup_future = asyncio.Future() - if isinstance(self.client, AsyncMongoClient): - num_indexes = len( - await (await self.checkpoint_collection.list_indexes()).to_list() - ) - else: - num_indexes = len(await self.checkpoint_collection.list_indexes().to_list()) - if num_indexes < 2: - await self.checkpoint_collection.create_index( - keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)], - unique=True, - ) - if self.ttl: - await self.checkpoint_collection.create_index( - keys=[("created_at", 1)], - expireAfterSeconds=self.ttl, - ) - if isinstance(self.client, AsyncMongoClient): - num_indexes = len( - await (await self.writes_collection.list_indexes()).to_list() - ) - else: - num_indexes = len(await self.writes_collection.list_indexes().to_list()) - if num_indexes < 2: - await self.writes_collection.create_index( - keys=[ - ("thread_id", 1), - ("checkpoint_ns", 1), - ("checkpoint_id", -1), - ("task_id", 1), - ("idx", 1), - ], - unique=True, - ) - if self.ttl: - await self.writes_collection.create_index( - keys=[("created_at", 1)], - expireAfterSeconds=self.ttl, - ) - self._setup_future.set_result(None) - - @classmethod - @asynccontextmanager - async def from_conn_string( - cls, - conn_string: str, - db_name: str = "checkpointing_db", - checkpoint_collection_name: str = "checkpoints_aio", - writes_collection_name: str = "checkpoint_writes_aio", - ttl: Optional[int] = None, - **kwargs: Any, - ) -> AsyncIterator[AsyncMongoDBSaver]: - """Create asynchronous checkpointer - - This includes creation of collections and indexes if they don't exist - """ - client: Optional[AsyncMongoClient] = None - try: - client = AsyncMongoClient( - conn_string, - driver=DRIVER_METADATA, - ) - saver = AsyncMongoDBSaver( - client, - db_name, - checkpoint_collection_name, - writes_collection_name, - ttl, - **kwargs, - ) - await saver._setup() - yield saver - - finally: - if client: - await client.close() - - async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Get a checkpoint tuple from the database asynchronously. - - This method retrieves a checkpoint tuple from the MongoDB database based on the - provided config. If the config contains a "checkpoint_id" key, the checkpoint with - the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint - for the given thread ID is retrieved. - - Args: - config (RunnableConfig): The config to use for retrieving the checkpoint. - - Returns: - Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. - """ - await self._setup() - thread_id = config["configurable"]["thread_id"] - checkpoint_ns = config["configurable"].get("checkpoint_ns", "") - if checkpoint_id := get_checkpoint_id(config): - query = { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint_id, - } - else: - query = {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns} - - result = self.checkpoint_collection.find( - query, sort=[("checkpoint_id", -1)], limit=1 - ) - async for doc in result: - config_values = { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": doc["checkpoint_id"], - } - checkpoint = self.serde.loads_typed((doc["type"], doc["checkpoint"])) - serialized_writes = self.writes_collection.find(config_values) - pending_writes = [ - ( - wrt["task_id"], - wrt["channel"], - self.serde.loads_typed((wrt["type"], wrt["value"])), - ) - async for wrt in serialized_writes - ] - return CheckpointTuple( - {"configurable": config_values}, - checkpoint, - loads_metadata(doc["metadata"]), - ( - { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": doc["parent_checkpoint_id"], - } - } - if doc.get("parent_checkpoint_id") - else None - ), - pending_writes, - ) - - async def alist( - self, - config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None, - ) -> AsyncIterator[CheckpointTuple]: - """List checkpoints from the database asynchronously. - - This method retrieves a list of checkpoint tuples from the MongoDB database based - on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). - - Args: - config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. - filter (Optional[dict[str, Any]]): Additional filtering criteria for metadata. - before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. - limit (Optional[int]): Maximum number of checkpoints to return. - - Yields: - AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples. - """ - await self._setup() - query = {} - if config is not None: - if "thread_id" in config["configurable"]: - query["thread_id"] = config["configurable"]["thread_id"] - if "checkpoint_ns" in config["configurable"]: - query["checkpoint_ns"] = config["configurable"]["checkpoint_ns"] - - if filter: - for key, value in filter.items(): - query[f"metadata.{key}"] = dumps_metadata(value) - - if before is not None: - query["checkpoint_id"] = {"$lt": before["configurable"]["checkpoint_id"]} - - result = self.checkpoint_collection.find( - query, limit=0 if limit is None else limit, sort=[("checkpoint_id", -1)] - ) - - async for doc in result: - config_values = { - "thread_id": doc["thread_id"], - "checkpoint_ns": doc["checkpoint_ns"], - "checkpoint_id": doc["checkpoint_id"], - } - serialized_writes = self.writes_collection.find(config_values) - pending_writes = [ - ( - wrt["task_id"], - wrt["channel"], - self.serde.loads_typed((wrt["type"], wrt["value"])), - ) - async for wrt in serialized_writes - ] - - yield CheckpointTuple( - config={ - "configurable": { - "thread_id": doc["thread_id"], - "checkpoint_ns": doc["checkpoint_ns"], - "checkpoint_id": doc["checkpoint_id"], - } - }, - checkpoint=self.serde.loads_typed((doc["type"], doc["checkpoint"])), - metadata=loads_metadata(doc["metadata"]), - parent_config=( - { - "configurable": { - "thread_id": doc["thread_id"], - "checkpoint_ns": doc["checkpoint_ns"], - "checkpoint_id": doc["parent_checkpoint_id"], - } - } - if doc.get("parent_checkpoint_id") - else None - ), - pending_writes=pending_writes, - ) - - async def aput( - self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions, - ) -> RunnableConfig: - """Save a checkpoint to the database asynchronously. - - This method saves a checkpoint to the MongoDB database. The checkpoint is associated - with the provided config and its parent config (if any). - - Args: - config (RunnableConfig): The config to associate with the checkpoint. - checkpoint (Checkpoint): The checkpoint to save. - metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. - new_versions (ChannelVersions): New channel versions as of this write. - - Returns: - RunnableConfig: Updated configuration after storing the checkpoint. - """ - await self._setup() - thread_id = config["configurable"]["thread_id"] - checkpoint_ns = config["configurable"]["checkpoint_ns"] - checkpoint_id = checkpoint["id"] - type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint) - metadata = metadata.copy() - metadata.update(config.get("metadata", {})) - doc: dict[str, Any] = { - "parent_checkpoint_id": config["configurable"].get("checkpoint_id"), - "type": type_, - "checkpoint": serialized_checkpoint, - "metadata": dumps_metadata(metadata), - } - upsert_query = { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint_id, - } - - if self.ttl: - doc["created_at"] = datetime.now() - # Perform your operations here - await self.checkpoint_collection.update_one( - upsert_query, {"$set": doc}, upsert=True - ) - return { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint_id, - } - } - - async def aput_writes( - self, - config: RunnableConfig, - writes: Sequence[tuple[str, Any]], - task_id: str, - task_path: str = "", - ) -> None: - """Store intermediate writes linked to a checkpoint asynchronously. - - This method saves intermediate writes associated with a checkpoint to the database. - - Args: - config (RunnableConfig): Configuration of the related checkpoint. - writes (Sequence[tuple[str, Any]]): List of writes to store, each as (channel, value) pair. - task_id (str): Identifier for the task creating the writes. - task_path (str): Path of the task creating the writes. - """ - await self._setup() - thread_id = config["configurable"]["thread_id"] - checkpoint_ns = config["configurable"]["checkpoint_ns"] - checkpoint_id = config["configurable"]["checkpoint_id"] - set_method = ( # Allow replacement on existing writes only if there were errors. - "$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert" - ) - operations = [] - now = datetime.now() - for idx, (channel, value) in enumerate(writes): - upsert_query = { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint_id, - "task_id": task_id, - "task_path": task_path, - "idx": WRITES_IDX_MAP.get(channel, idx), - } - - type_, serialized_value = self.serde.dumps_typed(value) - - update_doc: dict[str, Any] = { - "channel": channel, - "type": type_, - "value": serialized_value, - } - - if self.ttl: - update_doc["created_at"] = now - - operations.append( - UpdateOne( - filter=upsert_query, - update={set_method: update_doc}, - upsert=True, - ) - ) - await self.writes_collection.bulk_write(operations) - - async def adelete_thread( - self, - thread_id: str, - ) -> None: - """Delete all checkpoints and writes associated with a specific thread ID asynchronously. - - Args: - thread_id (str): The thread ID whose checkpoints should be deleted. - """ - # Delete all checkpoints associated with the thread ID - await self.checkpoint_collection.delete_many({"thread_id": thread_id}) - - # Delete all writes associated with the thread ID - await self.writes_collection.delete_many({"thread_id": thread_id}) - - def delete_thread( - self, - thread_id: str, - ) -> None: - """Delete all checkpoints and writes associated with a specific thread ID. - - Args: - thread_id (str): The thread ID whose checkpoints should be deleted. - """ - return asyncio.run_coroutine_threadsafe( - self.adelete_thread(thread_id), self.loop - ).result() - - def list( - self, - config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None, - ) -> Iterator[CheckpointTuple]: - """List checkpoints from the database. - - This method retrieves a list of checkpoint tuples from the MongoDB database - based on the provided config. The checkpoints are ordered by checkpoint ID in - descending order (newest first). - - Args: - config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. - filter (Optional[dict[str, Any]]): Additional filtering criteria for metadata. - before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. - limit (Optional[int]): Maximum number of checkpoints to return. - - Yields: - Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. - """ - aiter_ = self.alist(config, filter=filter, before=before, limit=limit) - while True: - try: - yield asyncio.run_coroutine_threadsafe( - cast(Any, anext(aiter_)), - self.loop, - ).result() - except StopAsyncIteration: - break - - def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Get a checkpoint tuple from the database. - - This method retrieves a checkpoint tuple from the MongoDB database based on - the provided config. If the config contains a "checkpoint_id" key, the - checkpoint with the matching thread ID and "checkpoint_id" is retrieved. - Otherwise, the latest checkpoint for the given thread ID is retrieved. - - Args: - config (RunnableConfig): The config to use for retrieving the checkpoint. - - Returns: - Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. - """ - try: - # check if we are in the main thread, only bg threads can block - # we don't check in other methods to avoid the overhead - if asyncio.get_running_loop() is self.loop: - raise asyncio.InvalidStateError( - "Synchronous calls to AsyncMongoDBSaver are only allowed from a " - "different thread. From the main thread, use the async interface." - "For example, use `await checkpointer.aget_tuple(...)` or `await " - "graph.ainvoke(...)`." - ) - except RuntimeError: - pass - return asyncio.run_coroutine_threadsafe( - self.aget_tuple(config), self.loop - ).result() - - def put( - self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions, - ) -> RunnableConfig: - """Save a checkpoint to the database. - - This method saves a checkpoint to the MongoDB database. The checkpoint - is associated with the provided config and its parent config (if any). - - Args: - config (RunnableConfig): The config to associate with the checkpoint. - checkpoint (Checkpoint): The checkpoint to save. - metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. - new_versions (ChannelVersions): New channel versions as of this write. - - Returns: - RunnableConfig: Updated configuration after storing the checkpoint. - """ - return asyncio.run_coroutine_threadsafe( - self.aput(config, checkpoint, metadata, new_versions), self.loop - ).result() - - def put_writes( - self, - config: RunnableConfig, - writes: Sequence[tuple[str, Any]], - task_id: str, - task_path: str = "", - ) -> None: - """Store intermediate writes linked to a checkpoint. - - This method saves intermediate writes associated with a checkpoint to the database. - - Args: - config (RunnableConfig): Configuration of the related checkpoint. - writes (Sequence[tuple[str, Any]]): List of writes to store, each as (channel, value) pair. - task_id (str): Identifier for the task creating the writes. - """ - return asyncio.run_coroutine_threadsafe( - self.aput_writes(config, writes, task_id, task_path), self.loop - ).result() diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py index fcf77289..e07c5d00 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py @@ -470,20 +470,6 @@ def delete_thread( # Delete all writes associated with the thread ID self.writes_collection.delete_many({"thread_id": thread_id}) - async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Asynchronously fetch a checkpoint tuple using the given configuration. - - Asynchronously wraps the blocking `self.get_tuple` method. - - Args: - config: Configuration specifying which checkpoint to retrieve. - - Returns: - Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. - - """ - return await run_in_executor(None, self.get_tuple, config) - async def alist( self, config: Optional[RunnableConfig], @@ -586,3 +572,13 @@ async def adelete_thread( thread_id: The thread ID whose checkpoints should be deleted. """ return await run_in_executor(None, self.delete_thread, thread_id) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + Asynchronously wraps the blocking `self.get_tuple` method. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + """ + return await run_in_executor(None, self.get_tuple, config) diff --git a/libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_highlevel_graph.py b/libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_highlevel_graph.py index 38957f09..cbb46953 100644 --- a/libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_highlevel_graph.py +++ b/libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_highlevel_graph.py @@ -15,7 +15,7 @@ import operator import os import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import Generator from typing import Annotated import pytest @@ -26,7 +26,7 @@ from langgraph.graph import END, StateGraph from typing_extensions import TypedDict -from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver +from langgraph.checkpoint.mongodb import MongoDBSaver # --- Configuration --- MONGODB_URI = os.environ.get( @@ -119,21 +119,6 @@ def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]: checkpointer.writes_collection.drop() -@pytest.fixture(scope="function") -async def checkpointer_mongodb_async() -> AsyncGenerator[AsyncMongoDBSaver, None]: - async with AsyncMongoDBSaver.from_conn_string( - MONGODB_URI, - db_name=DB_NAME, - checkpoint_collection_name=CHECKPOINT_CLXN_NAME + "_async", - writes_collection_name=WRITES_CLXN_NAME + "_async", - ) as checkpointer: - await checkpointer.checkpoint_collection.delete_many({}) - await checkpointer.writes_collection.delete_many({}) - yield checkpointer - await checkpointer.checkpoint_collection.drop() - await checkpointer.writes_collection.drop() - - @pytest.fixture(autouse=True) def disable_langsmith() -> None: """Disable LangSmith tracing for all tests""" @@ -144,12 +129,10 @@ def disable_langsmith() -> None: async def test_fanout( joke_subjects: OverallState, checkpointer_mongodb: MongoDBSaver, - checkpointer_mongodb_async: AsyncMongoDBSaver, checkpointer_memory: InMemorySaver, ) -> None: checkpointers = { "mongodb": checkpointer_mongodb, - "mongodb_async": checkpointer_mongodb_async, "in_memory": checkpointer_memory, "in_memory_async": checkpointer_memory, } @@ -178,9 +161,7 @@ async def test_fanout( print(f"{cname}: {end - start:.4f} seconds") -async def test_custom_properties_async( - checkpointer_mongodb: MongoDBSaver, checkpointer_mongodb_async: AsyncMongoDBSaver -) -> None: +async def test_custom_properties_async(checkpointer_mongodb: MongoDBSaver) -> None: # Create the state graph state_graph = fanout_to_subgraph() @@ -196,7 +177,7 @@ async def test_custom_properties_async( } # Compile the state graph with the provided checkpointing mechanism - compiled_state_graph = state_graph.compile(checkpointer=checkpointer_mongodb_async) + compiled_state_graph = state_graph.compile(checkpointer=checkpointer_mongodb) # Invoke the compiled state graph with user input await compiled_state_graph.ainvoke( @@ -206,7 +187,7 @@ async def test_custom_properties_async( debug=False, ) - checkpoint_tuple = await checkpointer_mongodb_async.aget_tuple(config) + checkpoint_tuple = await checkpointer_mongodb.aget_tuple(config) assert checkpoint_tuple is not None assert checkpoint_tuple.metadata["user_id"] == user_id assert checkpoint_tuple.metadata["assistant_id"] == assistant_id diff --git a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py index 1d92ce70..298a957d 100644 --- a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py +++ b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py @@ -1,13 +1,13 @@ import os from collections.abc import AsyncGenerator -from typing import Any, Union +from typing import Any import pytest import pytest_asyncio from bson.errors import InvalidDocument -from pymongo import AsyncMongoClient, MongoClient +from pymongo import MongoClient -from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver +from langgraph.checkpoint.mongodb import MongoDBSaver MONGODB_URI = os.environ.get( "MONGODB_URI", "mongodb://localhost:27017/?directConnection=true" @@ -16,50 +16,35 @@ COLLECTION_NAME = "sync_checkpoints_aio" -@pytest_asyncio.fixture(params=["run_in_executor", "aio"]) -async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator: - if request.param == "aio": - # Use async client and checkpointer - aclient: AsyncMongoClient = AsyncMongoClient(MONGODB_URI) - adb = aclient[DB_NAME] - for clxn in await adb.list_collection_names(): - await adb.drop_collection(clxn) - async with AsyncMongoDBSaver.from_conn_string( - MONGODB_URI, DB_NAME, COLLECTION_NAME - ) as checkpointer: - yield checkpointer - await aclient.close() - else: - # Use sync client and checkpointer with async methods run in executor - client: MongoClient = MongoClient(MONGODB_URI) - db = client[DB_NAME] - for clxn in db.list_collection_names(): - db.drop_collection(clxn) - with MongoDBSaver.from_conn_string( - MONGODB_URI, DB_NAME, COLLECTION_NAME - ) as checkpointer: - yield checkpointer - client.close() +@pytest_asyncio.fixture +async def saver(request: pytest.FixtureRequest) -> AsyncGenerator: + client: MongoClient = MongoClient(MONGODB_URI) + db = client[DB_NAME] + for clxn in db.list_collection_names(): + db.drop_collection(clxn) + with MongoDBSaver.from_conn_string( + MONGODB_URI, DB_NAME, COLLECTION_NAME + ) as checkpointer: + yield checkpointer + client.close() @pytest.mark.asyncio -async def test_asearch( - input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver] -) -> None: +async def test_asearch(input_data: dict[str, Any], saver: MongoDBSaver) -> None: # save checkpoints - await async_saver.aput( + await saver.aput( input_data["config_1"], input_data["chkpnt_1"], input_data["metadata_1"], {}, ) - await async_saver.aput( + await saver.aput( input_data["config_2"], input_data["chkpnt_2"], input_data["metadata_2"], {}, ) - await async_saver.aput( + await saver.aput( input_data["config_3"], input_data["chkpnt_3"], input_data["metadata_3"], @@ -75,23 +60,23 @@ async def test_asearch( query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match - search_results_1 = [c async for c in async_saver.alist(None, filter=query_1)] + search_results_1 = [c async for c in saver.alist(None, filter=query_1)] assert len(search_results_1) == 1 assert search_results_1[0].metadata == input_data["metadata_1"] - search_results_2 = [c async for c in async_saver.alist(None, filter=query_2)] + search_results_2 = [c async for c in saver.alist(None, filter=query_2)] assert len(search_results_2) == 1 assert search_results_2[0].metadata == input_data["metadata_2"] - search_results_3 = [c async for c in async_saver.alist(None, filter=query_3)] + search_results_3 = [c async for c in saver.alist(None, filter=query_3)] assert len(search_results_3) == 3 - search_results_4 = [c async for c in async_saver.alist(None, filter=query_4)] + search_results_4 = [c async for c in saver.alist(None, filter=query_4)] assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = [ - c async for c in async_saver.alist({"configurable": {"thread_id": "thread-2"}}) + c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}}) ] assert len(search_results_5) == 2 assert { @@ -101,29 +86,27 @@ async def test_asearch( @pytest.mark.asyncio -async def test_null_chars( - input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver] -) -> None: +async def test_null_chars(input_data: dict[str, Any], saver: MongoDBSaver) -> None: """In MongoDB string *values* can be any valid UTF-8 including nulls. *Field names*, however, cannot contain nulls characters.""" null_str = "\x00abc" # string containing null character # 1. null string in field *value* - null_value_cfg = await async_saver.aput( + null_value_cfg = await saver.aput( input_data["config_1"], input_data["chkpnt_1"], {"my_key": null_str}, {}, ) - null_tuple = await async_saver.aget_tuple(null_value_cfg) + null_tuple = await saver.aget_tuple(null_value_cfg) assert null_tuple.metadata["my_key"] == null_str # type: ignore - cps = [c async for c in async_saver.alist(None, filter={"my_key": null_str})] + cps = [c async for c in saver.alist(None, filter={"my_key": null_str})] assert cps[0].metadata["my_key"] == null_str # 2. null string in field *name* with pytest.raises(InvalidDocument): - await async_saver.aput( + await saver.aput( input_data["config_1"], input_data["chkpnt_1"], {null_str: "my_value"}, # type: ignore diff --git a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_delete_thread.py b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_delete_thread.py index f99b59b6..525a9461 100644 --- a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_delete_thread.py +++ b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_delete_thread.py @@ -5,7 +5,6 @@ from pymongo import MongoClient from langgraph.checkpoint.mongodb import MongoDBSaver -from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver # Setup: MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017") @@ -99,7 +98,7 @@ async def test_adelete_thread() -> None: db[CHKPT_COLLECTION_NAME].delete_many({}) db[WRITES_COLLECTION_NAME].delete_many({}) - async with AsyncMongoDBSaver.from_conn_string( + with MongoDBSaver.from_conn_string( MONGODB_URI, DB_NAME, CHKPT_COLLECTION_NAME, WRITES_COLLECTION_NAME ) as saver: # Thread 1 data @@ -130,7 +129,7 @@ async def test_adelete_thread() -> None: "writes": {"baz": "qux"}, } - assert await saver.checkpoint_collection.count_documents({}) == 0 + assert saver.checkpoint_collection.count_documents({}) == 0 # Save checkpoints for both threads await saver.aput(config_1, chkpnt_1, metadata_1, {}) @@ -146,25 +145,13 @@ async def test_adelete_thread() -> None: # Verify we have write data assert ( - await saver.checkpoint_collection.count_documents( - {"thread_id": thread_1_id} - ) - > 0 - ) - assert ( - await saver.writes_collection.count_documents({"thread_id": thread_1_id}) - > 0 - ) - assert ( - await saver.checkpoint_collection.count_documents( - {"thread_id": thread_2_id} - ) - > 0 + saver.checkpoint_collection.count_documents({"thread_id": thread_1_id}) > 0 ) + assert saver.writes_collection.count_documents({"thread_id": thread_1_id}) > 0 assert ( - await saver.writes_collection.count_documents({"thread_id": thread_2_id}) - > 0 + saver.checkpoint_collection.count_documents({"thread_id": thread_2_id}) > 0 ) + assert saver.writes_collection.count_documents({"thread_id": thread_2_id}) > 0 # Delete thread 1 await saver.adelete_thread(thread_1_id) @@ -172,25 +159,13 @@ async def test_adelete_thread() -> None: # Verify thread 1 data is gone assert await saver.aget_tuple(config_1) is None assert ( - await saver.checkpoint_collection.count_documents( - {"thread_id": thread_1_id} - ) - == 0 - ) - assert ( - await saver.writes_collection.count_documents({"thread_id": thread_1_id}) - == 0 + saver.checkpoint_collection.count_documents({"thread_id": thread_1_id}) == 0 ) + assert saver.writes_collection.count_documents({"thread_id": thread_1_id}) == 0 # Verify thread 2 data still exists assert await saver.aget_tuple(config_2) is not None assert ( - await saver.checkpoint_collection.count_documents( - {"thread_id": thread_2_id} - ) - > 0 - ) - assert ( - await saver.writes_collection.count_documents({"thread_id": thread_2_id}) - > 0 + saver.checkpoint_collection.count_documents({"thread_id": thread_2_id}) > 0 ) + assert saver.writes_collection.count_documents({"thread_id": thread_2_id}) > 0 From cd076020d69db9550bc906a19a77f24b92e7b20e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 17 Nov 2025 10:46:00 -0500 Subject: [PATCH 2/3] Fix ordering --- .../langgraph/checkpoint/mongodb/saver.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py index e07c5d00..dae081a0 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py @@ -470,6 +470,16 @@ def delete_thread( # Delete all writes associated with the thread ID self.writes_collection.delete_many({"thread_id": thread_id}) + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + Asynchronously wraps the blocking `self.get_tuple` method. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + """ + return await run_in_executor(None, self.get_tuple, config) + async def alist( self, config: Optional[RunnableConfig], @@ -572,13 +582,3 @@ async def adelete_thread( thread_id: The thread ID whose checkpoints should be deleted. """ return await run_in_executor(None, self.delete_thread, thread_id) - - async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Get a checkpoint tuple from the database. - - Asynchronously wraps the blocking `self.get_tuple` method. - - Args: - config (RunnableConfig): The config to use for retrieving the checkpoint. - """ - return await run_in_executor(None, self.get_tuple, config) From f47fd9442cd77ef3046b25cda33f5225a7e03e78 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 17 Nov 2025 10:46:45 -0500 Subject: [PATCH 3/3] Fix docstring --- .../langgraph/checkpoint/mongodb/saver.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py index dae081a0..fcf77289 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py @@ -471,12 +471,16 @@ def delete_thread( self.writes_collection.delete_many({"thread_id": thread_id}) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Get a checkpoint tuple from the database. + """Asynchronously fetch a checkpoint tuple using the given configuration. Asynchronously wraps the blocking `self.get_tuple` method. - Args: - config (RunnableConfig): The config to use for retrieving the checkpoint. + Args: + config: Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + """ return await run_in_executor(None, self.get_tuple, config)