Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Gateway provider unimplemented endpoint errors #10822

Merged
merged 6 commits into from Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 3 additions & 13 deletions mlflow/gateway/providers/ai21labs.py
Expand Up @@ -6,10 +6,12 @@
from mlflow.gateway.config import AI21LabsConfig, RouteConfig
from mlflow.gateway.providers.base import BaseProvider
from mlflow.gateway.providers.utils import rename_payload_keys, send_request
from mlflow.gateway.schemas import chat, completions, embeddings
from mlflow.gateway.schemas import completions


class AI21LabsProvider(BaseProvider):
NAME = "AI21Labs"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, AI21LabsConfig):
Expand Down Expand Up @@ -82,15 +84,3 @@ async def completions(self, payload: completions.RequestPayload) -> completions.
total_tokens=None,
),
)

async def chat(self, payload: chat.RequestPayload) -> None:
# AI21Labs does not have a chat endpoint
raise HTTPException(
status_code=404, detail="The chat route is not available for AI21Labs models."
)

async def embeddings(self, payload: embeddings.RequestPayload) -> None:
# AI21Labs does not have an embeddings endpoint
raise HTTPException(
status_code=404, detail="The embeddings route is not available for AI21Labs models."
)
16 changes: 3 additions & 13 deletions mlflow/gateway/providers/anthropic.py
Expand Up @@ -10,7 +10,7 @@
)
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter
from mlflow.gateway.providers.utils import rename_payload_keys, send_request
from mlflow.gateway.schemas import chat, completions, embeddings
from mlflow.gateway.schemas import completions


class AnthropicAdapter(ProviderAdapter):
Expand Down Expand Up @@ -97,6 +97,8 @@ def model_to_embeddings(cls, resp, config):


class AnthropicProvider(BaseProvider, AnthropicAdapter):
NAME = "Anthropic"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, AnthropicConfig):
Expand Down Expand Up @@ -134,15 +136,3 @@ async def completions(self, payload: completions.RequestPayload) -> completions.
# ```

return AnthropicAdapter.model_to_completions(resp, self.config)

async def chat(self, payload: chat.RequestPayload) -> None:
# Anthropic does not have a chat endpoint
raise HTTPException(
status_code=404, detail="The chat route is not available for Anthropic models."
)

async def embeddings(self, payload: embeddings.RequestPayload) -> None:
# Anthropic does not have an embeddings endpoint
raise HTTPException(
status_code=404, detail="The embeddings route is not available for Anthropic models."
)
15 changes: 12 additions & 3 deletions mlflow/gateway/providers/base.py
Expand Up @@ -19,13 +19,22 @@ def __init__(self, config: RouteConfig):
self.config = config

async def chat(self, payload: chat.RequestPayload) -> chat.ResponsePayload:
raise NotImplementedError
raise HTTPException(
status_code=404,
detail=f"The chat route is not available for {self.NAME} models.",
)
Copy link
Collaborator Author

@gabrielfu gabrielfu Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 501 status makes more sense, but not sure if this will break any user code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. 501 (not implemented) makes more sense. Let's change it.


async def completions(self, payload: completions.RequestPayload) -> completions.ResponsePayload:
raise NotImplementedError
raise HTTPException(
status_code=404,
detail=f"The completions route is not available for {self.NAME} models.",
)

async def embeddings(self, payload: embeddings.RequestPayload) -> embeddings.ResponsePayload:
raise NotImplementedError
raise HTTPException(
status_code=404,
detail=f"The embeddings route is not available for {self.NAME} models.",
)

@staticmethod
def check_for_model_field(payload):
Expand Down
16 changes: 3 additions & 13 deletions mlflow/gateway/providers/bedrock.py
Expand Up @@ -17,7 +17,7 @@
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter
from mlflow.gateway.providers.cohere import CohereAdapter
from mlflow.gateway.providers.utils import rename_payload_keys
from mlflow.gateway.schemas import chat, completions, embeddings
from mlflow.gateway.schemas import completions

AWS_BEDROCK_ANTHROPIC_MAXIMUM_MAX_TOKENS = 8191

Expand Down Expand Up @@ -164,6 +164,8 @@ def of_str(cls, name: str):


class AWSBedrockProvider(BaseProvider):
NAME = "AWS Bedrock"

def __init__(self, config: RouteConfig):
super().__init__(config)

Expand Down Expand Up @@ -283,15 +285,3 @@ async def completions(self, payload: completions.RequestPayload) -> completions.
payload = self.underlying_provider_adapter.completions_to_model(payload, self.config)
response = self._request(payload)
return self.underlying_provider_adapter.model_to_completions(response, self.config)

async def chat(self, payload: chat.RequestPayload) -> None:
# AWS Bedrock does not have a chat endpoint
raise HTTPException(
status_code=404, detail="The chat route is not available for AWS Bedrock models."
)

async def embeddings(self, payload: embeddings.RequestPayload) -> None:
# AWS Bedrock does not have an embeddings endpoint
raise HTTPException(
status_code=404, detail="The embeddings route is not available for AWS Bedrock models."
)
7 changes: 3 additions & 4 deletions mlflow/gateway/providers/cohere.py
Expand Up @@ -7,7 +7,7 @@
from mlflow.gateway.config import CohereConfig, RouteConfig
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter
from mlflow.gateway.providers.utils import rename_payload_keys, send_request
from mlflow.gateway.schemas import chat, completions, embeddings
from mlflow.gateway.schemas import completions, embeddings


class CohereAdapter(ProviderAdapter):
Expand Down Expand Up @@ -115,6 +115,8 @@ def embeddings_to_model(cls, payload, config):


class CohereProvider(BaseProvider):
NAME = "Cohere"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, CohereConfig):
Expand All @@ -130,9 +132,6 @@ async def _request(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
payload=payload,
)

async def chat(self, payload: chat.RequestPayload) -> chat.ResponsePayload:
raise HTTPException(status_code=422, detail="The chat route is not available for Cohere.")

async def completions(self, payload: completions.RequestPayload) -> completions.ResponsePayload:
payload = jsonable_encoder(payload, exclude_none=True)
self.check_for_model_field(payload)
Expand Down
18 changes: 3 additions & 15 deletions mlflow/gateway/providers/huggingface.py
Expand Up @@ -10,10 +10,12 @@
rename_payload_keys,
send_request,
)
from mlflow.gateway.schemas import chat, completions, embeddings
from mlflow.gateway.schemas import completions


class HFTextGenerationInferenceServerProvider(BaseProvider):
NAME = "Hugging Face Text Generation Inference"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(
Expand All @@ -31,12 +33,6 @@ async def _request(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
payload=payload,
)

async def chat(self, payload: chat.RequestPayload) -> chat.ResponsePayload:
raise HTTPException(
status_code=404,
detail="The chat route is not available for the Text Generation Inference provider.",
)

async def completions(self, payload: completions.RequestPayload) -> completions.ResponsePayload:
payload = jsonable_encoder(payload, exclude_none=True)
self.check_for_model_field(payload)
Expand Down Expand Up @@ -116,11 +112,3 @@ async def completions(self, payload: completions.RequestPayload) -> completions.
total_tokens=input_tokens + output_tokens,
),
)

async def embeddings(self, payload: embeddings.RequestPayload) -> embeddings.ResponsePayload:
raise HTTPException(
status_code=404,
detail=(
"The embedding route is not available for the Text Generation Inference provider."
),
)
2 changes: 2 additions & 0 deletions mlflow/gateway/providers/mlflow.py
Expand Up @@ -52,6 +52,8 @@ def validate_predictions(cls, predictions):


class MlflowModelServingProvider(BaseProvider):
NAME = "MLflow Model Serving"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(
Expand Down
2 changes: 2 additions & 0 deletions mlflow/gateway/providers/mosaicml.py
Expand Up @@ -13,6 +13,8 @@


class MosaicMLProvider(BaseProvider):
NAME = "MosaicML"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, MosaicMLConfig):
Expand Down
2 changes: 2 additions & 0 deletions mlflow/gateway/providers/openai.py
Expand Up @@ -7,6 +7,8 @@


class OpenAIProvider(BaseProvider):
NAME = "OpenAI"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, OpenAIConfig):
Expand Down
2 changes: 2 additions & 0 deletions mlflow/gateway/providers/palm.py
Expand Up @@ -11,6 +11,8 @@


class PaLMProvider(BaseProvider):
NAME = "PaLM"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, PaLMConfig):
Expand Down
4 changes: 2 additions & 2 deletions tests/gateway/providers/test_huggingface.py
Expand Up @@ -140,7 +140,7 @@ async def test_chat_is_not_supported_for_tgi():
with pytest.raises(HTTPException, match=r".*") as e:
await provider.chat(chat.RequestPayload(**payload))
assert (
"The chat route is not available for the Text Generation Inference provider."
"The chat route is not available for Hugging Face Text Generation Inference models."
in e.value.detail
)
assert e.value.status_code == 404
Expand All @@ -155,7 +155,7 @@ async def test_embeddings_are_not_supported_for_tgi():
with pytest.raises(HTTPException, match=r".*") as e:
await provider.embeddings(embeddings.RequestPayload(**payload))
assert (
"The embedding route is not available for the Text Generation Inference provider."
"The embeddings route is not available for Hugging Face Text Generation Inference models."
in e.value.detail
)
assert e.value.status_code == 404