Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] URI Data Loader #1294

Merged
merged 20 commits into from
Nov 7, 2023
16 changes: 15 additions & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
Documents,
Embeddable,
EmbeddingFunction,
DataLoader,
Embeddings,
IDs,
Include,
Loadable,
Metadatas,
URIs,
Where,
QueryResult,
GetResult,
Expand Down Expand Up @@ -62,6 +65,7 @@ def create_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
"""Create a new collection with the given name and metadata.
Expand Down Expand Up @@ -98,6 +102,7 @@ def get_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
"""Get a collection with the given name.
Args:
Expand Down Expand Up @@ -127,6 +132,7 @@ def get_or_create_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
"""Get or create a collection with the given name and metadata.
Args:
Expand Down Expand Up @@ -193,6 +199,7 @@ def _add(
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.
Expand All @@ -203,6 +210,7 @@ def _add(
embedding: The sequence of embeddings to add.
metadata: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.

Returns:
True if the embeddings were added successfully.
Expand All @@ -217,6 +225,7 @@ def _update(
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Update entries in a collection specified by UUID.

Expand All @@ -226,7 +235,7 @@ def _update(
embeddings: The sequence of embeddings to update. Defaults to None.
metadatas: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.

uris: URIs of data sources for each embedding. Defaults to None.
Returns:
True if the embeddings were updated successfully.
"""
Expand All @@ -240,6 +249,7 @@ def _upsert(
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Add or update entries in the a collection specified by UUID.
If an entry with the same id already exists, it will be updated,
Expand All @@ -251,6 +261,7 @@ def _upsert(
embeddings: The sequence of embeddings to add
metadatas: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
"""
pass

Expand Down Expand Up @@ -496,6 +507,7 @@ def create_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand All @@ -511,6 +523,7 @@ def get_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand All @@ -525,6 +538,7 @@ def get_or_create_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand Down
30 changes: 26 additions & 4 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Documents,
Embeddable,
EmbeddingFunction,
Embeddings,
GetResult,
IDs,
Include,
Loadable,
Metadatas,
QueryResult,
URIs,
)
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
Expand Down Expand Up @@ -173,13 +177,17 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
return self._server.create_collection(
name=name,
metadata=metadata,
embedding_function=embedding_function,
data_loader=data_loader,
tenant=self.tenant,
database=self.database,
get_or_create=get_or_create,
Expand All @@ -188,14 +196,18 @@ def create_collection(
@override
def get_collection(
self,
name: Optional[str] = None,
name: str,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
return self._server.get_collection(
id=id,
name=name,
embedding_function=embedding_function,
data_loader=data_loader,
tenant=self.tenant,
database=self.database,
)
Expand All @@ -205,12 +217,16 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
return self._server.get_or_create_collection(
name=name,
metadata=metadata,
embedding_function=embedding_function,
data_loader=data_loader,
tenant=self.tenant,
database=self.database,
)
Expand Down Expand Up @@ -251,13 +267,15 @@ def _add(
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._add(
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)

@override
Expand All @@ -268,13 +286,15 @@ def _update(
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._update(
collection_id=collection_id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)

@override
Expand All @@ -285,13 +305,15 @@ def _upsert(
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._upsert(
collection_id=collection_id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)

@override
Expand Down
43 changes: 34 additions & 9 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
from chromadb.api import ServerAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
DataLoader,
Documents,
Embeddable,
Embeddings,
EmbeddingFunction,
IDs,
Include,
Loadable,
Metadatas,
URIs,
Where,
WhereDocument,
GetResult,
Expand Down Expand Up @@ -223,6 +226,7 @@ def create_collection(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand All @@ -246,6 +250,7 @@ def create_collection(
id=resp_json["id"],
name=resp_json["name"],
embedding_function=embedding_function,
data_loader=data_loader,
metadata=resp_json["metadata"],
)

Expand All @@ -255,7 +260,10 @@ def get_collection(
self,
name: str,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand All @@ -276,6 +284,7 @@ def get_collection(
name=resp_json["name"],
id=resp_json["id"],
embedding_function=embedding_function,
data_loader=data_loader,
metadata=resp_json["metadata"],
)

Expand All @@ -287,16 +296,20 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
return cast(
Collection,
self.create_collection(
name,
metadata,
embedding_function,
name=name,
metadata=metadata,
embedding_function=embedding_function,
data_loader=data_loader,
get_or_create=True,
tenant=tenant,
database=database,
Expand Down Expand Up @@ -403,6 +416,8 @@ def _get(
embeddings=body.get("embeddings", None),
metadatas=body.get("metadatas", None),
documents=body.get("documents", None),
data=None,
uris=body.get("uris", None),
)

@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand All @@ -429,7 +444,11 @@ def _delete(
def _submit_batch(
self,
batch: Tuple[
IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
url: str,
) -> requests.Response:
Expand All @@ -444,6 +463,7 @@ def _submit_batch(
"embeddings": batch[1],
"metadatas": batch[2],
"documents": batch[3],
"uris": batch[4],
}
),
)
Expand All @@ -458,12 +478,13 @@ def _add(
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""
Adds a batch of embeddings to the database
- pass in column oriented data lists
"""
batch = (ids, embeddings, metadatas, documents)
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
raise_chroma_error(resp)
Expand All @@ -478,12 +499,13 @@ def _update(
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""
batch = (ids, embeddings, metadatas, documents)
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
batch, "/collections/" + str(collection_id) + "/update"
Expand All @@ -500,12 +522,13 @@ def _upsert(
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""
Upserts a batch of embeddings in the database
- pass in column oriented data lists
"""
batch = (ids, embeddings, metadatas, documents)
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
batch, "/collections/" + str(collection_id) + "/upsert"
Expand Down Expand Up @@ -547,6 +570,8 @@ def _query(
embeddings=body.get("embeddings", None),
metadatas=body.get("metadatas", None),
documents=body.get("documents", None),
uris=body.get("uris", None),
data=None,
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
Expand Down
Loading