Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT REVIEW] AI gateway rate limits design & quick POC #9939

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/gateway/openai/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
routes:
- name: chat
route_type: llm/v1/chat
limit:
Copy link
Member Author

@harupy harupy Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need multiple limits?

renewal_period: "minute"
calls: 1
Comment on lines +4 to +6
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global limit vs. per-route limit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per-route limit

model:
provider: openai
name: gpt-3.5-turbo
Expand Down
4 changes: 4 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,7 @@ def get(self):
MLFLOW_SYSTEM_METRICS_SAMPLES_BEFORE_LOGGING = _EnvironmentVariable(
"MLFLOW_SYSTEM_METRICS_SAMPLES_BEFORE_LOGGING", int, None
)

MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI = _EnvironmentVariable(
"MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI", str, None
)
49 changes: 34 additions & 15 deletions mlflow/gateway/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.responses import FileResponse, RedirectResponse
from pydantic import BaseModel
from slowapi import Limiter, _rate_limit_exceeded_handler
Copy link
Member Author

@harupy harupy Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address

from mlflow.environment_variables import MLFLOW_GATEWAY_CONFIG
from mlflow.environment_variables import (
MLFLOW_GATEWAY_CONFIG,
MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI,
)
from mlflow.exceptions import MlflowException
from mlflow.gateway.base_models import SetLimitsModel
from mlflow.gateway.config import (
Expand Down Expand Up @@ -35,17 +41,19 @@


class GatewayAPI(FastAPI):
def __init__(self, config: GatewayConfig, *args: Any, **kwargs: Any):
def __init__(self, config: GatewayConfig, limiter: Limiter, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.dynamic_routes: Dict[str, Route] = {}
self.set_dynamic_routes(config)
self.state.limiter = limiter
self.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
self.set_dynamic_routes(config, limiter)

def set_dynamic_routes(self, config: GatewayConfig) -> None:
def set_dynamic_routes(self, config: GatewayConfig, limiter: Limiter) -> None:
self.dynamic_routes.clear()
for route in config.routes:
self.add_api_route(
path=f"{MLFLOW_GATEWAY_ROUTE_BASE}{route.name}{MLFLOW_QUERY_SUFFIX}",
endpoint=_route_type_to_endpoint(route),
endpoint=_route_type_to_endpoint(route, limiter),
methods=["POST"],
)
self.dynamic_routes[route.name] = route.to_route()
Expand All @@ -57,28 +65,30 @@ def get_dynamic_route(self, route_name: str) -> Optional[Route]:
def _create_chat_endpoint(config: RouteConfig):
prov = get_provider(config.model.provider)(config)

async def _chat(payload: chat.RequestPayload) -> chat.ResponsePayload:
return await prov.chat(payload)
# https://slowapi.readthedocs.io/en/latest/#limitations-and-known-issues
async def _chat(request: Request) -> chat.ResponsePayload:
payload = await request.json()
return await prov.chat(chat.RequestPayload(**payload))

return _chat


def _create_completions_endpoint(config: RouteConfig):
prov = get_provider(config.model.provider)(config)

async def _completions(
payload: completions.RequestPayload,
) -> completions.ResponsePayload:
return await prov.completions(payload)
async def _completions(request: Request) -> completions.ResponsePayload:
payload = await request.json()
return await prov.completions(**payload)

return _completions


def _create_embeddings_endpoint(config: RouteConfig):
prov = get_provider(config.model.provider)(config)

async def _embeddings(payload: embeddings.RequestPayload) -> embeddings.ResponsePayload:
return await prov.embeddings(payload)
async def _embeddings(request: Request) -> embeddings.ResponsePayload:
payload = await request.json()
return await prov.embeddings(embeddings.RequestPayload(**payload))

return _embeddings

Expand All @@ -87,14 +97,19 @@ async def _custom(request: Request):
return request.json()


def _route_type_to_endpoint(config: RouteConfig):
def _route_type_to_endpoint(config: RouteConfig, limiter: Limiter):
provider_to_factory = {
RouteType.LLM_V1_CHAT: _create_chat_endpoint,
RouteType.LLM_V1_COMPLETIONS: _create_completions_endpoint,
RouteType.LLM_V1_EMBEDDINGS: _create_embeddings_endpoint,
}
if factory := provider_to_factory.get(config.route_type):
return factory(config)
handler = factory(config)
if config.limit:
limit_value = f"{config.limit.calls}/{config.limit.renewal_period}"
return limiter.limit(limit_value)(handler)
else:
return handler

raise HTTPException(
status_code=404,
Expand Down Expand Up @@ -148,8 +163,12 @@ def create_app_from_config(config: GatewayConfig) -> GatewayAPI:
"""
Create the GatewayAPI app from the gateway configuration.
"""
limiter = Limiter(
key_func=get_remote_address, storage_uri=MLFLOW_GATEWAY_RATE_LIMITS_STORAGE_URI.get()
)
app = GatewayAPI(
config=config,
limiter=limiter,
title="MLflow Gateway API",
description="The core gateway API for reverse proxy interface using remote inference "
"endpoints within MLflow",
Expand Down
21 changes: 11 additions & 10 deletions mlflow/gateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,22 @@ def validate_config(cls, config, values):
return cls._validate_config(config, values)


class Limit(LimitModel):
calls: int
key: Optional[str] = None
renewal_period: str


class LimitsConfig(ConfigModel):
limits: Optional[List[Limit]] = []


# pylint: disable=no-self-argument
class RouteConfig(ConfigModel):
name: str
route_type: RouteType
model: Model
limit: Optional[Limit] = None

@validator("name")
def validate_endpoint_name(cls, route_name):
Expand Down Expand Up @@ -375,20 +386,10 @@ class Config:
}


class Limit(LimitModel):
calls: int
key: Optional[str] = None
renewal_period: str


class GatewayConfig(ConfigModel):
routes: List[RouteConfig]


class LimitsConfig(ConfigModel):
limits: Optional[List[Limit]] = []


def _load_route_config(path: Union[str, Path]) -> GatewayConfig:
"""
Reads the gateway configuration yaml file from the storage location and returns an instance
Expand Down