From 40ee1acaf4fb4140a2021ff5871aa2b8c864fc27 Mon Sep 17 00:00:00 2001 From: Jib Date: Wed, 21 Feb 2024 14:53:00 -0500 Subject: [PATCH] updated the README.md to include usage instructions; 'beefed up' unit tests with mocks --- libs/partners/mongodb-atlas/README.md | 39 +++ .../langchain_mongodb_atlas/vectorstores.py | 1 - .../tests/unit_tests/test_vectorstores.py | 240 +++++++++++++++++- 3 files changed, 265 insertions(+), 15 deletions(-) diff --git a/libs/partners/mongodb-atlas/README.md b/libs/partners/mongodb-atlas/README.md index 1090f6f0b4bd5c..676ceb4dd00e3b 100644 --- a/libs/partners/mongodb-atlas/README.md +++ b/libs/partners/mongodb-atlas/README.md @@ -1 +1,40 @@ # langchain-mongodb-atlas + +# Installation +``` +pip install -U langchain-mongodb-atlas +``` + +# Usage +- See [integrations doc](../../../docs/docs/integrations/vectorstores/mongodb_atlas.ipynb) for more in-depth usage instructions. +- See [Getting Started with the LangChain Integration](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas. + +## Using MongoDBAtlasVectorSearch +```python +from langchain_community.vectorstores import MongoDBAtlasVectorSearch + +# Pull MongoDB Atlas URI from environment variables +MONGODB_ATLAS_CLUSTER_URI = os.environ.get("MONGODB_ATLAS_CLUSTER_URI") + +DB_NAME = "langchain_db" +COLLECTION_NAME = "test" +ATLAS_VECTOR_SEARCH_INDEX_NAME = "index_name" +MONGODB_COLLECTION = client[DB_NAME][COLLECITON_NAME] + +# Create the vector search via `from_connection_string` +vector_search = MongoDBAtlasVectorSearch.from_connection_string( + MONGODB_ATLAS_CLUSTER_URI, + DB_NAME + "." + COLLECTION_NAME, + OpenAIEmbeddings(disallowed_special=()), + index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, +) + +# initialize MongoDB python client +client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) +# Create the vector search via instantiation +vector_search_2 = MongoDBAtlasVectorSearch( + collection=MONGODB_COLLECTION, + embeddings=OpenAIEmbeddings(disallowed_special=()), + index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, +) +``` \ No newline at end of file diff --git a/libs/partners/mongodb-atlas/langchain_mongodb_atlas/vectorstores.py b/libs/partners/mongodb-atlas/langchain_mongodb_atlas/vectorstores.py index 601b319ea5518a..cab2e849e70e62 100644 --- a/libs/partners/mongodb-atlas/langchain_mongodb_atlas/vectorstores.py +++ b/libs/partners/mongodb-atlas/langchain_mongodb_atlas/vectorstores.py @@ -215,7 +215,6 @@ def _similarity_search_with_score( for res in cursor: text = res.pop(self._text_key) score = res.pop("score") - del res[self._embedding_key] docs.append((Document(page_content=text, metadata=res), score)) return docs diff --git a/libs/partners/mongodb-atlas/tests/unit_tests/test_vectorstores.py b/libs/partners/mongodb-atlas/tests/unit_tests/test_vectorstores.py index 1f35855b9d1787..05ca5a1a050429 100644 --- a/libs/partners/mongodb-atlas/tests/unit_tests/test_vectorstores.py +++ b/libs/partners/mongodb-atlas/tests/unit_tests/test_vectorstores.py @@ -1,38 +1,104 @@ -import os +import uuid +from copy import deepcopy +from typing import Any, List, Optional import pytest +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from pymongo import MongoClient from pymongo.collection import Collection +from pymongo.results import DeleteResult, InsertManyResult from langchain_mongodb_atlas.vectorstores import MongoDBAtlasVectorSearch INDEX_NAME = "langchain-test-index" NAMESPACE = "langchain_test_db.langchain_test_collection" -CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") -def get_collection() -> Collection: - test_client: MongoClient = MongoClient(CONNECTION_STRING) - return test_client[DB_NAME][COLLECTION_NAME] +class ConsistentFakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return self.embed_documents([text])[0] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) + + +class MockCollection(Collection): + """Mocked Mongo Collection""" + + _aggregate_result: list[Any] + _insert_result: Optional[InsertManyResult] + _data: list[Any] + + def __init__(self) -> None: + self._data = [] + self._aggregate_result = [] + self._insert_result = None + + def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore + old_len = len(self._data) + self._data = [] + return DeleteResult({"n": old_len}, acknowledged=True) + + def insert_many(self, to_insert: list[Any], *args, **kwargs) -> InsertManyResult: # type: ignore + mongodb_inserts = [ + {"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert + ] + self._data.extend(mongodb_inserts) + return self._insert_result or InsertManyResult( + [k["_id"] for k in mongodb_inserts], acknowledged=True + ) + + def aggregate(self, *args, **kwargs) -> list[Any]: # type: ignore + return deepcopy(self._aggregate_result) + + def count_documents(self, *args, **kwargs) -> int: # type: ignore + return len(self._data) + + def __repr__(self) -> str: + return "FakeCollection" + + +def get_collection() -> MockCollection: + return MockCollection() @pytest.fixture() -def collection() -> Collection: +def collection() -> MockCollection: return get_collection() -def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None: - """Test initialization of vector store class""" - assert MongoDBAtlasVectorSearch(collection, embedding_openai) +@pytest.fixture() +def embedding_openai() -> Embeddings: + return ConsistentFakeEmbeddings() -def test_init_from_connection_string(embedding_openai: Embeddings) -> None: +def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None: """Test initialization of vector store class""" - assert MongoDBAtlasVectorSearch.from_connection_string( - CONNECTION_STRING, NAMESPACE, embedding_openai - ) + assert MongoDBAtlasVectorSearch(collection, embedding_openai) def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None: @@ -40,3 +106,149 @@ def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) - assert MongoDBAtlasVectorSearch.from_texts( [], embedding_openai, collection=collection ) + + +class TestMongoDBAtlasVectorSearch: + @classmethod + def setup_class(cls) -> None: + # ensure the test collection is empty + collection = get_collection() + assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501 + + @classmethod + def teardown_class(cls) -> None: + collection = get_collection() + # delete all the documents in the collection + collection.delete_many({}) # type: ignore[index] + + @pytest.fixture(autouse=True) + def setup(self) -> None: + collection = get_collection() + # delete all the documents in the collection + collection.delete_many({}) # type: ignore[index] + + def _validate_search( + self, + vectorstore: MongoDBAtlasVectorSearch, + collection: MockCollection, + search_term: str = "sandwich", + page_content: str = "What is a sandwich?", + metadata: Optional[Any] = 1, + ) -> None: + collection._aggregate_result = list( + filter( + lambda x: search_term.lower() in x[vectorstore._text_key].lower(), + collection._data, + ) + ) + output = vectorstore.similarity_search("", k=1) + assert output[0].page_content == page_content + assert output[0].metadata.get("c") == metadata + + def test_from_documents( + self, embedding_openai: Embeddings, collection: MockCollection + ) -> None: + """Test end to end construction and search.""" + documents = [ + Document(page_content="Dogs are tough.", metadata={"a": 1}), + Document(page_content="Cats have fluff.", metadata={"b": 1}), + Document(page_content="What is a sandwich?", metadata={"c": 1}), + Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), + ] + vectorstore = MongoDBAtlasVectorSearch.from_documents( + documents, + embedding_openai, + collection=collection, + index_name=INDEX_NAME, + ) + self._validate_search( + vectorstore, collection, metadata=documents[2].metadata["c"] + ) + + def test_from_texts( + self, embedding_openai: Embeddings, collection: MockCollection + ) -> None: + texts = [ + "Dogs are tough.", + "Cats have fluff.", + "What is a sandwich?", + "That fence is purple.", + ] + vectorstore = MongoDBAtlasVectorSearch.from_texts( + texts, + embedding_openai, + collection=collection, + index_name=INDEX_NAME, + ) + self._validate_search(vectorstore, collection, metadata=None) + + def test_from_texts_with_metadatas( + self, embedding_openai: Embeddings, collection: MockCollection + ) -> None: + texts = [ + "Dogs are tough.", + "Cats have fluff.", + "What is a sandwich?", + "The fence is purple.", + ] + metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] + vectorstore = MongoDBAtlasVectorSearch.from_texts( + texts, + embedding_openai, + metadatas=metadatas, + collection=collection, + index_name=INDEX_NAME, + ) + self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"]) + + def test_from_texts_with_metadatas_and_pre_filter( + self, embedding_openai: Embeddings, collection: MockCollection + ) -> None: + texts = [ + "Dogs are tough.", + "Cats have fluff.", + "What is a sandwich?", + "The fence is purple.", + ] + metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] + vectorstore = MongoDBAtlasVectorSearch.from_texts( + texts, + embedding_openai, + metadatas=metadatas, + collection=collection, + index_name=INDEX_NAME, + ) + collection._aggregate_result = list( + filter( + lambda x: "sandwich" in x[vectorstore._text_key].lower() + and x.get("c") < 0, + collection._data, + ) + ) + output = vectorstore.similarity_search( + "Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}} + ) + assert output == [] + + def test_mmr( + self, embedding_openai: Embeddings, collection: MockCollection + ) -> None: + texts = ["foo", "foo", "fou", "foy"] + vectorstore = MongoDBAtlasVectorSearch.from_texts( + texts, + embedding_openai, + collection=collection, + index_name=INDEX_NAME, + ) + query = "foo" + self._validate_search( + vectorstore, + collection, + search_term=query[0:2], + page_content=query, + metadata=None, + ) + output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1) + assert len(output) == len(texts) + assert output[0].page_content == "foo" + assert output[1].page_content != "foo"