Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement max marginal relevance for momento vector index #13619

Merged
Merged
94 changes: 90 additions & 4 deletions libs/langchain/langchain/vectorstores/momento_vector_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -11,15 +12,17 @@
)
from uuid import uuid4

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain.utils import get_from_env
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance

VST = TypeVar("VST", bound="VectorStore")

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from momento import PreviewVectorIndexClient
Expand Down Expand Up @@ -75,9 +78,8 @@ def __init__(
index_name (str, optional): The name of the index to store the documents in.
Defaults to "default".
distance_strategy (DistanceStrategy, optional): The distance strategy to
use. Defaults to DistanceStrategy.COSINE. If you select
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
Euclidean distance.
use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses
the squared Euclidean distance. Defaults to DistanceStrategy.COSINE.
text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text".
ensure_index_exists (bool, optional): Whether to ensure that the index
Expand Down Expand Up @@ -125,6 +127,7 @@ def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
else:
logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
raise ValueError(
f"Distance strategy {self.distance_strategy} not implemented."
)
Expand All @@ -137,8 +140,10 @@ def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
elif isinstance(response, CreateIndex.IndexAlreadyExists):
return False
elif isinstance(response, CreateIndex.Error):
logger.error(f"Error creating index: {response.inner_exception}")
raise response.inner_exception
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")

def add_texts(
Expand Down Expand Up @@ -331,6 +336,87 @@ def similarity_search_by_vector(
)
return [doc for doc, _ in results]

def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.

Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.

Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import SearchAndFetchVectors

response = self._client.search_and_fetch_vectors(
self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA
)

if isinstance(response, SearchAndFetchVectors.Success):
pass
elif isinstance(response, SearchAndFetchVectors.Error):
logger.error(f"Error searching and fetching vectors: {response}")
return []
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")

mmr_selected = maximal_marginal_relevance(
query_embedding=np.array([embedding], dtype=np.float32),
embedding_list=[hit.vector for hit in response.hits],
lambda_mult=lambda_mult,
k=k,
)
selected = [response.hits[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501
for metadata in selected
]

def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.

Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.

Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self._embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)

@classmethod
def from_texts(
cls: Type[VST],
Expand Down
15 changes: 7 additions & 8 deletions libs/langchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_from_texts_with_metadatas(


def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None:
# """Test end to end construction and search with scores and IDs."""
"""Test end to end construction and search with scores and IDs."""
texts = ["apple", "orange", "hammer"]
metadatas = [{"page": f"{i}"} for i in range(len(texts))]

Expand Down Expand Up @@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None:
)
assert isinstance(response, Search.Success)
assert [hit.id for hit in response.hits] == ids


def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None:
"""Test max marginal relevance search."""
pepperoni_pizza = "pepperoni pizza"
cheese_pizza = "cheese pizza"
hot_dog = "hot dog"

vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog])
wait()
search_results = vector_store.similarity_search("pizza", k=2)

assert search_results == [
Document(page_content=pepperoni_pizza, metadata={}),
Document(page_content=cheese_pizza, metadata={}),
]

search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2)
assert search_results == [
Document(page_content=pepperoni_pizza, metadata={}),
Document(page_content=hot_dog, metadata={}),
]
Loading