# Parent Document Retrieval Using MongoDB and LangChain

This notebook shows you how to implement parent document retrieval in your RAG application using MongoDB's LangChain integration.

## Step 1: Install required libraries

- **datasets**: Python package to download datasets from Hugging Face

- **pymongo**: Python driver for MongoDB

- **langchain**: Python package for LangChain's core modules

- **langchain-openai**: Python package to use OpenAI models via LangChain

- **langgraph**: Python package to orchestrate LLM workflows as graphs

- **langchain-mongodb**: Python package to use MongoDB features in LangChain

- **langchain-openai**: Python package to use OpenAI models via LangChain

In [None]:
! pip install -qU datasets pymongo langchain langgraph langchain-mongodb langchain-openai

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m50.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/137.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.2/137.2 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/50.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Step 2: Setup prerequisites

- **Set the MongoDB connection string**: Follow the steps [here](https://www.mongodb.com/docs/manual/reference/connection-string/) to get the connection string from the Atlas UI.

- **Set the OpenAI API key**: Steps to obtain an API key are [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key)

- **Set the Hugging Face token**: Steps to create a token are [here](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-tokens). You only need **read** token for this tutorial.

In [None]:
import getpass
import os

from pymongo import MongoClient

In [None]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key:")

Enter your OpenAI API Key:··········


In [None]:
MONGODB_URI = getpass.getpass("Enter your MongoDB connection string:")
mongodb_client = MongoClient(
    MONGODB_URI, appname="devrel.showcase.parent_doc_retrieval"
)
mongodb_client.admin.command("ping")

Enter your MongoDB connection string:··········


{'ok': 1}

In [None]:
os.environ["HF_TOKEN"] = getpass.getpass("Enter your HF Access Token:")

Enter your HF Access Token:··········


## Step 3: Load the dataset

In [None]:
import pandas as pd
from datasets import load_dataset

In [None]:
from datasets import load_dataset
import pandas as pd

data = load_dataset("mongodb-eai/docs", streaming=True, split="train")
data_head = data.take(500)
df = pd.DataFrame(data_head)

In [None]:
df.head()

Unnamed: 0,_id,sourceName,url,action,body,format,metadata,title,updated
0,{'$oid': '664b88c96e4f895074208162'},snooty-cloud-docs,https://mongodb.com/docs/atlas/access-tracking/,created,# View Database Access History\n\n- This featu...,md,"{'contentType': None, 'pageDescription': None,...",View Database Access History,{'$date': '2024-05-20T17:30:49.148Z'}
1,{'$oid': '664b88c96e4f895074208178'},snooty-cloud-docs,https://mongodb.com/docs/atlas/access/manage-t...,created,# Manage Organization Teams\n\nYou can create ...,md,"{'contentType': None, 'pageDescription': None,...",Manage Organization Teams,{'$date': '2024-05-20T17:30:49.148Z'}
2,{'$oid': '664b88c96e4f895074208183'},snooty-cloud-docs,https://mongodb.com/docs/atlas/access/orgs-cre...,created,# Manage Organizations\n\nIn the organizations...,md,"{'contentType': None, 'pageDescription': None,...",Manage Organizations,{'$date': '2024-05-20T17:30:49.148Z'}
3,{'$oid': '664b88c96e4f89507420818f'},snooty-cloud-docs,https://mongodb.com/docs/atlas/alert-basics/,created,# Alert Basics\n\nAtlas provides built-in tool...,md,"{'contentType': None, 'pageDescription': None,...",Alert Basics,{'$date': '2024-05-20T17:30:49.148Z'}
4,{'$oid': '664b88c96e4f89507420819d'},snooty-cloud-docs,https://mongodb.com/docs/atlas/alert-resolutions/,created,# Resolve Alerts\n\nAtlas issues alerts for th...,md,"{'contentType': None, 'pageDescription': None,...",Resolve Alerts,{'$date': '2024-05-20T17:30:49.148Z'}


## Step 4: Convert dataset to LangChain Documents

In [None]:
from langchain_core.documents import Document

In [None]:
docs = []
metadata_fields = ["updated", "url", "title"]
for _, row in df.iterrows():
    content = row["body"]
    metadata = row["metadata"]
    for field in metadata_fields:
        metadata[field] = row[field]
    docs.append(Document(page_content=content, metadata=metadata))

In [None]:
docs[0]

Document(metadata={'contentType': None, 'pageDescription': None, 'productName': 'MongoDB Atlas', 'tags': ['atlas', 'docs'], 'version': None, 'updated': {'$date': '2024-05-20T17:30:49.148Z'}, 'url': 'https://mongodb.com/docs/atlas/access-tracking/', 'title': 'View Database Access History'}, page_content='# View Database Access History\n\n- This feature is not available for `M0` free clusters, `M2`, and `M5` clusters. To learn more, see Atlas M0 (Free Cluster), M2, and M5 Limits.\n\n- This feature is not supported on Serverless instances at this time. To learn more, see Serverless Instance Limitations.\n\n## Overview\n\nAtlas parses the MongoDB database logs to collect a list of authentication requests made against your clusters through the following methods:\n\n- `mongosh`\n\n- Compass\n\n- Drivers\n\nAuthentication requests made with API Keys through the Atlas Administration API are not logged.\n\nAtlas logs the following information for each authentication request within the last 7 da

In [None]:
len(docs)

500

## Step 5: Instantiate the retriever

In [None]:
from langchain_mongodb.retrievers import (
    MongoDBAtlasParentDocumentRetriever,
)
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [None]:
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

In [None]:
DB_NAME = "langchain"
COLLECTION_NAME = "parent_doc"

In [None]:
def get_splitter(chunk_size: int) -> RecursiveCharacterTextSplitter:
    """
    Returns a token-based text splitter with overlap

    Args:
        chunk_size (_type_): Chunk size in number of tokens

    Returns:
        RecursiveCharacterTextSplitter: Recursive text splitter object
    """
    return RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        encoding_name="cl100k_base",
        chunk_size=chunk_size,
        chunk_overlap=0.15 * chunk_size,
    )

### Parent document retriever

In [None]:
parent_doc_retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
    connection_string=MONGODB_URI,
    embedding_model=embedding_model,
    child_splitter=get_splitter(200),
    database_name=DB_NAME,
    collection_name=COLLECTION_NAME,
    text_key="page_content",
    search_kwargs={"k": 10},
)

In [None]:
# Parent chunk retriever
parent_chunk_retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
     connection_string=MONGODB_URI,
     embedding_model=embedding_model,
     child_splitter=get_splitter(200),
     parent_splitter=get_splitter(800),
     database_name=DB_NAME,
     collection_name=COLLECTION_NAME,
     text_key="page_content",
     search_kwargs={"k": 10},
 )

## Step 6: Ingest documents into MongoDB

In [None]:
import asyncio
from typing import Generator, List

In [None]:
BATCH_SIZE = 256
MAX_CONCURRENCY = 4

In [None]:
async def process_batch(batch: Generator, semaphore: asyncio.Semaphore) -> None:
    """
    Ingest batches of documents into MongoDB

    Args:
        batch (Generator): Chunk of documents to ingest
        semaphore (as): Asyncio semaphore
    """
    async with semaphore:
        await parent_doc_retriever.aadd_documents(batch)
        print(f"Processed {len(batch)} documents")

In [None]:
def get_batches(docs: List[Document], batch_size: int) -> Generator:
    """
    Return batches of documents to ingest into MongoDB

    Args:
        docs (List[Document]): List of LangChain documents
        batch_size (int): Batch size

    Yields:
        Generator: Batch of documents
    """
    for i in range(0, len(docs), batch_size):
        yield docs[i : i + batch_size]

In [None]:
async def process_docs(docs: List[Document]) -> List[None]:
    """
    Asynchronously ingest LangChain documents into MongoDB

    Args:
        docs (List[Document]): List of LangChain documents

    Returns:
        List[None]: Results of the task executions
    """
    semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
    batches = get_batches(docs, BATCH_SIZE)

    tasks = []
    for batch in batches:
        tasks.append(process_batch(batch, semaphore))
    # Gather results from all tasks
    results = await asyncio.gather(*tasks)
    return results

In [None]:
collection = mongodb_client[DB_NAME][COLLECTION_NAME]
# Delete any existing documents from the collection
collection.delete_many({})
print("Deletion complete.")

Deletion complete.


In [None]:
collection = mongodb_client[DB_NAME][COLLECTION_NAME]
# Delete any existing documents from the collection
collection.delete_many({})
print("Deletion complete.")
# Ingest LangChain documents into MongoDB
results = await process_docs(docs)

Deletion complete.
Processed 244 documents
Processed 256 documents


## Step 7: Create a vector search index

In [None]:
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel

In [None]:
VS_INDEX_NAME = "vector_index"

In [None]:
# Vector search index definition
model = SearchIndexModel(
    definition={
        "fields": [
            {
                "type": "vector",
                "path": "embedding",
                "numDimensions": 1536,
                "similarity": "cosine",
            }
        ]
    },
    name=VS_INDEX_NAME,
    type="vectorSearch",
)

In [None]:
# Check if the index already exists, if not create it
try:
    collection.create_search_index(model=model)
    print(
        f"Successfully created index {VS_INDEX_NAME} for collection {COLLECTION_NAME}"
    )
except OperationFailure:
    print(
        f"Duplicate index {VS_INDEX_NAME} found for collection {COLLECTION_NAME}. Skipping index creation."
    )

Successfully created index vector_index for collection parent_doc


## Step 8: Usage

### In a RAG application

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

In [None]:
# Retrieve and parse documents
retrieve = {
    "context": parent_doc_retriever
    | (lambda docs: "\n\n".join([d.page_content for d in docs])),
    "question": RunnablePassthrough(),
}
template = """Answer the question based only on the following context. If no context is provided, respond with I DON't KNOW: \
{context}

Question: {question}
"""
# Define the chat prompt
prompt = ChatPromptTemplate.from_template(template)
# Define the model to be used for chat completion
llm = ChatOpenAI(temperature=0, model="gpt-4o-2024-11-20")
# Parse output as a string
parse_output = StrOutputParser()
# Naive RAG chain
rag_chain = retrieve | prompt | llm | parse_output

In [None]:
# Test the RAG chain
print(rag_chain.invoke("How do I improve slow queries in MongoDB?"))

To improve slow queries in MongoDB, you can use the following tools and best practices:

### Tools:
1. **Performance Advisor**:
   - Monitors slow queries and suggests new indexes to improve query performance.
   - Provides recommendations for index ranking and dropping unused indexes.

2. **Namespace Insights**:
   - Monitors collection-level query latency and provides query latency metrics and statistics.

3. **Query Profiler**:
   - Displays slow-running operations and their key performance statistics.
   - Allows exploration of a sample of historical queries for up to the last 24 hours.

4. **Real-Time Performance Panel (RTPP)**:
   - Identifies relevant database operations, evaluates query execution times, and shows the ratio of documents scanned to documents returned during query execution.

### Best Practices:
1. Create queries that are supported by your current indexes to reduce search time.
2. Avoid creating documents with large array fields that require extensive processing.


### In an AI agent

In [None]:
from typing import Annotated, Dict

from langchain.agents import tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict

In [None]:
# Converting the retriever into an agent tool
@tool
def get_info_about_mongodb(user_query: str) -> str:
    """
    Retrieve information about MongoDB.

    Args:
    user_query (str): The user's query string.

    Returns:
    str: The retrieved information formatted as a string.
    """
    docs = parent_doc_retriever.invoke(user_query)
    context = "\n\n".join([d.page_content for d in docs])
    return context

In [None]:
tools = [get_info_about_mongodb]

In [None]:
# Define the LLM to use as the brain of the agent
llm = ChatOpenAI(temperature=0, model="gpt-4o-2024-11-20")
# Agent prompt
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "You are a helpful AI assistant."
            " You are provided with tools to answer questions about MongoDB."
            " Think step-by-step and use these tools to get the information required to answer the user query."
            " Do not re-run tools unless absolutely necessary."
            " If you are not able to get enough information using the tools, reply with I DON'T KNOW."
            " You have access to the following tools: {tool_names}."
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
# Partial the prompt with tool names
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
# Bind tools to LLM
llm_with_tools = prompt | llm.bind_tools(tools)

In [None]:
# Define graph state
class GraphState(TypedDict):
    messages: Annotated[list, add_messages]

In [None]:
def agent(state: GraphState) -> Dict[str, List]:
    """
    Agent node

    Args:
        state (GraphState): Graph state

    Returns:
        Dict[str, List]: Updates to the graph state
    """
    messages = state["messages"]
    response = llm_with_tools.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

In [None]:
# Convert tools into a graph node
tool_node = ToolNode(tools)

In [None]:
# Parameterize the graph with the state
graph = StateGraph(GraphState)
# Add graph nodes
graph.add_node("agent", agent)
graph.add_node("tools", tool_node)
# Add graph edges
graph.add_edge(START, "agent")
graph.add_edge("tools", "agent")
graph.add_conditional_edges(
    "agent",
    tools_condition,
    {"tools": "tools", END: END},
)
# Compile the graph
app = graph.compile()

In [None]:
# Execute the agent and view outputs
inputs = {
    "messages": [
        ("user", "How do I improve slow queries in MongoDB?"),
    ]
}

for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Node {key}:")
        print(value)
print("---FINAL ANSWER---")
print(value["messages"][-1].content)

Node agent:
{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_TXveSn3gsgMOJc3sHCuq1Xj2', 'function': {'arguments': '{"user_query":"How do I improve slow queries in MongoDB?"}', 'name': 'get_info_about_mongodb'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 28, 'prompt_tokens': 165, 'total_tokens': 193, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-11-20', 'system_fingerprint': 'fp_ec7eab8ec3', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-b041acdf-2c87-4977-888e-0a8ebd4b7600-0', tool_calls=[{'name': 'get_info_about_mongodb', 'args': {'user_query': 'How do I improve slow queries in MongoDB?'}, 'id': 'call_TXveSn3gsgMOJc3sHCuq1Xj2', 'type': 'tool_call'}], usage_metadata={'input_tokens': 165, 'output_tokens': 2