Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def aput(
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
metadata = metadata.copy()
metadata.update(config.get("metadata", {}))
doc = {
doc: dict[str, Any] = {
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
"type": type_,
"checkpoint": serialized_checkpoint,
Expand Down Expand Up @@ -410,6 +410,7 @@ async def aput_writes(
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
)
operations = []
now = datetime.now()
for idx, (channel, value) in enumerate(writes):
upsert_query = {
"thread_id": thread_id,
Expand All @@ -419,19 +420,22 @@ async def aput_writes(
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
}
if self.ttl:
upsert_query["created_at"] = datetime.now()

type_, serialized_value = self.serde.dumps_typed(value)

update_doc: dict[str, Any] = {
"channel": channel,
"type": type_,
"value": serialized_value,
}

if self.ttl:
update_doc["created_at"] = now

operations.append(
UpdateOne(
upsert_query,
{
set_method: {
"channel": channel,
"type": type_,
"value": serialized_value,
}
},
filter=upsert_query,
update={set_method: update_doc},
upsert=True,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def put(
"checkpoint_id": checkpoint_id,
}
if self.ttl:
upsert_query["created_at"] = datetime.now()
doc["created_at"] = datetime.now()

self.checkpoint_collection.update_one(upsert_query, {"$set": doc}, upsert=True)
return {
Expand Down Expand Up @@ -425,6 +425,7 @@ def put_writes(
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
)
operations = []
now = datetime.now()
for idx, (channel, value) in enumerate(writes):
upsert_query = {
"thread_id": thread_id,
Expand All @@ -434,20 +435,22 @@ def put_writes(
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
}
if self.ttl:
upsert_query["created_at"] = datetime.now()

type_, serialized_value = self.serde.dumps_typed(value)

update_doc: dict[str, Any] = {
"channel": channel,
"type": type_,
"value": serialized_value,
}

if self.ttl:
update_doc["created_at"] = now

operations.append(
UpdateOne(
upsert_query,
{
set_method: {
"channel": channel,
"type": type_,
"value": serialized_value,
}
},
filter=upsert_query,
update={set_method: update_doc},
upsert=True,
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from collections.abc import AsyncGenerator

import pytest
import pytest_asyncio
from langchain_core.runnables import RunnableConfig
from pymongo import MongoClient

from langgraph.checkpoint.mongodb import MongoDBSaver
from langgraph.types import Interrupt

MONGODB_URI = os.environ.get(
"MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true"
)
DB_NAME: str = "test_langgraph_db"
COLLECTION_NAME: str = "checkpoints_interrupts"
WRITES_COLLECTION_NAME: str = "writes_interrupts"
TTL: int = 60 * 60


@pytest_asyncio.fixture()
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
# Use sync client and checkpointer with async methods run in executor
client: MongoClient = MongoClient(MONGODB_URI)
db = client[DB_NAME]
for clxn in db.list_collection_names():
db.drop_collection(clxn)
with MongoDBSaver.from_conn_string(
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
) as checkpointer:
yield checkpointer
client.close()


async def test_put_writes_on_interrupt(async_saver: MongoDBSaver) -> None:
"""Test that no error is raised when interrupted workflow updates writes."""
config: RunnableConfig = {
"configurable": {
"checkpoint_id": "check1",
"thread_id": "thread1",
"checkpoint_ns": "",
}
}
task_id = "task_id"
task_path = "~__pregel_pull, human_feedback"

writes1 = [
(
"__interrupt__",
(
Interrupt(
value="please provide input",
),
),
)
]
await async_saver.aput_writes(config, writes1, task_id, task_path)

writes2 = [("__interrupt__", (Interrupt(value="please provide another input"),))]
await async_saver.aput_writes(config, writes2, task_id, task_path)
163 changes: 163 additions & 0 deletions libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_time_travel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
from collections.abc import Generator
from operator import add
from typing import Annotated, Any, TypedDict

import pytest
from langchain_core.runnables import RunnableConfig
from pymongo import MongoClient
from typing_extensions import NotRequired

from langgraph.checkpoint.mongodb import MongoDBSaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import StateSnapshot

# Test configuration
MONGODB_URI = os.environ.get(
"MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true"
)


class ExpenseState(TypedDict):
amount: NotRequired[int]
version: NotRequired[int]
approved: NotRequired[bool]
messages: Annotated[list[str], add]


def add_expense_node(state: ExpenseState) -> dict[str, Any]:
"""Node adds expense and a message"""
return dict(amount=100, version=1, approved=False, messages=["Added new expense"])


def validate_expense_node(state: ExpenseState) -> dict[str, Any]:
"""Node that processes data based on current state"""
if state.get("amount") == 200:
return dict(approved=True, messages=["expense approved"])
else:
return dict(approved=False, messages=["expense denied"])


@pytest.fixture(
params=[None, 60 * 60],
ids=["ttl_none", "ttl_3600"],
)
def checkpointer(request: Any) -> Generator[MongoDBSaver]:
db_name = "langgraph_timetravel_db"
checkpoint_collection_name = "checkpoints"
writes_collection_name = "checkpoint_writes"

# Initialize MongoDB checkpointer
client: MongoClient = MongoClient(MONGODB_URI)

# Clean up any existing test data.
client.drop_database(db_name)

saver = MongoDBSaver(
client=client,
db_name=db_name,
collection_name=checkpoint_collection_name,
WRITES_COLLECTION_NAME=writes_collection_name,
ttl=request.param,
)

# Can use this to compare
# saver = InMemorySaver()

yield saver

client[db_name].drop_collection(checkpoint_collection_name)
client[db_name].drop_collection(writes_collection_name)
client.close()


def test(checkpointer: MongoDBSaver) -> None:
"""Test ability to use checkpointer to update exact state of graph.

In this simple example, we assume an initial state has been set incorrectly.
To fix this, instead of rerunning from start,
we find the incorrect node, update_state, and continue (by passing None to invoke or stream).

This example does not use interrupt/resume as one might, for example,
in an expense report approval workflow.
"""
initial_state: ExpenseState = dict(
amount=0, version=0, approved=False, messages=["Initial state"]
)
config: RunnableConfig = dict(configurable=dict(thread_id="test-time-travel"))

# Create the graph, which should be a 2-step procedure
workflow = StateGraph(ExpenseState)
workflow.add_node("add_expense", add_expense_node)
workflow.add_node("validate_expense", validate_expense_node)
workflow.add_edge(START, "add_expense")
workflow.add_edge("validate_expense", END)
workflow.add_edge("add_expense", "validate_expense")
graph = workflow.compile(checkpointer=checkpointer)

# Run the graph
graph.invoke(input=initial_state, config=config) # type:ignore[arg-type]

# Check to see whether the final state is approved
final_state = graph.get_state(config=config)

# It is not approved.
assert not final_state.values["approved"]

# Let's use time-travel to find the checkpoint before "add_expense"
checkpoints: list[StateSnapshot] = list(graph.get_state_history(config))
# checkpoints: list[CheckpointTuple] = list(checkpointer.list(config))
print(f"\nFound {len(checkpoints)} checkpoints")

target_checkpoint = None
for checkpoint in checkpoints:
# Look for checkpoint after increment but before final processing
if (
checkpoint.metadata and checkpoint.metadata.get("step") == 1
): # Before validate node
target_checkpoint = checkpoint
break

for state in checkpoints:
if state.metadata:
print(f"\nstep: {state.metadata['step']}")
print(f"next: {state.next}")
print(f"checkpoint_id: {state.config['configurable']['checkpoint_id']}")
print(f"values: {state.values}")

# Get state at that checkpoint
assert target_checkpoint
past_state = graph.get_state(target_checkpoint.config)

# Update the expense amount to 200 that validate amounts
updated_state = dict(**past_state.values)
# updated_state = {}
updated_state["amount"] = 200
updated_state["version"] = 2
updated_state["messages"] += ["Updated state"]

updated_config = graph.update_state(
config=target_checkpoint.config, values=updated_state
)

# Continue from the checkpoint
print("\nContinuing execution with stream(None, config)...")
final_step = None
for step in graph.stream(None, updated_config):
print(f"Continuation step: {step}")
final_step = step

# Verify the final result
assert isinstance(final_step, dict)
assert final_step["validate_expense"]["approved"]
# Note that all values are not in the final step
assert "amount" not in final_step["validate_expense"]
# They ARE available from graph.get_state
final_state = graph.get_state(updated_config)
assert final_state.values["amount"] == 200
assert set(final_state.values.keys()) == {
"amount",
"version",
"messages",
"approved",
}
Loading