diff --git a/examples/langgraph-checkpointer/README.md b/examples/langgraph-checkpointer/README.md new file mode 100644 index 00000000..895287f4 --- /dev/null +++ b/examples/langgraph-checkpointer/README.md @@ -0,0 +1,72 @@ +# Dapr For Agents - LangGraph Checkpointer + +Supporting Dapr backed Checkpointer for LangGraph based Agents. + +## Pre-requisites + +- [Dapr CLI and initialized environment](https://docs.dapr.io/getting-started) +- [Install Python 3.10+](https://www.python.org/downloads/) + +## Install Dapr python-SDK + + + +```bash +pip3 install -r requirements.txt +``` + +## Run the example + +Export your `OPENAI_API_KEY`: + +```bash +export OPENAI_API_KEY="SK-..." +``` + +Run the following command in a terminal/command prompt: + + + +```bash +# 1. Run the LangGraph agent +dapr run --app-id langgraph-checkpointer --app-port 5001 -- python3 agent.py +``` + + + +## Cleanup + +Either press CTRL + C to quit the app or run the following command in a new terminal to stop the app: + +```bash +dapr stop --app-id langgraph-checkpointer +``` + diff --git a/examples/langgraph-checkpointer/agent.py b/examples/langgraph-checkpointer/agent.py new file mode 100644 index 00000000..1060c609 --- /dev/null +++ b/examples/langgraph-checkpointer/agent.py @@ -0,0 +1,68 @@ +import os + +from dapr.ext.langgraph import DaprCheckpointer +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + + +def add(a: int, b: int) -> int: + """Adds a and b. + + Args: + a: first int + b: second int + """ + return a + b + + +def multiply(a: int, b: int) -> int: + """Multiply a and b. + + Args: + a: first int + b: second int + """ + return a * b + + +tools = [add, multiply] +llm = ChatOpenAI(model='gpt-4o', api_key=os.environ['OPENAI_API_KEY']) +llm_with_tools = llm.bind_tools(tools) + +sys_msg = SystemMessage( + content='You are a helpful assistant tasked with performing arithmetic on a set of inputs.' +) + + +def assistant(state: MessagesState): + return {'messages': [llm_with_tools.invoke([sys_msg] + state['messages'])]} + + +builder = StateGraph(MessagesState) + +builder.add_node('assistant', assistant) +builder.add_node('tools', ToolNode(tools)) + +builder.add_edge(START, 'assistant') +builder.add_conditional_edges( + 'assistant', + tools_condition, +) +builder.add_edge('tools', 'assistant') + +memory = DaprCheckpointer(store_name='statestore', key_prefix='dapr') +react_graph_memory = builder.compile(checkpointer=memory) + +config = {'configurable': {'thread_id': '1'}} + +messages = [HumanMessage(content='Add 3 and 4.')] +messages = react_graph_memory.invoke({'messages': messages}, config) +for m in messages['messages']: + m.pretty_print() + +messages = [HumanMessage(content='Multiply that by 2.')] +messages = react_graph_memory.invoke({'messages': messages}, config) +for m in messages['messages']: + m.pretty_print() diff --git a/examples/langgraph-checkpointer/requirements.txt b/examples/langgraph-checkpointer/requirements.txt new file mode 100644 index 00000000..11f8124c --- /dev/null +++ b/examples/langgraph-checkpointer/requirements.txt @@ -0,0 +1,5 @@ +langchain-core>=1.0.7 +langchain-openai>=1.0.3 +langgraph>=1.0.3 +dapr-ext-workflow>=1.16.0.dev +dapr>=1.16.0.dev \ No newline at end of file diff --git a/ext/dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py b/ext/dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py index 123b313d..a18de1c3 100644 --- a/ext/dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py +++ b/ext/dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py @@ -1,11 +1,23 @@ +import base64 import json -from typing import Any, Sequence, Tuple +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple, cast -from langchain_core.load import dumps +import msgpack +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.runnables import RunnableConfig +from ulid import ULID from dapr.clients import DaprClient -from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointTuple +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]): @@ -19,7 +31,9 @@ class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]): def __init__(self, store_name: str, key_prefix: str): self.store_name = store_name self.key_prefix = key_prefix + self.serde = JsonPlusSerializer() self.client = DaprClient() + self._key_cache: Dict[str, str] = {} # helper: construct Dapr key for a thread def _get_key(self, config: RunnableConfig) -> str: @@ -36,59 +50,89 @@ def _get_key(self, config: RunnableConfig) -> str: return f'{self.key_prefix}:{thread_id}' - # restore a checkpoint - def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: - key = self._get_key(config) - - resp = self.client.get_state(store_name=self.store_name, key=key) - if not resp.data: - return None - - wrapper = json.loads(resp.data) - cp_data = wrapper.get('checkpoint', wrapper) - metadata = wrapper.get('metadata', {'step': 0}) - if 'step' not in metadata: - metadata['step'] = 0 - - cp = Checkpoint(**cp_data) - return CheckpointTuple( - config=config, - checkpoint=cp, - parent_config=None, - metadata=metadata, - ) - - # save a full checkpoint snapshot def put( self, config: RunnableConfig, checkpoint: Checkpoint, - parent_config: RunnableConfig | None, - metadata: dict[str, Any], - ) -> None: - key = self._get_key(config) + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + thread_id = config['configurable']['thread_id'] + checkpoint_ns = config['configurable'].get('checkpoint_ns', '') + config_checkpoint_id = config['configurable'].get('checkpoint_id', '') + thread_ts = config['configurable'].get('thread_ts', '') + + checkpoint_id = config_checkpoint_id or thread_ts or checkpoint.get('id', '') + + parent_checkpoint_id = None + if ( + checkpoint.get('id') + and config_checkpoint_id + and checkpoint.get('id') != config_checkpoint_id + ): + parent_checkpoint_id = config_checkpoint_id + checkpoint_id = checkpoint['id'] + + storage_safe_thread_id = self._safe_id(thread_id) + storage_safe_checkpoint_ns = self._safe_ns(checkpoint_ns) + storage_safe_checkpoint_id = self._safe_id(checkpoint_id) + + copy = checkpoint.copy() + next_config = { + 'configurable': { + 'thread_id': thread_id, + 'checkpoint_ns': checkpoint_ns, + 'checkpoint_id': checkpoint_id, + } + } - checkpoint_serializable = { - 'v': checkpoint['v'], - 'id': checkpoint['id'], - 'ts': checkpoint['ts'], - 'channel_values': checkpoint['channel_values'], - 'channel_versions': checkpoint['channel_versions'], - 'versions_seen': checkpoint['versions_seen'], + checkpoint_ts = None + if checkpoint_id: + try: + ulid_obj = ULID.from_str(checkpoint_id) + checkpoint_ts = ulid_obj.timestamp + except Exception: + checkpoint_ts = time.time() * 1000 + + checkpoint_data = { + 'thread_id': storage_safe_thread_id, + 'checkpoint_ns': storage_safe_checkpoint_ns, + 'checkpoint_id': storage_safe_checkpoint_id, + 'parent_checkpoint_id': ( + '00000000-0000-0000-0000-000000000000' + if (parent_checkpoint_id if parent_checkpoint_id else '') == '' + else parent_checkpoint_id + ), + 'checkpoint_ts': checkpoint_ts, + 'checkpoint': self._dump_checkpoint(copy), + 'metadata': self._dump_metadata(metadata), + 'has_writes': False, } - wrapper = {'checkpoint': checkpoint_serializable, 'metadata': metadata} + # Guard case where metadata is None + metadata = metadata or {} - self.client.save_state(self.store_name, key, dumps(wrapper)) + if all(key in metadata for key in ['source', 'step']): + checkpoint_data['source'] = metadata['source'] + checkpoint_data['step'] = metadata['step'] - reg_resp = self.client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY) - registry = json.loads(reg_resp.data) if reg_resp.data else [] + checkpoint_key = self._make_safe_checkpoint_key( + thread_id=thread_id, checkpoint_ns=checkpoint_ns, checkpoint_id=checkpoint_id + ) + + _, data = self.serde.dumps_typed(checkpoint_data) + self.client.save_state(store_name=self.store_name, key=checkpoint_key, value=data) + + latest_pointer_key = ( + f'checkpoint_latest:{storage_safe_thread_id}:{storage_safe_checkpoint_ns}' + ) + + self.client.save_state( + store_name=self.store_name, key=latest_pointer_key, value=checkpoint_key + ) - if key not in registry: - registry.append(key) - self.client.save_state(self.store_name, self.REGISTRY_KEY, json.dumps(registry)) + return next_config - # incremental persistence (for streamed runs) def put_writes( self, config: RunnableConfig, @@ -96,24 +140,50 @@ def put_writes( task_id: str, task_path: str = '', ) -> None: - _ = task_id, task_path - - key = self._get_key(config) + """Store intermediate writes linked to a checkpoint with integrated key registry.""" + thread_id = config['configurable']['thread_id'] + checkpoint_ns = config['configurable'].get('checkpoint_ns', '') + checkpoint_id = config['configurable'].get('checkpoint_id', '') + storage_safe_thread_id = (self._safe_id(thread_id),) + storage_safe_checkpoint_ns = self._safe_ns(checkpoint_ns) + + writes_objects: List[Dict[str, Any]] = [] + for idx, (channel, value) in enumerate(writes): + type_, blob = self.serde.dumps_typed(value) + write_obj: Dict[str, Any] = { + 'thread_id': storage_safe_thread_id, + 'checkpoint_ns': storage_safe_checkpoint_ns, + 'checkpoint_id': self._safe_id(checkpoint_id), + 'task_id': task_id, + 'task_path': task_path, + 'idx': WRITES_IDX_MAP.get(channel, idx), + 'channel': channel, + 'type': type_, + 'blob': self._encode_blob(blob), + } + writes_objects.append(write_obj) + + for write_obj in writes_objects: + idx_value = write_obj['idx'] + assert isinstance(idx_value, int) + key = self._make_safe_checkpoint_key( + thread_id=thread_id, checkpoint_ns=checkpoint_ns, checkpoint_id=checkpoint_id + ) - resp = self.client.get_state(store_name=self.store_name, key=key) - if not resp.data: - return + self.client.save_state(store_name=self.store_name, key=key, value=json.dumps(write_obj)) - wrapper = json.loads(resp.data) - cp = wrapper.get('checkpoint', {}) + checkpoint_key = self._make_safe_checkpoint_key( + thread_id=thread_id, checkpoint_ns=checkpoint_ns, checkpoint_id=checkpoint_id + ) - for field, value in writes: - cp['channel_values'][field] = value + latest_pointer_key = ( + f'checkpoint_latest:{storage_safe_thread_id}:{storage_safe_checkpoint_ns}' + ) - wrapper['checkpoint'] = cp - self.client.save_state(self.store_name, key, json.dumps(wrapper)) + self.client.save_state( + store_name=self.store_name, key=latest_pointer_key, value=checkpoint_key + ) - # enumerate all saved checkpoints def list(self, config: RunnableConfig) -> list[CheckpointTuple]: reg_resp = self.client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY) if not reg_resp.data: @@ -143,7 +213,6 @@ def list(self, config: RunnableConfig) -> list[CheckpointTuple]: return checkpoints - # remove a checkpoint and update the registry def delete_thread(self, config: RunnableConfig) -> None: key = self._get_key(config) @@ -162,3 +231,179 @@ def delete_thread(self, config: RunnableConfig) -> None: key=self.REGISTRY_KEY, value=json.dumps(registry), ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + thread_id = config['configurable']['thread_id'] + checkpoint_ns = config['configurable'].get('checkpoint_ns', '') + + storage_safe_thread_id = self._safe_id(thread_id) + storage_safe_checkpoint_ns = self._safe_ns(checkpoint_ns) + + key = ':'.join( + [ + 'checkpoint_latest', + storage_safe_thread_id, + storage_safe_checkpoint_ns, + ] + ) + + # First we extract the latest checkpoint key + checkpoint_key = self.client.get_state(store_name=self.store_name, key=key) + if not checkpoint_key.data: + return None + + # To then derive the checkpoint data + checkpoint_data = self.client.get_state( + store_name=self.store_name, + # checkpoint_key.data can either be str or bytes + key=checkpoint_key.data.decode() + if isinstance(checkpoint_key.data, bytes) + else checkpoint_key.data, + ) + + if not checkpoint_data.data: + return None + + if isinstance(checkpoint_data.data, bytes): + unpacked = msgpack.unpackb(checkpoint_data.data) + + checkpoint_values = unpacked[b'checkpoint'] + channel_values = checkpoint_values[b'channel_values'] + + decoded_messages = [] + for item in channel_values[b'messages']: + if isinstance(item, msgpack.ExtType): + decoded_messages.append( + self._convert_checkpoint_message( + self._load_metadata(msgpack.unpackb(item.data)) + ) + ) + else: + decoded_messages.append(item) + + checkpoint_values[b'channel_values'][b'messages'] = decoded_messages + + mdata = unpacked.get(b'metadata') + if isinstance(mdata, bytes): + mdata = self._load_metadata(msgpack.unpackb(mdata)) + + metadata = { + k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v + for k, v in mdata.items() + } + + checkpoint_obj = Checkpoint( + **{ + key.decode() if isinstance(key, bytes) else key: value + for key, value in checkpoint_values.items() + } + ) + + checkpoint = self._decode_bytes(checkpoint_obj) + elif isinstance(checkpoint_data.data, str): + unpacked = json.loads(checkpoint_data.data) + checkpoint = unpacked.get('checkpoint', None) + metadata = unpacked.get('metadata', None) + + if not metadata or not checkpoint: + return None + else: + return None + + return CheckpointTuple( + config=config, + checkpoint=checkpoint, + metadata=metadata, + parent_config=None, + pending_writes=[], + ) + + def _safe_id(self, id) -> str: + return '00000000-0000-0000-0000-000000000000' if id == '' else id + + def _safe_ns(self, ns) -> str: + return '__empty__' if ns == '' else ns + + def _convert_checkpoint_message(self, msg_item): + _, _, data_dict, _ = msg_item + data_dict = self._decode_bytes(data_dict) + + msg_type = data_dict.get('type') + + if msg_type == 'human': + return HumanMessage(**data_dict) + elif msg_type == 'ai': + return AIMessage(**data_dict) + elif msg_type == 'tool': + return ToolMessage(**data_dict) + else: + raise ValueError(f'Unknown message type: {msg_type}') + + def _decode_bytes(self, obj): + if isinstance(obj, bytes): + try: + s = obj.decode() + # Convert to int if it's a number, the unpacked channel_version holds \xa1 which unpacks as strings + # LangGraph needs Ints for '>' comparison + if s.isdigit(): + return int(s) + return s + except Exception: + return obj + if isinstance(obj, dict): + return {self._decode_bytes(k): self._decode_bytes(v) for k, v in obj.items()} + if isinstance(obj, list): + return [self._decode_bytes(v) for v in obj] + if isinstance(obj, tuple): + return tuple(self._decode_bytes(v) for v in obj) + return obj + + def _encode_blob(self, blob: Any) -> str: + if isinstance(blob, bytes): + return base64.b64encode(blob).decode() + return blob + + def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: + type_, data = self.serde.dumps_typed(checkpoint) + + if type_ == 'json': + checkpoint_data = cast(dict, json.loads(data)) + else: + checkpoint_data = cast(dict, self.serde.loads_typed((type_, data))) + + if 'channel_values' in checkpoint_data: + for key, value in checkpoint_data['channel_values'].items(): + if isinstance(value, bytes): + checkpoint_data['channel_values'][key] = { + '__bytes__': self._encode_blob(value) + } + + if 'channel_versions' in checkpoint_data: + checkpoint_data['channel_versions'] = { + k: str(v) for k, v in checkpoint_data['channel_versions'].items() + } + + return {'type': type_, **checkpoint_data, 'pending_sends': []} + + def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata: + type_str, data_bytes = self.serde.dumps_typed(metadata) + return self.serde.loads_typed((type_str, data_bytes)) + + def _dump_metadata(self, metadata: CheckpointMetadata) -> str: + _, serialized_bytes = self.serde.dumps_typed(metadata) + return serialized_bytes + + def _make_safe_checkpoint_key( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> str: + return ':'.join( + [ + 'checkpoint', + thread_id, + checkpoint_ns, + checkpoint_id, + ] + ) diff --git a/ext/dapr-ext-langgraph/setup.cfg b/ext/dapr-ext-langgraph/setup.cfg index bb32e782..3a06237f 100644 --- a/ext/dapr-ext-langgraph/setup.cfg +++ b/ext/dapr-ext-langgraph/setup.cfg @@ -27,6 +27,8 @@ install_requires = dapr >= 1.16.1rc1 langgraph >= 0.3.6 langchain >= 0.1.17 + python-ulid >= 3.0.0 + msgpack-python >= 0.4.5 [options.packages.find] include = diff --git a/ext/dapr-ext-langgraph/tests/test_checkpointer.py b/ext/dapr-ext-langgraph/tests/test_checkpointer.py index 05184f8a..fc51d918 100644 --- a/ext/dapr-ext-langgraph/tests/test_checkpointer.py +++ b/ext/dapr-ext-langgraph/tests/test_checkpointer.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- +import base64 import json import unittest from datetime import datetime from unittest import mock +import msgpack from dapr.ext.langgraph.dapr_checkpointer import DaprCheckpointer from langgraph.checkpoint.base import Checkpoint @@ -61,17 +63,37 @@ def test_put_saves_checkpoint_and_registry(self, mock_client_cls): mock_client.get_state.return_value.data = json.dumps([]) cp = DaprCheckpointer(self.store, self.prefix) - cp.put(self.config, self.checkpoint, None, {'step': 10}) - - first_call = mock_client.save_state.call_args_list[0][0] - assert first_call[0] == 'statestore' - assert first_call[1] == 'lg:t1' - saved_payload = json.loads(first_call[2]) + cp.put(self.config, self.checkpoint, {'step': 10}, None) + + first_call = mock_client.save_state.call_args_list[0] + first_call_kwargs = first_call.kwargs + assert first_call_kwargs['store_name'] == 'statestore' + assert first_call_kwargs['key'] == 'checkpoint:t1::cp1' + unpacked = msgpack.unpackb(first_call_kwargs['value']) # We're packing bytes + saved_payload = {} + for k, v in unpacked.items(): + k = k.decode() if isinstance(k, bytes) else k + if ( + k == 'checkpoint' or k == 'metadata' + ): # Need to convert b'' on checkpoint/metadata dict key/values + if k == 'metadata': + v = msgpack.unpackb(v) # Metadata value is packed + val = {} + for sk, sv in v.items(): + sk = sk.decode() if isinstance(sk, bytes) else sk + sv = sv.decode() if isinstance(sv, bytes) else sv + val[sk] = sv + else: + val = v.decode() if isinstance(v, bytes) else v + saved_payload[k] = val assert saved_payload['metadata']['step'] == 10 - second_call = mock_client.save_state.call_args_list[1][0] - assert second_call[0] == 'statestore' - assert second_call[1] == DaprCheckpointer.REGISTRY_KEY + second_call = mock_client.save_state.call_args_list[1] + second_call_kwargs = second_call.kwargs + assert second_call_kwargs['store_name'] == 'statestore' + assert ( + second_call_kwargs['value'] == 'checkpoint:t1::cp1' + ) # Here we're testing if the last checkpoint is the first_call above def test_put_writes_updates_channel_values(self, mock_client_cls): mock_client = mock_client_cls.return_value @@ -93,9 +115,12 @@ def test_put_writes_updates_channel_values(self, mock_client_cls): cp.put_writes(self.config, writes=[('a', 99)], task_id='task1') # save_state is called with updated checkpoint - call = mock_client.save_state.call_args[0] - saved = json.loads(call[2]) - assert saved['checkpoint']['channel_values']['a'] == 99 + call = mock_client.save_state.call_args_list[0] + # As we're using named input params we've got to fetch through kwargs + kwargs = call.kwargs + saved = json.loads(kwargs['value']) + # As the value obj is base64 encoded in 'blob' we got to unpack it + assert msgpack.unpackb(base64.b64decode(saved['blob'])) == 99 def test_list_returns_all_checkpoints(self, mock_client_cls): mock_client = mock_client_cls.return_value diff --git a/tox.ini b/tox.ini index 7c31dd8a..0697a408 100644 --- a/tox.ini +++ b/tox.ini @@ -61,6 +61,7 @@ commands = ./validate.sh demo_workflow ./validate.sh workflow ./validate.sh jobs + ./validate.sh langgraph-checkpointer ./validate.sh ../ commands_pre = pip3 install -e {toxinidir}/