diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py index fcf77289..8b8ebd37 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py @@ -17,6 +17,8 @@ CheckpointTuple, get_checkpoint_id, ) +from langgraph.checkpoint.serde.base import SerializerProtocol +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from pymongo import ASCENDING, MongoClient, UpdateOne from pymongo.database import Database as MongoDatabase @@ -81,6 +83,7 @@ def __init__( checkpoint_collection_name: str = "checkpoints", writes_collection_name: str = "checkpoint_writes", ttl: Optional[int] = None, + serde: SerializerProtocol | None = None, **kwargs: Any, ) -> None: super().__init__() @@ -89,6 +92,10 @@ def __init__( self.checkpoint_collection = self.db[checkpoint_collection_name] self.writes_collection = self.db[writes_collection_name] self.ttl = ttl + if serde is not None: + self.serde = serde + else: + self.serde = JsonPlusSerializer() # Create indexes if not present if len(self.checkpoint_collection.list_indexes().to_list()) < 2: @@ -236,7 +243,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: return CheckpointTuple( {"configurable": config_values}, checkpoint, - loads_metadata(doc["metadata"]), + loads_metadata(self.serde, doc["metadata"]), ( { "configurable": { @@ -291,7 +298,7 @@ def list( if filter: for key, value in filter.items(): - query[f"metadata.{key}"] = dumps_metadata(value) + query[f"metadata.{key}"] = dumps_metadata(self.serde, value) if before is not None: query["checkpoint_id"] = {"$lt": before["configurable"]["checkpoint_id"]} @@ -325,7 +332,7 @@ def list( } }, checkpoint=self.serde.loads_typed((doc["type"], doc["checkpoint"])), - metadata=loads_metadata(doc["metadata"]), + metadata=loads_metadata(self.serde, doc["metadata"]), parent_config=( { "configurable": { @@ -381,7 +388,7 @@ def put( "parent_checkpoint_id": config["configurable"].get("checkpoint_id"), "type": type_, "checkpoint": serialized_checkpoint, - "metadata": dumps_metadata(metadata), + "metadata": dumps_metadata(self.serde, metadata), } upsert_query = { "thread_id": thread_id, diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py index e19ac0ed..86ad0536 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py @@ -8,12 +8,9 @@ from langgraph.checkpoint.base import CheckpointMetadata from langgraph.checkpoint.serde.base import SerializerProtocol -from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from pymongo import AsyncMongoClient from pymongo.driver_info import DriverInfo -serde: SerializerProtocol = JsonPlusSerializer() - DRIVER_METADATA = DriverInfo( name="Langgraph", version=version("langgraph-checkpoint-mongodb") ) @@ -25,7 +22,9 @@ def _append_client_metadata(client: AsyncMongoClient) -> None: client.append_metadata(DRIVER_METADATA) -def loads_metadata(metadata: dict[str, Any]) -> CheckpointMetadata: +def loads_metadata( + serde: SerializerProtocol, metadata: dict[str, Any] +) -> CheckpointMetadata: """Deserialize metadata document The CheckpointMetadata class itself cannot be stored directly in MongoDB, @@ -38,13 +37,14 @@ def loads_metadata(metadata: dict[str, Any]) -> CheckpointMetadata: if isinstance(metadata, dict): output = dict() for key, value in metadata.items(): - output[key] = loads_metadata(value) + output[key] = loads_metadata(serde, value) return output else: return serde.loads_typed(metadata) def dumps_metadata( + serde: SerializerProtocol, metadata: Union[CheckpointMetadata, Any], ) -> Union[bytes, dict[str, Any]]: """Serialize all values in metadata dictionary. @@ -54,7 +54,7 @@ def dumps_metadata( if isinstance(metadata, dict): output = dict() for key, value in metadata.items(): - output[key] = dumps_metadata(value) + output[key] = dumps_metadata(serde, value) return output else: return serde.dumps_typed(metadata) diff --git a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_serde.py b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_serde.py new file mode 100644 index 00000000..f6799089 --- /dev/null +++ b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_serde.py @@ -0,0 +1,56 @@ +import os +from typing import Any + +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from pymongo import MongoClient + +from langgraph.checkpoint.mongodb import MongoDBSaver + +MONGODB_URI = os.environ.get( + "MONGODB_URI", "mongodb://localhost:27017/?directConnection=true" +) +DB_NAME = os.environ.get("DB_NAME", "langgraph-test") +COLLECTION_NAME = "serde_checkpoints" + + +class CustomSerializer(JsonPlusSerializer): + def __init__(self) -> None: + super().__init__() + self.dumps_called = False + self.loads_called = False + + def dumps_typed(self, obj: Any) -> tuple[str, bytes]: + self.dumps_called = True + return super().dumps_typed(obj) + + def loads_typed(self, obj: tuple[str, bytes]) -> Any: + self.loads_called = True + return super().loads_typed(obj) + + +def test_custom_serde(input_data: dict[str, Any]) -> None: + client: MongoClient = MongoClient(MONGODB_URI) + db = client[DB_NAME] + db.drop_collection(COLLECTION_NAME) + + custom_serializer = CustomSerializer() + + with MongoDBSaver.from_conn_string( + MONGODB_URI, DB_NAME, COLLECTION_NAME, serde=custom_serializer + ) as saver: + put_config = saver.put( + input_data["config_1"], + input_data["chkpnt_1"], + input_data["metadata_1"], + {}, + ) + + assert custom_serializer.dumps_called + + retrieved_checkpoint_tuple = saver.get_tuple(put_config) + + assert custom_serializer.loads_called + + assert retrieved_checkpoint_tuple is not None + assert retrieved_checkpoint_tuple.checkpoint == input_data["chkpnt_1"] + assert retrieved_checkpoint_tuple.metadata == input_data["metadata_1"]