diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 5cec27e09..c40a37b7f 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -14,6 +14,9 @@ dependencies = [ "dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz", ] +[project.optional-dependencies] +sglang = ["sglang-router==0.2.1"] + [tool.setuptools.package-data] "dstack.gateway" = [ "resources/systemd/*", diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index f973980e0..67a7e409d 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -460,7 +460,9 @@ def create_gateway( image_id=aws_resources.get_gateway_image_id(ec2_client), instance_type="t3.micro", iam_instance_profile=None, - user_data=get_gateway_user_data(configuration.ssh_key_pub), + user_data=get_gateway_user_data( + configuration.ssh_key_pub, router=configuration.router + ), tags=tags, security_group_id=security_group_id, spot=False, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 8b37a72b2..8573ec99a 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -277,7 +277,9 @@ def create_gateway( image_reference=_get_gateway_image_ref(), vm_size="Standard_B1ms", instance_name=instance_name, - user_data=get_gateway_user_data(configuration.ssh_key_pub), + user_data=get_gateway_user_data( + configuration.ssh_key_pub, router=configuration.router + ), ssh_pub_keys=[configuration.ssh_key_pub], spot=False, disk_size=30, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 29129f6cb..769bf8912 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -38,6 +38,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -881,7 +882,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -897,7 +898,7 @@ def get_gateway_user_data(authorized_key: str) -> str: "s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/", "/etc/nginx/nginx.conf", ], - ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())], + ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router))], ], ssh_authorized_keys=[authorized_key], ) @@ -1018,7 +1019,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1026,16 +1027,21 @@ def get_dstack_gateway_wheel(build: str) -> str: r.raise_for_status() build = r.text.strip() logger.debug("Found the latest gateway build: %s", build) - return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # Build package spec with extras if router is specified + if router: + return f"dstack-gateway[{router.type}] @ {wheel}" + return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands() -> List[str]: +def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() + gateway_package = get_dstack_gateway_wheel(build, router) return [ "mkdir -p /home/ubuntu/dstack", "python3 -m venv /home/ubuntu/dstack/blue", "python3 -m venv /home/ubuntu/dstack/green", - f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}", + f"/home/ubuntu/dstack/blue/bin/pip install '{gateway_package}'", "sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run", ] diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 23aab379d..8e5c36fcc 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -599,7 +599,9 @@ def create_gateway( machine_type="e2-medium", accelerators=[], spot=False, - user_data=get_gateway_user_data(configuration.ssh_key_pub), + user_data=get_gateway_user_data( + configuration.ssh_key_pub, router=configuration.router + ), authorized_keys=[configuration.ssh_key_pub], labels=labels, tags=[gcp_resources.DSTACK_GATEWAY_TAG], diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index bdab2c4a1..71a59ad22 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -1,9 +1,10 @@ +import shlex import subprocess import tempfile import threading import time from enum import Enum -from typing import Optional +from typing import List, Optional from gpuhunt import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, AcceleratorVendor from kubernetes import client @@ -51,6 +52,7 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Memory +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error, parse_memory @@ -371,7 +373,9 @@ def create_gateway( # Consider deploying an NLB. It seems it requires some extra configuration on the cluster: # https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html instance_name = generate_unique_gateway_instance_name(configuration) - commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub]) + commands = _get_gateway_commands( + authorized_keys=[configuration.ssh_key_pub], router=configuration.router + ) pod = client.V1Pod( metadata=client.V1ObjectMeta( name=instance_name, @@ -983,9 +987,13 @@ def _add_authorized_key_to_jump_pod( ) -def _get_gateway_commands(authorized_keys: list[str]) -> list[str]: +def _get_gateway_commands( + authorized_keys: List[str], router: Optional[AnyRouterConfig] = None +) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() - gateway_commands = " && ".join(get_dstack_gateway_commands()) + gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) + quoted_gateway_commands = shlex.quote(gateway_commands) + commands = [ # install packages "apt-get update && apt-get install -y sudo wget openssh-server nginx python3.10-venv libaugeas0", @@ -1013,7 +1021,7 @@ def _get_gateway_commands(authorized_keys: list[str]) -> list[str]: # start sshd "/usr/sbin/sshd -p 22 -o PermitUserEnvironment=yes", # run gateway - f"su ubuntu -c '{gateway_commands}'", + f"su ubuntu -c {quoted_gateway_commands}", "sleep infinity", ] return commands diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6a480b580..39befe739 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -7,6 +7,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.utils.tags import tags_validator @@ -50,6 +51,10 @@ class GatewayConfiguration(CoreModel): default: Annotated[bool, Field(description="Make the gateway default")] = False backend: Annotated[BackendType, Field(description="The gateway backend")] region: Annotated[str, Field(description="The gateway region")] + router: Annotated[ + Optional[AnyRouterConfig], + Field(description="The router configuration"), + ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") ] = None @@ -113,6 +118,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None + router: Optional[AnyRouterConfig] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py new file mode 100644 index 000000000..ec779b124 --- /dev/null +++ b/src/dstack/_internal/core/models/routers.py @@ -0,0 +1,16 @@ +from enum import Enum +from typing import Literal + +from dstack._internal.core.models.common import CoreModel + + +class RouterType(str, Enum): + SGLANG = "sglang" + + +class SGLangRouterConfig(CoreModel): + type: Literal["sglang"] = "sglang" + policy: Literal["random", "round_robin", "cache_aware", "power_of_two"] = "cache_aware" + + +AnyRouterConfig = SGLangRouterConfig diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/router_workers.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/router_workers.jinja2 new file mode 100644 index 000000000..3af7ea612 --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/router_workers.jinja2 @@ -0,0 +1,23 @@ +{% for replica in replicas %} +# Worker {{ loop.index }} +upstream router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream { + server unix:{{ replica.socket }}; +} + +server { + listen 127.0.0.1:{{ ports[loop.index0] }}; + access_log off; # disable access logs for this internal endpoint + + proxy_read_timeout 300s; + proxy_send_timeout 300s; + + location / { + proxy_pass http://router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header Connection ""; + proxy_set_header Upgrade $http_upgrade; + } +} +{% endfor %} diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index b096fa80e..31f987706 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { + {% if router_port is not none %} + server 127.0.0.1:{{ router_port }}; # SGLang router on the gateway + {% else %} {% for replica in replicas %} server unix:{{ replica.socket }}; # replica {{ replica.id }} {% endfor %} + {% endif %} } {% else %} @@ -32,6 +36,13 @@ server { } {% endfor %} + {# For SGLang router: block all requests except whitelisted locations added dynamically above #} + {% if router is not none and router.type == "sglang" %} + location / { + return 403; + } + {% endif %} + location @websocket { set $dstack_replica_hit 1; {% if replicas %} diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index e1bfa4ff2..dd4f63f32 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -36,6 +36,7 @@ async def register_service( model=body.options.openai.model if body.options.openai is not None else None, ssh_private_key=body.ssh_private_key, repo=repo, + router=body.router, nginx=nginx, service_conn_pool=service_conn_pool, ) diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 8ab69b6af..53a29f68c 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.lib.models import RateLimit @@ -44,6 +45,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () + router: Optional[AnyRouterConfig] = None class RegisterReplicaRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py new file mode 100644 index 000000000..9678699ac --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py @@ -0,0 +1,18 @@ +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.proxy.gateway.services.model_routers.sglang import SglangRouter +from dstack._internal.proxy.lib.errors import ProxyError + +from .base import Router, RouterContext + + +def get_router(router: AnyRouterConfig, context: RouterContext) -> Router: + if router.type == RouterType.SGLANG: + return SglangRouter(config=router, context=context) + raise ProxyError(f"Router type '{router.type}' is not available") + + +__all__ = [ + "Router", + "RouterContext", + "get_router", +] diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py new file mode 100644 index 000000000..867591ca1 --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Literal, Optional + +from pydantic import BaseModel + +from dstack._internal.core.models.routers import AnyRouterConfig + + +class RouterContext(BaseModel): + """Context for router initialization and configuration.""" + + class Config: + frozen = True + + host: str = "127.0.0.1" + port: int + log_dir: Path + log_level: Literal["debug", "info", "warning", "error"] = "info" + + +class Router(ABC): + """Abstract base class for router implementations. + A router manages the lifecycle of worker replicas and handles request routing. + Different router implementations may have different mechanisms for managing + replicas. + """ + + def __init__( + self, + context: RouterContext, + config: Optional[AnyRouterConfig] = None, + ): + """Initialize router with context. + + Args: + context: Runtime context for the router (host, port, logging, etc.) + config: Optional router configuration (implementation-specific) + """ + self.context = context + + @abstractmethod + def start(self) -> None: + """Start the router process. + + Raises: + Exception: If the router fails to start. + """ + ... + + @abstractmethod + def stop(self) -> None: + """Stop the router process. + + Raises: + Exception: If the router fails to stop. + """ + ... + + @abstractmethod + def is_running(self) -> bool: + """Check if the router is currently running and responding. + + Returns: + True if the router is running and healthy, False otherwise. + """ + ... + + @abstractmethod + def remove_replicas(self, replica_urls: List[str]) -> None: + """Unregister replicas from the router (actual API calls to remove workers). + + Args: + replica_urls: The list of replica URLs to remove from router. + + Raises: + Exception: If removing replicas fails. + """ + ... + + @abstractmethod + def update_replicas(self, replica_urls: List[str]) -> None: + """Update replicas for service, replacing the current set. + + Args: + replica_urls: The new list of replica URLs for this service. + + Raises: + Exception: If updating replicas fails. + """ + ... diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py new file mode 100644 index 000000000..c3a0dfaae --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -0,0 +1,269 @@ +import shutil +import subprocess +import sys +import time +import urllib.parse +from typing import List, Optional + +import httpx +import psutil + +from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig +from dstack._internal.proxy.lib.errors import UnexpectedProxyError +from dstack._internal.utils.logging import get_logger + +from .base import Router, RouterContext + +logger = get_logger(__name__) + + +class SglangRouter(Router): + """SGLang router implementation with 1:1 service-to-router.""" + + TYPE = RouterType.SGLANG + + def __init__(self, config: SGLangRouterConfig, context: RouterContext): + """Initialize SGLang router. + + Args: + config: SGLang router configuration (policy, cache_threshold, etc.) + context: Runtime context for the router (host, port, logging, etc.) + """ + super().__init__(context=context, config=config) + self.config = config + + def pid_from_tcp_ipv4_port(self, port: int) -> Optional[int]: + """ + Return PID of the process listening on the given TCP IPv4 port. + If no process is found, return None. + """ + for conn in psutil.net_connections(kind="tcp4"): + if conn.laddr and conn.laddr.port == port and conn.status == psutil.CONN_LISTEN: + return conn.pid + return None + + def start(self) -> None: + try: + logger.info("Starting sglang-router-new on port %s...", self.context.port) + + # Prometheus port is offset by 10000 from router port to keep it in a separate range + prometheus_port = self.context.port + 10000 + + cmd = [ + sys.executable, + "-m", + "sglang_router.launch_router", + "--host", + self.context.host, + "--port", + str(self.context.port), + "--prometheus-port", + str(prometheus_port), + "--prometheus-host", + self.context.host, + "--log-level", + self.context.log_level, + "--log-dir", + str(self.context.log_dir), + "--policy", + self.config.policy, + ] + + subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + time.sleep(2) + + if not self.is_running(): + raise UnexpectedProxyError( + f"Failed to start sglang router on port {self.context.port}" + ) + + logger.info( + "Sglang router started successfully on port %s (prometheus on %s)", + self.context.port, + prometheus_port, + ) + + except Exception: + logger.exception("Failed to start sglang-router") + raise + + def stop(self) -> None: + try: + pid = self.pid_from_tcp_ipv4_port(self.context.port) + + if pid: + logger.debug( + "Stopping sglang-router process (PID: %s) on port %s", + pid, + self.context.port, + ) + try: + proc = psutil.Process(pid) + proc.terminate() + try: + proc.wait(timeout=5) + except psutil.TimeoutExpired: + logger.warning( + "Process %s did not terminate gracefully, forcing kill", pid + ) + proc.kill() + except psutil.NoSuchProcess: + logger.debug("sglang-router process %s already exited before stop()", pid) + else: + logger.debug("No sglang-router process found on port %s", self.context.port) + + # Clean up router logs + if self.context.log_dir.exists(): + logger.debug("Cleaning up router logs for port %s...", self.context.port) + shutil.rmtree(self.context.log_dir, ignore_errors=True) + + except Exception: + logger.exception("Failed to stop sglang-router") + raise + + def is_running(self) -> bool: + """Check if the SGLang router is running and responding to HTTP requests on the assigned port.""" + try: + with httpx.Client(timeout=5.0) as client: + response = client.get(f"http://{self.context.host}:{self.context.port}/workers") + return response.status_code == 200 + except httpx.RequestError as e: + logger.debug( + "Sglang router not responding on port %s: %s", + self.context.port, + e, + ) + return False + + def remove_replicas(self, replica_urls: List[str]) -> None: + for replica_url in replica_urls: + self._remove_worker_from_router(replica_url) + + def update_replicas(self, replica_urls: List[str]) -> None: + """Update replicas for service, replacing the current set.""" + # Query router to get current worker URLs + current_workers = self._get_router_workers() + current_worker_urls: set[str] = set() + for worker in current_workers: + url = worker.get("url") + if url and isinstance(url, str): + # Normalize URL by removing trailing slashes to avoid path artifacts + normalized_url = url.rstrip("/") + current_worker_urls.add(normalized_url) + # Normalize target URLs to ensure consistent comparison + target_worker_urls = {url.rstrip("/") for url in replica_urls} + + # Workers to add + workers_to_add = target_worker_urls - current_worker_urls + # Workers to remove + workers_to_remove = current_worker_urls - target_worker_urls + + if workers_to_add: + logger.info( + "Sglang router update: adding %d workers for router on port %s", + len(workers_to_add), + self.context.port, + ) + if workers_to_remove: + logger.info( + "Sglang router update: removing %d workers for router on port %s", + len(workers_to_remove), + self.context.port, + ) + + # Add workers + for worker_url in sorted(workers_to_add): + success = self._add_worker_to_router(worker_url) + if not success: + logger.warning("Failed to add worker %s, continuing with others", worker_url) + + # Remove workers + for worker_url in sorted(workers_to_remove): + success = self._remove_worker_from_router(worker_url) + if not success: + logger.warning("Failed to remove worker %s, continuing with others", worker_url) + + def _get_router_workers(self) -> List[dict]: + try: + with httpx.Client(timeout=5.0) as client: + response = client.get(f"http://{self.context.host}:{self.context.port}/workers") + if response.status_code == 200: + response_data = response.json() + workers = response_data.get("workers", []) + return workers + return [] + except Exception: + logger.exception("Error getting sglang router workers") + return [] + + def _add_worker_to_router(self, worker_url: str) -> bool: + try: + payload = {"url": worker_url, "worker_type": "regular"} + with httpx.Client(timeout=5.0) as client: + response = client.post( + f"http://{self.context.host}:{self.context.port}/workers", + json=payload, + ) + if response.status_code == 202: + response_data = response.json() + if response_data.get("status") == "accepted": + logger.info( + "Worker %s accepted by sglang router on port %s", + worker_url, + self.context.port, + ) + return True + else: + logger.error( + "Sglang router on port %s failed to accept worker: %s", + self.context.port, + response_data, + ) + return False + else: + logger.error( + "Failed to add worker %s: status %d, %s", + worker_url, + response.status_code, + response.text, + ) + return False + except Exception: + logger.exception("Error adding worker %s", worker_url) + return False + + def _remove_worker_from_router(self, worker_url: str) -> bool: + try: + encoded_url = urllib.parse.quote(worker_url, safe="") + with httpx.Client(timeout=5.0) as client: + response = client.delete( + f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" + ) + if response.status_code == 202: + response_data = response.json() + if response_data.get("status") == "accepted": + logger.info( + "Removed worker %s from sglang router on port %s", + worker_url, + self.context.port, + ) + return True + else: + logger.error( + "Sglang router on port %s failed to remove worker: %s", + self.context.port, + response_data, + ) + return False + else: + logger.error( + "Failed to remove worker %s: status %d, %s", + worker_url, + response.status_code, + response.text, + ) + return False + except Exception: + logger.exception("Error removing worker %s", worker_url) + return False diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 2d3e755ac..bbda92d91 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -1,16 +1,23 @@ import importlib.resources +import socket import subprocess import tempfile from asyncio import Lock from pathlib import Path -from typing import Optional +from typing import Dict, Optional import jinja2 from pydantic import BaseModel from typing_extensions import Literal +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY from dstack._internal.proxy.gateway.models import ACMESettings +from dstack._internal.proxy.gateway.services.model_routers import ( + Router, + RouterContext, + get_router, +) from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -28,10 +35,9 @@ class SiteConfig(BaseModel): def render(self) -> str: template = read_package_resource(f"{self.type}.jinja2") - return jinja2.Template(template).render( - **self.dict(), - proxy_port=PROXY_PORT_ON_GATEWAY, - ) + render_dict = self.dict() + render_dict["proxy_port"] = PROXY_PORT_ON_GATEWAY + return jinja2.Template(template).render(**render_dict) class ReplicaConfig(BaseModel): @@ -64,6 +70,8 @@ class ServiceConfig(SiteConfig): limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] replicas: list[ReplicaConfig] + router: Optional[AnyRouterConfig] = None + router_port: Optional[int] = None class ModelEntrypointConfig(SiteConfig): @@ -77,14 +85,95 @@ class Nginx: def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._conf_dir = conf_dir self._lock: Lock = Lock() + # 1:1 service-to-router mapping + self._router_port_to_domain: Dict[int, str] = {} + self._domain_to_router: Dict[str, Router] = {} + self._ROUTER_PORT_MIN: int = 20000 + self._ROUTER_PORT_MAX: int = 24999 + self._WORKER_PORT_MIN: int = 10001 + self._WORKER_PORT_MAX: int = 11999 + self._next_router_port: int = self._ROUTER_PORT_MIN + # Tracking of worker ports to avoid conflicts across router instances + self._allocated_worker_ports: set[int] = set() + self._domain_to_worker_ports: Dict[str, list[int]] = {} + self._next_worker_port: int = self._WORKER_PORT_MIN async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.debug("Registering %s domain %s", conf.type, conf.domain) conf_name = self.get_config_name(conf.domain) - async with self._lock: if conf.https: await run_async(self.run_certbot, conf.domain, acme) + + if isinstance(conf, ServiceConfig) and conf.router: + if conf.router.type == RouterType.SGLANG: + # Check if router already exists for this domain + if conf.domain in self._domain_to_router: + # Router already exists, reuse it + router = self._domain_to_router[conf.domain] + router_port = router.context.port + conf.router_port = router_port + else: + # Allocate router port for new router + router_port = self._allocate_router_port() + conf.router_port = router_port + + # Create per-service log directory + log_dir = Path(f"./router_logs/{conf.domain}") + + # Create router context with allocated port + ctx = RouterContext( + port=router_port, + log_dir=log_dir, + ) + + # Create new router instance for this service + router = get_router(conf.router, context=ctx) + + # Store mappings + self._router_port_to_domain[router_port] = conf.domain + self._domain_to_router[conf.domain] = router + + # Start router if not running + try: + if not await run_async(router.is_running): + await run_async(router.start) + except Exception: + # Clean up on failure + del self._router_port_to_domain[router_port] + del self._domain_to_router[conf.domain] + raise + + allocated_ports = self._allocate_worker_ports(len(conf.replicas)) + replica_urls = [ + f"http://{router.context.host}:{port}" for port in allocated_ports + ] + + # Write router workers config + try: + if conf.replicas: + await run_async(self.write_router_workers_conf, conf, allocated_ports) + # Discard old worker ports if domain already has allocated ports (required for scaling case) + if conf.domain in self._domain_to_worker_ports: + old_worker_ports = self._domain_to_worker_ports[conf.domain] + for port in old_worker_ports: + self._allocated_worker_ports.discard(port) + self._domain_to_worker_ports[conf.domain] = allocated_ports + except Exception as e: + logger.exception( + "write_router_workers_conf failed for domain=%s: %s", conf.domain, e + ) + raise + + # Update replicas to router (actual HTTP API calls to add workers) + try: + await run_async(router.update_replicas, replica_urls) + except Exception as e: + logger.exception( + "Failed to add replicas to router for domain=%s: %s", conf.domain, e + ) + raise + await run_async(self.write_conf, conf.render(), conf_name) logger.info("Registered %s domain %s", conf.type, conf.domain) @@ -96,6 +185,37 @@ async def unregister(self, domain: str) -> None: return async with self._lock: await run_async(sudo_rm, conf_path) + + if domain in self._domain_to_router: + router = self._domain_to_router[domain] + # Remove all workers for this domain + if domain in self._domain_to_worker_ports: + worker_ports = self._domain_to_worker_ports[domain] + replica_urls = [ + f"http://{router.context.host}:{port}" for port in worker_ports + ] + await run_async(router.remove_replicas, replica_urls) + # Stop and kill the router + await run_async(router.stop) + # Remove from mappings + router_port = router.context.port + if router_port in self._router_port_to_domain: + del self._router_port_to_domain[router_port] + del self._domain_to_router[domain] + + # Discard worker ports for this domain + if domain in self._domain_to_worker_ports: + worker_ports = self._domain_to_worker_ports[domain] + for port in worker_ports: + self._allocated_worker_ports.discard(port) + del self._domain_to_worker_ports[domain] + logger.debug("Freed worker ports %s for domain %s", worker_ports, domain) + + # Remove workers config file + workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" + if workers_conf_path.exists(): + await run_async(sudo_rm, workers_conf_path) + await run_async(self.reload) logger.info("Unregistered domain %s", domain) @@ -164,10 +284,147 @@ def certificate_exists(domain: str) -> bool: def get_config_name(domain: str) -> str: return f"443-{domain}.conf" + @staticmethod + def _is_port_available(port: int) -> bool: + """Check if a port is actually available (not in use by any process). + + Tries to bind to the port to see if it's available. + """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("127.0.0.1", port)) + # If bind succeeds, port is available + return True + except OSError: + # If bind fails (e.g., Address already in use), port is not available + return False + except Exception: + logger.warning("Error checking port %s availability", port) + return False + + def _allocate_router_port(self) -> int: + """Allocate next available router port in fixed range. + + Checks both our internal allocation map and actual port availability + to avoid conflicts with other services. Range chosen to avoid ephemeral ports. + """ + port = self._next_router_port + max_attempts = self._ROUTER_PORT_MAX - self._ROUTER_PORT_MIN + 1 + attempts = 0 + + while attempts < max_attempts: + # Check if port is already allocated by us + if port in self._router_port_to_domain: + port += 1 + if port > self._ROUTER_PORT_MAX: + port = self._ROUTER_PORT_MIN # Wrap around + attempts += 1 + continue + + # Check if port is actually available on the system + if self._is_port_available(port): + # Port is available, allocate it + self._next_router_port = port + 1 + if self._next_router_port > self._ROUTER_PORT_MAX: + self._next_router_port = self._ROUTER_PORT_MIN # Wrap around + logger.debug("Allocated router port %s", port) + return port + + # Port is in use, try next one + logger.debug("Port %s is in use, trying next port", port) + port += 1 + if port > self._ROUTER_PORT_MAX: + port = self._ROUTER_PORT_MIN # Wrap around + attempts += 1 + + raise UnexpectedProxyError( + f"Router port range exhausted ({self._ROUTER_PORT_MIN}-{self._ROUTER_PORT_MAX}). " + "All ports in range appear to be in use." + ) + + def _allocate_worker_ports(self, num_ports: int) -> list[int]: + """Allocate worker ports globally in fixed range. + + Worker ports are used by nginx to listen and proxy to worker sockets. + They must be unique across all router instances. Range chosen to avoid ephemeral ports. + + Args: + num_ports: Number of worker ports to allocate + + Returns: + List of allocated worker port numbers + """ + allocated = [] + port = self._next_worker_port + max_attempts = (self._WORKER_PORT_MAX - self._WORKER_PORT_MIN + 1) * 2 # Allow wrap-around + attempts = 0 + + while len(allocated) < num_ports and attempts < max_attempts: + # Check if port is already allocated globally + if port in self._allocated_worker_ports: + port += 1 + if port > self._WORKER_PORT_MAX: + port = self._WORKER_PORT_MIN # Wrap around + attempts += 1 + continue + + # Check if port is actually available on the system + if self._is_port_available(port): + allocated.append(port) + self._allocated_worker_ports.add(port) + logger.debug("Allocated worker port %s", port) + port += 1 + if port > self._WORKER_PORT_MAX: + port = self._WORKER_PORT_MIN # Wrap around + else: + logger.debug("Worker port %s is in use, trying next port", port) + port += 1 + if port > self._WORKER_PORT_MAX: + port = self._WORKER_PORT_MIN # Wrap around + + attempts += 1 + + if len(allocated) < num_ports: + # Free up the ports we did allocate + for p in allocated: + self._allocated_worker_ports.discard(p) + raise UnexpectedProxyError( + f"Failed to allocate {num_ports} worker ports in range " + f"({self._WORKER_PORT_MIN}-{self._WORKER_PORT_MAX}). " + f"Only allocated {len(allocated)} ports after {attempts} attempts." + ) + + # Update next worker port for next allocation + self._next_worker_port = port + if self._next_worker_port > self._WORKER_PORT_MAX: + self._next_worker_port = self._WORKER_PORT_MIN # Wrap around + + return allocated + def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") + def write_router_workers_conf(self, conf: ServiceConfig, allocated_ports: list[int]) -> None: + """Write router workers configuration file (generic).""" + # Pass ports to template + workers_config = generate_router_workers_config(conf, allocated_ports) + workers_conf_name = f"router-workers.{conf.domain}.conf" + self.write_conf(workers_config, workers_conf_name) + + +def generate_router_workers_config(conf: ServiceConfig, allocated_ports: list[int]) -> str: + """Generate router workers configuration (generic, uses router_workers.jinja2 template).""" + template = read_package_resource("router_workers.jinja2") + return jinja2.Template(template).render( + domain=conf.domain, + replicas=conf.replicas, + ports=allocated_ports, + proxy_port=PROXY_PORT_ON_GATEWAY, + ) + def read_package_resource(file: str) -> str: return ( diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 3ea412d79..636d8c38e 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -6,6 +6,7 @@ import dstack._internal.proxy.gateway.schemas.registry as schemas from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType from dstack._internal.proxy.gateway import models as gateway_models from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.services.nginx import ( @@ -44,6 +45,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + router: Optional[AnyRouterConfig] = None, ) -> None: service = models.Service( project_name=project_name, @@ -54,6 +56,7 @@ async def register_service( auth=auth, client_max_body_size=client_max_body_size, replicas=(), + router=router, ) async with lock: @@ -306,6 +309,15 @@ async def get_nginx_service_config( ) -> ServiceConfig: limit_req_zones: list[LimitReqZoneConfig] = [] locations: list[LocationConfig] = [] + is_sglang = service.router and service.router.type == RouterType.SGLANG + sglang_whitelisted_paths = [ + "/generate", + "/v1/", + "/chat/completions", + ] # Prefix match for paths that end with a slash and exact match for paths that don't + sglang_limits: dict[str, LimitReqConfig] = {} + sglang_prefix_lengths: dict[str, int] = {} # Track prefix lengths for most-specific selection + for i, rate_limit in enumerate(service.rate_limits): zone_name = f"{i}.{service.domain_safe}" if isinstance(rate_limit.key, models.IPAddressPartitioningKey): @@ -317,13 +329,39 @@ async def get_nginx_service_config( limit_req_zones.append( LimitReqZoneConfig(name=zone_name, key=key, rpm=round(rate_limit.rps * 60)) ) - locations.append( - LocationConfig( - prefix=rate_limit.prefix, - limit_req=LimitReqConfig(zone=zone_name, burst=rate_limit.burst), + if is_sglang: + for path in sglang_whitelisted_paths: + if rate_limit.prefix == path or path.startswith(rate_limit.prefix): + # Use the longest prefix if multiple prefixes match the same path + current_prefix_len = len(rate_limit.prefix) + if path not in sglang_limits or current_prefix_len > sglang_prefix_lengths.get( + path, 0 + ): + sglang_limits[path] = LimitReqConfig( + zone=zone_name, burst=rate_limit.burst + ) + sglang_prefix_lengths[path] = current_prefix_len + else: + locations.append( + LocationConfig( + prefix=rate_limit.prefix, + limit_req=LimitReqConfig(zone=zone_name, burst=rate_limit.burst), + ) ) - ) - if not any(location.prefix == "/" for location in locations): + + # Add SGLang whitelisted paths as locations + if is_sglang: + for path in sglang_whitelisted_paths: + # Use prefix match for paths that end with a slash and exact match for paths that don't + if path.endswith("/"): + locations.append(LocationConfig(prefix=path, limit_req=sglang_limits.get(path))) + else: + locations.append( + LocationConfig(prefix=f"= {path}", limit_req=sglang_limits.get(path)) + ) + + # Don't auto-add / location for SGLang routers (catch-all 403 handles it) + if not any(location.prefix == "/" for location in locations) and not is_sglang: locations.append(LocationConfig(prefix="/", limit_req=None)) return ServiceConfig( domain=service.domain_safe, @@ -335,6 +373,7 @@ async def get_nginx_service_config( limit_req_zones=limit_req_zones, locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs + router=service.router, ) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 5cb5471d8..bf37e0b5a 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -7,6 +7,7 @@ from typing_extensions import Annotated from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError @@ -57,6 +58,7 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] + router: Optional[AnyRouterConfig] = None @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index afad2831b..273b1fb89 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -108,6 +108,7 @@ async def create_gateway_compute( ssh_key_pub=gateway_ssh_public_key, certificate=configuration.certificate, tags=configuration.tags, + router=configuration.router, ) gpd = await run_async( @@ -448,10 +449,16 @@ async def _update_gateway(gateway_compute_model: GatewayComputeModel, build: str gateway_compute_model.ssh_private_key, ) logger.debug("Updating gateway %s", connection.ip_address) + compute_config = GatewayComputeConfiguration.__response__.parse_raw( + gateway_compute_model.configuration + ) + + # Build package spec with extras and wheel URL + gateway_package = get_dstack_gateway_wheel(build, compute_config.router) commands = [ # prevent update.sh from overwriting itself during execution "cp dstack/update.sh dstack/_update.sh", - f"sh dstack/_update.sh {get_dstack_gateway_wheel(build)} {build}", + f'sh dstack/_update.sh "{gateway_package}" {build}', "rm dstack/_update.sh", ] stdout = await connection.tunnel.aexec("/bin/sh -c '" + " && ".join(commands) + "'") diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index f8c090079..d4f1c831e 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -9,6 +9,7 @@ from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.configurations import RateLimit from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port from dstack._internal.proxy.gateway.schemas.stats import ServiceStats from dstack._internal.server import settings @@ -45,6 +46,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, + router: Optional[AnyRouterConfig] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" @@ -59,6 +61,7 @@ async def register_service( "options": options, "rate_limits": [limit.dict() for limit in rate_limits], "ssh_private_key": ssh_private_key, + "router": router.dict() if router is not None else None, } resp = await self._client.post( self._url(f"/api/registry/{project}/services/register"), json=payload diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index a8089a93a..05c1fa909 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -82,6 +82,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) + router = gateway_configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -119,6 +120,7 @@ async def _register_service_in_gateway( options=service_spec.options, rate_limits=run_spec.configuration.rate_limits, ssh_private_key=run_model.project.ssh_private_key, + router=router, ) logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) except SSHError: diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 996157350..6d06e6969 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -70,6 +70,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -121,6 +122,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -201,6 +203,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "name": "test", "backend": backend.type.value, "region": "us", + "router": None, "domain": None, "default": True, "public_ip": True, @@ -253,6 +256,7 @@ async def test_create_gateway_without_name( "name": "random-name", "backend": backend.type.value, "region": "us", + "router": None, "domain": None, "default": True, "public_ip": True, @@ -355,6 +359,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": gateway.wildcard_domain, "default": True, "public_ip": True, @@ -477,6 +482,7 @@ def get_backend(project, backend_type): "name": gateway_gcp.name, "backend": backend_gcp.type.value, "region": gateway_gcp.region, + "router": None, "domain": gateway_gcp.wildcard_domain, "default": False, "public_ip": True, @@ -546,6 +552,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": "test.com", "default": False, "public_ip": True,