Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import List

from langchain_core.caches import BaseCache as BaseCache
from langchain_core.callbacks import Callbacks as Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_mongodb.retrievers.parent_document import (
MongoDBAtlasParentDocumentRetriever,
)

from langchain_mongodb.retrievers.self_querying import MongoDBAtlasSelfQueryRetriever

__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import uuid
from typing import Any, List, Optional

import pymongo
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Expand All @@ -20,10 +21,10 @@
from langchain_mongodb.utils import DRIVER_METADATA, make_serializable


class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever):
class MongoDBAtlasParentDocumentRetriever(BaseRetriever):
"""MongoDB Atlas's ParentDocumentRetriever

Parent Document Retrieval is a common approach to enhance the performance of
"Parent Document Retrieval" is a common approach to enhance the performance of
retrieval methods in RAG by providing the LLM with a broader context to consider.
In essence, we divide the original documents into relatively small chunks,
embed each one, and store them in a vector database.
Expand All @@ -38,25 +39,25 @@ class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever):
and the docstore :class:`~langchain_mongodb.docstores.MongoDBDocStore`
by the same MongoDB Collection.

For more details, see superclasses
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever`
and :class:`~langchain.retrievers.MultiVectorRetriever`.
This retriever extends :class:`~langchain_core.retrievers.BaseRetriever` and
implements the parent document retrieval pattern without requiring the legacy
langchain-classic package, making it compatible with LangChain 1.0+.

Examples:
>>> from langchain_mongodb.retrievers.parent_document import (
>>> ParentDocumentRetriever
>>> MongoDBAtlasParentDocumentRetriever
>>> )
>>> from langchain_text_splitters import RecursiveCharacterTextSplitter
>>> from langchain_openai import OpenAIEmbeddings
>>>
>>> retriever = ParentDocumentRetriever.from_connection_string(
>>> retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
>>> "mongodb+srv://<user>:<clustername>.mongodb.net",
>>> OpenAIEmbeddings(model="text-embedding-3-large"),
>>> RecursiveCharacterTextSplitter(chunk_size=400),
>>> "example_database"
>>> )
retriever.add_documents([Document(..., technical_report_pages)
>>> resp = retriever.invoke("Langchain MongDB Partnership Ecosystem")
>>> retriever.add_documents([Document(...), ...]) # Parent documents
>>> resp = retriever.invoke("Langchain MongoDB Partnership Ecosystem")
>>> print(resp)
[Document(...), ...]

Expand All @@ -65,17 +66,22 @@ class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever):
vectorstore: MongoDBAtlasVectorSearch
"""Vectorstore API to add, embed, and search through child documents"""

child_splitter: TextSplitter

docstore: MongoDBDocStore
"""Provides an API around the Collection to add the parent documents"""

id_key: str = "doc_id"
"""Key stored in metadata pointing to parent document"""

search_kwargs: dict = {}
"""Additional search parameters for vector search"""

def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
) -> List[Document]:
query_vector = self.vectorstore._embedding.embed_query(query)

Expand Down Expand Up @@ -120,10 +126,10 @@ def _get_relevant_documents(
return docs

async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""Asynchronous version of get_relevant_documents"""

Expand All @@ -136,14 +142,14 @@ async def _aget_relevant_documents(

@classmethod
def from_connection_string(
cls,
connection_string: str,
embedding_model: Embeddings,
child_splitter: TextSplitter,
database_name: str,
collection_name: str = "document_with_chunks",
id_key: str = "doc_id",
**kwargs: Any,
cls,
connection_string: str,
embedding_model: Embeddings,
child_splitter: TextSplitter,
database_name: str,
collection_name: str = "document_with_chunks",
id_key: str = "doc_id",
**kwargs: Any,
) -> MongoDBAtlasParentDocumentRetriever:
"""Construct Retriever using one Collection for VectorStore and one for DocStore

Expand All @@ -158,7 +164,7 @@ def from_connection_string(
If parent_splitter is given, the documents will have already been split.
database_name: Name of database to connect to. Created if it does not exist.
collection_name: Name of collection to use.
It includes parent documents, sub-documents and their embeddings.
It includes parent documents, sub-documents, and their embeddings.
id_key: Key used to identify parent documents.
**kwargs: Additional keyword arguments. See parent classes for more.

Expand All @@ -184,6 +190,44 @@ def from_connection_string(
**kwargs,
)

def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
add_to_docstore: bool = True,
) -> None:
"""Add documents to vectorstore and docstore.

Args:
documents: List of parent documents to add
ids: Optional list of ids for documents. If not provided, will be generated.
add_to_docstore: Whether to add parent documents to docstore
"""
if ids is None:
doc_ids = [str(uuid.uuid4()) for _ in documents]
else:
doc_ids = ids

# Split documents into chunks
sub_docs = []
for i, doc in enumerate(documents):
_id = doc_ids[i]
# Use child_splitter to split the document
_sub_docs = self.child_splitter.split_documents([doc])
# Add parent doc id to metadata
for _doc in _sub_docs:
_doc.metadata[self.id_key] = _id
sub_docs.extend(_sub_docs)

# Add chunks to vectorstore (OUTSIDE the loop)
self.vectorstore.add_documents(sub_docs)

# Add parent documents to docstore using mset (OUTSIDE the loop)
if add_to_docstore:
# mset expects a sequence of tuples (key, Document)
key_value_pairs = [(doc_id, doc) for doc_id, doc in zip(doc_ids, documents)]
self.docstore.mset(key_value_pairs)

def close(self) -> None:
"""Close the resources used by the MongoDBAtlasParentDocumentRetriever."""
self.vectorstore.close()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Sequence, Tuple, Union

from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever

from langchain_classic.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import Runnable
from langchain_core.structured_query import (
Expand All @@ -16,6 +16,7 @@
from pydantic import Field

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_classic.chains.query_constructor.schema import AttributeInfo


class MongoDBStructuredQueryTranslator(Visitor):
Expand Down
27 changes: 15 additions & 12 deletions libs/langchain-mongodb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@ build-backend = "hatchling.build"

[project]
name = "langchain-mongodb"
version = "0.7.2"
version = "0.8.0"
description = "An integration package connecting MongoDB and LangChain"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"langchain-core>=0.3,<1.0",
"langchain>=0.3,<1.0",
"langchain-core>=1.0.0,<1.1.0",
"langchain>=1.0.0,<1.1.0",
"pymongo>=4.6.1",
"langchain-text-splitters>=0.3,<1.0",
"langchain-text-splitters>=1.0.0,<1.1.0",
"numpy>=1.26",
"lark<2.0.0,>=1.1.9",
]

[dependency-groups]
dev = [
"freezegun>=1.2.2",
"langchain>=0.3.14,<1.0",
"langchain-core>=0.3.29,<1.0",
"langchain-text-splitters>=0.3.5,<1.0",
"langchain>=1.0.0,<1.1.0",
"langchain-core>=1.0.0,<1.1.0",
"langchain-text-splitters>=1.0.0,<1.1.0",
"pytest-mock>=3.10.0",
"pytest>=7.3.0",
"syrupy>=4.0.2",
Expand All @@ -33,13 +33,13 @@ dev = [
"pre-commit>=4.0",
"mypy>=1.10",
"simsimd>=6.5.0",
"langchain-ollama>=0.2.2,<1.0",
"langchain-openai>=0.2.14,<1.0",
"langchain-community>=0.3.27,<1.0",
"langchain-ollama>=1.0.0,<1.1.0",
"langchain-openai>=1.0.0,<1.1.0",
"langchain-community>=1.0.0a1,<1.1.0",
"pypdf>=5.0.1",
"langgraph>=0.2.72",
"langgraph>=1.0.0,<1.1.0",
"flaky>=3.8.1",
"langchain-tests==0.3.22,<1.0",
"langchain-tests>=1.0.0,<1.1.0",
"pip>=25.0.1",
"typing-extensions>=4.12.2",
]
Expand All @@ -60,6 +60,8 @@ filterwarnings = [
"error",
# Ignore ResourceWarning raised by langchain standardized base classes.
"ignore:unclosed <socket.socket:ResourceWarning",
# Ignore LangGraph deprecation warnings from internal imports
"ignore::langgraph.warnings.LangGraphDeprecatedSinceV10",
]

[tool.mypy]
Expand All @@ -81,3 +83,4 @@ lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035", "UP045"]

[tool.coverage.run]
omit = ["tests/*"]

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests
from flaky import flaky # type:ignore[import-untyped]
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langgraph.prebuilt import create_react_agent
from langchain.agents import create_agent
from pymongo import MongoClient

from langchain_mongodb.agent_toolkit import (
Expand Down Expand Up @@ -64,7 +64,9 @@ def test_toolkit_response(db):
prompt = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)

test_query = "Which country's customers spent the most?"
agent = create_react_agent(llm, toolkit.get_tools(), prompt=prompt)
agent = create_agent(llm,
tools=toolkit.get_tools(),
system_prompt=prompt)
agent.step_timeout = 60
events = agent.stream(
{"messages": [("user", test_query)]},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import warnings

from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_classic.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_core.messages import message_to_dict

from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Generator, Sequence, Union

import pytest
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_classic.chains.query_constructor.schema import AttributeInfo
from langchain_classic.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.documents import Document
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai.chat_models.base import BaseChatOpenAI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import mongomock
import pytest
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_classic.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_core.messages import message_to_dict
from pytest_mock import MockerFixture

Expand Down
Loading