-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
types.py
525 lines (449 loc) · 18.4 KB
/
types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
from numpy.typing import NDArray
import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
import chromadb.errors as errors
from chromadb.types import (
Metadata,
UpdateMetadata,
Vector,
LiteralValue,
LogicalOperator,
WhereOperator,
OperatorExpression,
Where,
WhereDocumentOperator,
WhereDocument,
)
from inspect import signature
from tenacity import retry
# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
META_KEY_CHROMA_DOCUMENT = "chroma:document"
T = TypeVar("T")
OneOrMany = Union[T, List[T]]
# URIs
URI = str
URIs = List[URI]
def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs:
if isinstance(target, str):
# One URI
return cast(URIs, [target])
# Already a sequence
return cast(URIs, target)
# IDs
ID = str
IDs = List[ID]
def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
if isinstance(target, str):
# One ID
return cast(IDs, [target])
# Already a sequence
return cast(IDs, target)
# Embeddings
Embedding = Vector
Embeddings = List[Embedding]
def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings:
if isinstance(target, List):
# One Embedding
if isinstance(target[0], (int, float)):
return cast(Embeddings, [target])
# Already a sequence
return cast(Embeddings, target)
# Metadatas
Metadatas = List[Metadata]
def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas:
# One Metadata dict
if isinstance(target, dict):
return cast(Metadatas, [target])
# Already a sequence
return cast(Metadatas, target)
CollectionMetadata = Dict[str, Any]
UpdateCollectionMetadata = UpdateMetadata
# Documents
Document = str
Documents = List[Document]
def is_document(target: Any) -> bool:
if not isinstance(target, str):
return False
return True
def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:
# One Document
if is_document(target):
return cast(Documents, [target])
# Already a sequence
return cast(Documents, target)
# Images
ImageDType = Union[np.uint, np.int_, np.float_]
Image = NDArray[ImageDType]
Images = List[Image]
def is_image(target: Any) -> bool:
if not isinstance(target, np.ndarray):
return False
if len(target.shape) < 2:
return False
return True
def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
if is_image(target):
return cast(Images, [target])
# Already a sequence
return cast(Images, target)
Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID)
# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]]
# However, this provokes an incompatibility with the Overrides library and Python 3.7
Include = List[
Union[
Literal["documents"],
Literal["embeddings"],
Literal["metadatas"],
Literal["distances"],
Literal["uris"],
Literal["data"],
]
]
# Re-export types from chromadb.types
LiteralValue = LiteralValue
LogicalOperator = LogicalOperator
WhereOperator = WhereOperator
OperatorExpression = OperatorExpression
Where = Where
WhereDocumentOperator = WhereDocumentOperator
Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)
Loadable = List[Optional[Image]]
L = TypeVar("L", covariant=True, bound=Loadable)
class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
documents: Optional[List[Document]]
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
included: Include
class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
uris: Optional[List[List[URI]]]
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
included: Include
class IndexMetadata(TypedDict):
dimensionality: int
# The current number of elements in the index (total = additions - deletes)
curr_elements: int
# The auto-incrementing ID of the last inserted element, never decreases so
# can be used as a count of total historical size. Should increase by 1 every add.
# Assume cannot overflow
total_elements_added: int
time_created: float
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...
def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
call = getattr(cls, "__call__")
def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = call(self, input)
return validate_embeddings(maybe_cast_one_to_many_embedding(result))
setattr(cls, "__call__", __call__)
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)
def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
) -> None:
function_signature = signature(
embedding_function.__class__.__call__
).parameters.keys()
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
if not function_signature == protocol_signature:
raise ValueError(
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
)
class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
def validate_ids(ids: IDs) -> IDs:
"""Validates ids to ensure it is a list of strings"""
if not isinstance(ids, list):
raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs")
if len(ids) == 0:
raise ValueError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs")
seen = set()
dups = set()
for id_ in ids:
if not isinstance(id_, str):
raise ValueError(f"Expected ID to be a str, got {id_}")
if id_ in seen:
dups.add(id_)
else:
seen.add(id_)
if dups:
n_dups = len(dups)
if n_dups < 10:
example_string = ", ".join(dups)
message = (
f"Expected IDs to be unique, found duplicates of: {example_string}"
)
else:
examples = []
for idx, dup in enumerate(dups):
examples.append(dup)
if idx == 10:
break
example_string = (
f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}"
)
message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}"
raise errors.DuplicateIDError(message)
return ids
def validate_metadata(metadata: Metadata) -> Metadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(
f"Expected metadata to be a dict or None, got {type(metadata).__name__} as metadata"
)
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(
f"Expected metadata to be a non-empty dict, got {len(metadata)} metadata attributes"
)
for key, value in metadata.items():
if key == META_KEY_CHROMA_DOCUMENT:
raise ValueError(
f"Expected metadata to not contain the reserved key {META_KEY_CHROMA_DOCUMENT}"
)
if not isinstance(key, str):
raise TypeError(
f"Expected metadata key to be a str, got {key} which is a {type(key).__name__}"
)
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(value, (str, int, float)):
raise ValueError(
f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value).__name__}"
)
return metadata
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
if not isinstance(metadata, dict) and metadata is not None:
raise ValueError(
f"Expected metadata to be a dict or None, got {type(metadata)}"
)
if metadata is None:
return metadata
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Expected metadata key to be a str, got {key}")
# isinstance(True, int) evaluates to True, so we need to check for bools separately
if not isinstance(value, bool) and not isinstance(
value, (str, int, float, type(None))
):
raise ValueError(
f"Expected metadata value to be a str, int, or float, got {value}"
)
return metadata
def validate_metadatas(metadatas: Metadatas) -> Metadatas:
"""Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools"""
if not isinstance(metadatas, list):
raise ValueError(f"Expected metadatas to be a list, got {metadatas}")
for metadata in metadatas:
validate_metadata(metadata)
return metadatas
def validate_where(where: Where) -> Where:
"""
Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions,
or in the case of $and and $or, a list of where expressions
"""
if not isinstance(where, dict):
raise ValueError(f"Expected where to be a dict, got {where}")
if len(where) != 1:
raise ValueError(f"Expected where to have exactly one operator, got {where}")
for key, value in where.items():
if not isinstance(key, str):
raise ValueError(f"Expected where key to be a str, got {key}")
if (
key != "$and"
and key != "$or"
and key != "$in"
and key != "$nin"
and not isinstance(value, (str, int, float, dict))
):
raise ValueError(
f"Expected where value to be a str, int, float, or operator expression, got {value}"
)
if key == "$and" or key == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
)
if len(value) <= 1:
raise ValueError(
f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}"
)
for where_expression in value:
validate_where(where_expression)
# Value is a operator expression
if isinstance(value, dict):
# Ensure there is only one operator
if len(value) != 1:
raise ValueError(
f"Expected operator expression to have exactly one operator, got {value}"
)
for operator, operand in value.items():
# Only numbers can be compared with gt, gte, lt, lte
if operator in ["$gt", "$gte", "$lt", "$lte"]:
if not isinstance(operand, (int, float)):
raise ValueError(
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
)
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected operand value to be an list for operator {operator}, got {operand}"
)
if operator not in [
"$gt",
"$gte",
"$lt",
"$lte",
"$ne",
"$eq",
"$in",
"$nin",
]:
raise ValueError(
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
f"got {operator}"
)
if not isinstance(operand, (str, int, float, list)):
raise ValueError(
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
)
if isinstance(operand, list) and (
len(operand) == 0
or not all(isinstance(x, type(operand[0])) for x in operand)
):
raise ValueError(
f"Expected where operand value to be a non-empty list, and all values to be of the same type "
f"got {operand}"
)
return where
def validate_where_document(where_document: WhereDocument) -> WhereDocument:
"""
Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or,
a list of where_document expressions
"""
if not isinstance(where_document, dict):
raise ValueError(
f"Expected where document to be a dictionary, got {where_document}"
)
if len(where_document) != 1:
raise ValueError(
f"Expected where document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator not in ["$contains", "$not_contains", "$and", "$or"]:
raise ValueError(
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
)
if operator == "$and" or operator == "$or":
if not isinstance(operand, list):
raise ValueError(
f"Expected document value for $and or $or to be a list of where document expressions, got {operand}"
)
if len(operand) <= 1:
raise ValueError(
f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}"
)
for where_document_expression in operand:
validate_where_document(where_document_expression)
# Value is a $contains operator
elif not isinstance(operand, str):
raise ValueError(
f"Expected where document operand value for operator $contains to be a str, got {operand}"
)
elif len(operand) == 0:
raise ValueError(
"Expected where document operand value for operator $contains to be a non-empty str"
)
return where_document
def validate_include(include: Include, allow_distances: bool) -> Include:
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
to control if distances is allowed"""
if not isinstance(include, list):
raise ValueError(f"Expected include to be a list, got {include}")
for item in include:
if not isinstance(item, str):
raise ValueError(f"Expected include item to be a str, got {item}")
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")
if item not in allowed_values:
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
)
return include
def validate_n_results(n_results: int) -> int:
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
# Check Number of requested results
if not isinstance(n_results, int):
raise ValueError(
f"Expected requested number of results to be a int, got {n_results}"
)
if n_results <= 0:
raise TypeError(
f"Number of requested results {n_results}, cannot be negative, or zero."
)
return n_results
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validates embeddings to ensure it is a list of list of ints, or floats"""
if not isinstance(embeddings, list):
raise ValueError(
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
)
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
)
if not all([isinstance(e, list) for e in embeddings]):
raise ValueError(
"Expected each embedding in the embeddings to be a list, got "
f"{list(set([type(e).__name__ for e in embeddings]))}"
)
for i, embedding in enumerate(embeddings):
if len(embedding) == 0:
raise ValueError(
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
)
if not all(
[
isinstance(value, (int, float)) and not isinstance(value, bool)
for value in embedding
]
):
raise ValueError(
"Expected each value in the embedding to be a int or float, got an embedding with "
f"{list(set([type(value).__name__ for value in embedding]))} - {embedding}"
)
return embeddings
def validate_batch(
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
limits: Dict[str, Any],
) -> None:
if len(batch[0]) > limits["max_batch_size"]:
raise ValueError(
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
)