-
Notifications
You must be signed in to change notification settings - Fork 14k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
updated the README.md to include usage instructions; 'beefed up' unit…
… tests with mocks
- Loading branch information
Showing
3 changed files
with
265 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,40 @@ | ||
# langchain-mongodb-atlas | ||
|
||
# Installation | ||
``` | ||
pip install -U langchain-mongodb-atlas | ||
``` | ||
|
||
# Usage | ||
- See [integrations doc](../../../docs/docs/integrations/vectorstores/mongodb_atlas.ipynb) for more in-depth usage instructions. | ||
- See [Getting Started with the LangChain Integration](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas. | ||
|
||
## Using MongoDBAtlasVectorSearch | ||
```python | ||
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | ||
|
||
# Pull MongoDB Atlas URI from environment variables | ||
MONGODB_ATLAS_CLUSTER_URI = os.environ.get("MONGODB_ATLAS_CLUSTER_URI") | ||
|
||
DB_NAME = "langchain_db" | ||
COLLECTION_NAME = "test" | ||
ATLAS_VECTOR_SEARCH_INDEX_NAME = "index_name" | ||
MONGODB_COLLECTION = client[DB_NAME][COLLECITON_NAME] | ||
|
||
# Create the vector search via `from_connection_string` | ||
vector_search = MongoDBAtlasVectorSearch.from_connection_string( | ||
MONGODB_ATLAS_CLUSTER_URI, | ||
DB_NAME + "." + COLLECTION_NAME, | ||
OpenAIEmbeddings(disallowed_special=()), | ||
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, | ||
) | ||
|
||
# initialize MongoDB python client | ||
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) | ||
# Create the vector search via instantiation | ||
vector_search_2 = MongoDBAtlasVectorSearch( | ||
collection=MONGODB_COLLECTION, | ||
embeddings=OpenAIEmbeddings(disallowed_special=()), | ||
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, | ||
) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
240 changes: 226 additions & 14 deletions
240
libs/partners/mongodb-atlas/tests/unit_tests/test_vectorstores.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,254 @@ | ||
import os | ||
import uuid | ||
from copy import deepcopy | ||
from typing import Any, List, Optional | ||
|
||
import pytest | ||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from pymongo import MongoClient | ||
from pymongo.collection import Collection | ||
from pymongo.results import DeleteResult, InsertManyResult | ||
|
||
from langchain_mongodb_atlas.vectorstores import MongoDBAtlasVectorSearch | ||
|
||
INDEX_NAME = "langchain-test-index" | ||
NAMESPACE = "langchain_test_db.langchain_test_collection" | ||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") | ||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") | ||
|
||
|
||
def get_collection() -> Collection: | ||
test_client: MongoClient = MongoClient(CONNECTION_STRING) | ||
return test_client[DB_NAME][COLLECTION_NAME] | ||
class ConsistentFakeEmbeddings(Embeddings): | ||
"""Fake embeddings functionality for testing.""" | ||
|
||
def __init__(self, dimensionality: int = 10) -> None: | ||
self.known_texts: List[str] = [] | ||
self.dimensionality = dimensionality | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Return consistent embeddings for each text seen so far.""" | ||
out_vectors = [] | ||
for text in texts: | ||
if text not in self.known_texts: | ||
self.known_texts.append(text) | ||
vector = [float(1.0)] * (self.dimensionality - 1) + [ | ||
float(self.known_texts.index(text)) | ||
] | ||
out_vectors.append(vector) | ||
return out_vectors | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Return consistent embeddings for the text, if seen before, or a constant | ||
one if the text is unknown.""" | ||
return self.embed_documents([text])[0] | ||
|
||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | ||
return self.embed_documents(texts) | ||
|
||
async def aembed_query(self, text: str) -> List[float]: | ||
return self.embed_query(text) | ||
|
||
|
||
class MockCollection(Collection): | ||
"""Mocked Mongo Collection""" | ||
|
||
_aggregate_result: list[Any] | ||
_insert_result: Optional[InsertManyResult] | ||
_data: list[Any] | ||
|
||
def __init__(self) -> None: | ||
self._data = [] | ||
self._aggregate_result = [] | ||
self._insert_result = None | ||
|
||
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore | ||
old_len = len(self._data) | ||
self._data = [] | ||
return DeleteResult({"n": old_len}, acknowledged=True) | ||
|
||
def insert_many(self, to_insert: list[Any], *args, **kwargs) -> InsertManyResult: # type: ignore | ||
mongodb_inserts = [ | ||
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert | ||
] | ||
self._data.extend(mongodb_inserts) | ||
return self._insert_result or InsertManyResult( | ||
[k["_id"] for k in mongodb_inserts], acknowledged=True | ||
) | ||
|
||
def aggregate(self, *args, **kwargs) -> list[Any]: # type: ignore | ||
return deepcopy(self._aggregate_result) | ||
|
||
def count_documents(self, *args, **kwargs) -> int: # type: ignore | ||
return len(self._data) | ||
|
||
def __repr__(self) -> str: | ||
return "FakeCollection" | ||
|
||
|
||
def get_collection() -> MockCollection: | ||
return MockCollection() | ||
|
||
|
||
@pytest.fixture() | ||
def collection() -> Collection: | ||
def collection() -> MockCollection: | ||
return get_collection() | ||
|
||
|
||
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None: | ||
"""Test initialization of vector store class""" | ||
assert MongoDBAtlasVectorSearch(collection, embedding_openai) | ||
@pytest.fixture() | ||
def embedding_openai() -> Embeddings: | ||
return ConsistentFakeEmbeddings() | ||
|
||
|
||
def test_init_from_connection_string(embedding_openai: Embeddings) -> None: | ||
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None: | ||
"""Test initialization of vector store class""" | ||
assert MongoDBAtlasVectorSearch.from_connection_string( | ||
CONNECTION_STRING, NAMESPACE, embedding_openai | ||
) | ||
assert MongoDBAtlasVectorSearch(collection, embedding_openai) | ||
|
||
|
||
def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None: | ||
"""Test from_texts operation on an empty list""" | ||
assert MongoDBAtlasVectorSearch.from_texts( | ||
[], embedding_openai, collection=collection | ||
) | ||
|
||
|
||
class TestMongoDBAtlasVectorSearch: | ||
@classmethod | ||
def setup_class(cls) -> None: | ||
# ensure the test collection is empty | ||
collection = get_collection() | ||
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501 | ||
|
||
@classmethod | ||
def teardown_class(cls) -> None: | ||
collection = get_collection() | ||
# delete all the documents in the collection | ||
collection.delete_many({}) # type: ignore[index] | ||
|
||
@pytest.fixture(autouse=True) | ||
def setup(self) -> None: | ||
collection = get_collection() | ||
# delete all the documents in the collection | ||
collection.delete_many({}) # type: ignore[index] | ||
|
||
def _validate_search( | ||
self, | ||
vectorstore: MongoDBAtlasVectorSearch, | ||
collection: MockCollection, | ||
search_term: str = "sandwich", | ||
page_content: str = "What is a sandwich?", | ||
metadata: Optional[Any] = 1, | ||
) -> None: | ||
collection._aggregate_result = list( | ||
filter( | ||
lambda x: search_term.lower() in x[vectorstore._text_key].lower(), | ||
collection._data, | ||
) | ||
) | ||
output = vectorstore.similarity_search("", k=1) | ||
assert output[0].page_content == page_content | ||
assert output[0].metadata.get("c") == metadata | ||
|
||
def test_from_documents( | ||
self, embedding_openai: Embeddings, collection: MockCollection | ||
) -> None: | ||
"""Test end to end construction and search.""" | ||
documents = [ | ||
Document(page_content="Dogs are tough.", metadata={"a": 1}), | ||
Document(page_content="Cats have fluff.", metadata={"b": 1}), | ||
Document(page_content="What is a sandwich?", metadata={"c": 1}), | ||
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), | ||
] | ||
vectorstore = MongoDBAtlasVectorSearch.from_documents( | ||
documents, | ||
embedding_openai, | ||
collection=collection, | ||
index_name=INDEX_NAME, | ||
) | ||
self._validate_search( | ||
vectorstore, collection, metadata=documents[2].metadata["c"] | ||
) | ||
|
||
def test_from_texts( | ||
self, embedding_openai: Embeddings, collection: MockCollection | ||
) -> None: | ||
texts = [ | ||
"Dogs are tough.", | ||
"Cats have fluff.", | ||
"What is a sandwich?", | ||
"That fence is purple.", | ||
] | ||
vectorstore = MongoDBAtlasVectorSearch.from_texts( | ||
texts, | ||
embedding_openai, | ||
collection=collection, | ||
index_name=INDEX_NAME, | ||
) | ||
self._validate_search(vectorstore, collection, metadata=None) | ||
|
||
def test_from_texts_with_metadatas( | ||
self, embedding_openai: Embeddings, collection: MockCollection | ||
) -> None: | ||
texts = [ | ||
"Dogs are tough.", | ||
"Cats have fluff.", | ||
"What is a sandwich?", | ||
"The fence is purple.", | ||
] | ||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] | ||
vectorstore = MongoDBAtlasVectorSearch.from_texts( | ||
texts, | ||
embedding_openai, | ||
metadatas=metadatas, | ||
collection=collection, | ||
index_name=INDEX_NAME, | ||
) | ||
self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"]) | ||
|
||
def test_from_texts_with_metadatas_and_pre_filter( | ||
self, embedding_openai: Embeddings, collection: MockCollection | ||
) -> None: | ||
texts = [ | ||
"Dogs are tough.", | ||
"Cats have fluff.", | ||
"What is a sandwich?", | ||
"The fence is purple.", | ||
] | ||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] | ||
vectorstore = MongoDBAtlasVectorSearch.from_texts( | ||
texts, | ||
embedding_openai, | ||
metadatas=metadatas, | ||
collection=collection, | ||
index_name=INDEX_NAME, | ||
) | ||
collection._aggregate_result = list( | ||
filter( | ||
lambda x: "sandwich" in x[vectorstore._text_key].lower() | ||
and x.get("c") < 0, | ||
collection._data, | ||
) | ||
) | ||
output = vectorstore.similarity_search( | ||
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}} | ||
) | ||
assert output == [] | ||
|
||
def test_mmr( | ||
self, embedding_openai: Embeddings, collection: MockCollection | ||
) -> None: | ||
texts = ["foo", "foo", "fou", "foy"] | ||
vectorstore = MongoDBAtlasVectorSearch.from_texts( | ||
texts, | ||
embedding_openai, | ||
collection=collection, | ||
index_name=INDEX_NAME, | ||
) | ||
query = "foo" | ||
self._validate_search( | ||
vectorstore, | ||
collection, | ||
search_term=query[0:2], | ||
page_content=query, | ||
metadata=None, | ||
) | ||
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1) | ||
assert len(output) == len(texts) | ||
assert output[0].page_content == "foo" | ||
assert output[1].page_content != "foo" |