Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rate limit to deployment api #10779

Merged
merged 7 commits into from Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 21 additions & 1 deletion docs/source/llms/deployments/index.rst
Expand Up @@ -101,6 +101,9 @@ For details about the configuration file's parameters (including parameters for
name: gpt-3.5-turbo
config:
openai_api_key: $OPENAI_API_KEY
limit:
renewal_period: minute
calls: 10

- name: chat
endpoint_type: llm/v1/chat
Expand Down Expand Up @@ -284,6 +287,9 @@ Here's an example of a provider configuration within an endpoint:
name: gpt-4
config:
openai_api_key: $OPENAI_API_KEY
limit:
renewal_period: minute
calls: 10

In the above configuration, ``openai`` is the `provider` for the model.

Expand Down Expand Up @@ -324,6 +330,11 @@ an endpoint in the MLflow Deployments Server consists of the following fields:
* **name**: The name of the model to use. For example, ``gpt-3.5-turbo`` for OpenAI's ``GPT-3.5-Turbo`` model.
* **config**: Contains any additional configuration details required for the model. This includes specifying the API base URL and the API key.

* **limit**: Specify the rate limit setting this endpoint will follow. The limit field contains the following fields:

* **renewal_period**: The time unit of the rate limit, one of [second|minute|hour|day|month|year].
* **calls**: The number of calls this endpoint will accept within the specified time unit.

Comment on lines +333 to +337
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating the docs!

Here's an example of an endpoint configuration:

.. code-block:: yaml
Expand All @@ -336,6 +347,9 @@ Here's an example of an endpoint configuration:
name: gpt-3.5-turbo
config:
openai_api_key: $OPENAI_API_KEY
limit:
renewal_period: minute
calls: 10

In the example above, a request sent to the completions endpoint would be forwarded to the
``gpt-3.5-turbo`` model provided by ``openai``.
Expand Down Expand Up @@ -423,10 +437,13 @@ Here is an example of a single-endpoint configuration:
name: gpt-3.5-turbo
config:
openai_api_key: $OPENAI_API_KEY
limit:
renewal_period: minute
calls: 10


In this example, we define an endpoint named ``chat`` that corresponds to the ``llm/v1/chat`` type, which
will use the ``gpt-3.5-turbo`` model from OpenAI to return query responses from the OpenAI service.
will use the ``gpt-3.5-turbo`` model from OpenAI to return query responses from the OpenAI service, and accept up to 10 requests per minute.

The MLflow Deployments Server configuration is very easy to update.
Simply edit the configuration file and save your changes, and the MLflow Deployments Server will automatically
Expand Down Expand Up @@ -681,6 +698,9 @@ An example configuration for Azure OpenAI is:
openai_deployment_name: "{your_deployment_name}"
openai_api_base: "https://{your_resource_name}-azureopenai.openai.azure.com/"
openai_api_version: "2023-05-15"
limit:
renewal_period: minute
calls: 10


.. note::
Expand Down
3 changes: 3 additions & 0 deletions examples/deployments/deployments_server/openai/config.yaml
Expand Up @@ -6,6 +6,9 @@ endpoints:
name: gpt-3.5-turbo
config:
openai_api_key: $OPENAI_API_KEY
limit:
renewal_period: minute
calls: 10

- name: completions
endpoint_type: llm/v1/completions
Expand Down
47 changes: 33 additions & 14 deletions mlflow/deployments/server/app.py
@@ -1,11 +1,13 @@
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from fastapi import FastAPI, HTTPException, Request
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.responses import FileResponse, RedirectResponse
from pydantic import BaseModel
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address

from mlflow.deployments.server.config import Endpoint
from mlflow.deployments.server.constants import (
Expand All @@ -16,7 +18,10 @@
MLFLOW_DEPLOYMENTS_LIST_ENDPOINTS_PAGE_SIZE,
MLFLOW_DEPLOYMENTS_QUERY_SUFFIX,
)
from mlflow.environment_variables import MLFLOW_DEPLOYMENTS_CONFIG
from mlflow.environment_variables import (
MLFLOW_DEPLOYMENTS_CONFIG,
MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI,
)
from mlflow.exceptions import MlflowException
from mlflow.gateway.base_models import SetLimitsModel
from mlflow.gateway.config import (
Expand All @@ -40,29 +45,29 @@
from mlflow.gateway.utils import SearchRoutesToken, make_streaming_response
from mlflow.version import VERSION

_logger = logging.getLogger(__name__)


class GatewayAPI(FastAPI):
def __init__(self, config: GatewayConfig, *args: Any, **kwargs: Any):
def __init__(self, config: GatewayConfig, limiter: Limiter, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.state.limiter = limiter
self.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
self.dynamic_routes: Dict[str, Route] = {}
self.set_dynamic_routes(config)
self.set_dynamic_routes(config, limiter)

def set_dynamic_routes(self, config: GatewayConfig) -> None:
def set_dynamic_routes(self, config: GatewayConfig, limiter: Limiter) -> None:
self.dynamic_routes.clear()
for route in config.routes:
self.add_api_route(
path=(
MLFLOW_DEPLOYMENTS_ENDPOINTS_BASE + route.name + MLFLOW_DEPLOYMENTS_QUERY_SUFFIX
),
endpoint=_route_type_to_endpoint(route),
endpoint=_route_type_to_endpoint(route, limiter, "deployments"),
methods=["POST"],
)
# TODO: Remove Gateway server URLs after deprecation window elapses
self.add_api_route(
path=f"{MLFLOW_GATEWAY_ROUTE_BASE}{route.name}{MLFLOW_QUERY_SUFFIX}",
endpoint=_route_type_to_endpoint(route),
endpoint=_route_type_to_endpoint(route, limiter, "gateway"),
methods=["POST"],
include_in_schema=False,
)
Expand All @@ -75,8 +80,9 @@ def get_dynamic_route(self, route_name: str) -> Optional[Route]:
def _create_chat_endpoint(config: RouteConfig):
prov = get_provider(config.model.provider)(config)

# https://slowapi.readthedocs.io/en/latest/#limitations-and-known-issues
async def _chat(
payload: chat.RequestPayload,
request: Request, payload: chat.RequestPayload
) -> Union[chat.ResponsePayload, chat.StreamResponsePayload]:
if payload.stream:
return await make_streaming_response(prov.chat_stream(payload))
Expand All @@ -90,7 +96,7 @@ def _create_completions_endpoint(config: RouteConfig):
prov = get_provider(config.model.provider)(config)

async def _completions(
payload: completions.RequestPayload,
request: Request, payload: completions.RequestPayload
) -> Union[completions.ResponsePayload, completions.StreamResponsePayload]:
if payload.stream:
return await make_streaming_response(prov.completions_stream(payload))
Expand All @@ -103,7 +109,9 @@ async def _completions(
def _create_embeddings_endpoint(config: RouteConfig):
prov = get_provider(config.model.provider)(config)

async def _embeddings(payload: embeddings.RequestPayload) -> embeddings.ResponsePayload:
async def _embeddings(
request: Request, payload: embeddings.RequestPayload
) -> embeddings.ResponsePayload:
return await prov.embeddings(payload)

return _embeddings
Expand All @@ -113,14 +121,20 @@ async def _custom(request: Request):
return request.json()


def _route_type_to_endpoint(config: RouteConfig):
def _route_type_to_endpoint(config: RouteConfig, limiter: Limiter, key: str):
provider_to_factory = {
RouteType.LLM_V1_CHAT: _create_chat_endpoint,
RouteType.LLM_V1_COMPLETIONS: _create_completions_endpoint,
RouteType.LLM_V1_EMBEDDINGS: _create_embeddings_endpoint,
}
if factory := provider_to_factory.get(config.route_type):
return factory(config)
handler = factory(config)
if limit := config.limit:
limit_value = f"{limit.calls}/{limit.renewal_period}"
handler.__name__ = f"{handler.__name__}_{config.name}_{key}"
return limiter.limit(limit_value)(handler)
else:
return handler

raise HTTPException(
status_code=404,
Expand All @@ -147,6 +161,7 @@ class Config:
"name": "gpt-3.5-turbo",
"provider": "openai",
},
"limit": {"calls": 1, "key": None, "renewal_period": "minute"},
},
harupy marked this conversation as resolved.
Show resolved Hide resolved
{
"name": "anthropic-completions",
Expand Down Expand Up @@ -212,8 +227,12 @@ def create_app_from_config(config: GatewayConfig) -> GatewayAPI:
"""
Create the GatewayAPI app from the gateway configuration.
"""
limiter = Limiter(
key_func=get_remote_address, storage_uri=MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI.get()
harupy marked this conversation as resolved.
Show resolved Hide resolved
)
app = GatewayAPI(
config=config,
limiter=limiter,
title="MLflow Deployments Server",
description="The core deployments API for reverse proxy interface using remote inference "
"endpoints within MLflow",
Expand Down
6 changes: 5 additions & 1 deletion mlflow/deployments/server/config.py
@@ -1,12 +1,15 @@
from typing import Optional

from mlflow.gateway.base_models import ResponseModel
from mlflow.gateway.config import RouteModelInfo
from mlflow.gateway.config import Limit, RouteModelInfo


class Endpoint(ResponseModel):
name: str
endpoint_type: str
model: RouteModelInfo
endpoint_url: str
limit: Optional[Limit]

class Config:
schema_extra = {
Expand All @@ -18,5 +21,6 @@ class Config:
"provider": "openai",
},
"endpoint_url": "/endpoints/completions/invocations",
"limit": {"calls": 1, "key": None, "renewal_period": "minute"},
}
}
4 changes: 4 additions & 0 deletions mlflow/environment_variables.py
Expand Up @@ -495,3 +495,7 @@ def get(self):
MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT = _EnvironmentVariable(
"MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT", int, 120
)

MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI = _EnvironmentVariable(
"MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI", str, None
)
41 changes: 31 additions & 10 deletions mlflow/gateway/config.py
Expand Up @@ -332,11 +332,22 @@ class Config:
allow_population_by_field_name = True


class Limit(LimitModel):
calls: int
key: Optional[str] = None
renewal_period: str


class LimitsConfig(ConfigModel):
limits: Optional[List[Limit]] = []


# pylint: disable=no-self-argument
class RouteConfig(AliasedConfigModel):
name: str
route_type: RouteType = Field(alias="endpoint_type")
model: Model
limit: Optional[Limit] = None

@validator("name")
def validate_endpoint_name(cls, route_name):
Expand Down Expand Up @@ -387,6 +398,23 @@ def validate_route_type(cls, value):
return value
raise MlflowException.invalid_parameter_value(f"The route_type '{value}' is not supported.")

@validator("limit", pre=True)
def validate_limit(cls, value):
from limits import parse

if value:
limit = Limit(**value)
try:
parse(f"{limit.calls}/{limit.renewal_period}")
except ValueError:
raise MlflowException.invalid_parameter_value(
"Failed to parse the rate limit configuration."
"Please make sure limit.calls is a positive number and"
"limit.renewal_period is a right granularity"
)

return value

def to_route(self) -> "Route":
return Route(
name=self.name,
Expand All @@ -396,6 +424,7 @@ def to_route(self) -> "Route":
provider=self.model.provider,
),
route_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}",
limit=self.limit,
)


Expand Down Expand Up @@ -424,6 +453,7 @@ class Route(ConfigModel):
route_type: str
model: RouteModelInfo
route_url: str
limit: Optional[Limit] = None

class Config:
if IS_PYDANTIC_V2:
Expand All @@ -439,23 +469,14 @@ def to_endpoint(self):
endpoint_type=self.route_type,
model=self.model,
endpoint_url=self.route_url,
limit=self.limit,
)


class Limit(LimitModel):
calls: int
key: Optional[str] = None
renewal_period: str


class GatewayConfig(AliasedConfigModel):
routes: List[RouteConfig] = Field(alias="endpoints")


class LimitsConfig(ConfigModel):
limits: Optional[List[Limit]] = []


def _load_route_config(path: Union[str, Path]) -> GatewayConfig:
"""
Reads the gateway configuration yaml file from the storage location and returns an instance
Expand Down
1 change: 1 addition & 0 deletions requirements/gateway-requirements.txt
Expand Up @@ -9,3 +9,4 @@ watchfiles<1
aiohttp<4
boto3<2,>=1.28.56
tiktoken<1
slowapi<1
4 changes: 4 additions & 0 deletions requirements/gateway-requirements.yaml
Expand Up @@ -34,3 +34,7 @@ boto3:
tiktoken:
pip_release: tiktoken
max_major_version: 0

slowapi:
pip_release: slowapi
max_major_version: 0