From c314222796798545f168f6ff7e750eb24e8edd51 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 25 Jun 2024 23:17:10 -0400 Subject: [PATCH] Add a conversation memory that combines a (optionally persistent) vectorstore history with a token buffer (#22155) **langchain: ConversationVectorStoreTokenBufferMemory** -**Description:** This PR adds ConversationVectorStoreTokenBufferMemory. It is similar in concept to ConversationSummaryBufferMemory. It maintains an in-memory buffer of messages up to a preset token limit. After the limit is hit timestamped messages are written into a vectorstore retriever rather than into a summary. The user's prompt is then used to retrieve relevant fragments of the previous conversation. By persisting the vectorstore, one can maintain memory from session to session. -**Issue:** n/a -**Dependencies:** none -**Twitter handle:** Please no!!! - [X] **Add tests and docs**: I looked to see how the unit tests were written for the other ConversationMemory modules, but couldn't find anything other than a test for successful import. I need to know whether you are using pytest.mock or another fixture to simulate the LLM and vectorstore. In addition, I would like guidance on where to place the documentation. Should it be a notebook file in docs/docs? - [X] **Lint and test**: I am seeing some linting errors from a couple of modules unrelated to this PR. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Lincoln Stein Co-authored-by: isaac hershenson --- .../integration_template/document_loaders.py | 1 + libs/langchain/langchain/memory/__init__.py | 4 + .../memory/vectorstore_token_buffer_memory.py | 184 ++++++++++++++++++ .../tests/unit_tests/memory/test_imports.py | 1 + 4 files changed, 190 insertions(+) create mode 100644 libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py diff --git a/libs/cli/langchain_cli/integration_template/integration_template/document_loaders.py b/libs/cli/langchain_cli/integration_template/integration_template/document_loaders.py index ecf044f71e2162..62269b8d5d67a5 100644 --- a/libs/cli/langchain_cli/integration_template/integration_template/document_loaders.py +++ b/libs/cli/langchain_cli/integration_template/integration_template/document_loaders.py @@ -1,6 +1,7 @@ """__ModuleName__ document loader.""" from typing import Iterator + from langchain_core.document_loaders.base import BaseLoader from langchain_core.documents import Document diff --git a/libs/langchain/langchain/memory/__init__.py b/libs/langchain/langchain/memory/__init__.py index e59cbf1cef918e..296b1f7d28391d 100644 --- a/libs/langchain/langchain/memory/__init__.py +++ b/libs/langchain/langchain/memory/__init__.py @@ -48,6 +48,9 @@ from langchain.memory.summary_buffer import ConversationSummaryBufferMemory from langchain.memory.token_buffer import ConversationTokenBufferMemory from langchain.memory.vectorstore import VectorStoreRetrieverMemory +from langchain.memory.vectorstore_token_buffer_memory import ( + ConversationVectorStoreTokenBufferMemory, # avoid circular import +) if TYPE_CHECKING: from langchain_community.chat_message_histories import ( @@ -122,6 +125,7 @@ def __getattr__(name: str) -> Any: "ConversationSummaryBufferMemory", "ConversationSummaryMemory", "ConversationTokenBufferMemory", + "ConversationVectorStoreTokenBufferMemory", "CosmosDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", diff --git a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py new file mode 100644 index 00000000000000..0995bb3e34a672 --- /dev/null +++ b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py @@ -0,0 +1,184 @@ +""" +Class for a conversation memory buffer with older messages stored in a vectorstore . + +This implementats a conversation memory in which the messages are stored in a memory +buffer up to a specified token limit. When the limit is exceeded, older messages are +saved to a vectorstore backing database. The vectorstore can be made persistent across +sessions. +""" + +import warnings +from datetime import datetime +from typing import Any, Dict, List + +from langchain_core.messages import BaseMessage +from langchain_core.prompts.chat import SystemMessagePromptTemplate +from langchain_core.pydantic_v1 import Field, PrivateAttr +from langchain_core.vectorstores import VectorStoreRetriever + +from langchain.memory import ConversationTokenBufferMemory, VectorStoreRetrieverMemory +from langchain.memory.chat_memory import BaseChatMemory +from langchain.text_splitter import RecursiveCharacterTextSplitter + +DEFAULT_HISTORY_TEMPLATE = """ +Current date and time: {current_time}. + +Potentially relevant timestamped excerpts of previous conversations (you +do not need to use these if irrelevant): +{previous_history} + +""" + +TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S %Z" + + +class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory): + """Conversation chat memory with token limit and vectordb backing. + + load_memory_variables() will return a dict with the key "history". + It contains background information retrieved from the vector store + plus recent lines of the current conversation. + + To help the LLM understand the part of the conversation stored in the + vectorstore, each interaction is timestamped and the current date and + time is also provided in the history. A side effect of this is that the + LLM will have access to the current date and time. + + Initialization arguments: + + This class accepts all the initialization arguments of + ConversationTokenBufferMemory, such as `llm`. In addition, it + accepts the following additional arguments + + retriever: (required) A VectorStoreRetriever object to use + as the vector backing store + + split_chunk_size: (optional, 1000) Token chunk split size + for long messages generated by the AI + + previous_history_template: (optional) Template used to format + the contents of the prompt history + + + Example using ChromaDB: + + .. code-block:: python + + from langchain.memory.token_buffer_vectorstore_memory import ( + ConversationVectorStoreTokenBufferMemory + ) + from langchain_community.vectorstores import Chroma + from langchain_community.embeddings import HuggingFaceInstructEmbeddings + from langchain_openai import OpenAI + + embedder = HuggingFaceInstructEmbeddings( + query_instruction="Represent the query for retrieval: " + ) + chroma = Chroma(collection_name="demo", + embedding_function=embedder, + collection_metadata={"hnsw:space": "cosine"}, + ) + + retriever = chroma.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + 'k': 5, + 'score_threshold': 0.75, + }, + ) + + conversation_memory = ConversationVectorStoreTokenBufferMemory( + return_messages=True, + llm=OpenAI(), + retriever=retriever, + max_token_limit = 1000, + ) + + conversation_memory.save_context({"Human": "Hi there"}, + {"AI": "Nice to meet you!"} + ) + conversation_memory.save_context({"Human": "Nice day isn't it?"}, + {"AI": "I love Wednesdays."} + ) + conversation_memory.load_memory_variables({"input": "What time is it?"}) + + """ + + retriever: VectorStoreRetriever = Field(exclude=True) + memory_key: str = "history" + previous_history_template: str = DEFAULT_HISTORY_TEMPLATE + split_chunk_size: int = 1000 + + _memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None) + _timestamps: List[datetime] = PrivateAttr(default_factory=list) + + @property + def memory_retriever(self) -> VectorStoreRetrieverMemory: + """Return a memory retriever from the passed retriever object.""" + if self._memory_retriever is not None: + return self._memory_retriever + self._memory_retriever = VectorStoreRetrieverMemory(retriever=self.retriever) + return self._memory_retriever + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return history and memory buffer.""" + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + memory_variables = self.memory_retriever.load_memory_variables(inputs) + previous_history = memory_variables[self.memory_retriever.memory_key] + except AssertionError: # happens when db is empty + previous_history = "" + current_history = super().load_memory_variables(inputs) + template = SystemMessagePromptTemplate.from_template( + self.previous_history_template + ) + messages = [ + template.format( + previous_history=previous_history, + current_time=datetime.now().astimezone().strftime(TIMESTAMP_FORMAT), + ) + ] + messages.extend(current_history[self.memory_key]) + return {self.memory_key: messages} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save context from this conversation to buffer. Pruned.""" + BaseChatMemory.save_context(self, inputs, outputs) + self._timestamps.append(datetime.now().astimezone()) + # Prune buffer if it exceeds max token limit + buffer = self.chat_memory.messages + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + if curr_buffer_length > self.max_token_limit: + while curr_buffer_length > self.max_token_limit: + self._pop_and_store_interaction(buffer) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + + def save_remainder(self) -> None: + """ + Save the remainder of the conversation buffer to the vector store. + + This is useful if you have made the vectorstore persistent, in which + case this can be called before the end of the session to store the + remainder of the conversation. + """ + buffer = self.chat_memory.messages + while len(buffer) > 0: + self._pop_and_store_interaction(buffer) + + def _pop_and_store_interaction(self, buffer: List[BaseMessage]) -> None: + input = buffer.pop(0) + output = buffer.pop(0) + timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT) + # Split AI output into smaller chunks to avoid creating documents + # that will overflow the context window + ai_chunks = self._split_long_ai_text(str(output.content)) + for index, chunk in enumerate(ai_chunks): + self.memory_retriever.save_context( + {"Human": f"<{timestamp}/00> {str(input.content)}"}, + {"AI": f"<{timestamp}/{index:02}> {chunk}"}, + ) + + def _split_long_ai_text(self, text: str) -> List[str]: + splitter = RecursiveCharacterTextSplitter(chunk_size=self.split_chunk_size) + return [chunk.page_content for chunk in splitter.create_documents([text])] diff --git a/libs/langchain/tests/unit_tests/memory/test_imports.py b/libs/langchain/tests/unit_tests/memory/test_imports.py index a42684a46e045f..bba351d8ad0f88 100644 --- a/libs/langchain/tests/unit_tests/memory/test_imports.py +++ b/libs/langchain/tests/unit_tests/memory/test_imports.py @@ -13,6 +13,7 @@ "ConversationSummaryBufferMemory", "ConversationSummaryMemory", "ConversationTokenBufferMemory", + "ConversationVectorStoreTokenBufferMemory", "CosmosDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory",