In [3]:
from dotenv import load_dotenv

load_dotenv()

True

In [4]:
from typing import Annotated

from typing_extensions import TypedDict

from langgraph.graph.message import add_messages


class State(TypedDict):
    messages: Annotated[list, add_messages]

In [5]:
from langchain_core.tools import tool


@tool
def search(query: str):
    """Call to surf the web."""
    # This is a placeholder for the actual implementation
    return ["The answer to your question lies within."]


tools = [search]

In [6]:
from langchain_openai import ChatOpenAI


model = ChatOpenAI(temperature=0, streaming=True)

bound_model = model.bind_tools(tools)

In [7]:
from langgraph.prebuilt import ToolNode

tool_node = ToolNode(tools)

In [8]:
from typing import Literal


def should_continue(state: State) -> Literal["action", "__end__"]:
    """Return the next node to execute."""
    last_message = state["messages"][-1]
    if not last_message.tool_calls:
        return "__end__"
    return "action"


def call_model(state: State):
    response = model.invoke(state["messages"])
    return {"messages": response}

In [9]:
import psycopg2
from psycopg2 import sql
from contextlib import contextmanager
from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointTuple
from langchain_core.runnables import RunnableConfig
from typing import Optional, Iterator, AsyncIterator
from datetime import datetime, timezone


class PostgresSaver(BaseCheckpointSaver):
    def __init__(self, connection):
        self.connection = connection

    @classmethod
    def from_conn_string(cls, conn_string):
        connection = psycopg2.connect(conn_string)
        return cls(connection)

    @contextmanager
    def cursor(self):
        """Provide a transactional scope around a series of operations."""
        cursor = self.connection.cursor()
        try:
            yield cursor
            self.connection.commit()
        except Exception as e:
            self.connection.rollback()
            raise e
        finally:
            cursor.close()

    def setup(self) -> None:
        with self.cursor() as cursor:
            create_table_query = """
            CREATE TABLE IF NOT EXISTS checkpoints (
                thread_id TEXT NOT NULL,
                thread_ts TEXT NOT NULL,
                parent_ts TEXT,
                checkpoint BYTEA,
                metadata BYTEA,
                PRIMARY KEY (thread_id, thread_ts)
            );
            """
            cursor.execute(create_table_query)

    def get_latest_timestamp(self, thread_id: str) -> str:
        with self.cursor() as cursor:
            select_query = sql.SQL(
                "SELECT thread_ts FROM checkpoints WHERE thread_id = %s ORDER BY thread_ts DESC LIMIT 1"
            )
            cursor.execute(select_query, (thread_id,))
            result = cursor.fetchone()
            return result[0] if result else None

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        thread_id = config["configurable"]["thread_id"]
        thread_ts = config["configurable"].get(
            "thread_ts", self.get_latest_timestamp(thread_id)
        )

        with self.cursor() as cursor:
            select_query = sql.SQL(
                "SELECT checkpoint, metadata, parent_ts FROM checkpoints WHERE thread_id = %s AND thread_ts = %s"
            )
            cursor.execute(select_query, (thread_id, thread_ts))
            result = cursor.fetchone()
            if result:
                checkpoint, metadata, parent_ts = result
                return CheckpointTuple(
                    config,
                    self.serde.loads(bytes(checkpoint)),
                    self.serde.loads(bytes(metadata)),
                    (
                        {
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": parent_ts,
                            }
                        }
                        if parent_ts
                        else None
                    ),
                )
        return None

    def list(
        self,
        config: RunnableConfig,
        *,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> Iterator[CheckpointTuple]:
        thread_id = config["configurable"]["thread_id"]
        query = """
            SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata
            FROM checkpoints
            WHERE thread_id = %s
        """
        params = [thread_id]
        if before:
            query += " AND thread_ts < %s"
            params.append(before["configurable"]["thread_ts"])
        query += " ORDER BY thread_ts DESC"
        if limit:
            query += f" LIMIT {limit}"

        with self.cursor() as cursor:
            cursor.execute(query, params)
            for thread_id, thread_ts, parent_ts, checkpoint, metadata in cursor:
                yield CheckpointTuple(
                    {"configurable": {"thread_id": thread_id, "thread_ts": thread_ts}},
                    self.serde.loads(bytes(checkpoint)),
                    self.serde.loads(bytes(metadata)) if metadata else {},
                    (
                        {
                            "configurable": {
                                "thread_id": thread_id,
                                "thread_ts": parent_ts,
                            }
                        }
                        if parent_ts
                        else None
                    ),
                )

    def put(
        self, config: RunnableConfig, checkpoint: dict, metadata: dict
    ) -> RunnableConfig:
        thread_id = config["configurable"]["thread_id"]
        thread_ts = datetime.now(timezone.utc).isoformat()
        parent_ts = config["configurable"].get("thread_ts")

        with self.cursor() as cursor:
            insert_query = sql.SQL(
                """
                INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata)
                VALUES (%s, %s, %s, %s, %s)
                ON CONFLICT (thread_id, thread_ts) DO UPDATE
                SET parent_ts = EXCLUDED.parent_ts, checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata
                """
            )
            cursor.execute(
                insert_query,
                (
                    thread_id,
                    thread_ts,
                    parent_ts,
                    self.serde.dumps(checkpoint),
                    self.serde.dumps(metadata),
                ),
            )

        return {
            "configurable": {
                "thread_id": thread_id,
                "thread_ts": thread_ts,
            }
        }

    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        return self.get_tuple(config)

    async def alist(
        self,
        config: RunnableConfig,
        *,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None,
    ) -> AsyncIterator[CheckpointTuple]:
        for checkpoint_tuple in self.list(config, before=before, limit=limit):
            yield checkpoint_tuple

    async def aput(
        self, config: RunnableConfig, checkpoint: dict, metadata: dict
    ) -> RunnableConfig:
        return self.put(config, checkpoint, metadata)

    def close(self):
        self.connection.close()

In [17]:
from langgraph.graph import StateGraph

graph = StateGraph(State)

graph.add_node("agent", call_model)
graph.add_node("action", tool_node)

graph.set_entry_point("agent")

graph.add_conditional_edges(
    "agent",
    should_continue,
)

graph.add_edge("action", "agent")

In [18]:
conn_string = (
    "dbname=vdb user=vdb password=vdb host=localhost port=5432"
)
psycopg2.connect(database="vdb", user="vdb", password="vdb", host="localhost", port=5432)


memory = PostgresSaver.from_conn_string(conn_string)

runnable = graph.compile(checkpointer=memory)

In [19]:
from langchain_core.messages import HumanMessage

config = {"configurable": {"thread_id": "1"}}
input_message = HumanMessage(content="Hello, I am John")

runnable.invoke({"messages": input_message}, config=config)

{'messages': [HumanMessage(content='Hello, I am John', id='73ef9d02-a794-4036-9fe7-8c7462c5df80'),
  AIMessage(content='Nice to meet you, John! How can I assist you today?', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-441fcda9-307c-48f8-9012-9d3404528fde-0')]}

In [20]:
from langchain_core.messages import HumanMessage

config = {"configurable": {"thread_id": "42"}}
input_message = HumanMessage(content="Did I already introduce myself?")

runnable.invoke({"messages": input_message}, config=config)

{'messages': [HumanMessage(content='Did I already introduce myself?', id='b1d00924-718f-4eee-b59b-60e3c7eb3ac4'),
  AIMessage(content="I'm sorry, I don't have the ability to remember previous interactions. Please go ahead and introduce yourself again.", response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-6d6d7f45-e98c-4648-b8a2-ebf0e2effa9b-0')]}

In [21]:
from langchain_core.messages import HumanMessage

config = {"configurable": {"thread_id": "1"}}
input_message = HumanMessage(content="Did I already introduce myself?")

runnable.invoke({"messages": input_message}, config=config)

{'messages': [HumanMessage(content='Hello, I am John', id='73ef9d02-a794-4036-9fe7-8c7462c5df80'),
  AIMessage(content='Nice to meet you, John! How can I assist you today?', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-441fcda9-307c-48f8-9012-9d3404528fde-0'),
  HumanMessage(content='Did I already introduce myself?', id='2658d326-086b-421a-82be-921a17cab00a'),
  AIMessage(content='Yes, you introduced yourself as John. Is there anything else you would like to know or discuss?', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-768f9ca9-115d-48e9-9336-85110d869728-0')]}

### How to manage memory?

In [23]:
import psycopg2
from psycopg2 import sql
from contextlib import contextmanager


class MemoryManager:
    def __init__(self, conn_string):
        self.conn_string = conn_string

    @contextmanager
    def connection(self):
        """Provide a transactional scope around a series of operations."""
        connection = psycopg2.connect(self.conn_string)
        try:
            yield connection
            connection.commit()
        except Exception as e:
            connection.rollback()
            raise e
        finally:
            connection.close()

    @contextmanager
    def cursor(self):
        """Provide a cursor for database operations."""
        with self.connection() as connection:
            cursor = connection.cursor()
            try:
                yield cursor
            finally:
                cursor.close()

    def delete_by_thread_id(self, thread_id: str) -> None:
        """Delete memory based on thread ID.

        This method deletes entries from the checkpoints table where the thread_id matches
        the specified value.

        Args:
            thread_id (str): The thread ID for which the memory should be deleted.
        """
        with self.cursor() as cursor:
            delete_query = sql.SQL("DELETE FROM checkpoints WHERE thread_id = %s")
            cursor.execute(delete_query, (thread_id,))

    def count_checkpoints_by_thread_id(self) -> None:
        """Count the number of checkpoints for each thread ID.

        This method retrieves the count of checkpoints grouped by thread_id and prints
        the result.

        Returns:
            None
        """
        with self.cursor() as cursor:
            count_query = """
            SELECT thread_id, COUNT(*) AS count
            FROM checkpoints
            GROUP BY thread_id
            ORDER BY thread_id;
            """
            cursor.execute(count_query)
            results = cursor.fetchall()
            print("Checkpoint counts by thread ID:")
            for row in results:
                print(f"Thread ID: {row[0]}, Count: {row[1]}")

    def delete_all(self) -> None:
        """Delete all memory.

        This method deletes all entries from the checkpoints table.
        """
        with self.cursor() as cursor:
            delete_query = "DELETE FROM checkpoints"
            cursor.execute(delete_query)

In [24]:
conn_string = (
    "dbname=vdb user=vdb password=vdb host=localhost port=5432  "
)
memory_manager = MemoryManager(conn_string)

In [25]:
memory_manager.count_checkpoints_by_thread_id()

Checkpoint counts by thread ID:
Thread ID: 1, Count: 6
Thread ID: 42, Count: 3


In [26]:
thread_id_to_delete = "1"
memory_manager.delete_by_thread_id(thread_id_to_delete)

In [27]:
memory_manager.count_checkpoints_by_thread_id()

Checkpoint counts by thread ID:
Thread ID: 42, Count: 3


In [28]:
memory_manager.delete_all()

In [29]:
memory_manager.count_checkpoints_by_thread_id()

Checkpoint counts by thread ID:
