# How to create a custom checkpointer using Redis

When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions. Make sure that you have Redis running on port `6379` for going through this tutorial

This example shows how to use `Redis` as the backend for persisting checkpoint state.

NOTE: this is just an example implementation. You can implement your own checkpointer using a different database or modify this one as long as it conforms to the `BaseCheckpointSaver` interface.

## Install the necessary libraries for Redis on Python

In [1]:
%%capture --no-stderr
%pip install -U redis langgraph

## Checkpointer implementation

In [2]:
"""Implementation of a langgraph checkpoint saver using Redis."""
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator, Union, Tuple, Optional

import redis
from redis.asyncio import Redis as AsyncRedis, ConnectionPool as AsyncConnectionPool
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, CheckpointTuple
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class JsonAndBinarySerializer(JsonPlusSerializer):
    def _default(self, obj: Any) -> Any:
        if isinstance(obj, (bytes, bytearray)):
            return self._encode_constructor_args(
                obj.__class__, method="fromhex", args=[obj.hex()]
            )
        return super()._default(obj)

    def dumps(self, obj: Any) -> str:
        try:
            if isinstance(obj, (bytes, bytearray)):
                return obj.hex()
            return super().dumps(obj)
        except Exception as e:
            logger.error(f"Serialization error: {e}")
            raise

    def loads(self, s: str, is_binary: bool = False) -> Any:
        try:
            if is_binary:
                return bytes.fromhex(s)
            return super().loads(s)
        except Exception as e:
            logger.error(f"Deserialization error: {e}")
            raise


def initialize_sync_pool(
    host: str = "localhost", port: int = 6379, db: int = 0, **kwargs
) -> redis.ConnectionPool:
    """Initialize a synchronous Redis connection pool."""
    try:
        pool = redis.ConnectionPool(host=host, port=port, db=db, **kwargs)
        logger.info(
            f"Synchronous Redis pool initialized with host={host}, port={port}, db={db}"
        )
        return pool
    except Exception as e:
        logger.error(f"Error initializing sync pool: {e}")
        raise


def initialize_async_pool(
    url: str = "redis://localhost", **kwargs
) -> AsyncConnectionPool:
    """Initialize an asynchronous Redis connection pool."""
    try:
        pool = AsyncConnectionPool.from_url(url, **kwargs)
        logger.info(f"Asynchronous Redis pool initialized with url={url}")
        return pool
    except Exception as e:
        logger.error(f"Error initializing async pool: {e}")
        raise


@contextmanager
def _get_sync_connection(
    connection: Union[redis.Redis, redis.ConnectionPool, None]
) -> Generator[redis.Redis, None, None]:
    conn = None
    try:
        if isinstance(connection, redis.Redis):
            yield connection
        elif isinstance(connection, redis.ConnectionPool):
            conn = redis.Redis(connection_pool=connection)
            yield conn
        else:
            raise ValueError("Invalid sync connection object.")
    except redis.ConnectionError as e:
        logger.error(f"Sync connection error: {e}")
        raise
    finally:
        if conn:
            conn.close()


@asynccontextmanager
async def _get_async_connection(
    connection: Union[AsyncRedis, AsyncConnectionPool, None]
) -> AsyncGenerator[AsyncRedis, None]:
    conn = None
    try:
        if isinstance(connection, AsyncRedis):
            yield connection
        elif isinstance(connection, AsyncConnectionPool):
            conn = AsyncRedis(connection_pool=connection)
            yield conn
        else:
            raise ValueError("Invalid async connection object.")
    except redis.ConnectionError as e:
        logger.error(f"Async connection error: {e}")
        raise
    finally:
        if conn:
            await conn.aclose()


class RedisSaver(BaseCheckpointSaver):
    sync_connection: Optional[Union[redis.Redis, redis.ConnectionPool]] = None
    async_connection: Optional[Union[AsyncRedis, AsyncConnectionPool]] = None

    def __init__(
        self,
        sync_connection: Optional[Union[redis.Redis, redis.ConnectionPool]] = None,
        async_connection: Optional[Union[AsyncRedis, AsyncConnectionPool]] = None,
    ):
        super().__init__(serde=JsonAndBinarySerializer())
        self.sync_connection = sync_connection
        self.async_connection = async_connection

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        thread_id = config["configurable"]["thread_id"]
        parent_ts = config["configurable"].get("thread_ts")
        key = f"checkpoint:{thread_id}:{checkpoint['ts']}"
        try:
            with _get_sync_connection(self.sync_connection) as conn:
                conn.hset(
                    key,
                    mapping={
                        "checkpoint": self.serde.dumps(checkpoint),
                        "metadata": self.serde.dumps(metadata),
                        "parent_ts": parent_ts if parent_ts else "",
                    },
                )
                logger.info(
                    f"Checkpoint stored successfully for thread_id: {thread_id}, ts: {checkpoint['ts']}"
                )
        except Exception as e:
            logger.error(f"Failed to put checkpoint: {e}")
            raise
        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": checkpoint["ts"],
            },
        }

    async def aput(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
    ) -> RunnableConfig:
        thread_id = config["configurable"]["thread_id"]
        parent_ts = config["configurable"].get("thread_ts")
        key = f"checkpoint:{thread_id}:{checkpoint['ts']}"
        try:
            async with _get_async_connection(self.async_connection) as conn:
                await conn.hset(
                    key,
                    mapping={
                        "checkpoint": self.serde.dumps(checkpoint),
                        "metadata": self.serde.dumps(metadata),
                        "parent_ts": parent_ts if parent_ts else "",
                    },
                )
                logger.info(
                    f"Checkpoint stored successfully for thread_id: {thread_id}, ts: {checkpoint['ts']}"
                )
        except Exception as e:
            logger.error(f"Failed to aput checkpoint: {e}")
            raise
        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": checkpoint["ts"],
            },
        }

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        thread_id = config["configurable"]["thread_id"]
        thread_ts = config["configurable"].get("thread_ts", None)
        try:
            with _get_sync_connection(self.sync_connection) as conn:
                if thread_ts:
                    key = f"checkpoint:{thread_id}:{thread_ts}"
                else:
                    all_keys = conn.keys(f"checkpoint:{thread_id}:*")
                    if not all_keys:
                        logger.info(f"No checkpoints found for thread_id: {thread_id}")
                        return None
                    latest_key = max(all_keys, key=lambda k: k.decode().split(":")[-1])
                    key = latest_key.decode()
                checkpoint_data = conn.hgetall(key)
                if not checkpoint_data:
                    logger.info(f"No valid checkpoint data found for key: {key}")
                    return None
                checkpoint = self.serde.loads(checkpoint_data[b"checkpoint"].decode())
                metadata = self.serde.loads(checkpoint_data[b"metadata"].decode())
                parent_ts = checkpoint_data.get(b"parent_ts", b"").decode()
                parent_config = (
                    {"configurable": {"thread_id": thread_id, "thread_ts": parent_ts}}
                    if parent_ts
                    else None
                )
                logger.info(
                    f"Checkpoint retrieved successfully for thread_id: {thread_id}, ts: {thread_ts}"
                )
                return CheckpointTuple(
                    config=config,
                    checkpoint=checkpoint,
                    metadata=metadata,
                    parent_config=parent_config,
                )
        except Exception as e:
            logger.error(f"Failed to get checkpoint tuple: {e}")
            raise

    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        thread_id = config["configurable"]["thread_id"]
        thread_ts = config["configurable"].get("thread_ts", None)
        try:
            async with _get_async_connection(self.async_connection) as conn:
                if thread_ts:
                    key = f"checkpoint:{thread_id}:{thread_ts}"
                else:
                    all_keys = await conn.keys(f"checkpoint:{thread_id}:*")
                    if not all_keys:
                        logger.info(f"No checkpoints found for thread_id: {thread_id}")
                        return None
                    latest_key = max(all_keys, key=lambda k: k.decode().split(":")[-1])
                    key = latest_key.decode()
                checkpoint_data = await conn.hgetall(key)
                if not checkpoint_data:
                    logger.info(f"No valid checkpoint data found for key: {key}")
                    return None
                checkpoint = self.serde.loads(checkpoint_data[b"checkpoint"].decode())
                metadata = self.serde.loads(checkpoint_data[b"metadata"].decode())
                parent_ts = checkpoint_data.get(b"parent_ts", b"").decode()
                parent_config = (
                    {"configurable": {"thread_id": thread_id, "thread_ts": parent_ts}}
                    if parent_ts
                    else None
                )
                logger.info(
                    f"Checkpoint retrieved successfully for thread_id: {thread_id}, ts: {thread_ts}"
                )
                return CheckpointTuple(
                    config=config,
                    checkpoint=checkpoint,
                    metadata=metadata,
                    parent_config=parent_config,
                )
        except Exception as e:
            logger.error(f"Failed to get checkpoint tuple: {e}")
            raise

    def list(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> Generator[CheckpointTuple, None, None]:
        thread_id = config["configurable"]["thread_id"] if config else "*"
        pattern = f"checkpoint:{thread_id}:*"
        try:
            with _get_sync_connection(self.sync_connection) as conn:
                keys = conn.keys(pattern)
                if before:
                    keys = [
                        k
                        for k in keys
                        if k.decode().split(":")[-1]
                        < before["configurable"]["thread_ts"]
                    ]
                keys = sorted(
                    keys, key=lambda k: k.decode().split(":")[-1], reverse=True
                )
                if limit:
                    keys = keys[:limit]
                for key in keys:
                    data = conn.hgetall(key)
                    if data and "checkpoint" in data and "metadata" in data:
                        thread_ts = key.decode().split(":")[-1]
                        yield CheckpointTuple(
                            config={
                                "configurable": {
                                    "thread_id": thread_id,
                                    "thread_ts": thread_ts,
                                }
                            },
                            checkpoint=self.serde.loads(data["checkpoint"].decode()),
                            metadata=self.serde.loads(data["metadata"].decode()),
                            parent_config={
                                "configurable": {
                                    "thread_id": thread_id,
                                    "thread_ts": data.get("parent_ts", b"").decode(),
                                }
                            }
                            if data.get("parent_ts")
                            else None,
                        )
                        logger.info(
                            f"Checkpoint listed for thread_id: {thread_id}, ts: {thread_ts}"
                        )
        except Exception as e:
            logger.error(f"Failed to list checkpoints: {e}")
            raise

    async def alist(
        self,
        config: Optional[RunnableConfig],
        *,
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> AsyncGenerator[CheckpointTuple, None]:
        thread_id = config["configurable"]["thread_id"] if config else "*"
        pattern = f"checkpoint:{thread_id}:*"
        try:
            async with _get_async_connection(self.async_connection) as conn:
                keys = await conn.keys(pattern)
                if before:
                    keys = [
                        k
                        for k in keys
                        if k.decode().split(":")[-1]
                        < before["configurable"]["thread_ts"]
                    ]
                keys = sorted(
                    keys, key=lambda k: k.decode().split(":")[-1], reverse=True
                )
                if limit:
                    keys = keys[:limit]
                for key in keys:
                    data = await conn.hgetall(key)
                    if data and "checkpoint" in data and "metadata" in data:
                        thread_ts = key.decode().split(":")[-1]
                        yield CheckpointTuple(
                            config={
                                "configurable": {
                                    "thread_id": thread_id,
                                    "thread_ts": thread_ts,
                                }
                            },
                            checkpoint=self.serde.loads(data["checkpoint"].decode()),
                            metadata=self.serde.loads(data["metadata"].decode()),
                            parent_config={
                                "configurable": {
                                    "thread_id": thread_id,
                                    "thread_ts": data.get("parent_ts", b"").decode(),
                                }
                            }
                            if data.get("parent_ts")
                            else None,
                        )
                        logger.info(
                            f"Checkpoint listed for thread_id: {thread_id}, ts: {thread_ts}"
                        )
        except Exception as e:
            logger.error(f"Failed to list checkpoints: {e}")
            raise

## Checkpointer implementation

## Setup environment

In [3]:
import getpass
import os


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")

## Setup model and tools for the graph

In [4]:
from typing import Literal
from langchain_core.runnables import ConfigurableField
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent


@tool
def get_weather(city: Literal["nyc", "sf"]):
    """Use this to get weather information."""
    if city == "nyc":
        return "It might be cloudy in nyc"
    elif city == "sf":
        return "It's always sunny in sf"
    else:
        raise AssertionError("Unknown city")


tools = [get_weather]
model = ChatOpenAI(model_name="gpt-4o", temperature=0)

## Use sync connection

### With a connection pool

In [5]:
sync_pool = initialize_sync_pool(host="172.25.0.4", port=6379, db=0)

INFO:__main__:Synchronous Redis pool initialized with host=172.25.0.4, port=6379, db=0


In [6]:
checkpointer = RedisSaver(sync_connection=sync_pool)

In [7]:
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
res = graph.invoke({"messages": [("human", "what's the weather in sf")]}, config)

INFO:__main__:Checkpoint retrieved successfully for thread_id: 1, ts: None
INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:48.417492+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:48.420714+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:49.458951+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:49.465101+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:50.084141+00:00


In [8]:
res

{'messages': [HumanMessage(content="what's the weather in sf", id='64df8e19-0b9f-47f7-928f-4db3255485aa'),
  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_n2XQOZHfpXpaNaviJakmjo82', 'function': {'arguments': '{"city":"sf"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 57, 'total_tokens': 71}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9057607f-6fa7-452b-95c7-f8f9832cb343-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_n2XQOZHfpXpaNaviJakmjo82'}], usage_metadata={'input_tokens': 57, 'output_tokens': 14, 'total_tokens': 71}),
  ToolMessage(content="It's always sunny in sf", name='get_weather', id='444d80db-8230-440a-b0eb-46a3f4db1006', tool_call_id='call_n2XQOZHfpXpaNaviJakmjo82'),
  AIMessage(content='The weather in San Francisco is currently sunny.', response_metadata={

In [9]:
checkpointer.get(config)

INFO:__main__:Checkpoint retrieved successfully for thread_id: 1, ts: None


{'v': 1,
 'ts': '2024-07-08T12:21:14.392158+00:00',
 'id': '1ef3d249-1acf-60ed-bfff-248e42e4d9f5',
 'channel_values': {'messages': [],
  '__start__': {'messages': [['human', "what's the weather in sf"]]}},
 'channel_versions': {'__start__': 1},
 'versions_seen': {},
 'pending_sends': []}

### With a connection

In [10]:
import redis

# Initialize the Redis synchronous direct connection
sync_redis_direct = redis.Redis(host="172.25.0.4", port=6379, db=0)

# Initialize the RedisSaver with the synchronous direct connection
checkpointer = RedisSaver(sync_connection=sync_redis_direct)

graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
config = {"configurable": {"thread_id": "2"}}
res = graph.invoke({"messages": [("human", "what's the weather in sf")]}, config)

checkpoint_tuple = checkpointer.get_tuple(config)

INFO:__main__:Checkpoint retrieved successfully for thread_id: 2, ts: None
INFO:__main__:Checkpoint stored successfully for thread_id: 2, ts: 2024-07-09T08:22:50.132262+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 2, ts: 2024-07-09T08:22:50.135993+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 2, ts: 2024-07-09T08:22:50.875540+00:00
INFO:__main__:Checkpoint retrieved successfully for thread_id: 2, ts: None


## Use async connection

### With a connection pool

In [13]:
# Initialize a synchronous Redis connection pool
async_pool = initialize_async_pool(url="redis://172.25.0.4:6379/0")

checkpointer = RedisSaver(async_connection=async_pool)

INFO:__main__:Asynchronous Redis pool initialized with url=redis://172.25.0.4:6379/0


In [14]:
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
config = {"configurable": {"thread_id": "3"}}
res = await graph.ainvoke(
    {"messages": [("human", "what's the weather in nyc")]}, config
)

INFO:__main__:Checkpoint retrieved successfully for thread_id: 3, ts: None
INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:50.949172+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:50.951824+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:51.698633+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:51.702156+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:53.530983+00:00


In [15]:
checkpoint_tuple = await checkpointer.aget_tuple(config)

INFO:__main__:Checkpoint retrieved successfully for thread_id: 3, ts: None


In [16]:
checkpoint_tuple

CheckpointTuple(config={'configurable': {'thread_id': '3'}}, checkpoint={'v': 1, 'ts': '2024-07-08T12:21:18.866666+00:00', 'id': '1ef3d249-457b-62d3-bfff-b0e787336a7c', 'channel_values': {'messages': [], '__start__': {'messages': [['human', "what's the weather in nyc"]]}}, 'channel_versions': {'__start__': 1}, 'versions_seen': {}, 'pending_sends': []}, metadata={'source': 'input', 'step': -1, 'writes': {'messages': [['human', "what's the weather in nyc"]]}}, parent_config=None)

### Use connection

In [17]:
from redis.asyncio import Redis as AsyncRedis

async with await AsyncRedis(host="172.25.0.4", port=6379, db=0) as conn:
    checkpointer = RedisSaver(async_connection=conn)
    graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
    config = {"configurable": {"thread_id": "4"}}
    res = await graph.ainvoke(
        {"messages": [("human", "what's the weather in nyc")]}, config
    )
    checkpoint_tuples = [c async for c in checkpointer.alist(config)]

INFO:__main__:Checkpoint retrieved successfully for thread_id: 4, ts: None
INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:53.585109+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:53.587207+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:54.932663+00:00
INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:54.936425+00:00
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:55.982495+00:00
