Skip to content

Commit

Permalink
API tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
atroyn committed Oct 30, 2023
1 parent 9e00022 commit 07480b6
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 45 deletions.
10 changes: 5 additions & 5 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,18 @@ def get(
"""

if "datas" in include and self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)

valid_where = validate_where(where) if where else None
valid_where_document = (
validate_where_document(where_document) if where_document else None
)
valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
valid_include = validate_include(include, allow_distances=False)

if "datas" in include and self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)

if "datas" in include and "uris" not in include:
valid_include.append("uris")

Expand Down
22 changes: 5 additions & 17 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,30 +149,14 @@ def create_collection(
embedding_function: Optional[
EmbeddingFunction[Any]
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Any]] = None,
data_loader: Optional[DataLoader[Embeddable]] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
if metadata is not None:
validate_metadata(metadata)

if existing:
if get_or_create:
if metadata and existing[0]["metadata"] != metadata:
self._modify(id=existing[0]["id"], new_metadata=metadata)
existing = self._sysdb.get_collections(id=existing[0]["id"])
return Collection(
client=self,
id=existing[0]["id"],
name=existing[0]["name"],
metadata=existing[0]["metadata"], # type: ignore
embedding_function=embedding_function,
data_loader=data_loader,
)
else:
raise ValueError(f"Collection {name} already exists.")

# TODO: remove backwards compatibility in naming requirements
check_index_name(name)

Expand Down Expand Up @@ -222,13 +206,15 @@ def get_or_create_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Embeddable]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
return self.create_collection( # type: ignore
name=name,
metadata=metadata,
embedding_function=embedding_function,
data_loader=data_loader,
get_or_create=True,
tenant=tenant,
database=database,
Expand All @@ -245,6 +231,7 @@ def get_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Embeddable]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand All @@ -259,6 +246,7 @@ def get_collection(
name=existing[0]["name"],
metadata=existing[0]["metadata"], # type: ignore
embedding_function=embedding_function,
data_loader=data_loader,
)
else:
raise ValueError(f"Collection {name} does not exist.")
Expand Down
86 changes: 63 additions & 23 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,17 @@ def test_persist_index_loading(api_fixture, request):
api2 = request.getfixturevalue("local_persist_api_cache_bust")
collection = api2.get_collection("test")

includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=["embeddings", "documents", "metadatas", "distances"],
)
for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None


@pytest.mark.parametrize("api_fixture", [local_persist_api])
Expand All @@ -102,13 +106,17 @@ def embedding_function(input):
api2 = request.getfixturevalue("local_persist_api_cache_bust")
collection = api2.get_collection("test", embedding_function=embedding_function)

includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None


@pytest.mark.parametrize("api_fixture", [local_persist_api])
Expand All @@ -128,14 +136,18 @@ def embedding_function(input):
"test", embedding_function=embedding_function
)

includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)

for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None

assert nn["ids"] == [["id1"]]
assert nn["embeddings"] == [[[1, 2, 3]]]
Expand Down Expand Up @@ -243,9 +255,13 @@ def test_get_from_db(api):
api.reset()
collection = api.create_collection("testspace")
collection.add(**batch_records)
records = collection.get(include=["embeddings", "documents", "metadatas"])
includes = ["embeddings", "documents", "metadatas"]
records = collection.get(include=includes)
for key in records.keys():
assert len(records[key]) == 2
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
else:
assert records[key] is None


def test_reset_db(api):
Expand All @@ -264,32 +280,42 @@ def test_get_nearest_neighbors(api):
collection = api.create_collection("testspace")
collection.add(**batch_records)

includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_embeddings=[1.1, 2.3, 3.2],
n_results=1,
where={},
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None

nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=1,
where={},
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None

nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]],
n_results=1,
where={},
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in nn.keys():
assert len(nn[key]) == 2
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
else:
assert nn[key] is None


def test_delete(api):
Expand Down Expand Up @@ -392,14 +418,18 @@ def test_increment_index_on(api):
collection.add(**batch_records)
assert collection.count() == 2

includes = ["embeddings", "documents", "metadatas", "distances"]
# increment index
nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=1,
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None


def test_add_a_collection(api):
Expand Down Expand Up @@ -448,7 +478,10 @@ def test_peek(api):
# peek
peek = collection.peek()
for key in peek.keys():
assert len(peek[key]) == 2
if key in ["embeddings", "documents", "metadatas"] or key == "ids":
assert len(peek[key]) == 2
else:
assert peek[key] is None


# TEST METADATA AND METADATA FILTERING
Expand Down Expand Up @@ -1121,14 +1154,17 @@ def test_persist_index_loading_params(api, request):
)

assert collection.metadata["hnsw:space"] == "ip"

includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in nn.keys():
assert len(nn[key]) == 1
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None


def test_add_large(api):
Expand Down Expand Up @@ -1235,14 +1271,18 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api):
collection = api.create_collection("testspace")
collection.add(**records)

results1 = collection.query(
includes = ["embeddings", "documents", "metadatas", "distances"]
results = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=5,
where={},
include=["embeddings", "documents", "metadatas", "distances"],
include=includes,
)
for key in results1.keys():
assert len(results1[key][0]) == 2
for key in results.keys():
if key in includes or key == "ids":
assert len(results[key][0]) == 2
else:
assert results[key] is None


def test_invalid_n_results_param(api):
Expand Down

0 comments on commit 07480b6

Please sign in to comment.