diff --git a/contributing/AUTOSCALING.md b/contributing/AUTOSCALING.md index 552c4d6e1..eb5f6e197 100644 --- a/contributing/AUTOSCALING.md +++ b/contributing/AUTOSCALING.md @@ -3,7 +3,7 @@ `dstack` features auto-scaling for services published via the gateway. The general flow is: - STEP 1: `dstack-gateway` parses nginx `access.log` to collect per-second statistics about requests to the service and request times. -- STEP 2: `dstack-gateway` aggregates statistics over a 1-minute window. +- STEP 2: `dstack-gateway` aggregates statistics over several predefined windows. - STEP 3: The server keeps gateway connections alive in the scheduled `process_gateways_connections` task and continuously collects stats from active gateways. This is separate from `GatewayPipeline`, which handles gateway provisioning and deletion. - STEP 4: When `RunPipeline` processes a service run, it loads the latest collected gateway stats for that service. - STEP 5: The autoscaler (configured via `dstack.yml`) computes the desired replica count for each replica group. @@ -17,6 +17,6 @@ ## RPSAutoscaler -`RPSAutoscaler` implements simple target tracking scaling. The target value represents requests per second per replica (in a 1-minute window). +`RPSAutoscaler` implements simple target tracking scaling. The target value represents requests per second per replica (in a configurable window). `scale_up_delay` tells how much time has to pass since the last upscale or downscale event before the next upscaling. `scale_down_delay` tells how much time has to pass since the last upscale or downscale event before the next downscaling. diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index b9b958f03..f9dbaf4e2 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -102,8 +102,18 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: configuration_excludes["https"] = True replicas = run_spec.configuration.replicas - if isinstance(replicas, list) and all(g.router is None for g in replicas): - configuration_excludes["replicas"] = {"__all__": {"router": True}} + if isinstance(replicas, list): + replica_group_excludes: IncludeExcludeDictType = {} + if all(g.router is None for g in replicas): + replica_group_excludes["router"] = True + if all(g.scaling is None or g.scaling.window is None for g in replicas): + replica_group_excludes["scaling"] = {"window": True} + if replica_group_excludes: + configuration_excludes["replicas"] = {"__all__": replica_group_excludes} + + scaling = run_spec.configuration.scaling + if scaling is not None and scaling.window is None: + configuration_excludes["scaling"] = {"window": True} if configuration_excludes: spec_excludes["configuration"] = configuration_excludes diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 9f90b4037..a8b472b7f 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -40,6 +40,7 @@ parse_volume_configuration, ) from dstack._internal.core.services import is_valid_replica_group_name +from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS from dstack._internal.utils.common import has_duplicates, list_enum_values_for_annotation from dstack._internal.utils.json_schema import add_extra_schema_types from dstack._internal.utils.json_utils import ( @@ -66,6 +67,9 @@ MAX_PROBE_URL_LEN = 2048 DEFAULT_REPLICA_GROUP_NAME = "0" OPENAI_MODEL_PROBE_TIMEOUT = 30 +ALLOWED_SCALING_WINDOWS_DESCRIPTION = ", ".join(f"`{w}s`" for w in SERVICE_SCALING_WINDOWS) +DEFAULT_SCALING_WINDOW = 60 +assert DEFAULT_SCALING_WINDOW in SERVICE_SCALING_WINDOWS class RunConfigurationType(str, Enum): @@ -221,6 +225,16 @@ class ScalingSpec(CoreModel): gt=0, ), ] + window: Annotated[ + Optional[Duration], + Field( + description=( + "The time window used to calculate requests per second." + f" Allowed values: {ALLOWED_SCALING_WINDOWS_DESCRIPTION}." + f" Defaults to `{DEFAULT_SCALING_WINDOW}s`" + ), + ), + ] = None scale_up_delay: Annotated[ Duration, Field(description="The delay in seconds before scaling up") ] = Duration.parse("5m") @@ -228,6 +242,12 @@ class ScalingSpec(CoreModel): Duration, Field(description="The delay in seconds before scaling down") ] = Duration.parse("10m") + @validator("window") + def validate_window(cls, v: Optional[Duration]) -> Optional[Duration]: + if v is not None and v not in SERVICE_SCALING_WINDOWS: + raise ValueError(f"Window must be one of: {ALLOWED_SCALING_WINDOWS_DESCRIPTION}") + return v + class IPAddressPartitioningKey(CoreModel): type: Annotated[Literal["ip_address"], Field(description="Partitioning type")] = "ip_address" diff --git a/src/dstack/_internal/proxy/gateway/const.py b/src/dstack/_internal/proxy/gateway/const.py index 7b958030a..d9a294d2b 100644 --- a/src/dstack/_internal/proxy/gateway/const.py +++ b/src/dstack/_internal/proxy/gateway/const.py @@ -6,3 +6,4 @@ SERVER_CONNECTIONS_DIR_ON_GATEWAY = DSTACK_DIR_ON_GATEWAY / "server-connections" PROXY_PORT_ON_GATEWAY = 8000 SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE = "Service {ref} is already registered" +SERVICE_SCALING_WINDOWS = (30, 60, 300) diff --git a/src/dstack/_internal/proxy/gateway/services/stats.py b/src/dstack/_internal/proxy/gateway/services/stats.py index c3086fc05..358eff451 100644 --- a/src/dstack/_internal/proxy/gateway/services/stats.py +++ b/src/dstack/_internal/proxy/gateway/services/stats.py @@ -9,15 +9,15 @@ from pydantic import BaseModel +from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, ServiceStats, Stat from dstack._internal.proxy.lib.errors import UnexpectedProxyError from dstack._internal.utils.common import run_async logger = logging.getLogger(__name__) -WINDOWS = (30, 60, 300) -TTL = WINDOWS[-1] -EMPTY_STATS = {window: Stat(requests=0, request_time=0.0) for window in WINDOWS} +TTL = max(SERVICE_SCALING_WINDOWS) +EMPTY_STATS = {window: Stat(requests=0, request_time=0.0) for window in SERVICE_SCALING_WINDOWS} class StatFrame(BaseModel): @@ -67,7 +67,7 @@ def _aggregate(frames: Reversible[StatFrame], now: datetime.datetime) -> PerWind Aggregate 1s `frames` into windows 30s, 1m, 5m before `now` """ result = {} - for window in WINDOWS: + for window in SERVICE_SCALING_WINDOWS: req_count = 0 req_time_total = 0.0 for frame in reversed(frames): diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index 641d2cee4..f03f2957e 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import dstack._internal.utils.common as common_utils -from dstack._internal.core.models.configurations import ScalingSpec +from dstack._internal.core.models.configurations import DEFAULT_SCALING_WINDOW, ScalingSpec from dstack._internal.core.models.resources import Range from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats @@ -72,12 +72,14 @@ def __init__( min_replicas: int, max_replicas: int, target: float, + window: int, scale_up_delay: int, scale_down_delay: int, ): self.min_replicas = min_replicas self.max_replicas = max_replicas self.target = target + self.window = window self.scale_up_delay = scale_up_delay self.scale_down_delay = scale_down_delay @@ -92,8 +94,7 @@ def get_desired_count( now = common_utils.get_current_datetime() - # calculate the average RPS over the last minute - rps = stats[60].requests / 60 + rps = stats[self.window].requests / self.window new_desired_count = math.ceil(rps / self.target) # clip the desired count to the min and max values new_desired_count = min(max(new_desired_count, self.min_replicas), self.max_replicas) @@ -134,6 +135,7 @@ def get_service_scaler(count: Range[int], scaling: Optional[ScalingSpec]) -> Bas min_replicas=count.min, max_replicas=count.max, target=scaling.target, + window=scaling.window if scaling.window is not None else DEFAULT_SCALING_WINDOW, scale_up_delay=scaling.scale_up_delay, scale_down_delay=scaling.scale_down_delay, ) diff --git a/src/tests/_internal/server/services/services/test_autoscalers.py b/src/tests/_internal/server/services/services/test_autoscalers.py index 5df80b0c1..1125ae8ad 100644 --- a/src/tests/_internal/server/services/services/test_autoscalers.py +++ b/src/tests/_internal/server/services/services/test_autoscalers.py @@ -3,13 +3,21 @@ import pytest +from dstack._internal.core.models.configurations import DEFAULT_SCALING_WINDOW from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, Stat from dstack._internal.server.services.services.autoscalers import BaseServiceScaler, RPSAutoscaler @pytest.fixture def rps_scaler(): - return RPSAutoscaler(0, 5, 10, 5 * 60, 10 * 60) + return RPSAutoscaler( + min_replicas=0, + max_replicas=5, + target=10, + window=DEFAULT_SCALING_WINDOW, + scale_up_delay=5 * 60, + scale_down_delay=10 * 60, + ) @pytest.fixture @@ -21,7 +29,9 @@ def time(): def stats(rps: float) -> PerWindowStats: - return {60: Stat(requests=int(rps * 60), request_time=0.1)} + return { + DEFAULT_SCALING_WINDOW: Stat(requests=int(rps * DEFAULT_SCALING_WINDOW), request_time=0.1) + } class TestRPSAutoscaler: @@ -139,3 +149,22 @@ def test_scale_to_zero(self, rps_scaler: BaseServiceScaler, time: datetime.datet ) == 0 ) + + @pytest.mark.parametrize("window,expected", [(30, 3), (60, 2), (300, 1)]) + def test_window(self, window: int, expected: int, time: datetime.datetime) -> None: + stats: PerWindowStats = { + 30: Stat(requests=900, request_time=0.1), # 900 req / 30s = 30 rps → 3 replicas + 60: Stat(requests=1200, request_time=0.1), # 1200 req / 60s = 20 rps → 2 replicas + 300: Stat(requests=1500, request_time=0.1), # 1500 req / 300s = 5 rps → 1 replica + } + scaler = RPSAutoscaler( + min_replicas=0, + max_replicas=5, + target=10, + window=window, + scale_up_delay=5 * 60, + scale_down_delay=10 * 60, + ) + assert ( + scaler.get_desired_count(1, stats, time - datetime.timedelta(seconds=3600)) == expected + )