Skip to content

Commit

Permalink
feat: add configurable distance strategy and add more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
malandis committed Oct 9, 2023
1 parent 25c9c19 commit 3133ea9
Showing 1 changed file with 92 additions and 6 deletions.
98 changes: 92 additions & 6 deletions libs/langchain/langchain/vectorstores/momento_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.utils import get_from_env

from langchain.vectorstores.utils import DistanceStrategy

VST = TypeVar("VST", bound="VectorStore")

Expand All @@ -25,23 +25,48 @@


class MomentoVectorIndex(VectorStore):
"""Vector Store implementation backed by Momento Vector Index.
"""`Momento Vector Index` (MVI) vector store.
Momento Vector Index is a serverless vector index that can be used to store and
search vectors.
search vectors. To use you should have the ``momento`` python package installed.
Example:
.. code-block:: python
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import MomentoVectorIndex
from momento import (
CredentialProvider,
PreviewVectorIndexClient,
VectorIndexConfigurations,
)
vectorstore = MomentoVectorIndex(
embedding=OpenAIEmbeddings(),
client=PreviewVectorIndexClient(
VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_environment_variable(
"MOMENTO_API_KEY"
),
),
index_name="my-index",
)
"""

_client: "PreviewVectorIndexClient"
index_name: str
distance_strategy: DistanceStrategy
text_field: str
fields: set[str]
_ensure_index_exists: bool

def __init__(
self,
embedding: Embeddings,
client: "PreviewVectorIndexClient",
index_name: str = "default",
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
text_field: str = "text",
ensure_index_exists: bool = True,
**kwargs: Any,
):
"""Initialize a Vector Store backed by Momento Vector Index.
Expand All @@ -54,8 +79,14 @@ def __init__(
authenticate the Vector Index with.
index_name (str, optional): The name of the index to store the documents in.
Defaults to "default".
distance_strategy (DistanceStrategy, optional): The distance strategy to
use. Defaults to DistanceStrategy.COSINE. If you select
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
Euclidean distance.
text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text".
ensure_index_exists (bool, optional): Whether to ensure that the index
exists before adding documents to it. Defaults to True.
"""
try:
from momento import PreviewVectorIndexClient
Expand All @@ -68,17 +99,44 @@ def __init__(
self._client: PreviewVectorIndexClient = client
self._embedding = embedding
self.index_name = index_name
self.__validate_distance_strategy(distance_strategy)
self.distance_strategy = distance_strategy
self.text_field = text_field
self._ensure_index_exists = ensure_index_exists

@staticmethod
def __validate_distance_strategy(distance_strategy: DistanceStrategy) -> None:
if distance_strategy not in [
DistanceStrategy.COSINE,
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
raise ValueError(f"Distance strategy {distance_strategy} not implemented.")

@property
def embeddings(self) -> Embeddings:
return self._embedding

def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
"""Create index if it does not exist."""
from momento.requests.vector_index import SimilarityMetric
from momento.responses.vector_index import CreateIndex

response = self._client.create_index(self.index_name, num_dimensions)
similarity_metric = None
if self.distance_strategy == DistanceStrategy.COSINE:
similarity_metric = SimilarityMetric.COSINE_SIMILARITY
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
similarity_metric = SimilarityMetric.INNER_PRODUCT
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
else:
raise ValueError(
f"Distance strategy {self.distance_strategy} not implemented."
)

response = self._client.create_index(
self.index_name, num_dimensions, similarity_metric
)
if isinstance(response, CreateIndex.Success):
return True
elif isinstance(response, CreateIndex.IndexAlreadyExists):
Expand Down Expand Up @@ -126,7 +184,11 @@ def add_texts(
except NotImplementedError:
embeddings = [self._embedding.embed_query(x) for x in texts]

self._create_index_if_not_exists(len(embeddings[0]))
# Create index if it does not exist.
# We assume that if it does exist, then it was created with the desired number
# of dimensions and similarity metric.
if self._ensure_index_exists:
self._create_index_if_not_exists(len(embeddings[0]))

if "ids" in kwargs:
ids = kwargs["ids"]
Expand Down Expand Up @@ -179,6 +241,15 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[boo
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Search for similar documents to the query string.
Args:
query (str): The query string to search for.
k (int, optional): The number of results to return. Defaults to 4.
Returns:
List[Document]: A list of documents that are similar to the query.
"""
res = self.similarity_search_with_score(query=query, k=k, **kwargs)
return [doc for doc, _ in res]

Expand Down Expand Up @@ -251,6 +322,15 @@ def similarity_search_with_score_by_vector(
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Search for similar documents to the query vector.
Args:
embedding (List[float]): The query vector to search for.
k (int, optional): The number of results to return. Defaults to 4.
Returns:
List[Document]: A list of documents that are similar to the query.
"""
results = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
Expand Down Expand Up @@ -279,6 +359,12 @@ def from_texts(
in. Defaults to "default".
- text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text".
- distance_strategy (DistanceStrategy, optional): The distance strategy to
use. Defaults to DistanceStrategy.COSINE. If you select
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
Euclidean distance.
- ensure_index_exists (bool, optional): Whether to ensure that the index
exists before adding documents to it. Defaults to True.
Additionally you can either pass in a client or an API key
- client (PreviewVectorIndexClient): The Momento Vector Index client to use.
- api_key (Optional[str]): The configuration to use to initialize
Expand Down

0 comments on commit 3133ea9

Please sign in to comment.