Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/langgraph-checkpoint-mongodb/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

---

## Changes in version 0.2.0 (TBD)

- Implements async methods of MongoDBSaver.
- Deprecates ASyncMongoDBSaver, to be removed in 0.3.0

## Changes in version 0.1.4 (2025/06/13)

- Add TTL (time-to-live) indexes for automatic deletion of old checkpoints and writes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import builtins
import sys
import warnings
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager
from datetime import datetime
Expand Down Expand Up @@ -84,6 +85,12 @@ def __init__(
ttl: Optional[int] = None,
**kwargs: Any,
) -> None:
warnings.warn(
f"{self.__class__.__name__} is deprecated and will be removed in 0.3.0 release. "
"Please use the async methods of MongoDBSaver instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.client = client
self.db = self.client[db_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterator, Sequence
import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import contextmanager
from datetime import datetime
from importlib.metadata import version
Expand All @@ -7,7 +8,7 @@
Optional,
)

from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import RunnableConfig, run_in_executor
from pymongo import ASCENDING, MongoClient, UpdateOne
from pymongo.database import Database as MongoDatabase
from pymongo.driver_info import DriverInfo
Expand Down Expand Up @@ -468,3 +469,120 @@ def delete_thread(

# Delete all writes associated with the thread ID
self.writes_collection.delete_many({"thread_id": thread_id})

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronously fetch a checkpoint tuple using the given configuration.

Asynchronously wraps the blocking `self.get_tuple` method.

Args:
config: Configuration specifying which checkpoint to retrieve.

Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.

"""
return await run_in_executor(None, self.get_tuple, config)

async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.

Asynchronously wraps the blocking `self.list` generator.

Runs `self.list(...)` in a background thread and yields its items
asynchronously from an asyncio.Queue. This allows integration of
synchronous iterators into async code.

Args:
config: Configuration object passed to `self.list`.
filter: Optional filter dictionary.
before: Optional parameter to limit results before a given checkpoint.
limit: Optional maximum number of results to yield.

Yields:
AsyncIterator[CheckpointTuple]: An iterator of checkpoint tuples.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue[CheckpointTuple] = asyncio.Queue()
sentinel = object()

def run() -> None:
try:
for item in self.list(
config, filter=filter, before=before, limit=limit
):
loop.call_soon_threadsafe(queue.put_nowait, item)
finally:
loop.call_soon_threadsafe(queue.put_nowait, sentinel) # type: ignore

await run_in_executor(None, run)
while True:
item = await queue.get()
if item is sentinel:
break
yield item

async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.

Asynchronously wraps the blocking `self.put` method.

Args:
config: Configuration for the checkpoint.
checkpoint: The checkpoint to store.
metadata: Additional metadata for the checkpoint.
new_versions: New channel versions as of this write.

Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
return await run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
)

async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.

Asynchronously wraps the blocking `self.put_writes` method.

Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
"""
return await run_in_executor(
None, self.put_writes, config, writes, task_id, task_path
)

async def adelete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.

Asynchronously wraps the blocking `self.delete_thread` method.

Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
return await run_in_executor(None, self.delete_thread, thread_id)
1 change: 1 addition & 0 deletions libs/langgraph-checkpoint-mongodb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv"
markers = [
"requires: mark tests as requiring a specific library",
"compile: mark placeholder test used to compile integration tests without running them",
"asyncio: mark a test as asyncio",
]
asyncio_mode = "auto"

Expand Down
204 changes: 114 additions & 90 deletions libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,131 @@
import os
from typing import Any
from collections.abc import AsyncGenerator
from typing import Any, Union

import pytest
import pytest_asyncio
from bson.errors import InvalidDocument
from pymongo import AsyncMongoClient
from pymongo import AsyncMongoClient, MongoClient

from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver

MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
MONGODB_URI = os.environ.get(
"MONGODB_URI", "mongodb://localhost:27017/?directConnection=true"
)
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
COLLECTION_NAME = "sync_checkpoints_aio"


async def test_asearch(input_data: dict[str, Any]) -> None:
# Clear collections if they exist
client: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
db = client[DB_NAME]

for clxn in await db.list_collection_names():
await db.drop_collection(clxn)

async with AsyncMongoDBSaver.from_conn_string(
MONGODB_URI, DB_NAME, COLLECTION_NAME
) as saver:
# save checkpoints
await saver.aput(
input_data["config_1"],
input_data["chkpnt_1"],
input_data["metadata_1"],
{},
)
await saver.aput(
input_data["config_2"],
input_data["chkpnt_2"],
input_data["metadata_2"],
{},
)
await saver.aput(
input_data["config_3"],
input_data["chkpnt_3"],
input_data["metadata_3"],
{},
)

# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == input_data["metadata_1"]

search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == input_data["metadata_2"]

search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3

search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0

# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}


async def test_null_chars(input_data: dict[str, Any]) -> None:
@pytest_asyncio.fixture(params=["run_in_executor", "aio"])
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
if request.param == "aio":
# Use async client and checkpointer
aclient: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
adb = aclient[DB_NAME]
for clxn in await adb.list_collection_names():
await adb.drop_collection(clxn)
async with AsyncMongoDBSaver.from_conn_string(
MONGODB_URI, DB_NAME, COLLECTION_NAME
) as checkpointer:
yield checkpointer
await aclient.close()
else:
# Use sync client and checkpointer with async methods run in executor
client: MongoClient = MongoClient(MONGODB_URI)
db = client[DB_NAME]
for clxn in db.list_collection_names():
db.drop_collection(clxn)
with MongoDBSaver.from_conn_string(
MONGODB_URI, DB_NAME, COLLECTION_NAME
) as checkpointer:
yield checkpointer
client.close()


@pytest.mark.asyncio
async def test_asearch(
input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver]
) -> None:
# save checkpoints
await async_saver.aput(
input_data["config_1"],
input_data["chkpnt_1"],
input_data["metadata_1"],
{},
)
await async_saver.aput(
input_data["config_2"],
input_data["chkpnt_2"],
input_data["metadata_2"],
{},
)
await async_saver.aput(
input_data["config_3"],
input_data["chkpnt_3"],
input_data["metadata_3"],
{},
)

# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = [c async for c in async_saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == input_data["metadata_1"]

search_results_2 = [c async for c in async_saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == input_data["metadata_2"]

search_results_3 = [c async for c in async_saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3

search_results_4 = [c async for c in async_saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0

# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c async for c in async_saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}


@pytest.mark.asyncio
async def test_null_chars(
input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver]
) -> None:
"""In MongoDB string *values* can be any valid UTF-8 including nulls.
*Field names*, however, cannot contain nulls characters."""
async with AsyncMongoDBSaver.from_conn_string(
MONGODB_URI, DB_NAME, COLLECTION_NAME
) as saver:
null_str = "\x00abc" # string containing null character

# 1. null string in field *value*
null_value_cfg = await saver.aput(
null_str = "\x00abc" # string containing null character

# 1. null string in field *value*
null_value_cfg = await async_saver.aput(
input_data["config_1"],
input_data["chkpnt_1"],
{"my_key": null_str},
{},
)
null_tuple = await async_saver.aget_tuple(null_value_cfg)
assert null_tuple.metadata["my_key"] == null_str # type: ignore
cps = [c async for c in async_saver.alist(None, filter={"my_key": null_str})]
assert cps[0].metadata["my_key"] == null_str

# 2. null string in field *name*
with pytest.raises(InvalidDocument):
await async_saver.aput(
input_data["config_1"],
input_data["chkpnt_1"],
{"my_key": null_str},
{null_str: "my_value"}, # type: ignore
{},
)
null_tuple = await saver.aget_tuple(null_value_cfg)
assert null_tuple.metadata["my_key"] == null_str # type: ignore
cps = [c async for c in saver.alist(None, filter={"my_key": null_str})]
assert cps[0].metadata["my_key"] == null_str

# 2. null string in field *name*
with pytest.raises(InvalidDocument):
await saver.aput(
input_data["config_1"],
input_data["chkpnt_1"],
{null_str: "my_value"}, # type: ignore
{},
)
Loading
Loading