[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mongodb-developer/GenAI-Showcase/blob/main/notebooks/advanced_techniques/langchain_parent_document_retrieval.ipynb)

[![View Article](https://img.shields.io/badge/View%20Article-blue)](https://www.mongodb.com/developer/products/atlas/parent-doc-retrieval/?utm_campaign=devrel&utm_source=cross-post&utm_medium=organic_social&utm_content=https%3A%2F%2Fgithub.com%2Fmongodb-developer%2FGenAI-Showcase&utm_term=apoorva.joshi)

# Parent Document Retrieval Using MongoDB and LangChain

This notebook shows you how to implement parent document retrieval in your RAG application using MongoDB's LangChain integration.

## Step 1: Install required libraries

- **datasets**: Python package to download datasets from Hugging Face

- **pymongo**: Python driver for MongoDB

- **langchain**: Python package for LangChain's core modules

- **langchain-openai**: Python package to use OpenAI models via LangChain

- **langgraph**: Python package to orchestrate LLM workflows as graphs

- **langchain-mongodb**: Python package to use MongoDB features in LangChain

- **langchain-openai**: Python package to use OpenAI models via LangChain

In [1]:
! pip install -qU datasets pymongo langchain langgraph langchain-mongodb langchain-openai

## Step 2: Setup prerequisites

- **Set the MongoDB connection string**: Follow the steps [here](https://www.mongodb.com/docs/manual/reference/connection-string/) to get the connection string from the Atlas UI.

- **Set the OpenAI API key**: Steps to obtain an API key are [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key)

- **Set the Hugging Face token**: Steps to create a token are [here](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-tokens). You only need **read** token for this tutorial.

In [2]:
import getpass
import os

from pymongo import MongoClient

In [3]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key:")

In [4]:
MONGODB_URI = getpass.getpass("Enter your MongoDB connection string:")
mongodb_client = MongoClient(
    MONGODB_URI, appname="devrel.showcase.parent_doc_retrieval"
)
mongodb_client.admin.command("ping")

{'ok': 1.0,
 '$clusterTime': {'clusterTime': Timestamp(1734037711, 1),
  'signature': {'hash': b'v\xa2\xc7\xf6\xc4\xc5z\x97%Q_\xc1\xa5\xaf}\x05(\x92\x80\xc2',
   'keyId': 7390069253761662978}},
 'operationTime': Timestamp(1734037711, 1)}

In [5]:
os.environ["HF_TOKEN"] = getpass.getpass("Enter your HF Access Token:")

## Step 3: Load the dataset

In [6]:
import pandas as pd
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
data = load_dataset("mongodb-eai/docs", streaming=True, split="train")
data_head = data.take(1000)
df = pd.DataFrame(data_head)

In [8]:
df.head()

Unnamed: 0,updated,_id,metadata,action,sourceName,body,url,format,title
0,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f895074208162'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# View Database Access History\n\n- This featu...,https://mongodb.com/docs/atlas/access-tracking/,md,View Database Access History
1,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f895074208178'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Manage Organization Teams\n\nYou can create ...,https://mongodb.com/docs/atlas/access/manage-t...,md,Manage Organization Teams
2,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f895074208183'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Manage Organizations\n\nIn the organizations...,https://mongodb.com/docs/atlas/access/orgs-cre...,md,Manage Organizations
3,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f89507420818f'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Alert Basics\n\nAtlas provides built-in tool...,https://mongodb.com/docs/atlas/alert-basics/,md,Alert Basics
4,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f89507420819d'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Resolve Alerts\n\nAtlas issues alerts for th...,https://mongodb.com/docs/atlas/alert-resolutions/,md,Resolve Alerts


## Step 4: Convert dataset to LangChain Documents

In [9]:
from langchain_core.documents import Document

In [10]:
docs = []
metadata_fields = ["updated", "url", "title"]
for _, row in df.iterrows():
    content = row["body"]
    metadata = row["metadata"]
    for field in metadata_fields:
        metadata[field] = row[field]
    docs.append(Document(page_content=content, metadata=metadata))

In [11]:
docs[0]

Document(metadata={'contentType': None, 'pageDescription': None, 'productName': 'MongoDB Atlas', 'tags': ['atlas', 'docs'], 'version': None, 'updated': {'$date': '2024-05-20T17:30:49.148Z'}, 'url': 'https://mongodb.com/docs/atlas/access-tracking/', 'title': 'View Database Access History'}, page_content='# View Database Access History\n\n- This feature is not available for `M0` free clusters, `M2`, and `M5` clusters. To learn more, see Atlas M0 (Free Cluster), M2, and M5 Limits.\n\n- This feature is not supported on Serverless instances at this time. To learn more, see Serverless Instance Limitations.\n\n## Overview\n\nAtlas parses the MongoDB database logs to collect a list of authentication requests made against your clusters through the following methods:\n\n- `mongosh`\n\n- Compass\n\n- Drivers\n\nAuthentication requests made with API Keys through the Atlas Administration API are not logged.\n\nAtlas logs the following information for each authentication request within the last 7 da

In [12]:
len(docs)

1000

## Step 5: Instantiate the retriever

In [16]:
from langchain_mongodb.retrievers import (
    MongoDBAtlasParentDocumentRetriever,
)
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [17]:
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

In [18]:
DB_NAME = "langchain"
COLLECTION_NAME = "parent_doc"

In [19]:
def get_splitter(chunk_size: int) -> RecursiveCharacterTextSplitter:
    """
    Returns a token-based text splitter with overlap

    Args:
        chunk_size (_type_): Chunk size in number of tokens

    Returns:
        RecursiveCharacterTextSplitter: Recursive text splitter object
    """
    return RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        encoding_name="cl100k_base",
        chunk_size=chunk_size,
        chunk_overlap=0.15 * chunk_size,
    )

### Parent document retriever

In [20]:
parent_doc_retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
    connection_string=MONGODB_URI,
    embedding_model=embedding_model,
    child_splitter=get_splitter(200),
    database_name=DB_NAME,
    collection_name=COLLECTION_NAME,
    text_key="page_content",
    search_kwargs={"top_k": 10},
)

In [None]:
# # Parent chunk retriever
# parent_chunk_retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
#     connection_string=MONGODB_URI,
#     embedding_model=embedding_model,
#     child_splitter=get_splitter(200),
#     parent_splitter=get_splitter(800),
#     database_name=DB_NAME,
#     collection_name=COLLECTION_NAME,
#     text_key="page_content",
#     search_kwargs={"top_k": 10},
# )

## Step 6: Ingest documents into MongoDB

In [22]:
import asyncio
from typing import Generator, List

In [23]:
BATCH_SIZE = 256
MAX_CONCURRENCY = 4

In [24]:
async def process_batch(batch: Generator, semaphore: asyncio.Semaphore) -> None:
    """
    Ingest batches of documents into MongoDB

    Args:
        batch (Generator): Chunk of documents to ingest
        semaphore (as): Asyncio semaphore
    """
    async with semaphore:
        await parent_doc_retriever.aadd_documents(batch)
        print(f"Processed {len(batch)} documents")

In [25]:
def get_batches(docs: List[Document], batch_size: int) -> Generator:
    """
    Return batches of documents to ingest into MongoDB

    Args:
        docs (List[Document]): List of LangChain documents
        batch_size (int): Batch size

    Yields:
        Generator: Batch of documents
    """
    for i in range(0, len(docs), batch_size):
        yield docs[i : i + batch_size]

In [26]:
async def process_docs(docs: List[Document]) -> List[None]:
    """
    Asynchronously ingest LangChain documents into MongoDB

    Args:
        docs (List[Document]): List of LangChain documents

    Returns:
        List[None]: Results of the task executions
    """
    semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
    batches = get_batches(docs, BATCH_SIZE)

    tasks = []
    for batch in batches:
        tasks.append(process_batch(batch, semaphore))
    # Gather results from all tasks
    results = await asyncio.gather(*tasks)
    return results

In [27]:
collection = mongodb_client[DB_NAME][COLLECTION_NAME]
# Delete any existing documents from the collection
collection.delete_many({})
print("Deletion complete.")
# Ingest LangChain documents into MongoDB
results = await process_docs(docs)

Deletion complete.
Processed 256 documents
Processed 256 documents
Processed 256 documents
Processed 232 documents


## Step 7: Create a vector search index

In [28]:
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel

In [29]:
VS_INDEX_NAME = "vector_index"

In [30]:
# Vector search index definition
model = SearchIndexModel(
    definition={
        "fields": [
            {
                "type": "vector",
                "path": "embedding",
                "numDimensions": 1536,
                "similarity": "cosine",
            }
        ]
    },
    name=VS_INDEX_NAME,
    type="vectorSearch",
)

In [None]:
# Check if the index already exists, if not create it
try:
    collection.create_search_index(model=model)
    print(
        f"Successfully created index {VS_INDEX_NAME} for collection {COLLECTION_NAME}"
    )
except OperationFailure:
    print(
        f"Duplicate index {VS_INDEX_NAME} found for collection {COLLECTION_NAME}. Skipping index creation."
    )

Successfully created index vector_index for collection parent_doc


## Step 8: Usage

### In a RAG application

In [31]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

In [32]:
# Retrieve and parse documents
retrieve = {
    "context": parent_doc_retriever
    | (lambda docs: "\n\n".join([d.page_content for d in docs])),
    "question": RunnablePassthrough(),
}
template = """Answer the question based only on the following context. If no context is provided, respond with I DON't KNOW: \
{context}

Question: {question}
"""
# Define the chat prompt
prompt = ChatPromptTemplate.from_template(template)
# Define the model to be used for chat completion
llm = ChatOpenAI(temperature=0, model="gpt-4o-2024-11-20")
# Parse output as a string
parse_output = StrOutputParser()
# Naive RAG chain
rag_chain = retrieve | prompt | llm | parse_output

In [33]:
# Test the RAG chain
print(rag_chain.invoke("How do I improve slow queries in MongoDB?"))

To improve slow queries in MongoDB, you can follow these steps:

1. **Use the Performance Advisor**:
   - The Performance Advisor monitors slow queries and suggests indexes to improve performance.
   - Create the suggested indexes, especially those with high Impact scores and low Average Query Targeting scores.

2. **Analyze Query Performance**:
   - Use the **Query Profiler** to identify slow-running operations and their key performance statistics.
   - Use the **Real-Time Performance Panel (RTPP)** to evaluate query execution times and the ratio of documents scanned to documents returned.
   - Use **Namespace Insights** to monitor collection-level query latency.

3. **Optimize Indexes**:
   - Create indexes that support your queries to reduce the time needed to search for results.
   - Remove unused or inefficient indexes to improve write performance and free storage space.
   - Perform rolling index builds to minimize performance impact on replica sets and sharded clusters.

4. **Fi

### In an AI agent

In [34]:
from typing import Annotated, Dict

from langchain.agents import tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict

In [35]:
# Converting the retriever into an agent tool
@tool
def get_info_about_mongodb(user_query: str) -> str:
    """
    Retrieve information about MongoDB.

    Args:
    user_query (str): The user's query string.

    Returns:
    str: The retrieved information formatted as a string.
    """
    docs = parent_doc_retriever.invoke(user_query)
    context = "\n\n".join([d.page_content for d in docs])
    return context

In [36]:
tools = [get_info_about_mongodb]

In [37]:
# Define the LLM to use as the brain of the agent
llm = ChatOpenAI(temperature=0, model="gpt-4o-2024-11-20")
# Agent prompt
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "You are a helpful AI assistant."
            " You are provided with tools to answer questions about MongoDB."
            " Think step-by-step and use these tools to get the information required to answer the user query."
            " Do not re-run tools unless absolutely necessary."
            " If you are not able to get enough information using the tools, reply with I DON'T KNOW."
            " You have access to the following tools: {tool_names}."
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
# Partial the prompt with tool names
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
# Bind tools to LLM
llm_with_tools = prompt | llm.bind_tools(tools)

In [38]:
# Define graph state
class GraphState(TypedDict):
    messages: Annotated[list, add_messages]

In [39]:
def agent(state: GraphState) -> Dict[str, List]:
    """
    Agent node

    Args:
        state (GraphState): Graph state

    Returns:
        Dict[str, List]: Updates to the graph state
    """
    messages = state["messages"]
    response = llm_with_tools.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

In [40]:
# Convert tools into a graph node
tool_node = ToolNode(tools)

In [41]:
# Parameterize the graph with the state
graph = StateGraph(GraphState)
# Add graph nodes
graph.add_node("agent", agent)
graph.add_node("tools", tool_node)
# Add graph edges
graph.add_edge(START, "agent")
graph.add_edge("tools", "agent")
graph.add_conditional_edges(
    "agent",
    tools_condition,
    {"tools": "tools", END: END},
)
# Compile the graph
app = graph.compile()

In [42]:
# Execute the agent and view outputs
inputs = {
    "messages": [
        ("user", "How do I improve slow queries in MongoDB?"),
    ]
}

for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Node {key}:")
        print(value)
print("---FINAL ANSWER---")
print(value["messages"][-1].content)

Node agent:
{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_sifH0mrhbpesQie4BTnQytNk', 'function': {'arguments': '{"user_query":"How do I improve slow queries in MongoDB?"}', 'name': 'get_info_about_mongodb'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 27, 'prompt_tokens': 165, 'total_tokens': 192, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-11-20', 'system_fingerprint': 'fp_d924043139', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-bc1db263-f4f5-40ba-a6ba-b18a3585e095-0', tool_calls=[{'name': 'get_info_about_mongodb', 'args': {'user_query': 'How do I improve slow queries in MongoDB?'}, 'id': 'call_sifH0mrhbpesQie4BTnQytNk', 'type': 'tool_call'}], usage_metadata={'input_tokens': 165, 'output_tokens': 2