-
Notifications
You must be signed in to change notification settings - Fork 13.5k
/
momento_vector_index.py
489 lines (424 loc) Β· 18.6 KB
/
momento_vector_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
import logging
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
)
from uuid import uuid4
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_env
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import (
DistanceStrategy,
maximal_marginal_relevance,
)
VST = TypeVar("VST", bound="VectorStore")
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from momento import PreviewVectorIndexClient
class MomentoVectorIndex(VectorStore):
"""`Momento Vector Index` (MVI) vector store.
Momento Vector Index is a serverless vector index that can be used to store and
search vectors. To use you should have the ``momento`` python package installed.
Example:
.. code-block:: python
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import MomentoVectorIndex
from momento import (
CredentialProvider,
PreviewVectorIndexClient,
VectorIndexConfigurations,
)
vectorstore = MomentoVectorIndex(
embedding=OpenAIEmbeddings(),
client=PreviewVectorIndexClient(
VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_environment_variable(
"MOMENTO_API_KEY"
),
),
index_name="my-index",
)
"""
def __init__(
self,
embedding: Embeddings,
client: "PreviewVectorIndexClient",
index_name: str = "default",
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
text_field: str = "text",
ensure_index_exists: bool = True,
**kwargs: Any,
):
"""Initialize a Vector Store backed by Momento Vector Index.
Args:
embedding (Embeddings): The embedding function to use.
configuration (VectorIndexConfiguration): The configuration to initialize
the Vector Index with.
credential_provider (CredentialProvider): The credential provider to
authenticate the Vector Index with.
index_name (str, optional): The name of the index to store the documents in.
Defaults to "default".
distance_strategy (DistanceStrategy, optional): The distance strategy to
use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses
the squared Euclidean distance. Defaults to DistanceStrategy.COSINE.
text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text".
ensure_index_exists (bool, optional): Whether to ensure that the index
exists before adding documents to it. Defaults to True.
"""
try:
from momento import PreviewVectorIndexClient
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
self._client: PreviewVectorIndexClient = client
self._embedding = embedding
self.index_name = index_name
self.__validate_distance_strategy(distance_strategy)
self.distance_strategy = distance_strategy
self.text_field = text_field
self._ensure_index_exists = ensure_index_exists
@staticmethod
def __validate_distance_strategy(distance_strategy: DistanceStrategy) -> None:
if distance_strategy not in [
DistanceStrategy.COSINE,
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
raise ValueError(f"Distance strategy {distance_strategy} not implemented.")
@property
def embeddings(self) -> Embeddings:
return self._embedding
def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
"""Create index if it does not exist."""
from momento.requests.vector_index import SimilarityMetric
from momento.responses.vector_index import CreateIndex
similarity_metric = None
if self.distance_strategy == DistanceStrategy.COSINE:
similarity_metric = SimilarityMetric.COSINE_SIMILARITY
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
similarity_metric = SimilarityMetric.INNER_PRODUCT
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
else:
logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
raise ValueError(
f"Distance strategy {self.distance_strategy} not implemented."
)
response = self._client.create_index(
self.index_name, num_dimensions, similarity_metric
)
if isinstance(response, CreateIndex.Success):
return True
elif isinstance(response, CreateIndex.IndexAlreadyExists):
return False
elif isinstance(response, CreateIndex.Error):
logger.error(f"Error creating index: {response.inner_exception}")
raise response.inner_exception
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts (Iterable[str]): Iterable of strings to add to the vectorstore.
metadatas (Optional[List[dict]]): Optional list of metadatas associated with
the texts.
kwargs (Any): Other optional parameters. Specifically:
- ids (List[str], optional): List of ids to use for the texts.
Defaults to None, in which case uuids are generated.
Returns:
List[str]: List of ids from adding the texts into the vectorstore.
"""
from momento.requests.vector_index import Item
from momento.responses.vector_index import UpsertItemBatch
texts = list(texts)
if len(texts) == 0:
return []
if metadatas is not None:
for metadata, text in zip(metadatas, texts):
metadata[self.text_field] = text
else:
metadatas = [{self.text_field: text} for text in texts]
try:
embeddings = self._embedding.embed_documents(texts)
except NotImplementedError:
embeddings = [self._embedding.embed_query(x) for x in texts]
# Create index if it does not exist.
# We assume that if it does exist, then it was created with the desired number
# of dimensions and similarity metric.
if self._ensure_index_exists:
self._create_index_if_not_exists(len(embeddings[0]))
if "ids" in kwargs:
ids = kwargs["ids"]
if len(ids) != len(embeddings):
raise ValueError("Number of ids must match number of texts")
else:
ids = [str(uuid4()) for _ in range(len(embeddings))]
batch_size = 128
for i in range(0, len(embeddings), batch_size):
start = i
end = min(i + batch_size, len(embeddings))
items = [
Item(id=id, vector=vector, metadata=metadata)
for id, vector, metadata in zip(
ids[start:end],
embeddings[start:end],
metadatas[start:end],
)
]
response = self._client.upsert_item_batch(self.index_name, items)
if isinstance(response, UpsertItemBatch.Success):
pass
elif isinstance(response, UpsertItemBatch.Error):
raise response.inner_exception
else:
raise Exception(f"Unexpected response: {response}")
return ids
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID.
Args:
ids (List[str]): List of ids to delete.
kwargs (Any): Other optional parameters (unused)
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
from momento.responses.vector_index import DeleteItemBatch
if ids is None:
return True
response = self._client.delete_item_batch(self.index_name, ids)
return isinstance(response, DeleteItemBatch.Success)
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Search for similar documents to the query string.
Args:
query (str): The query string to search for.
k (int, optional): The number of results to return. Defaults to 4.
Returns:
List[Document]: A list of documents that are similar to the query.
"""
res = self.similarity_search_with_score(query=query, k=k, **kwargs)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Search for similar documents to the query string.
Args:
query (str): The query string to search for.
k (int, optional): The number of results to return. Defaults to 4.
kwargs (Any): Vector Store specific search parameters. The following are
forwarded to the Momento Vector Index:
- top_k (int, optional): The number of results to return.
Returns:
List[Tuple[Document, float]]: A list of tuples of the form
(Document, score).
"""
embedding = self._embedding.embed_query(query)
results = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
return results
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Search for similar documents to the query vector.
Args:
embedding (List[float]): The query vector to search for.
k (int, optional): The number of results to return. Defaults to 4.
kwargs (Any): Vector Store specific search parameters. The following are
forwarded to the Momento Vector Index:
- top_k (int, optional): The number of results to return.
Returns:
List[Tuple[Document, float]]: A list of tuples of the form
(Document, score).
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import Search
if "top_k" in kwargs:
k = kwargs["k"]
filter_expression = kwargs.get("filter_expression", None)
response = self._client.search(
self.index_name,
embedding,
top_k=k,
metadata_fields=ALL_METADATA,
filter_expression=filter_expression,
)
if not isinstance(response, Search.Success):
return []
results = []
for hit in response.hits:
text = cast(str, hit.metadata.pop(self.text_field))
doc = Document(page_content=text, metadata=hit.metadata)
pair = (doc, hit.score)
results.append(pair)
return results
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Search for similar documents to the query vector.
Args:
embedding (List[float]): The query vector to search for.
k (int, optional): The number of results to return. Defaults to 4.
Returns:
List[Document]: A list of documents that are similar to the query.
"""
results = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
return [doc for doc, _ in results]
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import SearchAndFetchVectors
filter_expression = kwargs.get("filter_expression", None)
response = self._client.search_and_fetch_vectors(
self.index_name,
embedding,
top_k=fetch_k,
metadata_fields=ALL_METADATA,
filter_expression=filter_expression,
)
if isinstance(response, SearchAndFetchVectors.Success):
pass
elif isinstance(response, SearchAndFetchVectors.Error):
logger.error(f"Error searching and fetching vectors: {response}")
return []
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")
mmr_selected = maximal_marginal_relevance(
query_embedding=np.array([embedding], dtype=np.float32),
embedding_list=[hit.vector for hit in response.hits],
lambda_mult=lambda_mult,
k=k,
)
selected = [response.hits[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore
for metadata in selected
]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self._embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
"""Return the Vector Store initialized from texts and embeddings.
Args:
cls (Type[VST]): The Vector Store class to use to initialize
the Vector Store.
texts (List[str]): The texts to initialize the Vector Store with.
embedding (Embeddings): The embedding function to use.
metadatas (Optional[List[dict]], optional): The metadata associated with
the texts. Defaults to None.
kwargs (Any): Vector Store specific parameters. The following are forwarded
to the Vector Store constructor and required:
- index_name (str, optional): The name of the index to store the documents
in. Defaults to "default".
- text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text".
- distance_strategy (DistanceStrategy, optional): The distance strategy to
use. Defaults to DistanceStrategy.COSINE. If you select
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
Euclidean distance.
- ensure_index_exists (bool, optional): Whether to ensure that the index
exists before adding documents to it. Defaults to True.
Additionally you can either pass in a client or an API key
- client (PreviewVectorIndexClient): The Momento Vector Index client to use.
- api_key (Optional[str]): The configuration to use to initialize
the Vector Index with. Defaults to None. If None, the configuration
is initialized from the environment variable `MOMENTO_API_KEY`.
Returns:
VST: Momento Vector Index vector store initialized from texts and
embeddings.
"""
from momento import (
CredentialProvider,
PreviewVectorIndexClient,
VectorIndexConfigurations,
)
if "client" in kwargs:
client = kwargs.pop("client")
else:
supplied_api_key = kwargs.pop("api_key", None)
api_key = supplied_api_key or get_from_env("api_key", "MOMENTO_API_KEY")
client = PreviewVectorIndexClient(
configuration=VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_string(api_key),
)
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore
vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return vector_db