# How to create a custom checkpointer using Firestore

When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions.

This example shows how to use firestore as the backend for persisting checkpoint state.

NOTE: this is just an example implementation. You can implement your own checkpointer using a different database or modify this one as long as it conforms to the BaseCheckpointSaver interface.

## Checkpointer implementation

In [1]:
%%capture --no-stderr
%pip install -U langgraph google-cloud-firestore

In [2]:
from typing import Any, AsyncIterator, Dict, Iterator, Optional
from google.cloud import firestore
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, CheckpointTuple
import pickle

class JsonPlusSerializerCompat(JsonPlusSerializer):
    def loads(self, data: bytes) -> Any:
        if data.startswith(b"\x80") and data.endswith(b"."):
            return pickle.loads(data)
        return super().loads(data)

class FirestoreSaver(BaseCheckpointSaver):
    serde = JsonPlusSerializerCompat()

    def __init__(self, project_id: str, collection_name: str = "checkpoints", *, serde: Optional[Any] = None) -> None:
        super().__init__(serde=serde)
        self.db: firestore.Client = firestore.Client(project=project_id)
        self.async_db: firestore.AsyncClient = firestore.AsyncClient(project=project_id)
        self.collection_name: str = collection_name

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        thread_id: str = config["configurable"]["thread_id"]
        thread_ts: Optional[str] = config["configurable"].get("thread_ts")
        
        doc_ref: firestore.DocumentReference = self.db.collection(self.collection_name).document(thread_id)
        doc: firestore.DocumentSnapshot = doc_ref.get()

        if not doc.exists:
            return None

        data: Dict[str, Any] = doc.to_dict()
        return self._process_checkpoint_data(thread_id, thread_ts, data)

    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        thread_id: str = config["configurable"]["thread_id"]
        thread_ts: Optional[str] = config["configurable"].get("thread_ts")
        
        doc_ref: firestore.AsyncDocumentReference = self.async_db.collection(self.collection_name).document(thread_id)
        doc: firestore.DocumentSnapshot = await doc_ref.get()
        
        data: Dict[str, Any] = doc.to_dict()
        return await self._aprocess_checkpoint_data(thread_id, thread_ts, data)

    def list(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[Dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> Iterator[CheckpointTuple]:
        thread_id: Optional[str] = config["configurable"]["thread_id"] if config else None
        if filter:
            raise NotImplementedError("Filtering is not implemented for FirestoreSaver")
        
        doc_ref: firestore.DocumentReference = self.db.collection(self.collection_name).document(thread_id)
        doc: firestore.DocumentSnapshot = doc_ref.get()

        if not doc.exists:
            raise ValueError(f"No checkpoints found for thread_id: {thread_id}")

        data: Dict[str, Any] = doc.to_dict()
        yield from self._process_checkpoint_list(thread_id, data, before, limit)

    async def alist(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[Dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> AsyncIterator[CheckpointTuple]:
        thread_id: Optional[str] = config["configurable"]["thread_id"] if config else None
        if filter:
            raise NotImplementedError("Filtering is not implemented for FirestoreSaver")
        
        doc_ref: firestore.AsyncDocumentReference = self.async_db.collection(self.collection_name).document(thread_id)
        doc: firestore.DocumentSnapshot = await doc_ref.get()

        if not doc.exists:
            raise ValueError(f"No checkpoints found for thread_id: {thread_id}")

        data: Dict[str, Any] = doc.to_dict()
        async for item in self._aprocess_checkpoint_list(thread_id, data, before, limit):
            yield item

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        thread_id: str = config["configurable"]["thread_id"]
        parent_ts: Optional[str] = config["configurable"].get("thread_ts")
        ts: str = checkpoint["id"]
        
        doc_ref: firestore.DocumentReference = self.db.collection(self.collection_name).document(thread_id)
        doc_ref.set({
            ts: {
                "checkpoint": self.serde.dumps(checkpoint),
                "metadata": self.serde.dumps(metadata),
                "parent_ts": parent_ts if parent_ts else "",
            }
        }, merge=True)

        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": ts,
            },
        }

    async def aput(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        thread_id: str = config["configurable"]["thread_id"]
        parent_ts: Optional[str] = config["configurable"].get("thread_ts")
        ts: str = checkpoint["id"]
        
        doc_ref: firestore.AsyncDocumentReference = self.async_db.collection(self.collection_name).document(thread_id)
        await doc_ref.set({
            ts: {
                "checkpoint": self.serde.dumps(checkpoint),
                "metadata": self.serde.dumps(metadata),
                "parent_ts": parent_ts if parent_ts else "",
            }
        }, merge=True)

        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": ts,
            },
        }

    def _process_checkpoint_data(self, thread_id: str, thread_ts: Optional[str], data: Dict[str, Any]) -> Optional[CheckpointTuple]:
        return self._process_checkpoint_data_common(thread_id, thread_ts, data)

    async def _aprocess_checkpoint_data(self, thread_id: str, thread_ts: Optional[str], data: Dict[str, Any]) -> Optional[CheckpointTuple]:
        return self._process_checkpoint_data_common(thread_id, thread_ts, data)

    def _process_checkpoint_list(self, thread_id: str, data: Dict[str, Any], before: Optional[RunnableConfig], limit: Optional[int]):
        yield from self._process_checkpoint_list_common(thread_id, data, before, limit)

    async def _aprocess_checkpoint_list(self, thread_id: str, data: Dict[str, Any], before: Optional[RunnableConfig], limit: Optional[int]):
        for item in self._process_checkpoint_list_common(thread_id, data, before, limit):
            yield item
    
    def _process_checkpoint_data_common(self, thread_id: str, thread_ts: Optional[str], data: Dict[str, Any]) -> Optional[CheckpointTuple]:
        if thread_ts:
            checkpoint_data = data[thread_ts]
        else:
            latest_ts = max(data.keys())
            checkpoint_data = data[latest_ts]
            thread_ts = latest_ts

        checkpoint: Checkpoint = self.serde.loads(checkpoint_data["checkpoint"])
        metadata: CheckpointMetadata = self.serde.loads(checkpoint_data["metadata"])
        parent_ts: str = checkpoint_data.get("parent_ts", "")
        parent_config: Optional[RunnableConfig] = {"configurable": {"thread_id": thread_id, "thread_ts": parent_ts}} if parent_ts else None

        config: RunnableConfig = {"configurable": {"thread_id": thread_id, "thread_ts": thread_ts}}
        return CheckpointTuple(config=config, checkpoint=checkpoint, metadata=metadata, parent_config=parent_config)

    def _process_checkpoint_list_common(self, thread_id: str, data: Dict[str, Any], before: Optional[RunnableConfig], limit: Optional[int]):
        count = 0
        for ts, checkpoint_data in sorted(data.items(), reverse=True):
            if before and ts >= before["configurable"]["thread_ts"]:
                continue
            checkpoint: Checkpoint = self.serde.loads(checkpoint_data["checkpoint"])
            metadata: CheckpointMetadata = self.serde.loads(checkpoint_data["metadata"])
            parent_ts: str = checkpoint_data.get("parent_ts", "")
            parent_config: Optional[RunnableConfig] = {"configurable": {"thread_id": thread_id, "thread_ts": parent_ts}} if parent_ts else None
            yield CheckpointTuple(
                config={"configurable": {"thread_id": thread_id, "thread_ts": ts}},
                checkpoint=checkpoint,
                metadata=metadata,
                parent_config=parent_config,
            )
            count += 1
            if limit is not None and count >= limit:
                break

## GCP authentification

Run the two commands below and login with your google account to populate [ADC credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc#local-dev).

In [3]:
#!gcloud auth login

In [4]:
#!gcloud auth application-default login

In [5]:
GCP_PROJECT_ID = ""
COLLECTION_NAME = "checkpoints"

# Basic example using graph

In [6]:
from langgraph.graph import StateGraph
from uuid import uuid4

checkpointer = FirestoreSaver(project_id=GCP_PROJECT_ID, collection_name=COLLECTION_NAME)
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.set_entry_point("add_one")
builder.set_finish_point("add_one")
graph = builder.compile(checkpointer=checkpointer)


config = {"configurable": {"thread_id": str(uuid4())}}

result = graph.invoke(3,config)

## Synchronous usage

In [7]:
graph.get_state(config)

StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f502-6366-8001-653ad92e7f5b'}}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, created_at='2024-07-17T13:03:44.394318+00:00', parent_config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}})

In [8]:
checkpointer.get(config)

{'v': 1,
 'ts': '2024-07-17T13:03:44.394318+00:00',
 'id': '1ef443cf-f502-6366-8001-653ad92e7f5b',
 'channel_values': {'__root__': 4, 'add_one': 'add_one'},
 'channel_versions': {'__start__': 2,
  '__root__': 3,
  'start:add_one': 3,
  'add_one': 3},
 'versions_seen': {'__start__': {'__start__': 1},
  'add_one': {'start:add_one': 2}},
 'pending_sends': []}

In [9]:
list = checkpointer.list(config, limit=3)
for item in list:
    print(item)

CheckpointTuple(config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f502-6366-8001-653ad92e7f5b'}}, checkpoint={'v': 1, 'ts': '2024-07-17T13:03:44.394318+00:00', 'id': '1ef443cf-f502-6366-8001-653ad92e7f5b', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}}, pending_writes=None)
CheckpointTuple(config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}}, checkpoint={'v': 1, 'ts': '2024-07-17T13:03:44.389752+00:00', 'id': '1ef443cf-f4f7-6100-8000-e74d88d270ef', 'channel_v

In [10]:
checkpointer.get_tuple(config)


CheckpointTuple(config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f502-6366-8001-653ad92e7f5b'}}, checkpoint={'v': 1, 'ts': '2024-07-17T13:03:44.394318+00:00', 'id': '1ef443cf-f502-6366-8001-653ad92e7f5b', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}}, pending_writes=None)

## Asynchronous usage

In [11]:
await checkpointer.aget(config)

{'v': 1,
 'ts': '2024-07-17T13:03:44.394318+00:00',
 'id': '1ef443cf-f502-6366-8001-653ad92e7f5b',
 'channel_values': {'__root__': 4, 'add_one': 'add_one'},
 'channel_versions': {'__start__': 2,
  '__root__': 3,
  'start:add_one': 3,
  'add_one': 3},
 'versions_seen': {'__start__': {'__start__': 1},
  'add_one': {'start:add_one': 2}},
 'pending_sends': []}

In [12]:
await checkpointer.aget_tuple(config)

CheckpointTuple(config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f502-6366-8001-653ad92e7f5b'}}, checkpoint={'v': 1, 'ts': '2024-07-17T13:03:44.394318+00:00', 'id': '1ef443cf-f502-6366-8001-653ad92e7f5b', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}}, pending_writes=None)

In [13]:
list = checkpointer.alist(config, limit=3)
async for item in list:
    print(item)

CheckpointTuple(config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f502-6366-8001-653ad92e7f5b'}}, checkpoint={'v': 1, 'ts': '2024-07-17T13:03:44.394318+00:00', 'id': '1ef443cf-f502-6366-8001-653ad92e7f5b', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}}, pending_writes=None)
CheckpointTuple(config={'configurable': {'thread_id': 'd4e38642-ee2e-445c-9dc1-238a9e4d2e03', 'thread_ts': '1ef443cf-f4f7-6100-8000-e74d88d270ef'}}, checkpoint={'v': 1, 'ts': '2024-07-17T13:03:44.389752+00:00', 'id': '1ef443cf-f4f7-6100-8000-e74d88d270ef', 'channel_v