# LangChain MongoDB Integration - Memory and Semantic Caching for RAG

This notebook is a companion to the [Memory and Semantic Caching](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/memory-semantic-cache/) tutorial. Refer to the page for set-up instructions and detailed explanations.

<a target="_blank" href="https://colab.research.google.com/github/mongodb/docs-notebooks/blob/main/ai-integrations/langchain-memory-and-semantic-caching.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
pip install --quiet --upgrade langchain langchain-community langchain-core langchain-mongodb langchain-voyageai langchain-openai pypdf

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "<openai-key>"
os.environ["VOYAGE_API_KEY"] = "<voyage-key>"
MONGODB_URI = "<connection-string>"

## Configure the Vector Store

In [None]:
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_voyageai import VoyageAIEmbeddings

# Use the voyage-3 embedding model
embedding_model = VoyageAIEmbeddings(model="voyage-3")

# Create the vector store
vector_store = MongoDBAtlasVectorSearch.from_connection_string(
   connection_string = MONGODB_URI,
   embedding = embedding_model,
   namespace = "langchain_db.rag_with_memory"
)

In [None]:
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load the PDF
loader = PyPDFLoader("https://investors.mongodb.com/node/13176/pdf")
data = loader.load()

# Split PDF into documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=20)
docs = text_splitter.split_documents(data)

# Add data to the vector store
vector_store.add_documents(docs)

In [None]:
# Use helper method to create the vector search index
vector_store.create_vector_search_index(
   dimensions = 1024,       # The dimensions of the vector embeddings to be indexed
   wait_until_complete = 60 # Number of seconds to wait for the index to build (can take around a minute)
)

## Implement RAG with Memory

In [None]:
from langchain_openai import ChatOpenAI

# Define the model to use for chat completion
llm = ChatOpenAI(model = "gpt-4o")

In [None]:
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import MessagesPlaceholder
         
# Define a function that gets the chat message history 
def get_session_history(session_id: str) -> MongoDBChatMessageHistory:
    return MongoDBChatMessageHistory(
        connection_string=MONGODB_URI,
        session_id=session_id,
        database_name="langchain_db",
        collection_name="rag_with_memory"
    )

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Create a prompt to generate standalone questions from follow-up questions
standalone_system_prompt = """
  Given a chat history and a follow-up question, rephrase the follow-up question to be a standalone question.
  Do NOT answer the question, just reformulate it if needed, otherwise return it as is.
  Only return the final standalone question.
"""

standalone_question_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", standalone_system_prompt),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ]
)
# Parse output as a string
parse_output = StrOutputParser()

question_chain = standalone_question_prompt | llm | parse_output

In [None]:
from langchain_core.runnables import RunnablePassthrough

# Create a retriever
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={ "k": 5 })

# Create a retriever chain that processes the question with history and retrieves documents
retriever_chain = RunnablePassthrough.assign(context=question_chain | retriever | (lambda docs: "\n\n".join([d.page_content for d in docs])))

In [None]:
# Create a prompt template that includes the retrieved context and chat history
rag_system_prompt = """Answer the question based only on the following context:
{context}
"""

rag_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", rag_system_prompt),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ]
)

In [None]:
# Build the RAG chain
rag_chain = (
    retriever_chain
    | rag_prompt
    | llm
    | parse_output
)

# Wrap the chain with message history
rag_with_memory = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="question",
    history_messages_key="history",
)

In [None]:
# First question
response_1 = rag_with_memory.invoke(
    {"question": "What was MongoDB's latest acquisition?"},
    {"configurable": {"session_id": "user_1"}}
)
print(response_1)

In [None]:
# Follow-up question that references the previous question
response_2 = rag_with_memory.invoke(
    {"question": "Why did they do it?"},
    {"configurable": {"session_id": "user_1"}}
)
print(response_2)

## Add Semantic Caching

The semantic cache caches only the input to the LLM. When using it in retrieval chains, 
note that documents retrieved can change between runs, resulting in cache misses for 
semantically similar queries.

In [None]:
from langchain_mongodb.cache import MongoDBAtlasSemanticCache
from langchain_core.globals import set_llm_cache

# Configure the semantic cache
set_llm_cache(MongoDBAtlasSemanticCache(
    connection_string = MONGODB_URI,
    database_name = "langchain_db",
    collection_name = "semantic_cache",
    embedding = embedding_model,
    index_name = "vector_index",
    similarity_threshold = 0.5  # Adjust based on your requirements
))

In [None]:
%%time

# First query (not cached)
rag_with_memory.invoke(
  {"question": "What was MongoDB's latest acquisition?"},
  {"configurable": {"session_id": "user_2"}}
)

In [None]:
%%time

# Second query (cached)
rag_with_memory.invoke(
  {"question": "What company did MongoDB acquire recently?"},
  {"configurable": {"session_id": "user_2"}}
)