Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TST] Property Test Generation Fixes #2383

Merged
merged 15 commits into from
Jun 26, 2024
6 changes: 0 additions & 6 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,6 @@ def _exact_distances(
return np.argsort(distances).tolist(), distances.tolist()


def is_metadata_valid(normalized_record_set: NormalizedRecordSet) -> bool:
if normalized_record_set["metadatas"] is None:
return True
return not any([len(m) == 0 for m in normalized_record_set["metadatas"]])


def ann_accuracy(
collection: Collection,
record_set: RecordSet,
Expand Down
43 changes: 33 additions & 10 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import re
from hypothesis.strategies._internal.strategies import SearchStrategy
from chromadb.test.conftest import NOT_CLUSTER_ONLY

from dataclasses import dataclass

from chromadb.api.types import (
Documents,
Embeddable,
Expand Down Expand Up @@ -60,7 +58,7 @@ class RecordSet(TypedDict):

ids: Union[types.ID, List[types.ID]]
embeddings: Optional[Union[types.Embeddings, types.Embedding]]
metadatas: Optional[Union[List[types.Metadata], types.Metadata]]
metadatas: Optional[Union[List[Optional[types.Metadata]], types.Metadata]]
documents: Optional[Union[List[types.Document], types.Document]]


Expand All @@ -71,7 +69,7 @@ class NormalizedRecordSet(TypedDict):

ids: List[types.ID]
embeddings: Optional[types.Embeddings]
metadatas: Optional[List[types.Metadata]]
metadatas: Optional[List[Optional[types.Metadata]]]
documents: Optional[List[types.Document]]


Expand Down Expand Up @@ -347,10 +345,16 @@ def collections(


@st.composite
def metadata(draw: st.DrawFn, collection: Collection) -> types.Metadata:
def metadata(
draw: st.DrawFn, collection: Collection, min_size=0, max_size=None
) -> Optional[types.Metadata]:
"""Strategy for generating metadata that could be a part of the given collection"""
# First draw a random dictionary.
metadata: types.Metadata = draw(st.dictionaries(safe_text, st.one_of(*safe_values)))
metadata: types.Metadata = draw(
st.dictionaries(
safe_text, st.one_of(*safe_values), min_size=min_size, max_size=max_size
)
)
# Then, remove keys that overlap with the known keys for the coll
# to avoid type errors when comparing.
if collection.known_metadata_keys:
Expand All @@ -362,6 +366,9 @@ def metadata(draw: st.DrawFn, collection: Collection) -> types.Metadata:
k: st.just(v) for k, v in collection.known_metadata_keys.items()
}
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
# We don't allow submitting empty metadata
if metadata == {}:
return None
return metadata


Expand Down Expand Up @@ -394,6 +401,12 @@ def recordsets(
id_strategy: SearchStrategy[str] = safe_text,
min_size: int = 1,
max_size: int = 50,
# If num_unique_metadata is not None, then the number of metadata generations
# will be the size of the record set. If set, the number of metadata
# generations will be the value of num_unique_metadata.
num_unique_metadata: Optional[int] = None,
HammadB marked this conversation as resolved.
Show resolved Hide resolved
min_metadata_size: int = 0,
max_metadata_size: Optional[int] = None,
) -> RecordSet:
collection = draw(collection_strategy)

Expand All @@ -404,9 +417,20 @@ def recordsets(
embeddings: Optional[Embeddings] = None
if collection.has_embeddings:
embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype)
metadatas = draw(
st.lists(metadata(collection), min_size=len(ids), max_size=len(ids))
num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids)
generated_metadatas = draw(
st.lists(
metadata(
collection, min_size=min_metadata_size, max_size=max_metadata_size
),
min_size=num_metadata,
max_size=num_metadata,
)
)
metadatas = []
for i in range(len(ids)):
metadatas.append(generated_metadatas[i % len(generated_metadatas)])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

documents: Optional[Documents] = None
if collection.has_documents:
documents = draw(
Expand All @@ -423,7 +447,7 @@ def recordsets(
if embeddings is not None and draw(st.booleans())
else embeddings
)
single_metadata: Union[Metadata, List[Metadata]] = (
single_metadata: Union[Optional[Metadata], List[Optional[Metadata]]] = (
metadatas[0] if draw(st.booleans()) else metadatas
)
single_document = (
Expand All @@ -435,7 +459,6 @@ def recordsets(
"metadatas": single_metadata,
"documents": single_document,
}

return {
"ids": ids,
"embeddings": embeddings,
Expand Down
133 changes: 100 additions & 33 deletions chromadb/test/property/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")


# Hypothesis tends to generate smaller values so we explicitly segregate the
# the tests into tiers, Small, Medium. Hypothesis struggles to generate large
# record sets so we explicitly create a large record set without using Hypothesis
@given(
collection=collection_st,
record_set=strategies.recordsets(collection_st),
record_set=strategies.recordsets(collection_st, min_size=1, max_size=500),
should_compact=st.booleans(),
)
@settings(
Expand All @@ -34,47 +37,110 @@
fast=hypothesis.settings(max_examples=200),
),
)
def test_add(
def test_add_small(
api: ServerAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
should_compact: bool,
) -> None:
_test_add(api, collection, record_set, should_compact)


@given(
collection=collection_st,
record_set=strategies.recordsets(
collection_st,
min_size=250,
max_size=500,
num_unique_metadata=5,
min_metadata_size=1,
max_metadata_size=5,
),
should_compact=st.booleans(),
)
@settings(
deadline=None,
parent=override_hypothesis_profile(
normal=hypothesis.settings(max_examples=10),
fast=hypothesis.settings(max_examples=5),
),
suppress_health_check=[
hypothesis.HealthCheck.too_slow,
hypothesis.HealthCheck.data_too_large,
hypothesis.HealthCheck.large_base_example,
hypothesis.HealthCheck.function_scoped_fixture,
],
)
def test_add_medium(
api: ServerAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
should_compact: bool,
) -> None:
# Cluster tests transmit their results over grpc, which has a payload limit
# This breaks the ann_accuracy invariant by default, since
# the vector reader returns a payload of dataset size. So we need to batch
# the queries in the ann_accuracy invariant
_test_add(api, collection, record_set, should_compact, batch_ann_accuracy=True)


def _test_add(
api: ServerAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
should_compact: bool,
batch_ann_accuracy: bool = False,
) -> None:
reset(api)

# TODO: Generative embedding functions
coll = api.create_collection(
name=collection.name,
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)

normalized_record_set = invariants.wrap_all(record_set)

if not invariants.is_metadata_valid(normalized_record_set):
with pytest.raises(Exception):
coll.add(**normalized_record_set)
return

coll.add(**record_set)

if not NOT_CLUSTER_ONLY:
# TODO: The type of add() is incorrect as it does not allow for metadatas
# like [{"a": 1}, None, {"a": 3}]
coll.add(**record_set) # type: ignore
if (
not NOT_CLUSTER_ONLY
and should_compact
and len(normalized_record_set["ids"]) > 10
):
# Only wait for compaction if the size of the collection is
# some minimal size
if should_compact and len(normalized_record_set["ids"]) > 10:
initial_version = coll.get_model()["version"]
# Wait for the model to be updated
wait_for_version_increase(api, collection.name, initial_version)
initial_version = coll.get_model()["version"]
# Wait for the model to be updated
wait_for_version_increase(api, collection.name, initial_version)

invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))
n_results = max(1, (len(normalized_record_set["ids"]) // 10))
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_results=n_results,
embedding_function=collection.embedding_function,
)

if batch_ann_accuracy:
batch_size = 10
for i in range(0, len(normalized_record_set["ids"]), batch_size):
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_results=n_results,
embedding_function=collection.embedding_function,
query_indices=list(
range(i, min(i + batch_size, len(normalized_record_set["ids"])))
),
)
else:
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_results=n_results,
embedding_function=collection.embedding_function,
)


# Hypothesis struggles to generate large record sets so we explicitly create
# a large record set
def create_large_recordset(
Copy link
Contributor

@atroyn atroyn Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel like we could add some more randomization in here. For example, all embeddings are the same - this is guaranteed to produce a bad HNSW graph. Unrelated to the focus of this PR However.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I tried to replace this with hypothesis but still need to do some munging, I cut myself a task

min_size: int = 45000,
max_size: int = 50000,
Expand All @@ -94,9 +160,11 @@ def create_large_recordset(
return cast(strategies.RecordSet, record_set)


@given(collection=collection_st)
@settings(deadline=None, max_examples=1)
def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None:
@given(collection=collection_st, should_compact=st.booleans())
@settings(deadline=None, max_examples=5)
def test_add_large(
api: ServerAPI, collection: strategies.Collection, should_compact: bool
) -> None:
reset(api)

record_set = create_large_recordset(
Expand All @@ -111,10 +179,6 @@ def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None:
)
normalized_record_set = invariants.wrap_all(record_set)

if not invariants.is_metadata_valid(normalized_record_set):
with pytest.raises(Exception):
coll.add(**normalized_record_set)
return
for batch in create_batches(
api=api,
ids=cast(List[str], record_set["ids"]),
Expand All @@ -123,6 +187,14 @@ def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None:
documents=cast(List[str], record_set["documents"]),
):
coll.add(*batch)

if not NOT_CLUSTER_ONLY and should_compact:
initial_version = coll.get_model()["version"]
# Wait for the model to be updated, since the record set is larger, add some additional time
wait_for_version_increase(
api, collection.name, initial_version, additional_time=240
)

invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))


Expand All @@ -141,12 +213,7 @@ def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection)
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)
normalized_record_set = invariants.wrap_all(record_set)

if not invariants.is_metadata_valid(normalized_record_set):
with pytest.raises(Exception):
coll.add(**normalized_record_set)
return
with pytest.raises(Exception) as e:
coll.add(**record_set)
assert "exceeds maximum batch size" in str(e.value)
Expand Down
Loading
Loading