Skip to content

Commit

Permalink
Merge branch 'chroma-core:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
csbasil committed Feb 15, 2024
2 parents 83b2308 + da68516 commit 3d898ed
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 19 deletions.
32 changes: 16 additions & 16 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import orjson as json
import logging
from typing import Optional, cast, Tuple
from typing import Sequence
Expand Down Expand Up @@ -147,7 +147,7 @@ def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp = self._session.get(self._api_url)
raise_chroma_error(resp)
return int(resp.json()["nanosecond heartbeat"])
return int(json.loads(resp.text)["nanosecond heartbeat"])

@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -177,7 +177,7 @@ def get_database(
params={"tenant": tenant},
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Database(
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
)
Expand All @@ -198,7 +198,7 @@ def get_tenant(self, name: str) -> Tenant:
self._api_url + "/tenants/" + name,
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Tenant(name=resp_json["name"])

@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
Expand All @@ -221,7 +221,7 @@ def list_collections(
},
)
raise_chroma_error(resp)
json_collections = resp.json()
json_collections = json.loads(resp.text)
collections = []
for json_collection in json_collections:
collections.append(Collection(self, **json_collection))
Expand All @@ -239,7 +239,7 @@ def count_collections(
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
return cast(int, resp.json())
return cast(int, json.loads(resp.text))

@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -268,7 +268,7 @@ def create_collection(
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Collection(
client=self,
id=resp_json["id"],
Expand Down Expand Up @@ -302,7 +302,7 @@ def get_collection(
self._api_url + "/collections/" + name if name else str(id), params=_params
)
raise_chroma_error(resp)
resp_json = resp.json()
resp_json = json.loads(resp.text)
return Collection(
client=self,
name=resp_json["name"],
Expand Down Expand Up @@ -381,7 +381,7 @@ def _count(
self._api_url + "/collections/" + str(collection_id) + "/count"
)
raise_chroma_error(resp)
return cast(int, resp.json())
return cast(int, json.loads(resp.text))

@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -434,7 +434,7 @@ def _get(
)

raise_chroma_error(resp)
body = resp.json()
body = json.loads(resp.text)
return GetResult(
ids=body["ids"],
embeddings=body.get("embeddings", None),
Expand Down Expand Up @@ -462,7 +462,7 @@ def _delete(
)

raise_chroma_error(resp)
return cast(IDs, resp.json())
return cast(IDs, json.loads(resp.text))

@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
def _submit_batch(
Expand Down Expand Up @@ -586,7 +586,7 @@ def _query(
)

raise_chroma_error(resp)
body = resp.json()
body = json.loads(resp.text)

return QueryResult(
ids=body["ids"],
Expand All @@ -604,15 +604,15 @@ def reset(self) -> bool:
"""Resets the database"""
resp = self._session.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, resp.json())
return cast(bool, json.loads(resp.text))

@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
def get_version(self) -> str:
"""Returns the version of the server"""
resp = self._session.get(self._api_url + "/version")
raise_chroma_error(resp)
return cast(str, resp.json())
return cast(str, json.loads(resp.text))

@override
def get_settings(self) -> Settings:
Expand All @@ -626,7 +626,7 @@ def max_batch_size(self) -> int:
if self._max_batch_size == -1:
resp = self._session.get(self._api_url + "/pre-flight-checks")
raise_chroma_error(resp)
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"])
return self._max_batch_size


Expand All @@ -637,7 +637,7 @@ def raise_chroma_error(resp: requests.Response) -> None:

chroma_error = None
try:
body = resp.json()
body = json.loads(resp.text)
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](body["message"])
Expand Down
8 changes: 6 additions & 2 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]

META_KEY_CHROMA_DOCUMENT = "chroma:document"
T = TypeVar("T")
OneOrMany = Union[T, List[T]]

Expand Down Expand Up @@ -265,6 +265,10 @@ def validate_metadata(metadata: Metadata) -> Metadata:
if len(metadata) == 0:
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
for key, value in metadata.items():
if key == META_KEY_CHROMA_DOCUMENT:
raise ValueError(
f"Expected metadata to not contain the reserved key {META_KEY_CHROMA_DOCUMENT}"
)
if not isinstance(key, str):
raise TypeError(
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
Expand Down Expand Up @@ -476,7 +480,7 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
raise ValueError(
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
for i,embedding in enumerate(embeddings):
for i, embedding in enumerate(embeddings):
if len(embedding) == 0:
raise ValueError(
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
Expand Down
9 changes: 9 additions & 0 deletions chromadb/test/segment/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import tempfile
import pytest
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence

from chromadb.api.types import validate_metadata
from chromadb.config import System, Settings
from chromadb.db.base import ParameterValue, get_sql
from chromadb.db.impl.sqlite import SqliteDB
Expand Down Expand Up @@ -677,3 +679,10 @@ def test_delete_segment(
res = cur.execute(sql, params)
# assert that all FTS rows are gone
assert len(res.fetchall()) == 0


def test_metadata_validation_forbidden_key() -> None:
with pytest.raises(ValueError, match="chroma:document"):
validate_metadata(
{"chroma:document": "this is not the document you are looking for"}
)
1 change: 1 addition & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
'typing_extensions >= 4.5.0',
'tenacity>=8.2.3',
'PyYAML>=6.0.0',
'orjson>=3.9.12',
]

[tool.black]
Expand Down
1 change: 1 addition & 0 deletions clients/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ PyYAML>=6.0.0
requests >= 2.28
tenacity>=8.2.3
typing_extensions >= 4.5.0
orjson>=3.9.12
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
'tenacity>=8.2.3',
'PyYAML>=6.0.0',
'mmh3>=4.0.1',
'orjson>=3.9.12',
]

[tool.black]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ tqdm>=4.65.0
typer>=0.9.0
typing_extensions>=4.5.0
uvicorn[standard]==0.18.3
orjson>=3.9.12
1 change: 0 additions & 1 deletion server.htpasswd

This file was deleted.

0 comments on commit 3d898ed

Please sign in to comment.