diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 142ab78a05f..b3090317d6f 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -37,6 +37,8 @@ "UpdateCollectionMetadata", "QueryResult", "GetResult", + "__version__", + "Settings", ] logger = logging.getLogger(__name__) diff --git a/chromadb/api/client.py b/chromadb/api/client.py index ba797677e46..45761f57b59 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -1,10 +1,14 @@ +import logging from typing import ClassVar, Dict, Optional, Sequence from uuid import UUID import uuid from overrides import override import requests + +from chromadb import errors from chromadb.api import AdminAPI, ClientAPI, ServerAPI +from chromadb.api.fastapi import FastAPI from chromadb.api.types import ( CollectionMetadata, DataLoader, @@ -28,6 +32,9 @@ from chromadb.telemetry.product.events import ClientStartEvent from chromadb.types import Database, Tenant, Where, WhereDocument import chromadb.utils.embedding_functions as ef +from chromadb.utils.client_utils import compare_versions + +logger = logging.getLogger(__name__) class SharedSystemClient: @@ -137,12 +144,11 @@ def __init__( settings: Settings = Settings(), ) -> None: super().__init__(settings=settings) + self._validated = False self.tenant = tenant self.database = database # Create an admin client for verifying that databases and tenants exist self._admin_client = AdminClient.from_system(self._system) - self._validate_tenant_database(tenant=tenant, database=database) - # Get the root system component we want to interact with self._server = self._system.instance(ServerAPI) @@ -168,19 +174,19 @@ def from_system( # Note - we could do this in less verbose ways, but they break type checking @override def heartbeat(self) -> int: - return self._server.heartbeat() + return self._validate()._server.heartbeat() @override def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None ) -> Sequence[Collection]: - return self._server.list_collections( + return self._validate()._server.list_collections( limit, offset, tenant=self.tenant, database=self.database ) @override def count_collections(self) -> int: - return self._server.count_collections( + return self._validate()._server.count_collections( tenant=self.tenant, database=self.database ) @@ -195,7 +201,7 @@ def create_collection( data_loader: Optional[DataLoader[Loadable]] = None, get_or_create: bool = False, ) -> Collection: - return self._server.create_collection( + return self._validate()._server.create_collection( name=name, metadata=metadata, embedding_function=embedding_function, @@ -215,7 +221,7 @@ def get_collection( ] = ef.DefaultEmbeddingFunction(), # type: ignore data_loader: Optional[DataLoader[Loadable]] = None, ) -> Collection: - return self._server.get_collection( + return self._validate()._server.get_collection( id=id, name=name, embedding_function=embedding_function, @@ -234,7 +240,7 @@ def get_or_create_collection( ] = ef.DefaultEmbeddingFunction(), # type: ignore data_loader: Optional[DataLoader[Loadable]] = None, ) -> Collection: - return self._server.get_or_create_collection( + return self._validate()._server.get_or_create_collection( name=name, metadata=metadata, embedding_function=embedding_function, @@ -250,7 +256,7 @@ def _modify( new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, ) -> None: - return self._server._modify( + return self._validate()._server._modify( id=id, new_name=new_name, new_metadata=new_metadata, @@ -261,7 +267,7 @@ def delete_collection( self, name: str, ) -> None: - return self._server.delete_collection( + return self._validate()._server.delete_collection( name=name, tenant=self.tenant, database=self.database, @@ -281,7 +287,7 @@ def _add( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - return self._server._add( + return self._validate()._server._add( ids=ids, collection_id=collection_id, embeddings=embeddings, @@ -300,7 +306,7 @@ def _update( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - return self._server._update( + return self._validate()._server._update( collection_id=collection_id, ids=ids, embeddings=embeddings, @@ -319,7 +325,7 @@ def _upsert( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - return self._server._upsert( + return self._validate()._server._upsert( collection_id=collection_id, ids=ids, embeddings=embeddings, @@ -330,13 +336,13 @@ def _upsert( @override def _count(self, collection_id: UUID) -> int: - return self._server._count( + return self._validate()._server._count( collection_id=collection_id, ) @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: - return self._server._peek( + return self._validate()._server._peek( collection_id=collection_id, n=n, ) @@ -355,7 +361,7 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: - return self._server._get( + return self._validate()._server._get( collection_id=collection_id, ids=ids, where=where, @@ -375,7 +381,7 @@ def _delete( where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, ) -> IDs: - return self._server._delete( + return self._validate()._server._delete( collection_id=collection_id, ids=ids, where=where, @@ -392,7 +398,7 @@ def _query( where_document: WhereDocument = {}, include: Include = ["embeddings", "metadatas", "documents", "distances"], ) -> QueryResult: - return self._server._query( + return self._validate()._server._query( collection_id=collection_id, query_embeddings=query_embeddings, n_results=n_results, @@ -433,6 +439,57 @@ def set_database(self, database: str) -> None: self._validate_tenant_database(tenant=self.tenant, database=database) self.database = database + def _validate_connectivity(self) -> None: + """Validates connectivity to the server. Returns True if successful, False otherwise.""" + if not isinstance(self._server, FastAPI): + return + try: + self._server.heartbeat() + except requests.exceptions.ConnectionError as e: + raise errors.GenericError( + code=-1, message=f"Chroma server seems inaccessible: {str(e)}" + ) + except requests.exceptions.HTTPError as ex: + if ex.response.status_code in [504, 502, 503]: # type: ignore + # proxy errors, Gateway timeout = 504, Bad Gateway = 502, Service Unavailable = 503 + raise errors.GenericError( + code=ex.response.status_code, # type: ignore + message=f"Your proxy reports Chroma server might not be accessible: {str(ex)}", + ) + else: + raise errors.GenericError(code=ex.response.status_code, message=str(ex)) # type: ignore + + def _validate(self) -> "Client": + if self._validated: + return self + self._validate_connectivity() + self._version_compatibility_check() + self._validate_tenant_database(tenant=self.tenant, database=self.database) + self._validated = True + return self + + @staticmethod + def _min_server_compatible_version() -> str: + # TODO - this should be automatically generated upon release + return "0.4.15" + + def _version_compatibility_check(self) -> None: + """Checks if the client and server versions are compatible. Raises an error if not.""" + if isinstance(self._server, FastAPI): + server_version = self.get_version() + if ( + compare_versions(server_version, self._min_server_compatible_version()) + < 0 + ): + from chromadb import __version__ as local_chroma_version + + raise ValueError( + f"It appears you are using newer version of Chroma client (v{local_chroma_version}) " + f"that is not compatible with Chroma server (v{server_version}). " + f"Please upgrade your server to a compatible version " + f"(min: v{self._min_server_compatible_version()})." + ) + def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_tenant(name=tenant) @@ -440,6 +497,23 @@ def _validate_tenant_database(self, tenant: str, database: str) -> None: raise ValueError( "Could not connect to a Chroma server. Are you sure it is running?" ) + except requests.exceptions.HTTPError as ex: + if ex.response.status_code in [401, 403]: # type: ignore + raise ValueError( + "Authentication error. Have you configured your client to use authentication?" + ) + if ex.response.status_code == 404: # type: ignore + from chromadb import __version__ as local_chroma_version + + raise ValueError( + f"It appears you are using newer version of Chroma client (v{local_chroma_version}) " + f"that is not compatible with Chroma server (v{self.get_version()}). " + "Please upgrade your server to the latest version." + ) + else: + raise ValueError( + f"Could not connect to tenant {tenant}. Are you sure it exists?" + ) # Propagate ChromaErrors except ChromaError as e: raise e diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index d3d1a8a4e7e..21bdd74c288 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -639,14 +639,14 @@ def raise_chroma_error(resp: requests.Response) -> None: if "error" in body: if body["error"] in errors.error_types: chroma_error = errors.error_types[body["error"]](body["message"]) + else: + chroma_error = errors.GenericError(resp.status_code, body["message"]) + # TODO do we need to catch BaseException here? except BaseException: pass if chroma_error: raise chroma_error - try: - resp.raise_for_status() - except requests.HTTPError: - raise (Exception(resp.text)) + resp.raise_for_status() diff --git a/chromadb/errors.py b/chromadb/errors.py index f082fc76665..f6c85695fdd 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -75,6 +75,31 @@ def name(cls) -> str: return "AuthorizationError" +class GenericError(ChromaError): + def __init__(self, code: int, message: str) -> None: + self._code = code + self._message = message + + @overrides + def code(self) -> int: + return self._code + + @overrides + def message(self) -> str: + return self._message + + @classmethod + @overrides + def name(cls) -> str: + return "ServerError" + + def __str__(self) -> str: + return f"{self.name()}(code={self._code}, message={self._message})" + + def __repr__(self) -> str: + return str(self) + + error_types: Dict[str, Type[ChromaError]] = { "InvalidDimension": InvalidDimensionException, "InvalidCollection": InvalidCollectionException, diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index f01f4137908..f4c4a9c975f 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute -from fastapi import HTTPException, status +from fastapi import status from uuid import UUID from chromadb.api.models.Collection import Collection from chromadb.api.types import GetResult, QueryResult @@ -32,6 +32,7 @@ from chromadb.api import ServerAPI from chromadb.errors import ( ChromaError, + InvalidUUIDError, InvalidDimensionException, InvalidHTTPVersion, ) @@ -83,7 +84,9 @@ async def catch_exceptions_middleware( return fastapi_json_response(e) except Exception as e: logger.exception(e) - return JSONResponse(content={"error": repr(e)}, status_code=500) + return JSONResponse( + content={"error": str(type(e)), "message": f"e{str(e)}"}, status_code=500 + ) async def check_http_version_middleware( @@ -477,17 +480,14 @@ def delete_collection( ), ) def add(self, collection_id: str, add: AddEmbedding) -> None: - try: - result = self._api._add( - collection_id=_uuid(collection_id), - embeddings=add.embeddings, # type: ignore - metadatas=add.metadatas, # type: ignore - documents=add.documents, # type: ignore - uris=add.uris, # type: ignore - ids=add.ids, - ) - except InvalidDimensionException as e: - raise HTTPException(status_code=500, detail=str(e)) + result = self._api._add( + collection_id=_uuid(collection_id), + embeddings=add.embeddings, # type: ignore + metadatas=add.metadatas, # type: ignore + documents=add.documents, # type: ignore + uris=add.uris, # type: ignore + ids=add.ids, + ) return result # type: ignore @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION) diff --git a/chromadb/test/client/test_client_compatibility.py b/chromadb/test/client/test_client_compatibility.py new file mode 100644 index 00000000000..a4d65f9fe85 --- /dev/null +++ b/chromadb/test/client/test_client_compatibility.py @@ -0,0 +1,118 @@ +import json +import uuid +import time + +import pytest +from hypothesis import given, strategies as st +from pytest_httpserver import HTTPServer + +import chromadb +from chromadb.api.client import SharedSystemClient +from chromadb.errors import GenericError +from chromadb.types import Tenant, Database + + +@pytest.fixture(autouse=True) +def reset_client_settings() -> None: + SharedSystemClient.clear_system_cache() + + +def test_incompatible_server_version(caplog: pytest.LogCaptureFixture) -> None: + with HTTPServer(port=8001) as httpserver: + httpserver.expect_request("/api/v1/collections").respond_with_data( + json.dumps([]) + ) + httpserver.expect_request("/api/v1/heartbeat").respond_with_data( + json.dumps({"nanosecond heartbeat": int(time.time_ns())}) + ) + + httpserver.expect_request("/api/v1").respond_with_data( + json.dumps({"nanosecond heartbeat": int(time.time_ns())}) + ) + + httpserver.expect_request("/api/v1/version").respond_with_data( + json.dumps("0.4.1") + ) + client = chromadb.HttpClient( + host="localhost", + port="8001", + settings=chromadb.Settings(chroma_server_http_port=8001), + ) + + with pytest.raises(ValueError) as e: + client.list_collections() + assert "It appears you are using newer version of Chroma client" in str(e.value) + + +def test_compatible_server_version(caplog: pytest.LogCaptureFixture) -> None: + with HTTPServer(port=8001) as httpserver: + httpserver.expect_request("/api/v1/collections").respond_with_data( + json.dumps([]) + ) + httpserver.expect_request("/api/v1/heartbeat").respond_with_data( + json.dumps({"nanosecond heartbeat": int(time.time_ns())}) + ) + + httpserver.expect_request("/api/v1").respond_with_data( + json.dumps({"nanosecond heartbeat": int(time.time_ns())}) + ) + + httpserver.expect_request("/api/v1/version").respond_with_data( + json.dumps("0.4.15") + ) + httpserver.expect_request("/api/v1/tenants/default_tenant").respond_with_data( + json.dumps(Tenant(name="default_tenant")) + ) + httpserver.expect_request( + "/api/v1/databases/default_database" + ).respond_with_data( + json.dumps( + Database( + name="default_database", + tenant="default_tenant", + id=str(uuid.uuid4()), # type: ignore + ) + ) + ) + + client = chromadb.HttpClient( + host="localhost", + port="8001", + settings=chromadb.Settings(chroma_server_http_port=8001), + ) + + client.list_collections() + + +def test_client_server_not_available(caplog: pytest.LogCaptureFixture) -> None: + with HTTPServer(port=8002) as _: + client = chromadb.HttpClient( + host="localhost", + port="8001", + settings=chromadb.Settings(chroma_server_http_port=8001), + ) + + with pytest.raises(GenericError) as e: + client.list_collections() + assert "Chroma server seems inaccessible" in str(e.value) + + +@given(status=st.sampled_from([502, 503, 504])) +def test_client_server_with_proxy_error( + status: int, caplog: pytest.LogCaptureFixture +) -> None: + with HTTPServer(port=8001) as httpserver: + httpserver.expect_request("/api/v1/heartbeat").respond_with_data( + "Oh no!", status=status + ) + + httpserver.expect_request("/api/v1").respond_with_data("Oh no!", status=status) + client = chromadb.HttpClient( + host="localhost", + port="8001", + settings=chromadb.Settings(chroma_server_http_port=8001), + ) + + with pytest.raises(GenericError) as e: + client.list_collections() + assert "Your proxy reports Chroma server might not be" in str(e.value) diff --git a/chromadb/utils/client_utils.py b/chromadb/utils/client_utils.py new file mode 100644 index 00000000000..bfbb1d93613 --- /dev/null +++ b/chromadb/utils/client_utils.py @@ -0,0 +1,19 @@ +def compare_versions(version1: str, version2: str) -> int: + """Compares two versions of the format X.Y.Z and returns 1 if version1 is greater than version2, -1 if version1 is + less than version2, and 0 if version1 is equal to version2. + """ + v1_components = list(map(int, version1.split("."))) + v2_components = list(map(int, version2.split("."))) + + for v1, v2 in zip(v1_components, v2_components): + if v1 > v2: + return 1 + elif v1 < v2: + return -1 + + if len(v1_components) > len(v2_components): + return 1 + elif len(v1_components) < len(v2_components): + return -1 + + return 0 diff --git a/clients/js/test/client.test.ts b/clients/js/test/client.test.ts index 512237a2457..56af0ec5015 100644 --- a/clients/js/test/client.test.ts +++ b/clients/js/test/client.test.ts @@ -191,5 +191,5 @@ test('wrong code returns an error', async () => { // @ts-ignore - supposed to fail const results = await collection.get({ where: { "test": { "$contains": "hello" } } }); expect(results.error).toBeDefined() - expect(results.error).toContain("ValueError('Expected where operator") + expect(results.error).toContain("ValueError") }) diff --git a/requirements_dev.txt b/requirements_dev.txt index 4dce86e2efe..90696e499b4 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -8,6 +8,7 @@ mypy-protobuf pre-commit pytest pytest-asyncio +pytest-httpserver setuptools_scm types-protobuf types-requests==2.30.0.0