[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mongodb-developer/GenAI-Showcase/blob/main/notebooks/rag/langchain_parent_document_retrieval.ipynb)

[![View Article](https://img.shields.io/badge/View%20Article-blue)](https://www.mongodb.com/developer/products/atlas/advanced-rag-parent-doc-retrieval/?utm_campaign=devrel&utm_source=cross-post&utm_medium=organic_social&utm_content=https%3A%2F%2Fgithub.com%2Fmongodb-developer%2FGenAI-Showcase&utm_term=apoorva.joshi)

# Parent Document Retrieval Using MongoDB and LangChain

This notebook shows you how to implement parent document retrieval in your RAG application using MongoDB's LangChain integration.

## Step 1: Install required libraries

- **datasets**: Python package to download datasets from Hugging Face

- **pymongo**: Python driver for MongoDB

- **langchain**: Python package for LangChain's core modules

- **langchain-openai**: Python package to use OpenAI models via LangChain

In [150]:
! pip install -qU datasets pymongo langchain ragas rapidfuzz langchain-openai 'git+https://github.com/langchain-ai/langchain-mongodb.git@main#subdirectory=libs/mongodb' 

## Step 2: Setup prerequisites

- **Set the MongoDB connection string**: Follow the steps [here](https://www.mongodb.com/docs/manual/reference/connection-string/) to get the connection string from the Atlas UI.

- **Set the OpenAI API key**: Steps to obtain an API key are [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key)

- **Set the Hugging Face token**: Steps to create a token are [here](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-tokens). You only need **read** token for this tutorial.

In [4]:
import os
import getpass
from openai import OpenAI
from pymongo import MongoClient

In [5]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key:")

In [6]:
MONGODB_URI = getpass.getpass("Enter your MongoDB connection string:")
mongodb_client = MongoClient(
    MONGODB_URI, appname="devrel.showcase.parent_doc_retrieval"
)
mongodb_client.admin.command("ping")

{'ok': 1.0,
 '$clusterTime': {'clusterTime': Timestamp(1732218048, 1),
  'signature': {'hash': b"8\xbf>\xa0\xbc\xf1\x1c\x01\xcd!tr~\xd2\x15\xa5o'\xde\xfa",
   'keyId': 7390069253761662978}},
 'operationTime': Timestamp(1732218048, 1)}

In [7]:
os.environ["HF_TOKEN"] = getpass.getpass("Enter your HF Access Token:")

## Step 3: Load the dataset

In [3]:
from datasets import load_dataset
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [27]:
data = load_dataset("mongodb-eai/docs", split="train")
data_head = data.take(1000)
df = pd.DataFrame(data_head)

In [28]:
df.head()

Unnamed: 0,updated,_id,metadata,action,sourceName,body,url,format,title
0,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f895074208162'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# View Database Access History\n\n- This featu...,https://mongodb.com/docs/atlas/access-tracking/,md,View Database Access History
1,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f895074208178'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Manage Organization Teams\n\nYou can create ...,https://mongodb.com/docs/atlas/access/manage-t...,md,Manage Organization Teams
2,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f895074208183'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Manage Organizations\n\nIn the organizations...,https://mongodb.com/docs/atlas/access/orgs-cre...,md,Manage Organizations
3,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f89507420818f'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Alert Basics\n\nAtlas provides built-in tool...,https://mongodb.com/docs/atlas/alert-basics/,md,Alert Basics
4,{'$date': '2024-05-20T17:30:49.148Z'},{'$oid': '664b88c96e4f89507420819d'},"{'contentType': None, 'pageDescription': None,...",created,snooty-cloud-docs,# Resolve Alerts\n\nAtlas issues alerts for th...,https://mongodb.com/docs/atlas/alert-resolutions/,md,Resolve Alerts


## Step 4: Convert dataset to LangChain Documents

In [29]:
from langchain_core.documents import Document

In [30]:
docs = []
metadata_fields = ["updated", "url", "title"]
for _, row in df.iterrows():
    content = row["body"]
    metadata = row["metadata"]
    for field in metadata_fields:
        metadata[field] = row[field]
    docs.append(Document(page_content=content, metadata=metadata))

In [31]:
docs[0]

Document(metadata={'contentType': None, 'pageDescription': None, 'productName': 'MongoDB Atlas', 'tags': ['atlas', 'docs'], 'version': None, 'updated': {'$date': '2024-05-20T17:30:49.148Z'}, 'url': 'https://mongodb.com/docs/atlas/access-tracking/', 'title': 'View Database Access History'}, page_content='# View Database Access History\n\n- This feature is not available for `M0` free clusters, `M2`, and `M5` clusters. To learn more, see Atlas M0 (Free Cluster), M2, and M5 Limits.\n\n- This feature is not supported on Serverless instances at this time. To learn more, see Serverless Instance Limitations.\n\n## Overview\n\nAtlas parses the MongoDB database logs to collect a list of authentication requests made against your clusters through the following methods:\n\n- `mongosh`\n\n- Compass\n\n- Drivers\n\nAuthentication requests made with API Keys through the Atlas Administration API are not logged.\n\nAtlas logs the following information for each authentication request within the last 7 da

In [32]:
len(docs)

1000

## Step 5: Instantiate the retrievers

In [104]:
from langchain_mongodb.retrievers.parent_document import (
    MongoDBAtlasParentDocumentRetriever,
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings

In [105]:
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

In [None]:
DB_NAME = "langchain"
PARENT_DOC_COLLECTION = "parent_doc"
PARENT_CHUNK_COLLECTION = "parent_chunk"

In [106]:
def get_splitter(chunk_size):
    return RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        encoding_name="cl100k_base",
        chunk_size=chunk_size,
        chunk_overlap=0.15 * chunk_size,
    )

In [132]:
def get_retriever(collection, child_chunk_size, **kwargs):
    return MongoDBAtlasParentDocumentRetriever.from_connection_string(
        connection_string=MONGODB_URI,
        embedding_model=embedding_model,
        child_splitter=get_splitter(child_chunk_size),
        database_name=DB_NAME,
        collection_name=collection,
        text_key="page_content",
        search_type="similarity",
        search_kwargs={"k": 10},
    )

In [133]:
parent_doc_retriever = get_retriever(
    collection=PARENT_DOC_COLLECTION, child_chunk_size=200
)

In [134]:
parent_chunk_retriever = get_retriever(
    collection=PARENT_CHUNK_COLLECTION,
    child_chunk_size=200,
    parent_splitter=get_splitter(800),
)

## Step 6: Ingest documents into MongoDB

In [113]:
import asyncio

In [114]:
CHUNK_SIZE = 100
MAX_CONCURRENCY = 4

In [115]:
async def process_chunk(chunk, semaphore, retriever):
    async with semaphore:
        await retriever.aadd_documents(chunk)
        print(f"Processed {len(chunk)} documents")

In [116]:
def get_chunks(docs, chunk_size):
    for i in range(0, len(docs), chunk_size):
        yield docs[i : i + chunk_size]

In [117]:
async def process_docs(docs, retriever):
    semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
    chunks = get_chunks(docs, CHUNK_SIZE)

    tasks = []
    for chunk in chunks:
        tasks.append(process_chunk(chunk, semaphore, retriever))

    # Gather all tasks and get results
    results = await asyncio.gather(*tasks)
    return results

In [74]:
def get_collection(coll_name):
    return mongodb_client[DB_NAME][coll_name]

In [71]:
get_collection(PARENT_DOC_COLLECTION).delete_many({})
results = await process_docs(docs, parent_doc_retriever)

Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents


In [61]:
get_collection(PARENT_CHUNK_COLLECTION).delete_many({})
results = await process_docs(docs, parent_chunk_retriever)

Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents
Processed 100 documents


## Step 7: Create vector search indexes

In [75]:
from pymongo.operations import SearchIndexModel
from pymongo.errors import OperationFailure

In [73]:
model = SearchIndexModel(
    definition={
        "fields": [
            {
                "type": "vector",
                "path": "embedding",
                "numDimensions": 1536,
                "similarity": "cosine",
            }
        ]
    },
    name="vector_index",
    type="vectorSearch",
)

In [76]:
def create_vs_index(coll_name):
    try:
        get_collection(coll_name).create_search_index(model=model)
        print(f"Successfully created index for collection {coll_name}.")
    except OperationFailure:
        print(
            f"Duplicate index found for collection {coll_name}. Skipping index creation."
        )

In [77]:
create_vs_index(PARENT_DOC_COLLECTION)

Successfully created index for collection parent_doc.


In [78]:
create_vs_index(PARENT_CHUNK_COLLECTION)

Successfully created index for collection parent_chunk.


## Step 8: Create RAG Chain

In [139]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

In [142]:
def get_rag_chain(retriever):
    retrieve = {
        "context": retriever
        | (lambda docs: "\n\n".join([d.page_content for d in docs])),
        "question": RunnablePassthrough(),
    }
    template = """Answer the question based only on the following context. If no context is provided, respond with I DON't KNOW: \
    {context}

    Question: {question}
    """
    # Defining the chat prompt
    prompt = ChatPromptTemplate.from_template(template)
    # Defining the model to be used for chat completion
    llm = ChatOpenAI(temperature=0, model="gpt-4o-2024-11-20")
    # Parse output as a string
    parse_output = StrOutputParser()

    # Naive RAG chain
    rag_chain = retrieve | prompt | llm | parse_output
    return rag_chain