Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Thin client imports #2466

Merged
merged 5 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/release-chromadb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ jobs:
python-version: '3.12'
- name: Build Client
run: ./clients/python/build_python_thin_client.sh
- name: Test Client Package
run: bin/test-package.sh dist/*.tar.gz
- name: Install setuptools_scm
run: python -m pip install setuptools_scm
- name: Publish to Test PyPI
Expand Down
3 changes: 1 addition & 2 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
TypeVar,
)
from dataclasses import dataclass
from starlette.datastructures import Headers

from pydantic import SecretStr

Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(self, system: System) -> None:
)

@abstractmethod
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
pass

def ignore_operation(self, verb: str, path: str) -> bool:
Expand Down
10 changes: 6 additions & 4 deletions chromadb/auth/basic_authn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
OpenTelemetryGranularity,
trace_method,
)
from starlette.datastructures import Headers


from typing import Dict


logger = logging.getLogger(__name__)

__all__ = ["BasicAuthenticationServerProvider", "BasicAuthClientProvider"]
Expand Down Expand Up @@ -100,11 +102,11 @@ def __init__(self, system: System) -> None:
"BasicAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
)
@override
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change for anyone overriding auth. We should be more careful here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. But the use of starlette and fastapi here was already a breaking change for the thin client, which did not require (or import) either of those. Technically, duck-typing helps us here too.

try:
if AUTHORIZATION_HEADER not in headers:
if AUTHORIZATION_HEADER.lower() not in headers.keys():
raise AuthError(AUTHORIZATION_HEADER + " header not found")
_auth_header = headers[AUTHORIZATION_HEADER]
_auth_header = headers[AUTHORIZATION_HEADER.lower()]
_auth_header = re.sub(r"^Basic ", "", _auth_header)
_auth_header = _auth_header.strip()

Expand Down
11 changes: 6 additions & 5 deletions chromadb/auth/token_authn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import time
import traceback
from enum import Enum
from starlette.datastructures import Headers
from typing import cast, Dict, List, Optional, TypedDict, TypeVar

from fastapi import HTTPException

from overrides import override
from pydantic import SecretStr
import yaml
Expand Down Expand Up @@ -191,13 +190,13 @@ def __init__(self, system: System) -> None:
"TokenAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
)
@override
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
try:
if self._token_transport_header.value not in headers:
if self._token_transport_header.value.lower() not in headers.keys():
raise AuthError(
f"Authorization header '{self._token_transport_header.value}' not found"
)
token = headers[self._token_transport_header.value]
token = headers[self._token_transport_header.value.lower()]
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
if not token.startswith("Bearer "):
raise AuthError("Bearer not found in Authorization header")
Expand Down Expand Up @@ -232,4 +231,6 @@ def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
time.sleep(
random.uniform(0.001, 0.005)
) # add some jitter to avoid timing attacks
from fastapi import HTTPException
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to import in the hot path here? Why do we need this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good point. I had it initially in the __init__ as a self.HTTPException. Shall I move it back?

Copy link
Contributor Author

@tazarov tazarov Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HammadB, address this separately here - #2477 (a more elegant solution to the import problem)


raise HTTPException(status_code=403, detail="Forbidden")
2 changes: 1 addition & 1 deletion chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def auth_and_get_tenant_and_database_for_request(
if not self.authn_provider:
return (tenant, database)

user_identity = self.authn_provider.authenticate_or_raise(headers)
user_identity = self.authn_provider.authenticate_or_raise(dict(headers))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to fix a header check that was previously self._token_transport_header.value not in headers, which doesn't work with dict.


(
new_tenant,
Expand Down
3 changes: 1 addition & 2 deletions chromadb/test/auth/test_base_class_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from hypothesis import given, settings
from overrides import override
from starlette.datastructures import Headers
from typing import Dict, List, Tuple

from chromadb.api import ServerAPI
Expand All @@ -17,7 +16,7 @@ class DummyServerAuthenticationProvider(ServerAuthenticationProvider):
"""

@override
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
return UserIdentity(user_id="test_user")


Expand Down
22 changes: 10 additions & 12 deletions chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import numpy as np
import numpy.typing as npt
import httpx
from onnxruntime import InferenceSession, get_available_providers, SessionOptions
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random
from tokenizers import Tokenizer

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

Expand Down Expand Up @@ -156,8 +154,8 @@ def _forward(
return np.concatenate(all_embeddings)

@cached_property
def tokenizer(self) -> Tokenizer:
tokenizer = Tokenizer.from_file(
def tokenizer(self) -> "Tokenizer": # noqa F821
tokenizer = self.Tokenizer.from_file(
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
)
Expand All @@ -169,26 +167,26 @@ def tokenizer(self) -> Tokenizer:
return tokenizer

@cached_property
def model(self) -> InferenceSession:
def model(self) -> "InferenceSession": # noqa F821
if self._preferred_providers is None or len(self._preferred_providers) == 0:
if len(get_available_providers()) > 0:
if len(self.ort.get_available_providers()) > 0:
logger.debug(
f"WARNING: No ONNX providers provided, defaulting to available providers: "
f"{get_available_providers()}"
f"{self.ort.get_available_providers()}"
)
self._preferred_providers = get_available_providers()
self._preferred_providers = self.ort.get_available_providers()
elif not set(self._preferred_providers).issubset(
set(get_available_providers())
set(self.ort.get_available_providers())
):
raise ValueError(
f"Preferred providers must be subset of available providers: {get_available_providers()}"
f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}"
)

# Suppress onnxruntime warnings. This produces logspew, mainly when onnx tries to use CoreML, which doesn't fit this model.
so = SessionOptions()
so = self.ort.SessionOptions()
so.log_severity_level = 3

return InferenceSession(
return self.ort.InferenceSession(
os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html
# This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs
Expand Down
Loading