Skip to content

Commit

Permalink
updated the README.md to include usage instructions; 'beefed up' unit…
Browse files Browse the repository at this point in the history
… tests with mocks
  • Loading branch information
Jibola committed Feb 21, 2024
1 parent 0921e47 commit 40ee1ac
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 15 deletions.
39 changes: 39 additions & 0 deletions libs/partners/mongodb-atlas/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
# langchain-mongodb-atlas

# Installation
```
pip install -U langchain-mongodb-atlas
```

# Usage
- See [integrations doc](../../../docs/docs/integrations/vectorstores/mongodb_atlas.ipynb) for more in-depth usage instructions.
- See [Getting Started with the LangChain Integration](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas.

## Using MongoDBAtlasVectorSearch
```python
from langchain_community.vectorstores import MongoDBAtlasVectorSearch

# Pull MongoDB Atlas URI from environment variables
MONGODB_ATLAS_CLUSTER_URI = os.environ.get("MONGODB_ATLAS_CLUSTER_URI")

DB_NAME = "langchain_db"
COLLECTION_NAME = "test"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "index_name"
MONGODB_COLLECTION = client[DB_NAME][COLLECITON_NAME]

# Create the vector search via `from_connection_string`
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
MONGODB_ATLAS_CLUSTER_URI,
DB_NAME + "." + COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special=()),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)

# initialize MongoDB python client
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
# Create the vector search via instantiation
vector_search_2 = MongoDBAtlasVectorSearch(
collection=MONGODB_COLLECTION,
embeddings=OpenAIEmbeddings(disallowed_special=()),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
```
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def _similarity_search_with_score(
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
del res[self._embedding_key]
docs.append((Document(page_content=text, metadata=res), score))
return docs

Expand Down
240 changes: 226 additions & 14 deletions libs/partners/mongodb-atlas/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,254 @@
import os
import uuid
from copy import deepcopy
from typing import Any, List, Optional

import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.results import DeleteResult, InsertManyResult

from langchain_mongodb_atlas.vectorstores import MongoDBAtlasVectorSearch

INDEX_NAME = "langchain-test-index"
NAMESPACE = "langchain_test_db.langchain_test_collection"
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")


def get_collection() -> Collection:
test_client: MongoClient = MongoClient(CONNECTION_STRING)
return test_client[DB_NAME][COLLECTION_NAME]
class ConsistentFakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""

def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.dimensionality = dimensionality

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors

def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)

async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)


class MockCollection(Collection):
"""Mocked Mongo Collection"""

_aggregate_result: list[Any]
_insert_result: Optional[InsertManyResult]
_data: list[Any]

def __init__(self) -> None:
self._data = []
self._aggregate_result = []
self._insert_result = None

def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
old_len = len(self._data)
self._data = []
return DeleteResult({"n": old_len}, acknowledged=True)

def insert_many(self, to_insert: list[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
mongodb_inserts = [
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert
]
self._data.extend(mongodb_inserts)
return self._insert_result or InsertManyResult(
[k["_id"] for k in mongodb_inserts], acknowledged=True
)

def aggregate(self, *args, **kwargs) -> list[Any]: # type: ignore
return deepcopy(self._aggregate_result)

def count_documents(self, *args, **kwargs) -> int: # type: ignore
return len(self._data)

def __repr__(self) -> str:
return "FakeCollection"


def get_collection() -> MockCollection:
return MockCollection()


@pytest.fixture()
def collection() -> Collection:
def collection() -> MockCollection:
return get_collection()


def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None:
"""Test initialization of vector store class"""
assert MongoDBAtlasVectorSearch(collection, embedding_openai)
@pytest.fixture()
def embedding_openai() -> Embeddings:
return ConsistentFakeEmbeddings()


def test_init_from_connection_string(embedding_openai: Embeddings) -> None:
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None:
"""Test initialization of vector store class"""
assert MongoDBAtlasVectorSearch.from_connection_string(
CONNECTION_STRING, NAMESPACE, embedding_openai
)
assert MongoDBAtlasVectorSearch(collection, embedding_openai)


def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None:
"""Test from_texts operation on an empty list"""
assert MongoDBAtlasVectorSearch.from_texts(
[], embedding_openai, collection=collection
)


class TestMongoDBAtlasVectorSearch:
@classmethod
def setup_class(cls) -> None:
# ensure the test collection is empty
collection = get_collection()
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501

@classmethod
def teardown_class(cls) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]

@pytest.fixture(autouse=True)
def setup(self) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]

def _validate_search(
self,
vectorstore: MongoDBAtlasVectorSearch,
collection: MockCollection,
search_term: str = "sandwich",
page_content: str = "What is a sandwich?",
metadata: Optional[Any] = 1,
) -> None:
collection._aggregate_result = list(
filter(
lambda x: search_term.lower() in x[vectorstore._text_key].lower(),
collection._data,
)
)
output = vectorstore.similarity_search("", k=1)
assert output[0].page_content == page_content
assert output[0].metadata.get("c") == metadata

def test_from_documents(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
"""Test end to end construction and search."""
documents = [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
Document(page_content="Cats have fluff.", metadata={"b": 1}),
Document(page_content="What is a sandwich?", metadata={"c": 1}),
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
]
vectorstore = MongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
self._validate_search(
vectorstore, collection, metadata=documents[2].metadata["c"]
)

def test_from_texts(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"That fence is purple.",
]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
self._validate_search(vectorstore, collection, metadata=None)

def test_from_texts_with_metadatas(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"The fence is purple.",
]
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
metadatas=metadatas,
collection=collection,
index_name=INDEX_NAME,
)
self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"])

def test_from_texts_with_metadatas_and_pre_filter(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"The fence is purple.",
]
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
metadatas=metadatas,
collection=collection,
index_name=INDEX_NAME,
)
collection._aggregate_result = list(
filter(
lambda x: "sandwich" in x[vectorstore._text_key].lower()
and x.get("c") < 0,
collection._data,
)
)
output = vectorstore.similarity_search(
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
)
assert output == []

def test_mmr(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
query = "foo"
self._validate_search(
vectorstore,
collection,
search_term=query[0:2],
page_content=query,
metadata=None,
)
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0].page_content == "foo"
assert output[1].page_content != "foo"

0 comments on commit 40ee1ac

Please sign in to comment.