From 3c08dc7ff51ebee7ed3544c916fb4c4aae6a87be Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 6 Nov 2025 22:05:54 +0545 Subject: [PATCH 01/11] Add SGLang Router Support --- gateway/pyproject.toml | 5 +- .../_internal/core/backends/aws/compute.py | 4 +- .../_internal/core/backends/azure/compute.py | 4 +- .../_internal/core/backends/base/compute.py | 21 +- .../_internal/core/backends/gcp/compute.py | 4 +- src/dstack/_internal/core/models/gateways.py | 6 + src/dstack/_internal/core/models/routers.py | 27 ++ .../proxy/gateway/model_routers/__init__.py | 62 +++ .../proxy/gateway/model_routers/base.py | 147 +++++++ .../proxy/gateway/model_routers/sglang.py | 366 ++++++++++++++++++ .../gateway/resources/nginx/service.jinja2 | 4 + .../resources/nginx/sglang_workers.jinja2 | 23 ++ .../proxy/gateway/routers/registry.py | 1 + .../proxy/gateway/schemas/registry.py | 2 + .../_internal/proxy/gateway/services/nginx.py | 70 +++- .../proxy/gateway/services/registry.py | 6 + src/dstack/_internal/proxy/lib/models.py | 3 + .../server/services/gateways/__init__.py | 14 + .../server/services/gateways/client.py | 3 + .../server/services/services/__init__.py | 2 + .../_internal/server/routers/test_gateways.py | 7 + 21 files changed, 771 insertions(+), 10 deletions(-) create mode 100644 src/dstack/_internal/core/models/routers.py create mode 100644 src/dstack/_internal/proxy/gateway/model_routers/__init__.py create mode 100644 src/dstack/_internal/proxy/gateway/model_routers/base.py create mode 100644 src/dstack/_internal/proxy/gateway/model_routers/sglang.py create mode 100644 src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index a67171c25..3c877c0af 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,9 +11,12 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ git+https://github.com/dstackai/dstack.git@master", + "dstack[gateway] @ git+https://github.com/Bihan/dstack.git@add_sglang_router_support", ] +[project.optional-dependencies] +sglang = ["sglang-router==0.2.2"] + [tool.setuptools.package-data] "dstack.gateway" = [ "resources/systemd/*", diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index f973980e0..39141bd10 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -460,7 +460,9 @@ def create_gateway( image_id=aws_resources.get_gateway_image_id(ec2_client), instance_type="t3.micro", iam_instance_profile=None, - user_data=get_gateway_user_data(configuration.ssh_key_pub), + user_data=get_gateway_user_data( + configuration.ssh_key_pub, router_config=configuration.router_config + ), tags=tags, security_group_id=security_group_id, spot=False, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 8b37a72b2..357bfd0a9 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -277,7 +277,9 @@ def create_gateway( image_reference=_get_gateway_image_ref(), vm_size="Standard_B1ms", instance_name=instance_name, - user_data=get_gateway_user_data(configuration.ssh_key_pub), + user_data=get_gateway_user_data( + configuration.ssh_key_pub, router_config=configuration.router_config + ), ssh_pub_keys=[configuration.ssh_key_pub], spot=False, disk_size=30, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 9598a52b0..3fd0a5b6d 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -38,6 +38,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -876,7 +877,9 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str) -> str: +def get_gateway_user_data( + authorized_key: str, router_config: Optional[AnyRouterConfig] = None +) -> str: return get_cloud_config( package_update=True, packages=[ @@ -892,7 +895,7 @@ def get_gateway_user_data(authorized_key: str) -> str: "s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/", "/etc/nginx/nginx.conf", ], - ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())], + ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router_config))], ], ssh_authorized_keys=[authorized_key], ) @@ -1021,16 +1024,24 @@ def get_dstack_gateway_wheel(build: str) -> str: r.raise_for_status() build = r.text.strip() logger.debug("Found the latest gateway build: %s", build) - return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" -def get_dstack_gateway_commands() -> List[str]: +def get_dstack_gateway_commands(router_config: Optional[AnyRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() + wheel = get_dstack_gateway_wheel(build) + # Use router type directly as pip extra + if router_config: + gateway_package = f"dstack-gateway[{router_config.type}]" + else: + gateway_package = "dstack-gateway" return [ "mkdir -p /home/ubuntu/dstack", "python3 -m venv /home/ubuntu/dstack/blue", "python3 -m venv /home/ubuntu/dstack/green", - f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}", + f"/home/ubuntu/dstack/blue/bin/pip install {wheel}", + f"/home/ubuntu/dstack/blue/bin/pip install --upgrade '{gateway_package}'", "sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run", ] diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 23aab379d..a333d415f 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -599,7 +599,9 @@ def create_gateway( machine_type="e2-medium", accelerators=[], spot=False, - user_data=get_gateway_user_data(configuration.ssh_key_pub), + user_data=get_gateway_user_data( + configuration.ssh_key_pub, router_config=configuration.router_config + ), authorized_keys=[configuration.ssh_key_pub], labels=labels, tags=[gcp_resources.DSTACK_GATEWAY_TAG], diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6a480b580..2cab4df16 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -7,6 +7,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.utils.tags import tags_validator @@ -50,6 +51,10 @@ class GatewayConfiguration(CoreModel): default: Annotated[bool, Field(description="Make the gateway default")] = False backend: Annotated[BackendType, Field(description="The gateway backend")] region: Annotated[str, Field(description="The gateway region")] + router_config: Annotated[ + Optional[AnyRouterConfig], + Field(description="The router configuration"), + ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") ] = None @@ -113,6 +118,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None + router_config: Optional[AnyRouterConfig] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py new file mode 100644 index 000000000..58766aac7 --- /dev/null +++ b/src/dstack/_internal/core/models/routers.py @@ -0,0 +1,27 @@ +from enum import Enum +from typing import Union + +from pydantic import Field +from typing_extensions import Annotated, Literal + +from dstack._internal.core.models.common import CoreModel + + +class RouterType(str, Enum): + SGLANG = "sglang" + VLLM = "vllm" + + +class SGLangRouterConfig(CoreModel): + type: Literal["sglang"] = "sglang" + policy: str = "cache_aware" + + +class VLLMRouterConfig(CoreModel): + type: Literal["vllm"] = "vllm" + policy: str = "cache_aware" + + +AnyRouterConfig = Annotated[ + Union[SGLangRouterConfig, VLLMRouterConfig], Field(discriminator="type") +] diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py new file mode 100644 index 000000000..487a78e04 --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py @@ -0,0 +1,62 @@ +from typing import Dict, List, Optional, Type + +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.utils.logging import get_logger + +from .base import Replica, Router, RouterContext + +logger = get_logger(__name__) + +"""This provides a registry of available router implementations.""" + +_ROUTER_CLASSES: List[Type[Router]] = [] + +try: + from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter + + _ROUTER_CLASSES.append(SglangRouter) + logger.debug("Registered SglangRouter") +except ImportError as e: + logger.warning("SGLang router not available: %s", e) + +_ROUTER_TYPE_TO_CLASS_MAP: Dict[RouterType, Type[Router]] = {} + +for router_class in _ROUTER_CLASSES: + router_type_str = getattr(router_class, "TYPE", None) + if router_type_str is None: + logger.warning(f"Router class {router_class.__name__} missing TYPE attribute, skipping") + continue + router_type = RouterType(router_type_str) + _ROUTER_TYPE_TO_CLASS_MAP[router_type] = router_class + +_AVAILABLE_ROUTER_TYPES = list(_ROUTER_TYPE_TO_CLASS_MAP.keys()) + + +def get_router_class(router_type: RouterType) -> Optional[Type[Router]]: + """Get the router class for a given router type.""" + return _ROUTER_TYPE_TO_CLASS_MAP.get(router_type) + + +def get_router(router_config: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: + """Factory function to create a router instance from router configuration.""" + router_type = RouterType(router_config.type) + router_class = get_router_class(router_type) + + if router_class is None: + available_types = [rt.value for rt in _AVAILABLE_ROUTER_TYPES] + raise ValueError( + f"Router type '{router_type.value}' is not available. " + f"Available types: {available_types}" + ) + + # Router implementations may have different constructor signatures + # SglangRouter takes (router_config, context), others might differ + return router_class(router_config=router_config, context=context) + + +__all__ = [ + "Router", + "RouterContext", + "Replica", + "get_router", +] diff --git a/src/dstack/_internal/proxy/gateway/model_routers/base.py b/src/dstack/_internal/proxy/gateway/model_routers/base.py new file mode 100644 index 000000000..23c3a72b3 --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/model_routers/base.py @@ -0,0 +1,147 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Literal, Optional + +from pydantic import BaseModel + +from dstack._internal.core.models.routers import AnyRouterConfig + + +class RouterContext(BaseModel): + """Context for router initialization and configuration.""" + + class Config: + frozen = True + + host: str = "127.0.0.1" + port: int = 3000 + log_dir: Path = Path("./router_logs") + log_level: Literal["debug", "info", "warning", "error"] = "info" + + +class Replica(BaseModel): + """Represents a single replica (worker) endpoint managed by the router. + + The model field identifies which model this replica serves. + In SGLang, model = model_id (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct"). + """ + + url: str # HTTP URL where the replica is accessible (e.g., "http://127.0.0.1:10001") + model: str # (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct") + + +class Router(ABC): + """Abstract base class for router implementations (e.g., SGLang, vLLM). + + A router manages the lifecycle of worker replicas and handles request routing. + Different router implementations may have different mechanisms for managing + replicas. + """ + + def __init__( + self, + router_config: Optional[AnyRouterConfig] = None, + context: Optional[RouterContext] = None, + ): + """Initialize router with context. + + Args: + router_config: Optional router configuration (implementation-specific) + context: Runtime context for the router (host, port, logging, etc.) + """ + self.context = context or RouterContext() + + @abstractmethod + def start(self) -> None: + """Start the router process. + + Raises: + Exception: If the router fails to start. + """ + ... + + @abstractmethod + def stop(self) -> None: + """Stop the router process. + + Raises: + Exception: If the router fails to stop. + """ + ... + + @abstractmethod + def is_running(self) -> bool: + """Check if the router is currently running and responding. + + Returns: + True if the router is running and healthy, False otherwise. + """ + ... + + @abstractmethod + def register_replicas( + self, domain: str, num_replicas: int, model_id: Optional[str] = None + ) -> List[Replica]: + """Register replicas to a domain (allocate ports/URLs for workers). + + Args: + domain: The domain name for this service. + num_replicas: The number of replicas to allocate for this domain. + model_id: Optional model identifier (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct"). + Required only for routers that support IGW (Inference Gateway) mode for multi-model serving. + + Returns: + List of Replica objects with allocated URLs and model_id set (if provided). + + Raises: + Exception: If allocation fails. + """ + ... + + @abstractmethod + def unregister_replicas(self, domain: str) -> None: + """Unregister replicas for a domain (remove model and unassign all its replicas). + + Args: + domain: The domain name for this service. + + Raises: + Exception: If removal fails or domain is not found. + """ + ... + + @abstractmethod + def add_replicas(self, replicas: List[Replica]) -> None: + """Register replicas with the router (actual API calls to add workers). + + Args: + replicas: The list of replicas to add to router. + + Raises: + Exception: If adding replicas fails. + """ + ... + + @abstractmethod + def remove_replicas(self, replicas: List[Replica]) -> None: + """Unregister replicas from the router (actual API calls to remove workers). + + Args: + replicas: The list of replicas to remove from router. + + Raises: + Exception: If removing replicas fails. + """ + ... + + @abstractmethod + def update_replicas(self, replicas: List[Replica]) -> None: + """Update replicas for service, replacing the current set. + + Args: + replicas: The new list of replicas for this service. + + Raises: + Exception: If updating replicas fails. + """ + ... diff --git a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py new file mode 100644 index 000000000..048623ace --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py @@ -0,0 +1,366 @@ +import json +import shutil +import subprocess +import time +import urllib.parse +from collections import defaultdict +from typing import DefaultDict, Dict, List, Optional + +from dstack._internal.core.models.routers import SGLangRouterConfig +from dstack._internal.proxy.gateway.const import DSTACK_DIR_ON_GATEWAY +from dstack._internal.utils.logging import get_logger + +from .base import Replica, Router, RouterContext + +logger = get_logger(__name__) + + +class SglangRouter(Router): + """SGLang router implementation using IGW (Inference Gateway) mode for multi-model serving.""" + + TYPE = "sglang" + + def __init__(self, router_config: SGLangRouterConfig, context: Optional[RouterContext] = None): + """Initialize SGLang router. + + Args: + router_config: SGLang router configuration (policy, cache_threshold, etc.) + context: Runtime context for the router (host, port, logging, etc.) + """ + super().__init__(router_config=router_config, context=context) + self.config = router_config + self._domain_to_model_id: Dict[str, str] = {} # domain -> model_id + self._domain_to_ports: Dict[ + str, List[int] + ] = {} # domain -> allocated sglang worker ports. + self._next_worker_port: int = 10001 # Starting port for worker endpoints + + def start(self) -> None: + """Start the SGLang router process.""" + try: + logger.info("Starting sglang-router...") + + # Determine active venv (blue or green) + version_file = DSTACK_DIR_ON_GATEWAY / "version" + if version_file.exists(): + version = version_file.read_text().strip() + else: + version = "blue" + + # Use Python from the active venv + venv_python = DSTACK_DIR_ON_GATEWAY / version / "bin" / "python3" + + cmd = [ + str(venv_python), + "-m", + "sglang_router.launch_router", + "--host", + "0.0.0.0", # Bind to all interfaces (nginx connects via 127.0.0.1) + "--port", + str(self.context.port), + "--enable-igw", + "--log-level", + self.context.log_level, + "--log-dir", + str(self.context.log_dir), + ] + + if hasattr(self.config, "policy") and self.config.policy: + cmd.extend(["--policy", self.config.policy]) + + # Add additional required configs here + + subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + # Wait for router to start + time.sleep(2) + + # Verify router is running + if not self.is_running(): + raise Exception("Failed to start sglang router") + + logger.info("Sglang router started successfully") + + except Exception as e: + logger.error(f"Failed to start sglang-router: {e}") + raise + + def stop(self) -> None: + """Stop the SGLang router process.""" + try: + result = subprocess.run( + ["pgrep", "-f", "sglang::router"], capture_output=True, timeout=5 + ) + if result.returncode == 0: + logger.info("Stopping sglang-router process...") + subprocess.run(["pkill", "-f", "sglang::router"], timeout=5) + else: + logger.debug("No sglang-router process found to stop") + + # Clean up router logs + if self.context.log_dir.exists(): + logger.debug("Cleaning up router logs...") + shutil.rmtree(self.context.log_dir, ignore_errors=True) + else: + logger.debug("No router logs directory found to clean up") + + except Exception as e: + logger.error(f"Failed to stop sglang-router: {e}") + raise + + def is_running(self) -> bool: + """Check if the SGLang router is running and responding to HTTP requests.""" + try: + result = subprocess.run( + ["curl", "-s", f"http://{self.context.host}:{self.context.port}/workers"], + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except Exception as e: + logger.error(f"Error checking sglang router status: {e}") + return False + + def is_model_registered(self, model_id: str) -> bool: + """Check if a model with the given model_id is registered.""" + return model_id in self._domain_to_model_id.values() + + def register_replicas( + self, domain: str, num_replicas: int, model_id: Optional[str] = None + ) -> List[Replica]: + """Register replicas to a domain (allocate ports/URLs for workers). + SGLang router uses IGW (Inference Gateway) mode, which requires model_id for multi-model serving. + + Maintains in-memory state: + - domain_to_model_id: Maps domain to model_id for unregistering by domain and model validation. + - domain_to_ports: Maps domain to allocated ports to track port assignments and avoid conflicts. + + Args: + domain: The domain name for this service. + num_replicas: The number of replicas to allocate for this domain. + model_id: Model identifier (required for SGLang IGW mode). + + Raises: + ValueError: If model_id is None (required for SGLang IGW mode). + """ + if model_id is None: + raise ValueError("model_id is required for SGLang router (IGW mode)") + + is_new_model = not self.is_model_registered(model_id) + + if is_new_model: + # Store domain -> model_id mapping + self._domain_to_model_id[domain] = model_id + + # Allocate ports for replicas + allocated_ports = [] + for _ in range(num_replicas): + allocated_ports.append(self._next_worker_port) + self._next_worker_port += 1 + + self._domain_to_ports[domain] = allocated_ports + + logger.debug( + f"Allocated model {model_id} (domain {domain}) with {num_replicas} replicas " + f"on ports {allocated_ports}" + ) + else: + # Verify domain matches + if self._domain_to_model_id.get(domain) != model_id: + raise ValueError(f"Domain {domain} does not match model_id {model_id}") + + # Get current allocated ports + current_ports = self._domain_to_ports.get(domain, []) + current_count = len(current_ports) + + if num_replicas == current_count: + # No change needed, return existing replicas + replicas = [ + Replica(url=f"http://{self.context.host}:{port}", model=model_id) + for port in current_ports + ] + return replicas + + # Re-allocate ports for new count + allocated_ports = [] + for _ in range(num_replicas): + allocated_ports.append(self._next_worker_port) + self._next_worker_port += 1 + + self._domain_to_ports[domain] = allocated_ports + + logger.debug( + f"Updated model {model_id} (domain {domain}) with {num_replicas} replicas " + f"on ports {allocated_ports}" + ) + + # Create Replica objects with URLs and model_id + replicas = [ + Replica(url=f"http://{self.context.host}:{port}", model=model_id) + for port in allocated_ports + ] + return replicas + + def unregister_replicas(self, domain: str) -> None: + """Unregister replicas for a domain (remove model and unassign all its replicas).""" + # Get model_id from domain mapping + model_id = self._domain_to_model_id.get(domain) + if model_id is None: + logger.warning(f"Domain {domain} not found in router mapping, skipping unregister") + return + + # Remove all workers for this model_id from the router + current_workers = self._get_router_workers(model_id) + for worker in current_workers: + self._remove_worker_from_router(worker["url"]) + + # Clean up internal state + if domain in self._domain_to_model_id: + del self._domain_to_model_id[domain] + if domain in self._domain_to_ports: + del self._domain_to_ports[domain] + + logger.debug(f"Removed model {model_id} (domain {domain})") + + def add_replicas(self, replicas: List[Replica]) -> None: + """Register replicas with the router (actual HTTP API calls to add workers).""" + for replica in replicas: + self._add_worker_to_router(replica.url, replica.model) + + def remove_replicas(self, replicas: List[Replica]) -> None: + """Unregister replicas from the router (actual HTTP API calls to remove workers).""" + for replica in replicas: + self._remove_worker_from_router(replica.url) + + def update_replicas(self, replicas: List[Replica]) -> None: + """Update replicas for a model, replacing the current set.""" + # Group replicas by model_id + replicas_by_model: DefaultDict[str, List[Replica]] = defaultdict(list) + for replica in replicas: + replicas_by_model[replica.model].append(replica) + + # Update each model separately + for model_id, model_replicas in replicas_by_model.items(): + # Get current workers for this model_id + current_workers = self._get_router_workers(model_id) + current_worker_urls = {worker["url"] for worker in current_workers} + + # Calculate target worker URLs + target_worker_urls = {replica.url for replica in model_replicas} + + # Workers to add + workers_to_add = target_worker_urls - current_worker_urls + # Workers to remove + workers_to_remove = current_worker_urls - target_worker_urls + + if workers_to_add: + logger.info( + "Sglang router update: adding %d workers for model %s", + len(workers_to_add), + model_id, + ) + if workers_to_remove: + logger.info( + "Sglang router update: removing %d workers for model %s", + len(workers_to_remove), + model_id, + ) + + # Add workers + for worker_url in sorted(workers_to_add): + success = self._add_worker_to_router(worker_url, model_id) + if not success: + logger.warning("Failed to add worker %s, continuing with others", worker_url) + + # Remove workers + for worker_url in sorted(workers_to_remove): + success = self._remove_worker_from_router(worker_url) + if not success: + logger.warning( + "Failed to remove worker %s, continuing with others", worker_url + ) + + def _get_router_workers(self, model_id: str) -> List[dict]: + """Get all workers for a specific model_id from the router.""" + try: + result = subprocess.run( + ["curl", "-s", f"http://{self.context.host}:{self.context.port}/workers"], + capture_output=True, + timeout=5, + ) + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + workers = response.get("workers", []) + # Filter by model_id + workers = [w for w in workers if w.get("model_id") == model_id] + return workers + return [] + except Exception as e: + logger.error(f"Error getting sglang router workers: {e}") + return [] + + def _add_worker_to_router(self, worker_url: str, model_id: str) -> bool: + """Add a single worker to the router.""" + try: + payload = {"url": worker_url, "worker_type": "regular", "model_id": model_id} + result = subprocess.run( + [ + "curl", + "-X", + "POST", + f"http://{self.context.host}:{self.context.port}/workers", + "-H", + "Content-Type: application/json", + "-d", + json.dumps(payload), + ], + capture_output=True, + timeout=5, + ) + + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + if response.get("status") == "accepted": + logger.info("Added worker %s to sglang router", worker_url) + return True + else: + logger.error("Failed to add worker %s: %s", worker_url, response) + return False + else: + logger.error("Failed to add worker %s: %s", worker_url, result.stderr.decode()) + return False + except Exception as e: + logger.error(f"Error adding worker {worker_url}: {e}") + return False + + def _remove_worker_from_router(self, worker_url: str) -> bool: + """Remove a single worker from the router.""" + try: + # URL encode the worker URL for the DELETE request + encoded_url = urllib.parse.quote(worker_url, safe="") + + result = subprocess.run( + [ + "curl", + "-X", + "DELETE", + f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}", + ], + capture_output=True, + timeout=5, + ) + + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + if response.get("status") == "accepted": + logger.info("Removed worker %s from sglang router", worker_url) + return True + else: + logger.error("Failed to remove worker %s: %s", worker_url, response) + return False + else: + logger.error("Failed to remove worker %s: %s", worker_url, result.stderr.decode()) + return False + except Exception as e: + logger.error(f"Error removing worker {worker_url}: {e}") + return False diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index b096fa80e..f9ac27085 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { + {% if router_config is not none %} + server 127.0.0.1:3000; # SGLang router on the gateway + {% else %} {% for replica in replicas %} server unix:{{ replica.socket }}; # replica {{ replica.id }} {% endfor %} + {% endif %} } {% else %} diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 new file mode 100644 index 000000000..e5ffe12ee --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 @@ -0,0 +1,23 @@ +{% for replica in replicas %} +# Worker {{ loop.index }} +upstream sglang_worker_{{ ports[loop.index0] }}_upstream { + server unix:{{ replica.socket }}; +} + +server { + listen 127.0.0.1:{{ ports[loop.index0] }}; + access_log off; # disable access logs for this internal endpoint + + proxy_read_timeout 300s; + proxy_send_timeout 300s; + + location / { + proxy_pass http://sglang_worker_{{ ports[loop.index0] }}_upstream; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header Connection ""; + proxy_set_header Upgrade $http_upgrade; + } +} +{% endfor %} diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index e1bfa4ff2..faccdb257 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -36,6 +36,7 @@ async def register_service( model=body.options.openai.model if body.options.openai is not None else None, ssh_private_key=body.ssh_private_key, repo=repo, + router_config=body.router_config, nginx=nginx, service_conn_pool=service_conn_pool, ) diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 8ab69b6af..8ee9a41ea 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.lib.models import RateLimit @@ -44,6 +45,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () + router_config: Optional[AnyRouterConfig] = None class RegisterReplicaRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 2d3e755ac..dffac7a59 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -9,7 +9,13 @@ from pydantic import BaseModel from typing_extensions import Literal +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY +from dstack._internal.proxy.gateway.model_routers import ( + Router, + RouterContext, + get_router, +) from dstack._internal.proxy.gateway.models import ACMESettings from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async @@ -64,6 +70,8 @@ class ServiceConfig(SiteConfig): limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] replicas: list[ReplicaConfig] + router_config: Optional[AnyRouterConfig] = None + model_id: Optional[str] = None class ModelEntrypointConfig(SiteConfig): @@ -77,16 +85,47 @@ class Nginx: def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._conf_dir = conf_dir self._lock: Lock = Lock() + self._router: Optional[Router] = None async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.debug("Registering %s domain %s", conf.type, conf.domain) conf_name = self.get_config_name(conf.domain) - async with self._lock: if conf.https: await run_async(self.run_certbot, conf.domain, acme) await run_async(self.write_conf, conf.render(), conf_name) + if isinstance(conf, ServiceConfig) and conf.router_config and conf.model_id: + if self._router is None: + ctx = RouterContext( + host="127.0.0.1", + port=3000, + log_dir=Path("./router_logs"), + log_level="info", + ) + self._router = get_router(conf.router_config, context=ctx) + if not await run_async(self._router.is_running): + await run_async(self._router.start) + + replicas = await run_async( + self._router.register_replicas, + conf.domain, + len(conf.replicas), + conf.model_id, + ) + + allocated_ports = [int(r.url.rsplit(":", 1)[-1]) for r in replicas] + try: + await run_async(self.write_router_workers_conf, conf, allocated_ports) + except Exception as e: + logger.exception( + "write_router_workers_conf failed for domain=%s: %s", conf.domain, e + ) + raise + finally: + # Always update router state, regardless of nginx reload status + await run_async(self._router.update_replicas, replicas) + logger.info("Registered %s domain %s", conf.type, conf.domain) async def unregister(self, domain: str) -> None: @@ -96,6 +135,16 @@ async def unregister(self, domain: str) -> None: return async with self._lock: await run_async(sudo_rm, conf_path) + # Generic router implementation + if self._router is not None: + # Unregister replicas for this domain (router handles domain-to-model_id lookup) + await run_async(self._router.unregister_replicas, domain) + + # Remove workers config file (router-specific naming) + workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" + if workers_conf_path.exists(): + await run_async(sudo_rm, workers_conf_path) + await run_async(self.reload) logger.info("Unregistered domain %s", domain) @@ -168,6 +217,25 @@ def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") + def write_router_workers_conf(self, conf: ServiceConfig, allocated_ports: list[int]) -> None: + """Write router workers configuration file (generic).""" + # Pass ports to template + workers_config = generate_router_workers_config(conf, allocated_ports) + workers_conf_name = f"router-workers.{conf.domain}.conf" + workers_conf_path = self._conf_dir / workers_conf_name + sudo_write(workers_conf_path, workers_config) + self.reload() + + +def generate_router_workers_config(conf: ServiceConfig, allocated_ports: list[int]) -> str: + """Generate router workers configuration (generic, uses sglang_workers.jinja2 template).""" + template = read_package_resource("sglang_workers.jinja2") + return jinja2.Template(template).render( + replicas=conf.replicas, + ports=allocated_ports, + proxy_port=PROXY_PORT_ON_GATEWAY, + ) + def read_package_resource(file: str) -> str: return ( diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 3ea412d79..4a717a790 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -6,6 +6,7 @@ import dstack._internal.proxy.gateway.schemas.registry as schemas from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.gateway import models as gateway_models from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.services.nginx import ( @@ -44,6 +45,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + router_config: Optional[AnyRouterConfig] = None, ) -> None: service = models.Service( project_name=project_name, @@ -54,6 +56,8 @@ async def register_service( auth=auth, client_max_body_size=client_max_body_size, replicas=(), + router_config=router_config, + model_id=model.name if model is not None else None, ) async with lock: @@ -335,6 +339,8 @@ async def get_nginx_service_config( limit_req_zones=limit_req_zones, locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs + router_config=service.router_config, + model_id=service.model_id, ) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 5cb5471d8..e7fe0ebe4 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -7,6 +7,7 @@ from typing_extensions import Annotated from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError @@ -57,6 +58,8 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] + router_config: Optional[AnyRouterConfig] = None + model_id: Optional[str] = None @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index f47b19299..d2c654075 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -108,6 +108,7 @@ async def create_gateway_compute( ssh_key_pub=gateway_ssh_public_key, certificate=configuration.certificate, tags=configuration.tags, + router_config=configuration.router_config, ) gpd = await run_async( @@ -449,11 +450,24 @@ async def _update_gateway(gateway_compute_model: GatewayComputeModel, build: str gateway_compute_model.ssh_private_key, ) logger.debug("Updating gateway %s", connection.ip_address) + compute_config = GatewayComputeConfiguration.__response__.parse_raw( + gateway_compute_model.configuration + ) + + # Determine gateway package with router extras (similar to compute.py) + if compute_config.router_config: + gateway_package = f"dstack-gateway[{compute_config.router_config.type}]" + else: + gateway_package = "dstack-gateway" + commands = [ # prevent update.sh from overwriting itself during execution "cp dstack/update.sh dstack/_update.sh", f"sh dstack/_update.sh {get_dstack_gateway_wheel(build)} {build}", "rm dstack/_update.sh", + # Install gateway package with router extras to the active venv (blue or green) + # update.sh writes the active version to dstack/version + f"version=$(cat /home/ubuntu/dstack/version) && /home/ubuntu/dstack/$version/bin/pip install --upgrade '{gateway_package}'", ] stdout = await connection.tunnel.aexec("/bin/sh -c '" + " && ".join(commands) + "'") if "Update successfully completed" in stdout: diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index f8c090079..bc77f9fc5 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -9,6 +9,7 @@ from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.configurations import RateLimit from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port from dstack._internal.proxy.gateway.schemas.stats import ServiceStats from dstack._internal.server import settings @@ -45,6 +46,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, + router_config: Optional[AnyRouterConfig] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" @@ -59,6 +61,7 @@ async def register_service( "options": options, "rate_limits": [limit.dict() for limit in rate_limits], "ssh_private_key": ssh_private_key, + "router_config": router_config.dict() if router_config is not None else None, } resp = await self._client.post( self._url(f"/api/registry/{project}/services/register"), json=payload diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index a8089a93a..d08290865 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -82,6 +82,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) + router_config = gateway_configuration.router_config service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -119,6 +120,7 @@ async def _register_service_in_gateway( options=service_spec.options, rate_limits=run_spec.configuration.rate_limits, ssh_private_key=run_model.project.ssh_private_key, + router_config=router_config, ) logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) except SSHError: diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 996157350..83a55e7d1 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -70,6 +70,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router_config": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -121,6 +122,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router_config": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -201,6 +203,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "name": "test", "backend": backend.type.value, "region": "us", + "router_config": None, "domain": None, "default": True, "public_ip": True, @@ -253,6 +256,7 @@ async def test_create_gateway_without_name( "name": "random-name", "backend": backend.type.value, "region": "us", + "router_config": None, "domain": None, "default": True, "public_ip": True, @@ -355,6 +359,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router_config": None, "domain": gateway.wildcard_domain, "default": True, "public_ip": True, @@ -477,6 +482,7 @@ def get_backend(project, backend_type): "name": gateway_gcp.name, "backend": backend_gcp.type.value, "region": gateway_gcp.region, + "router_config": None, "domain": gateway_gcp.wildcard_domain, "default": False, "public_ip": True, @@ -546,6 +552,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router_config": None, "domain": "test.com", "default": False, "public_ip": True, From f1911327c32f32d98bac7615f38712b8f6f60584 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 7 Nov 2025 10:24:04 +0545 Subject: [PATCH 02/11] Rename router_config to router --- src/dstack/_internal/core/backends/aws/compute.py | 2 +- .../_internal/core/backends/azure/compute.py | 2 +- src/dstack/_internal/core/backends/base/compute.py | 12 +++++------- src/dstack/_internal/core/backends/gcp/compute.py | 2 +- src/dstack/_internal/core/models/gateways.py | 4 ++-- .../proxy/gateway/model_routers/__init__.py | 8 ++++---- .../_internal/proxy/gateway/model_routers/base.py | 4 ++-- .../proxy/gateway/model_routers/sglang.py | 8 ++++---- .../proxy/gateway/resources/nginx/service.jinja2 | 2 +- .../_internal/proxy/gateway/routers/registry.py | 2 +- .../_internal/proxy/gateway/schemas/registry.py | 2 +- .../_internal/proxy/gateway/services/nginx.py | 6 +++--- .../_internal/proxy/gateway/services/registry.py | 6 +++--- src/dstack/_internal/proxy/lib/models.py | 2 +- .../_internal/server/services/gateways/__init__.py | 6 +++--- .../_internal/server/services/gateways/client.py | 4 ++-- .../_internal/server/services/services/__init__.py | 4 ++-- .../_internal/server/routers/test_gateways.py | 14 +++++++------- 18 files changed, 44 insertions(+), 46 deletions(-) diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 39141bd10..67a7e409d 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -461,7 +461,7 @@ def create_gateway( instance_type="t3.micro", iam_instance_profile=None, user_data=get_gateway_user_data( - configuration.ssh_key_pub, router_config=configuration.router_config + configuration.ssh_key_pub, router=configuration.router ), tags=tags, security_group_id=security_group_id, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 357bfd0a9..8573ec99a 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -278,7 +278,7 @@ def create_gateway( vm_size="Standard_B1ms", instance_name=instance_name, user_data=get_gateway_user_data( - configuration.ssh_key_pub, router_config=configuration.router_config + configuration.ssh_key_pub, router=configuration.router ), ssh_pub_keys=[configuration.ssh_key_pub], spot=False, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 3fd0a5b6d..4f92912a8 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -877,9 +877,7 @@ def get_run_shim_script( ] -def get_gateway_user_data( - authorized_key: str, router_config: Optional[AnyRouterConfig] = None -) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -895,7 +893,7 @@ def get_gateway_user_data( "s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/", "/etc/nginx/nginx.conf", ], - ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router_config))], + ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router))], ], ssh_authorized_keys=[authorized_key], ) @@ -1028,12 +1026,12 @@ def get_dstack_gateway_wheel(build: str) -> str: return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" -def get_dstack_gateway_commands(router_config: Optional[AnyRouterConfig] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() wheel = get_dstack_gateway_wheel(build) # Use router type directly as pip extra - if router_config: - gateway_package = f"dstack-gateway[{router_config.type}]" + if router: + gateway_package = f"dstack-gateway[{router.type}]" else: gateway_package = "dstack-gateway" return [ diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index a333d415f..8e5c36fcc 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -600,7 +600,7 @@ def create_gateway( accelerators=[], spot=False, user_data=get_gateway_user_data( - configuration.ssh_key_pub, router_config=configuration.router_config + configuration.ssh_key_pub, router=configuration.router ), authorized_keys=[configuration.ssh_key_pub], labels=labels, diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 2cab4df16..39befe739 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -51,7 +51,7 @@ class GatewayConfiguration(CoreModel): default: Annotated[bool, Field(description="Make the gateway default")] = False backend: Annotated[BackendType, Field(description="The gateway backend")] region: Annotated[str, Field(description="The gateway region")] - router_config: Annotated[ + router: Annotated[ Optional[AnyRouterConfig], Field(description="The router configuration"), ] = None @@ -118,7 +118,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router_config: Optional[AnyRouterConfig] = None + router: Optional[AnyRouterConfig] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py index 487a78e04..008ec0ad8 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py @@ -37,9 +37,9 @@ def get_router_class(router_type: RouterType) -> Optional[Type[Router]]: return _ROUTER_TYPE_TO_CLASS_MAP.get(router_type) -def get_router(router_config: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: +def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: """Factory function to create a router instance from router configuration.""" - router_type = RouterType(router_config.type) + router_type = RouterType(router.type) router_class = get_router_class(router_type) if router_class is None: @@ -50,8 +50,8 @@ def get_router(router_config: AnyRouterConfig, context: Optional[RouterContext] ) # Router implementations may have different constructor signatures - # SglangRouter takes (router_config, context), others might differ - return router_class(router_config=router_config, context=context) + # SglangRouter takes (router, context), others might differ + return router_class(router=router, context=context) __all__ = [ diff --git a/src/dstack/_internal/proxy/gateway/model_routers/base.py b/src/dstack/_internal/proxy/gateway/model_routers/base.py index 23c3a72b3..fbd0e318d 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/base.py @@ -40,13 +40,13 @@ class Router(ABC): def __init__( self, - router_config: Optional[AnyRouterConfig] = None, + router: Optional[AnyRouterConfig] = None, context: Optional[RouterContext] = None, ): """Initialize router with context. Args: - router_config: Optional router configuration (implementation-specific) + router: Optional router configuration (implementation-specific) context: Runtime context for the router (host, port, logging, etc.) """ self.context = context or RouterContext() diff --git a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py index 048623ace..2d5e9e18b 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py @@ -20,15 +20,15 @@ class SglangRouter(Router): TYPE = "sglang" - def __init__(self, router_config: SGLangRouterConfig, context: Optional[RouterContext] = None): + def __init__(self, router: SGLangRouterConfig, context: Optional[RouterContext] = None): """Initialize SGLang router. Args: - router_config: SGLang router configuration (policy, cache_threshold, etc.) + router: SGLang router configuration (policy, cache_threshold, etc.) context: Runtime context for the router (host, port, logging, etc.) """ - super().__init__(router_config=router_config, context=context) - self.config = router_config + super().__init__(router=router, context=context) + self.config = router self._domain_to_model_id: Dict[str, str] = {} # domain -> model_id self._domain_to_ports: Dict[ str, List[int] diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index f9ac27085..d2e2555ec 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -4,7 +4,7 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { - {% if router_config is not none %} + {% if router is not none %} server 127.0.0.1:3000; # SGLang router on the gateway {% else %} {% for replica in replicas %} diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index faccdb257..dd4f63f32 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -36,7 +36,7 @@ async def register_service( model=body.options.openai.model if body.options.openai is not None else None, ssh_private_key=body.ssh_private_key, repo=repo, - router_config=body.router_config, + router=body.router, nginx=nginx, service_conn_pool=service_conn_pool, ) diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 8ee9a41ea..53a29f68c 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -45,7 +45,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () - router_config: Optional[AnyRouterConfig] = None + router: Optional[AnyRouterConfig] = None class RegisterReplicaRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index dffac7a59..bf0a5586a 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -70,7 +70,7 @@ class ServiceConfig(SiteConfig): limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] replicas: list[ReplicaConfig] - router_config: Optional[AnyRouterConfig] = None + router: Optional[AnyRouterConfig] = None model_id: Optional[str] = None @@ -95,7 +95,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: await run_async(self.run_certbot, conf.domain, acme) await run_async(self.write_conf, conf.render(), conf_name) - if isinstance(conf, ServiceConfig) and conf.router_config and conf.model_id: + if isinstance(conf, ServiceConfig) and conf.router and conf.model_id: if self._router is None: ctx = RouterContext( host="127.0.0.1", @@ -103,7 +103,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: log_dir=Path("./router_logs"), log_level="info", ) - self._router = get_router(conf.router_config, context=ctx) + self._router = get_router(conf.router, context=ctx) if not await run_async(self._router.is_running): await run_async(self._router.start) diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 4a717a790..0abc346c9 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -45,7 +45,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, - router_config: Optional[AnyRouterConfig] = None, + router: Optional[AnyRouterConfig] = None, ) -> None: service = models.Service( project_name=project_name, @@ -56,7 +56,7 @@ async def register_service( auth=auth, client_max_body_size=client_max_body_size, replicas=(), - router_config=router_config, + router=router, model_id=model.name if model is not None else None, ) @@ -339,7 +339,7 @@ async def get_nginx_service_config( limit_req_zones=limit_req_zones, locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs - router_config=service.router_config, + router=service.router, model_id=service.model_id, ) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index e7fe0ebe4..21890d233 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -58,7 +58,7 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] - router_config: Optional[AnyRouterConfig] = None + router: Optional[AnyRouterConfig] = None model_id: Optional[str] = None @property diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index d2c654075..08380b240 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -108,7 +108,7 @@ async def create_gateway_compute( ssh_key_pub=gateway_ssh_public_key, certificate=configuration.certificate, tags=configuration.tags, - router_config=configuration.router_config, + router=configuration.router, ) gpd = await run_async( @@ -455,8 +455,8 @@ async def _update_gateway(gateway_compute_model: GatewayComputeModel, build: str ) # Determine gateway package with router extras (similar to compute.py) - if compute_config.router_config: - gateway_package = f"dstack-gateway[{compute_config.router_config.type}]" + if compute_config.router: + gateway_package = f"dstack-gateway[{compute_config.router.type}]" else: gateway_package = "dstack-gateway" diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index bc77f9fc5..d4f1c831e 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -46,7 +46,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, - router_config: Optional[AnyRouterConfig] = None, + router: Optional[AnyRouterConfig] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" @@ -61,7 +61,7 @@ async def register_service( "options": options, "rate_limits": [limit.dict() for limit in rate_limits], "ssh_private_key": ssh_private_key, - "router_config": router_config.dict() if router_config is not None else None, + "router": router.dict() if router is not None else None, } resp = await self._client.post( self._url(f"/api/registry/{project}/services/register"), json=payload diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index d08290865..05c1fa909 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -82,7 +82,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) - router_config = gateway_configuration.router_config + router = gateway_configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -120,7 +120,7 @@ async def _register_service_in_gateway( options=service_spec.options, rate_limits=run_spec.configuration.rate_limits, ssh_private_key=run_model.project.ssh_private_key, - router_config=router_config, + router=router, ) logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) except SSHError: diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 83a55e7d1..6d06e6969 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -70,7 +70,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, - "router_config": None, + "router": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -122,7 +122,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, - "router_config": None, + "router": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -203,7 +203,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "name": "test", "backend": backend.type.value, "region": "us", - "router_config": None, + "router": None, "domain": None, "default": True, "public_ip": True, @@ -256,7 +256,7 @@ async def test_create_gateway_without_name( "name": "random-name", "backend": backend.type.value, "region": "us", - "router_config": None, + "router": None, "domain": None, "default": True, "public_ip": True, @@ -359,7 +359,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, - "router_config": None, + "router": None, "domain": gateway.wildcard_domain, "default": True, "public_ip": True, @@ -482,7 +482,7 @@ def get_backend(project, backend_type): "name": gateway_gcp.name, "backend": backend_gcp.type.value, "region": gateway_gcp.region, - "router_config": None, + "router": None, "domain": gateway_gcp.wildcard_domain, "default": False, "public_ip": True, @@ -552,7 +552,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, - "router_config": None, + "router": None, "domain": "test.com", "default": False, "public_ip": True, From 23d3e69ae0debd8db730e91b7f4637048b2c03e7 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 7 Nov 2025 17:40:15 +0545 Subject: [PATCH 03/11] Rename sglang_workers.jinja2 to router_workers.jinja2 --- .../nginx/{sglang_workers.jinja2 => router_workers.jinja2} | 4 ++-- src/dstack/_internal/proxy/gateway/services/nginx.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) rename src/dstack/_internal/proxy/gateway/resources/nginx/{sglang_workers.jinja2 => router_workers.jinja2} (72%) diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/router_workers.jinja2 similarity index 72% rename from src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 rename to src/dstack/_internal/proxy/gateway/resources/nginx/router_workers.jinja2 index e5ffe12ee..3af7ea612 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/router_workers.jinja2 @@ -1,6 +1,6 @@ {% for replica in replicas %} # Worker {{ loop.index }} -upstream sglang_worker_{{ ports[loop.index0] }}_upstream { +upstream router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream { server unix:{{ replica.socket }}; } @@ -12,7 +12,7 @@ server { proxy_send_timeout 300s; location / { - proxy_pass http://sglang_worker_{{ ports[loop.index0] }}_upstream; + proxy_pass http://router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream; proxy_http_version 1.1; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index bf0a5586a..29687d396 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -228,9 +228,10 @@ def write_router_workers_conf(self, conf: ServiceConfig, allocated_ports: list[i def generate_router_workers_config(conf: ServiceConfig, allocated_ports: list[int]) -> str: - """Generate router workers configuration (generic, uses sglang_workers.jinja2 template).""" - template = read_package_resource("sglang_workers.jinja2") + """Generate router workers configuration (generic, uses router_workers.jinja2 template).""" + template = read_package_resource("router_workers.jinja2") return jinja2.Template(template).render( + domain=conf.domain, replicas=conf.replicas, ports=allocated_ports, proxy_port=PROXY_PORT_ON_GATEWAY, From e6c2bcb752c61a33e44ddb741bd5418d46b00b26 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 13 Nov 2025 17:46:02 +0545 Subject: [PATCH 04/11] Resolve SGLang API expose issue --- .../proxy/gateway/model_routers/__init__.py | 53 +++---------------- .../gateway/resources/nginx/service.jinja2 | 23 ++++++++ 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py index 008ec0ad8..396d70583 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py @@ -1,57 +1,16 @@ -from typing import Dict, List, Optional, Type +from typing import Optional -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType -from dstack._internal.utils.logging import get_logger +from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter from .base import Replica, Router, RouterContext -logger = get_logger(__name__) - -"""This provides a registry of available router implementations.""" - -_ROUTER_CLASSES: List[Type[Router]] = [] - -try: - from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter - - _ROUTER_CLASSES.append(SglangRouter) - logger.debug("Registered SglangRouter") -except ImportError as e: - logger.warning("SGLang router not available: %s", e) - -_ROUTER_TYPE_TO_CLASS_MAP: Dict[RouterType, Type[Router]] = {} - -for router_class in _ROUTER_CLASSES: - router_type_str = getattr(router_class, "TYPE", None) - if router_type_str is None: - logger.warning(f"Router class {router_class.__name__} missing TYPE attribute, skipping") - continue - router_type = RouterType(router_type_str) - _ROUTER_TYPE_TO_CLASS_MAP[router_type] = router_class - -_AVAILABLE_ROUTER_TYPES = list(_ROUTER_TYPE_TO_CLASS_MAP.keys()) - - -def get_router_class(router_type: RouterType) -> Optional[Type[Router]]: - """Get the router class for a given router type.""" - return _ROUTER_TYPE_TO_CLASS_MAP.get(router_type) - def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: """Factory function to create a router instance from router configuration.""" - router_type = RouterType(router.type) - router_class = get_router_class(router_type) - - if router_class is None: - available_types = [rt.value for rt in _AVAILABLE_ROUTER_TYPES] - raise ValueError( - f"Router type '{router_type.value}' is not available. " - f"Available types: {available_types}" - ) - - # Router implementations may have different constructor signatures - # SglangRouter takes (router, context), others might differ - return router_class(router=router, context=context) + if router.type == "sglang": + return SglangRouter(router=router, context=context) + raise ValueError(f"Router type '{router.type}' is not available") __all__ = [ diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index d2e2555ec..494804c50 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -22,6 +22,28 @@ server { access_log {{ access_log_path }} dstack_stat; client_max_body_size {{ client_max_body_size }}; + {% if router is not none and router.type == "sglang" %} + # Whitelist approach: Only allow safe user-facing endpoints + # Allow /generate + location = /generate { + {% if auth %} + auth_request /_dstack_auth; + {% endif %} + try_files /nonexistent @$http_upgrade; + } + # Allow /v1/* endpoints (regex match) + location ~ ^/v1/ { + {% if auth %} + auth_request /_dstack_auth; + {% endif %} + try_files /nonexistent @$http_upgrade; + } + # Block everything else + # This will match any path that doesn't match /generate or /v1/* + location ~ . { + return 403; + } + {% else %} {% for location in locations %} location {{ location.prefix }} { {% if auth %} @@ -35,6 +57,7 @@ server { {% endif %} } {% endfor %} + {% endif %} location @websocket { set $dstack_replica_hit 1; From 178d5a5d4d84cd0085a20108a48361af98f68220 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 13 Nov 2025 20:44:24 +0545 Subject: [PATCH 05/11] Resolve model_id based single router to multi router Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation --- .../proxy/gateway/model_routers/__init__.py | 3 +- .../proxy/gateway/model_routers/base.py | 66 +--- .../proxy/gateway/model_routers/sglang.py | 291 ++++++------------ .../gateway/resources/nginx/service.jinja2 | 4 +- .../_internal/proxy/gateway/services/nginx.py | 283 ++++++++++++++--- 5 files changed, 347 insertions(+), 300 deletions(-) diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py index 396d70583..59e40dcb9 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py @@ -3,7 +3,7 @@ from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter -from .base import Replica, Router, RouterContext +from .base import Router, RouterContext def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: @@ -16,6 +16,5 @@ def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) __all__ = [ "Router", "RouterContext", - "Replica", "get_router", ] diff --git a/src/dstack/_internal/proxy/gateway/model_routers/base.py b/src/dstack/_internal/proxy/gateway/model_routers/base.py index fbd0e318d..93ef0b7d9 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/base.py @@ -19,20 +19,8 @@ class Config: log_level: Literal["debug", "info", "warning", "error"] = "info" -class Replica(BaseModel): - """Represents a single replica (worker) endpoint managed by the router. - - The model field identifies which model this replica serves. - In SGLang, model = model_id (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct"). - """ - - url: str # HTTP URL where the replica is accessible (e.g., "http://127.0.0.1:10001") - model: str # (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct") - - class Router(ABC): - """Abstract base class for router implementations (e.g., SGLang, vLLM). - + """Abstract base class for router implementations. A router manages the lifecycle of worker replicas and handles request routing. Different router implementations may have different mechanisms for managing replicas. @@ -79,55 +67,11 @@ def is_running(self) -> bool: ... @abstractmethod - def register_replicas( - self, domain: str, num_replicas: int, model_id: Optional[str] = None - ) -> List[Replica]: - """Register replicas to a domain (allocate ports/URLs for workers). - - Args: - domain: The domain name for this service. - num_replicas: The number of replicas to allocate for this domain. - model_id: Optional model identifier (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct"). - Required only for routers that support IGW (Inference Gateway) mode for multi-model serving. - - Returns: - List of Replica objects with allocated URLs and model_id set (if provided). - - Raises: - Exception: If allocation fails. - """ - ... - - @abstractmethod - def unregister_replicas(self, domain: str) -> None: - """Unregister replicas for a domain (remove model and unassign all its replicas). - - Args: - domain: The domain name for this service. - - Raises: - Exception: If removal fails or domain is not found. - """ - ... - - @abstractmethod - def add_replicas(self, replicas: List[Replica]) -> None: - """Register replicas with the router (actual API calls to add workers). - - Args: - replicas: The list of replicas to add to router. - - Raises: - Exception: If adding replicas fails. - """ - ... - - @abstractmethod - def remove_replicas(self, replicas: List[Replica]) -> None: + def remove_replicas(self, replica_urls: List[str]) -> None: """Unregister replicas from the router (actual API calls to remove workers). Args: - replicas: The list of replicas to remove from router. + replica_urls: The list of replica URLs to remove from router. Raises: Exception: If removing replicas fails. @@ -135,11 +79,11 @@ def remove_replicas(self, replicas: List[Replica]) -> None: ... @abstractmethod - def update_replicas(self, replicas: List[Replica]) -> None: + def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set. Args: - replicas: The new list of replicas for this service. + replica_urls: The new list of replica URLs for this service. Raises: Exception: If updating replicas fails. diff --git a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py index 2d5e9e18b..c0aa0f1b9 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py @@ -3,20 +3,19 @@ import subprocess import time import urllib.parse -from collections import defaultdict -from typing import DefaultDict, Dict, List, Optional +from typing import List, Optional from dstack._internal.core.models.routers import SGLangRouterConfig from dstack._internal.proxy.gateway.const import DSTACK_DIR_ON_GATEWAY from dstack._internal.utils.logging import get_logger -from .base import Replica, Router, RouterContext +from .base import Router, RouterContext logger = get_logger(__name__) class SglangRouter(Router): - """SGLang router implementation using IGW (Inference Gateway) mode for multi-model serving.""" + """SGLang router implementation with 1:1 service-to-router.""" TYPE = "sglang" @@ -29,16 +28,10 @@ def __init__(self, router: SGLangRouterConfig, context: Optional[RouterContext] """ super().__init__(router=router, context=context) self.config = router - self._domain_to_model_id: Dict[str, str] = {} # domain -> model_id - self._domain_to_ports: Dict[ - str, List[int] - ] = {} # domain -> allocated sglang worker ports. - self._next_worker_port: int = 10001 # Starting port for worker endpoints def start(self) -> None: - """Start the SGLang router process.""" try: - logger.info("Starting sglang-router...") + logger.info("Starting sglang-router-new on port %s...", self.context.port) # Determine active venv (blue or green) version_file = DSTACK_DIR_ON_GATEWAY / "version" @@ -47,18 +40,20 @@ def start(self) -> None: else: version = "blue" - # Use Python from the active venv venv_python = DSTACK_DIR_ON_GATEWAY / version / "bin" / "python3" + prometheus_port = self.context.port + 10000 + cmd = [ str(venv_python), "-m", "sglang_router.launch_router", "--host", - "0.0.0.0", # Bind to all interfaces (nginx connects via 127.0.0.1) + "0.0.0.0", "--port", str(self.context.port), - "--enable-igw", + "--prometheus-port", + str(prometheus_port), "--log-level", self.context.log_level, "--log-dir", @@ -68,48 +63,68 @@ def start(self) -> None: if hasattr(self.config, "policy") and self.config.policy: cmd.extend(["--policy", self.config.policy]) - # Add additional required configs here - subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - # Wait for router to start time.sleep(2) - # Verify router is running if not self.is_running(): - raise Exception("Failed to start sglang router") + raise Exception(f"Failed to start sglang router on port {self.context.port}") - logger.info("Sglang router started successfully") + logger.info( + "Sglang router started successfully on port %s (prometheus on %s)", + self.context.port, + prometheus_port, + ) except Exception as e: - logger.error(f"Failed to start sglang-router: {e}") + logger.error(f"Failed to start sglang-router-new: {e}") raise def stop(self) -> None: - """Stop the SGLang router process.""" try: result = subprocess.run( - ["pgrep", "-f", "sglang::router"], capture_output=True, timeout=5 + ["lsof", "-ti", f":{self.context.port}"], capture_output=True, timeout=5 ) if result.returncode == 0: - logger.info("Stopping sglang-router process...") - subprocess.run(["pkill", "-f", "sglang::router"], timeout=5) + pids = result.stdout.decode().strip().split("\n") + for pid in pids: + if pid: + logger.info( + "Stopping sglang-router-new process (PID: %s) on port %s", + pid, + self.context.port, + ) + subprocess.run(["kill", pid], timeout=5) else: - logger.debug("No sglang-router process found to stop") + result = subprocess.run( + ["pgrep", "-f", f"sglang.*--port.*{self.context.port}"], + capture_output=True, + timeout=5, + ) + if result.returncode == 0: + pids = result.stdout.decode().strip().split("\n") + for pid in pids: + if pid: + logger.info("Stopping sglang-router-new process (PID: %s)", pid) + subprocess.run(["kill", pid], timeout=5) + else: + logger.debug( + "No sglang-router-new process found on port %s", self.context.port + ) # Clean up router logs if self.context.log_dir.exists(): - logger.debug("Cleaning up router logs...") + logger.debug("Cleaning up router logs for port %s...", self.context.port) shutil.rmtree(self.context.log_dir, ignore_errors=True) else: logger.debug("No router logs directory found to clean up") except Exception as e: - logger.error(f"Failed to stop sglang-router: {e}") + logger.error(f"Failed to stop sglang-router-new: {e}") raise def is_running(self) -> bool: - """Check if the SGLang router is running and responding to HTTP requests.""" + """Check if the SGLang router is running and responding to HTTP requests on the assigned port.""" try: result = subprocess.run( ["curl", "-s", f"http://{self.context.host}:{self.context.port}/workers"], @@ -118,170 +133,55 @@ def is_running(self) -> bool: ) return result.returncode == 0 except Exception as e: - logger.error(f"Error checking sglang router status: {e}") + logger.error(f"Error checking sglang router status on port {self.context.port}: {e}") return False - def is_model_registered(self, model_id: str) -> bool: - """Check if a model with the given model_id is registered.""" - return model_id in self._domain_to_model_id.values() - - def register_replicas( - self, domain: str, num_replicas: int, model_id: Optional[str] = None - ) -> List[Replica]: - """Register replicas to a domain (allocate ports/URLs for workers). - SGLang router uses IGW (Inference Gateway) mode, which requires model_id for multi-model serving. - - Maintains in-memory state: - - domain_to_model_id: Maps domain to model_id for unregistering by domain and model validation. - - domain_to_ports: Maps domain to allocated ports to track port assignments and avoid conflicts. - - Args: - domain: The domain name for this service. - num_replicas: The number of replicas to allocate for this domain. - model_id: Model identifier (required for SGLang IGW mode). - - Raises: - ValueError: If model_id is None (required for SGLang IGW mode). - """ - if model_id is None: - raise ValueError("model_id is required for SGLang router (IGW mode)") + def remove_replicas(self, replica_urls: List[str]) -> None: + for replica_url in replica_urls: + self._remove_worker_from_router(replica_url) - is_new_model = not self.is_model_registered(model_id) - - if is_new_model: - # Store domain -> model_id mapping - self._domain_to_model_id[domain] = model_id - - # Allocate ports for replicas - allocated_ports = [] - for _ in range(num_replicas): - allocated_ports.append(self._next_worker_port) - self._next_worker_port += 1 - - self._domain_to_ports[domain] = allocated_ports - - logger.debug( - f"Allocated model {model_id} (domain {domain}) with {num_replicas} replicas " - f"on ports {allocated_ports}" + def update_replicas(self, replica_urls: List[str]) -> None: + """Update replicas for service, replacing the current set.""" + # Query router to get current worker URLs + current_workers = self._get_router_workers() + current_worker_urls: set[str] = set() + for worker in current_workers: + url = worker.get("url") + if url and isinstance(url, str): + current_worker_urls.add(url) + target_worker_urls = set(replica_urls) + + # Workers to add + workers_to_add = target_worker_urls - current_worker_urls + # Workers to remove + workers_to_remove = current_worker_urls - target_worker_urls + + if workers_to_add: + logger.info( + "Sglang router update: adding %d workers for router on port %s", + len(workers_to_add), + self.context.port, ) - else: - # Verify domain matches - if self._domain_to_model_id.get(domain) != model_id: - raise ValueError(f"Domain {domain} does not match model_id {model_id}") - - # Get current allocated ports - current_ports = self._domain_to_ports.get(domain, []) - current_count = len(current_ports) - - if num_replicas == current_count: - # No change needed, return existing replicas - replicas = [ - Replica(url=f"http://{self.context.host}:{port}", model=model_id) - for port in current_ports - ] - return replicas - - # Re-allocate ports for new count - allocated_ports = [] - for _ in range(num_replicas): - allocated_ports.append(self._next_worker_port) - self._next_worker_port += 1 - - self._domain_to_ports[domain] = allocated_ports - - logger.debug( - f"Updated model {model_id} (domain {domain}) with {num_replicas} replicas " - f"on ports {allocated_ports}" + if workers_to_remove: + logger.info( + "Sglang router update: removing %d workers for router on port %s", + len(workers_to_remove), + self.context.port, ) - # Create Replica objects with URLs and model_id - replicas = [ - Replica(url=f"http://{self.context.host}:{port}", model=model_id) - for port in allocated_ports - ] - return replicas - - def unregister_replicas(self, domain: str) -> None: - """Unregister replicas for a domain (remove model and unassign all its replicas).""" - # Get model_id from domain mapping - model_id = self._domain_to_model_id.get(domain) - if model_id is None: - logger.warning(f"Domain {domain} not found in router mapping, skipping unregister") - return - - # Remove all workers for this model_id from the router - current_workers = self._get_router_workers(model_id) - for worker in current_workers: - self._remove_worker_from_router(worker["url"]) - - # Clean up internal state - if domain in self._domain_to_model_id: - del self._domain_to_model_id[domain] - if domain in self._domain_to_ports: - del self._domain_to_ports[domain] - - logger.debug(f"Removed model {model_id} (domain {domain})") - - def add_replicas(self, replicas: List[Replica]) -> None: - """Register replicas with the router (actual HTTP API calls to add workers).""" - for replica in replicas: - self._add_worker_to_router(replica.url, replica.model) - - def remove_replicas(self, replicas: List[Replica]) -> None: - """Unregister replicas from the router (actual HTTP API calls to remove workers).""" - for replica in replicas: - self._remove_worker_from_router(replica.url) - - def update_replicas(self, replicas: List[Replica]) -> None: - """Update replicas for a model, replacing the current set.""" - # Group replicas by model_id - replicas_by_model: DefaultDict[str, List[Replica]] = defaultdict(list) - for replica in replicas: - replicas_by_model[replica.model].append(replica) - - # Update each model separately - for model_id, model_replicas in replicas_by_model.items(): - # Get current workers for this model_id - current_workers = self._get_router_workers(model_id) - current_worker_urls = {worker["url"] for worker in current_workers} - - # Calculate target worker URLs - target_worker_urls = {replica.url for replica in model_replicas} - - # Workers to add - workers_to_add = target_worker_urls - current_worker_urls - # Workers to remove - workers_to_remove = current_worker_urls - target_worker_urls - - if workers_to_add: - logger.info( - "Sglang router update: adding %d workers for model %s", - len(workers_to_add), - model_id, - ) - if workers_to_remove: - logger.info( - "Sglang router update: removing %d workers for model %s", - len(workers_to_remove), - model_id, - ) + # Add workers + for worker_url in sorted(workers_to_add): + success = self._add_worker_to_router(worker_url) + if not success: + logger.warning("Failed to add worker %s, continuing with others", worker_url) - # Add workers - for worker_url in sorted(workers_to_add): - success = self._add_worker_to_router(worker_url, model_id) - if not success: - logger.warning("Failed to add worker %s, continuing with others", worker_url) - - # Remove workers - for worker_url in sorted(workers_to_remove): - success = self._remove_worker_from_router(worker_url) - if not success: - logger.warning( - "Failed to remove worker %s, continuing with others", worker_url - ) + # Remove workers + for worker_url in sorted(workers_to_remove): + success = self._remove_worker_from_router(worker_url) + if not success: + logger.warning("Failed to remove worker %s, continuing with others", worker_url) - def _get_router_workers(self, model_id: str) -> List[dict]: - """Get all workers for a specific model_id from the router.""" + def _get_router_workers(self) -> List[dict]: try: result = subprocess.run( ["curl", "-s", f"http://{self.context.host}:{self.context.port}/workers"], @@ -291,18 +191,15 @@ def _get_router_workers(self, model_id: str) -> List[dict]: if result.returncode == 0: response = json.loads(result.stdout.decode()) workers = response.get("workers", []) - # Filter by model_id - workers = [w for w in workers if w.get("model_id") == model_id] return workers return [] except Exception as e: logger.error(f"Error getting sglang router workers: {e}") return [] - def _add_worker_to_router(self, worker_url: str, model_id: str) -> bool: - """Add a single worker to the router.""" + def _add_worker_to_router(self, worker_url: str) -> bool: try: - payload = {"url": worker_url, "worker_type": "regular", "model_id": model_id} + payload = {"url": worker_url, "worker_type": "regular"} result = subprocess.run( [ "curl", @@ -321,7 +218,11 @@ def _add_worker_to_router(self, worker_url: str, model_id: str) -> bool: if result.returncode == 0: response = json.loads(result.stdout.decode()) if response.get("status") == "accepted": - logger.info("Added worker %s to sglang router", worker_url) + logger.info( + "Added worker %s to sglang router on port %s", + worker_url, + self.context.port, + ) return True else: logger.error("Failed to add worker %s: %s", worker_url, response) @@ -334,9 +235,7 @@ def _add_worker_to_router(self, worker_url: str, model_id: str) -> bool: return False def _remove_worker_from_router(self, worker_url: str) -> bool: - """Remove a single worker from the router.""" try: - # URL encode the worker URL for the DELETE request encoded_url = urllib.parse.quote(worker_url, safe="") result = subprocess.run( @@ -353,7 +252,11 @@ def _remove_worker_from_router(self, worker_url: str) -> bool: if result.returncode == 0: response = json.loads(result.stdout.decode()) if response.get("status") == "accepted": - logger.info("Removed worker %s from sglang router", worker_url) + logger.info( + "Removed worker %s from sglang router on port %s", + worker_url, + self.context.port, + ) return True else: logger.error("Failed to remove worker %s: %s", worker_url, response) diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index 494804c50..3b6803595 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -5,7 +5,9 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { {% if router is not none %} - server 127.0.0.1:3000; # SGLang router on the gateway + {% if router.type == "sglang" and router_port is not none %} + server 127.0.0.1:{{ router_port }}; # SGLang router on the gateway + {% endif %} {% else %} {% for replica in replicas %} server unix:{{ replica.socket }}; # replica {{ replica.id }} diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 29687d396..dfbab8bb8 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -1,9 +1,10 @@ import importlib.resources +import socket import subprocess import tempfile from asyncio import Lock from pathlib import Path -from typing import Optional +from typing import Dict, Optional import jinja2 from pydantic import BaseModel @@ -34,10 +35,9 @@ class SiteConfig(BaseModel): def render(self) -> str: template = read_package_resource(f"{self.type}.jinja2") - return jinja2.Template(template).render( - **self.dict(), - proxy_port=PROXY_PORT_ON_GATEWAY, - ) + render_dict = self.dict() + render_dict["proxy_port"] = PROXY_PORT_ON_GATEWAY + return jinja2.Template(template).render(**render_dict) class ReplicaConfig(BaseModel): @@ -72,6 +72,7 @@ class ServiceConfig(SiteConfig): replicas: list[ReplicaConfig] router: Optional[AnyRouterConfig] = None model_id: Optional[str] = None + router_port: Optional[int] = None class ModelEntrypointConfig(SiteConfig): @@ -85,7 +86,18 @@ class Nginx: def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._conf_dir = conf_dir self._lock: Lock = Lock() - self._router: Optional[Router] = None + # 1:1 service-to-router mapping + self._router_port_to_domain: Dict[int, str] = {} + self._domain_to_router: Dict[str, Router] = {} + self._ROUTER_PORT_MIN: int = 20000 + self._ROUTER_PORT_MAX: int = 24999 + self._WORKER_PORT_MIN: int = 10001 + self._WORKER_PORT_MAX: int = 11999 + self._next_router_port: int = self._ROUTER_PORT_MIN + # Tracking of worker ports to avoid conflicts across router instances + self._allocated_worker_ports: set[int] = set() + self._domain_to_worker_ports: Dict[str, list[int]] = {} + self._next_worker_port: int = self._WORKER_PORT_MIN async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.debug("Registering %s domain %s", conf.type, conf.domain) @@ -93,38 +105,85 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: async with self._lock: if conf.https: await run_async(self.run_certbot, conf.domain, acme) - await run_async(self.write_conf, conf.render(), conf_name) - if isinstance(conf, ServiceConfig) and conf.router and conf.model_id: - if self._router is None: - ctx = RouterContext( - host="127.0.0.1", - port=3000, - log_dir=Path("./router_logs"), - log_level="info", - ) - self._router = get_router(conf.router, context=ctx) - if not await run_async(self._router.is_running): - await run_async(self._router.start) - - replicas = await run_async( - self._router.register_replicas, - conf.domain, - len(conf.replicas), - conf.model_id, - ) - - allocated_ports = [int(r.url.rsplit(":", 1)[-1]) for r in replicas] - try: - await run_async(self.write_router_workers_conf, conf, allocated_ports) - except Exception as e: - logger.exception( - "write_router_workers_conf failed for domain=%s: %s", conf.domain, e - ) - raise - finally: - # Always update router state, regardless of nginx reload status - await run_async(self._router.update_replicas, replicas) + if isinstance(conf, ServiceConfig) and conf.router: + if conf.router.type == "sglang": + # Check if router already exists for this domain + if conf.domain in self._domain_to_router: + # Router already exists, reuse it + router = self._domain_to_router[conf.domain] + router_port = router.context.port + conf.router_port = router_port + else: + # Allocate router port for new router + router_port = self._allocate_router_port() + conf.router_port = router_port + + # Create per-service log directory + log_dir = Path(f"./router_logs/{conf.domain}") + + # Create router context with allocated port + ctx = RouterContext( + host="127.0.0.1", + port=router_port, + log_dir=log_dir, + log_level="info", + ) + + # Create new router instance for this service + router = get_router(conf.router, context=ctx) + + # Store mappings + self._router_port_to_domain[router_port] = conf.domain + self._domain_to_router[conf.domain] = router + + # Start router if not running + if not await run_async(router.is_running): + await run_async(router.start) + + # Discard old worker ports if domain already has allocated ports (required for scaling case) + if conf.domain in self._domain_to_worker_ports: + old_worker_ports = self._domain_to_worker_ports[conf.domain] + for port in old_worker_ports: + self._allocated_worker_ports.discard(port) + + allocated_ports = self._allocate_worker_ports(len(conf.replicas)) + self._domain_to_worker_ports[conf.domain] = allocated_ports + + replica_urls = [ + f"http://{router.context.host}:{port}" for port in allocated_ports + ] + + # Write router workers config + try: + if conf.replicas: + await run_async(self.write_router_workers_conf, conf, allocated_ports) + except Exception as e: + # Discard allocated worker ports on error + for port in allocated_ports: + self._allocated_worker_ports.discard(port) + if conf.domain in self._domain_to_worker_ports: + del self._domain_to_worker_ports[conf.domain] + logger.exception( + "write_router_workers_conf failed for domain=%s: %s", conf.domain, e + ) + raise + + # Update replicas to router (actual HTTP API calls to add workers) + try: + await run_async(router.update_replicas, replica_urls) + except Exception as e: + # Free allocated worker ports on error + for port in allocated_ports: + self._allocated_worker_ports.discard(port) + if conf.domain in self._domain_to_worker_ports: + del self._domain_to_worker_ports[conf.domain] + logger.exception( + "Failed to add replicas to router for domain=%s: %s", conf.domain, e + ) + raise + + await run_async(self.write_conf, conf.render(), conf_name) logger.info("Registered %s domain %s", conf.type, conf.domain) @@ -135,12 +194,33 @@ async def unregister(self, domain: str) -> None: return async with self._lock: await run_async(sudo_rm, conf_path) - # Generic router implementation - if self._router is not None: - # Unregister replicas for this domain (router handles domain-to-model_id lookup) - await run_async(self._router.unregister_replicas, domain) - # Remove workers config file (router-specific naming) + if domain in self._domain_to_router: + router = self._domain_to_router[domain] + # Remove all workers for this domain + if domain in self._domain_to_worker_ports: + worker_ports = self._domain_to_worker_ports[domain] + replica_urls = [ + f"http://{router.context.host}:{port}" for port in worker_ports + ] + await run_async(router.remove_replicas, replica_urls) + # Stop and kill the router + await run_async(router.stop) + # Remove from mappings + router_port = router.context.port + if router_port in self._router_port_to_domain: + del self._router_port_to_domain[router_port] + del self._domain_to_router[domain] + + # Discard worker ports for this domain + if domain in self._domain_to_worker_ports: + worker_ports = self._domain_to_worker_ports[domain] + for port in worker_ports: + self._allocated_worker_ports.discard(port) + del self._domain_to_worker_ports[domain] + logger.debug("Freed worker ports %s for domain %s", worker_ports, domain) + + # Remove workers config file workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" if workers_conf_path.exists(): await run_async(sudo_rm, workers_conf_path) @@ -213,6 +293,125 @@ def certificate_exists(domain: str) -> bool: def get_config_name(domain: str) -> str: return f"443-{domain}.conf" + @staticmethod + def _is_port_available(port: int) -> bool: + """Check if a port is actually available (not in use by any process). + + Tries to bind to the port to see if it's available. + """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("127.0.0.1", port)) + # If bind succeeds, port is available + return True + except OSError: + # If bind fails (e.g., Address already in use), port is not available + return False + except Exception: + logger.debug("Error checking port %s availability", port) + return False + + def _allocate_router_port(self) -> int: + """Allocate next available router port in fixed range (20000-24999). + + Checks both our internal allocation map and actual port availability + to avoid conflicts with other services. Range chosen to avoid ephemeral ports. + """ + port = self._next_router_port + max_attempts = self._ROUTER_PORT_MAX - self._ROUTER_PORT_MIN + 1 + attempts = 0 + + while attempts < max_attempts: + # Check if port is already allocated by us + if port in self._router_port_to_domain: + port += 1 + if port > self._ROUTER_PORT_MAX: + port = self._ROUTER_PORT_MIN # Wrap around + attempts += 1 + continue + + # Check if port is actually available on the system + if self._is_port_available(port): + # Port is available, allocate it + self._next_router_port = port + 1 + if self._next_router_port > self._ROUTER_PORT_MAX: + self._next_router_port = self._ROUTER_PORT_MIN # Wrap around + logger.debug("Allocated router port %s", port) + return port + + # Port is in use, try next one + logger.debug("Port %s is in use, trying next port", port) + port += 1 + if port > self._ROUTER_PORT_MAX: + port = self._ROUTER_PORT_MIN # Wrap around + attempts += 1 + + raise UnexpectedProxyError( + f"Router port range exhausted ({self._ROUTER_PORT_MIN}-{self._ROUTER_PORT_MAX}). " + "All ports in range appear to be in use." + ) + + def _allocate_worker_ports(self, num_ports: int) -> list[int]: + """Allocate worker ports globally in fixed range (10001-11999). + + Worker ports are used by nginx to listen and proxy to worker sockets. + They must be unique across all router instances. Range chosen to avoid ephemeral ports. + + Args: + num_ports: Number of worker ports to allocate + + Returns: + List of allocated worker port numbers + """ + allocated = [] + port = self._next_worker_port + max_attempts = (self._WORKER_PORT_MAX - self._WORKER_PORT_MIN + 1) * 2 # Allow wrap-around + attempts = 0 + + while len(allocated) < num_ports and attempts < max_attempts: + # Check if port is already allocated globally + if port in self._allocated_worker_ports: + port += 1 + if port > self._WORKER_PORT_MAX: + port = self._WORKER_PORT_MIN # Wrap around + attempts += 1 + continue + + # Check if port is actually available on the system + if self._is_port_available(port): + allocated.append(port) + self._allocated_worker_ports.add(port) + logger.debug("Allocated worker port %s", port) + port += 1 + if port > self._WORKER_PORT_MAX: + port = self._WORKER_PORT_MIN # Wrap around + else: + logger.debug("Worker port %s is in use, trying next port", port) + port += 1 + if port > self._WORKER_PORT_MAX: + port = self._WORKER_PORT_MIN # Wrap around + + attempts += 1 + + if len(allocated) < num_ports: + # Free up the ports we did allocate + for p in allocated: + self._allocated_worker_ports.discard(p) + raise UnexpectedProxyError( + f"Failed to allocate {num_ports} worker ports in range " + f"({self._WORKER_PORT_MIN}-{self._WORKER_PORT_MAX}). " + f"Only allocated {len(allocated)} ports after {attempts} attempts." + ) + + # Update next worker port for next allocation + self._next_worker_port = port + if self._next_worker_port > self._WORKER_PORT_MAX: + self._next_worker_port = self._WORKER_PORT_MIN # Wrap around + + return allocated + def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") From b8794f1d385a2a7e10504ee46452aa60804dc065 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Mon, 17 Nov 2025 20:54:02 +0545 Subject: [PATCH 06/11] Resolve gateway installation command --- src/dstack/_internal/core/backends/base/compute.py | 9 ++++----- .../_internal/server/services/gateways/__init__.py | 13 ++++++------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 4f92912a8..987fed357 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1029,17 +1029,16 @@ def get_dstack_gateway_wheel(build: str) -> str: def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() wheel = get_dstack_gateway_wheel(build) - # Use router type directly as pip extra + # Build package spec with extras if router is specified if router: - gateway_package = f"dstack-gateway[{router.type}]" + gateway_package = f"dstack-gateway[{router.type}] @ {wheel}" else: - gateway_package = "dstack-gateway" + gateway_package = f"dstack-gateway @ {wheel}" return [ "mkdir -p /home/ubuntu/dstack", "python3 -m venv /home/ubuntu/dstack/blue", "python3 -m venv /home/ubuntu/dstack/green", - f"/home/ubuntu/dstack/blue/bin/pip install {wheel}", - f"/home/ubuntu/dstack/blue/bin/pip install --upgrade '{gateway_package}'", + f"/home/ubuntu/dstack/blue/bin/pip install '{gateway_package}'", "sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run", ] diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 08380b240..3e375ca2b 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -454,20 +454,19 @@ async def _update_gateway(gateway_compute_model: GatewayComputeModel, build: str gateway_compute_model.configuration ) - # Determine gateway package with router extras (similar to compute.py) + # Build package spec with extras and wheel URL (similar to compute.py) + wheel = get_dstack_gateway_wheel(build) if compute_config.router: - gateway_package = f"dstack-gateway[{compute_config.router.type}]" + gateway_package = f"dstack-gateway[{compute_config.router.type}] @ {wheel}" else: - gateway_package = "dstack-gateway" + gateway_package = f"dstack-gateway @ {wheel}" commands = [ # prevent update.sh from overwriting itself during execution "cp dstack/update.sh dstack/_update.sh", - f"sh dstack/_update.sh {get_dstack_gateway_wheel(build)} {build}", + # Pass the full package spec (with extras) to update.sh instead of just the wheel + f"sh dstack/_update.sh '{gateway_package}' {build}", "rm dstack/_update.sh", - # Install gateway package with router extras to the active venv (blue or green) - # update.sh writes the active version to dstack/version - f"version=$(cat /home/ubuntu/dstack/version) && /home/ubuntu/dstack/$version/bin/pip install --upgrade '{gateway_package}'", ] stdout = await connection.tunnel.aexec("/bin/sh -c '" + " && ".join(commands) + "'") if "Update successfully completed" in stdout: From 56995566d6870eb3825b7cc0816c6cfe35f318c4 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 18 Nov 2025 14:32:52 +0545 Subject: [PATCH 07/11] Resolved Minor Review Comments --- src/dstack/_internal/core/models/routers.py | 17 +++-------------- .../proxy/gateway/model_routers/__init__.py | 4 ++-- .../proxy/gateway/model_routers/sglang.py | 9 ++++++--- .../_internal/proxy/gateway/services/nginx.py | 8 ++++---- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index 58766aac7..ec779b124 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,27 +1,16 @@ from enum import Enum -from typing import Union - -from pydantic import Field -from typing_extensions import Annotated, Literal +from typing import Literal from dstack._internal.core.models.common import CoreModel class RouterType(str, Enum): SGLANG = "sglang" - VLLM = "vllm" class SGLangRouterConfig(CoreModel): type: Literal["sglang"] = "sglang" - policy: str = "cache_aware" - - -class VLLMRouterConfig(CoreModel): - type: Literal["vllm"] = "vllm" - policy: str = "cache_aware" + policy: Literal["random", "round_robin", "cache_aware", "power_of_two"] = "cache_aware" -AnyRouterConfig = Annotated[ - Union[SGLangRouterConfig, VLLMRouterConfig], Field(discriminator="type") -] +AnyRouterConfig = SGLangRouterConfig diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py index 59e40dcb9..b8d796ff4 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py @@ -2,15 +2,15 @@ from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter +from dstack._internal.proxy.lib.errors import ProxyError from .base import Router, RouterContext def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: - """Factory function to create a router instance from router configuration.""" if router.type == "sglang": return SglangRouter(router=router, context=context) - raise ValueError(f"Router type '{router.type}' is not available") + raise ProxyError(f"Router type '{router.type}' is not available") __all__ = [ diff --git a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py index c0aa0f1b9..7b7f28b67 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py @@ -5,8 +5,9 @@ import urllib.parse from typing import List, Optional -from dstack._internal.core.models.routers import SGLangRouterConfig +from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.proxy.gateway.const import DSTACK_DIR_ON_GATEWAY +from dstack._internal.proxy.lib.errors import UnexpectedProxyError from dstack._internal.utils.logging import get_logger from .base import Router, RouterContext @@ -17,7 +18,7 @@ class SglangRouter(Router): """SGLang router implementation with 1:1 service-to-router.""" - TYPE = "sglang" + TYPE = RouterType.SGLANG def __init__(self, router: SGLangRouterConfig, context: Optional[RouterContext] = None): """Initialize SGLang router. @@ -68,7 +69,9 @@ def start(self) -> None: time.sleep(2) if not self.is_running(): - raise Exception(f"Failed to start sglang router on port {self.context.port}") + raise UnexpectedProxyError( + f"Failed to start sglang router on port {self.context.port}" + ) logger.info( "Sglang router started successfully on port %s (prometheus on %s)", diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index dfbab8bb8..6734a4791 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from typing_extensions import Literal -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY from dstack._internal.proxy.gateway.model_routers import ( Router, @@ -107,7 +107,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: await run_async(self.run_certbot, conf.domain, acme) if isinstance(conf, ServiceConfig) and conf.router: - if conf.router.type == "sglang": + if conf.router.type == RouterType.SGLANG: # Check if router already exists for this domain if conf.domain in self._domain_to_router: # Router already exists, reuse it @@ -314,7 +314,7 @@ def _is_port_available(port: int) -> bool: return False def _allocate_router_port(self) -> int: - """Allocate next available router port in fixed range (20000-24999). + """Allocate next available router port in fixed range. Checks both our internal allocation map and actual port availability to avoid conflicts with other services. Range chosen to avoid ephemeral ports. @@ -354,7 +354,7 @@ def _allocate_router_port(self) -> int: ) def _allocate_worker_ports(self, num_ports: int) -> list[int]: - """Allocate worker ports globally in fixed range (10001-11999). + """Allocate worker ports globally in fixed range. Worker ports are used by nginx to listen and proxy to worker sockets. They must be unique across all router instances. Range chosen to avoid ephemeral ports. From a47585889d90b218efb84b4f32739d50f563b6e1 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 18 Nov 2025 19:50:26 +0545 Subject: [PATCH 08/11] Resolve comments and kubernetes sglang router intregration Test gateway package update Test gateway package update Test gateway package update Test gateway package update Test gateway package update Resolve rate limits and location issue Resolve rate limits and location issue Resolve rate limits and location issue Resolve all major comments Resolve all major comments Resolve kubernetes gateway issue with sglang intregration --- gateway/pyproject.toml | 2 +- .../_internal/core/backends/base/compute.py | 17 +- .../core/backends/kubernetes/compute.py | 16 +- .../proxy/gateway/model_routers/__init__.py | 2 +- .../proxy/gateway/model_routers/base.py | 4 +- .../proxy/gateway/model_routers/sglang.py | 159 ++++++++---------- .../gateway/resources/nginx/service.jinja2 | 28 +-- .../_internal/proxy/gateway/services/nginx.py | 7 +- .../proxy/gateway/services/registry.py | 51 +++++- src/dstack/_internal/proxy/lib/models.py | 1 - .../server/services/gateways/__init__.py | 12 +- 11 files changed, 147 insertions(+), 152 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 3c877c0af..7bade5896 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ git+https://github.com/Bihan/dstack.git@add_sglang_router_support", + "dstack[gateway] @ https://github.com/Bihan/dstack/archive/refs/heads/add_sglang_router_support.tar.gz", ] [project.optional-dependencies] diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 987fed357..5bdc4746c 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1014,7 +1014,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1022,18 +1022,17 @@ def get_dstack_gateway_wheel(build: str) -> str: r.raise_for_status() build = r.text.strip() logger.debug("Found the latest gateway build: %s", build) - # return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" - return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + # Build package spec with extras if router is specified + if router: + return f"dstack-gateway[{router.type}] @ {wheel}" + return f"dstack-gateway @ {wheel}" def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() - wheel = get_dstack_gateway_wheel(build) - # Build package spec with extras if router is specified - if router: - gateway_package = f"dstack-gateway[{router.type}] @ {wheel}" - else: - gateway_package = f"dstack-gateway @ {wheel}" + gateway_package = get_dstack_gateway_wheel(build, router) return [ "mkdir -p /home/ubuntu/dstack", "python3 -m venv /home/ubuntu/dstack/blue", diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index bf2c7c7fb..b8a3975a9 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -1,3 +1,4 @@ +import shlex import subprocess import tempfile import threading @@ -51,6 +52,7 @@ SSHConnectionParams, ) from dstack._internal.core.models.resources import CPUSpec, Memory +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import parse_memory @@ -403,7 +405,9 @@ def create_gateway( # Consider deploying an NLB. It seems it requires some extra configuration on the cluster: # https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html instance_name = generate_unique_gateway_instance_name(configuration) - commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub]) + commands = _get_gateway_commands( + authorized_keys=[configuration.ssh_key_pub], router=configuration.router + ) pod = client.V1Pod( metadata=client.V1ObjectMeta( name=instance_name, @@ -940,9 +944,13 @@ def _add_authorized_key_to_jump_pod( ) -def _get_gateway_commands(authorized_keys: List[str]) -> List[str]: +def _get_gateway_commands( + authorized_keys: List[str], router: Optional[AnyRouterConfig] = None +) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() - gateway_commands = " && ".join(get_dstack_gateway_commands()) + gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) + quoted_gateway_commands = shlex.quote(gateway_commands) + commands = [ # install packages "apt-get update && apt-get install -y sudo wget openssh-server nginx python3.10-venv libaugeas0", @@ -971,7 +979,7 @@ def _get_gateway_commands(authorized_keys: List[str]) -> List[str]: # start sshd "/usr/sbin/sshd -p 22 -o PermitUserEnvironment=yes", # run gateway - f"su ubuntu -c '{gateway_commands}'", + f"su ubuntu -c {quoted_gateway_commands}", "sleep infinity", ] return commands diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py index b8d796ff4..1db0e759a 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/__init__.py @@ -9,7 +9,7 @@ def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: if router.type == "sglang": - return SglangRouter(router=router, context=context) + return SglangRouter(config=router, context=context) raise ProxyError(f"Router type '{router.type}' is not available") diff --git a/src/dstack/_internal/proxy/gateway/model_routers/base.py b/src/dstack/_internal/proxy/gateway/model_routers/base.py index 93ef0b7d9..65a414ca0 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/base.py @@ -28,13 +28,13 @@ class Router(ABC): def __init__( self, - router: Optional[AnyRouterConfig] = None, + config: Optional[AnyRouterConfig] = None, context: Optional[RouterContext] = None, ): """Initialize router with context. Args: - router: Optional router configuration (implementation-specific) + config: Optional router configuration (implementation-specific) context: Runtime context for the router (host, port, logging, etc.) """ self.context = context or RouterContext() diff --git a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py index 7b7f28b67..fed4ef480 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/model_routers/sglang.py @@ -1,12 +1,13 @@ -import json import shutil import subprocess +import sys import time import urllib.parse from typing import List, Optional +import httpx + from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig -from dstack._internal.proxy.gateway.const import DSTACK_DIR_ON_GATEWAY from dstack._internal.proxy.lib.errors import UnexpectedProxyError from dstack._internal.utils.logging import get_logger @@ -20,50 +21,43 @@ class SglangRouter(Router): TYPE = RouterType.SGLANG - def __init__(self, router: SGLangRouterConfig, context: Optional[RouterContext] = None): + def __init__(self, config: SGLangRouterConfig, context: Optional[RouterContext] = None): """Initialize SGLang router. Args: - router: SGLang router configuration (policy, cache_threshold, etc.) + config: SGLang router configuration (policy, cache_threshold, etc.) context: Runtime context for the router (host, port, logging, etc.) """ - super().__init__(router=router, context=context) - self.config = router + super().__init__(config=config, context=context) + self.config = config def start(self) -> None: try: logger.info("Starting sglang-router-new on port %s...", self.context.port) - # Determine active venv (blue or green) - version_file = DSTACK_DIR_ON_GATEWAY / "version" - if version_file.exists(): - version = version_file.read_text().strip() - else: - version = "blue" - - venv_python = DSTACK_DIR_ON_GATEWAY / version / "bin" / "python3" - + # Prometheus port is offset by 10000 from router port to keep it in a separate range prometheus_port = self.context.port + 10000 cmd = [ - str(venv_python), + sys.executable, "-m", "sglang_router.launch_router", "--host", - "0.0.0.0", + self.context.host, "--port", str(self.context.port), "--prometheus-port", str(prometheus_port), + "--prometheus-host", + self.context.host, "--log-level", self.context.log_level, "--log-dir", str(self.context.log_dir), + "--policy", + self.config.policy, ] - if hasattr(self.config, "policy") and self.config.policy: - cmd.extend(["--policy", self.config.policy]) - subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) time.sleep(2) @@ -129,12 +123,9 @@ def stop(self) -> None: def is_running(self) -> bool: """Check if the SGLang router is running and responding to HTTP requests on the assigned port.""" try: - result = subprocess.run( - ["curl", "-s", f"http://{self.context.host}:{self.context.port}/workers"], - capture_output=True, - timeout=5, - ) - return result.returncode == 0 + with httpx.Client(timeout=5.0) as client: + response = client.get(f"http://{self.context.host}:{self.context.port}/workers") + return response.status_code == 200 except Exception as e: logger.error(f"Error checking sglang router status on port {self.context.port}: {e}") return False @@ -151,8 +142,11 @@ def update_replicas(self, replica_urls: List[str]) -> None: for worker in current_workers: url = worker.get("url") if url and isinstance(url, str): - current_worker_urls.add(url) - target_worker_urls = set(replica_urls) + # Normalize URL by removing trailing slashes to avoid path artifacts + normalized_url = url.rstrip("/") + current_worker_urls.add(normalized_url) + # Normalize target URLs to ensure consistent comparison + target_worker_urls = {url.rstrip("/") for url in replica_urls} # Workers to add workers_to_add = target_worker_urls - current_worker_urls @@ -186,16 +180,13 @@ def update_replicas(self, replica_urls: List[str]) -> None: def _get_router_workers(self) -> List[dict]: try: - result = subprocess.run( - ["curl", "-s", f"http://{self.context.host}:{self.context.port}/workers"], - capture_output=True, - timeout=5, - ) - if result.returncode == 0: - response = json.loads(result.stdout.decode()) - workers = response.get("workers", []) - return workers - return [] + with httpx.Client(timeout=5.0) as client: + response = client.get(f"http://{self.context.host}:{self.context.port}/workers") + if response.status_code == 200: + response_data = response.json() + workers = response_data.get("workers", []) + return workers + return [] except Exception as e: logger.error(f"Error getting sglang router workers: {e}") return [] @@ -203,36 +194,31 @@ def _get_router_workers(self) -> List[dict]: def _add_worker_to_router(self, worker_url: str) -> bool: try: payload = {"url": worker_url, "worker_type": "regular"} - result = subprocess.run( - [ - "curl", - "-X", - "POST", + with httpx.Client(timeout=5.0) as client: + response = client.post( f"http://{self.context.host}:{self.context.port}/workers", - "-H", - "Content-Type: application/json", - "-d", - json.dumps(payload), - ], - capture_output=True, - timeout=5, - ) - - if result.returncode == 0: - response = json.loads(result.stdout.decode()) - if response.get("status") == "accepted": - logger.info( - "Added worker %s to sglang router on port %s", + json=payload, + ) + if response.status_code == 200: + response_data = response.json() + if response_data.get("status") == "accepted": + logger.info( + "Added worker %s to sglang router on port %s", + worker_url, + self.context.port, + ) + return True + else: + logger.error("Failed to add worker %s: %s", worker_url, response_data) + return False + else: + logger.error( + "Failed to add worker %s: status %d, %s", worker_url, - self.context.port, + response.status_code, + response.text, ) - return True - else: - logger.error("Failed to add worker %s: %s", worker_url, response) return False - else: - logger.error("Failed to add worker %s: %s", worker_url, result.stderr.decode()) - return False except Exception as e: logger.error(f"Error adding worker {worker_url}: {e}") return False @@ -240,33 +226,30 @@ def _add_worker_to_router(self, worker_url: str) -> bool: def _remove_worker_from_router(self, worker_url: str) -> bool: try: encoded_url = urllib.parse.quote(worker_url, safe="") - - result = subprocess.run( - [ - "curl", - "-X", - "DELETE", - f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}", - ], - capture_output=True, - timeout=5, - ) - - if result.returncode == 0: - response = json.loads(result.stdout.decode()) - if response.get("status") == "accepted": - logger.info( - "Removed worker %s from sglang router on port %s", + with httpx.Client(timeout=5.0) as client: + response = client.delete( + f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" + ) + if response.status_code == 200: + response_data = response.json() + if response_data.get("status") == "accepted": + logger.info( + "Removed worker %s from sglang router on port %s", + worker_url, + self.context.port, + ) + return True + else: + logger.error("Failed to remove worker %s: %s", worker_url, response_data) + return False + else: + logger.error( + "Failed to remove worker %s: status %d, %s", worker_url, - self.context.port, + response.status_code, + response.text, ) - return True - else: - logger.error("Failed to remove worker %s: %s", worker_url, response) return False - else: - logger.error("Failed to remove worker %s: %s", worker_url, result.stderr.decode()) - return False except Exception as e: logger.error(f"Error removing worker {worker_url}: {e}") return False diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index 3b6803595..20590e37c 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -24,28 +24,6 @@ server { access_log {{ access_log_path }} dstack_stat; client_max_body_size {{ client_max_body_size }}; - {% if router is not none and router.type == "sglang" %} - # Whitelist approach: Only allow safe user-facing endpoints - # Allow /generate - location = /generate { - {% if auth %} - auth_request /_dstack_auth; - {% endif %} - try_files /nonexistent @$http_upgrade; - } - # Allow /v1/* endpoints (regex match) - location ~ ^/v1/ { - {% if auth %} - auth_request /_dstack_auth; - {% endif %} - try_files /nonexistent @$http_upgrade; - } - # Block everything else - # This will match any path that doesn't match /generate or /v1/* - location ~ . { - return 403; - } - {% else %} {% for location in locations %} location {{ location.prefix }} { {% if auth %} @@ -59,6 +37,12 @@ server { {% endif %} } {% endfor %} + + {# For SGLang router: block all requests except whitelisted locations added dynamically above #} + {% if router is not none and router.type == "sglang" %} + location / { + return 403; + } {% endif %} location @websocket { diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 6734a4791..6a55d3b9a 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -71,7 +71,6 @@ class ServiceConfig(SiteConfig): locations: list[LocationConfig] replicas: list[ReplicaConfig] router: Optional[AnyRouterConfig] = None - model_id: Optional[str] = None router_port: Optional[int] = None @@ -124,10 +123,8 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: # Create router context with allocated port ctx = RouterContext( - host="127.0.0.1", port=router_port, log_dir=log_dir, - log_level="info", ) # Create new router instance for this service @@ -421,9 +418,7 @@ def write_router_workers_conf(self, conf: ServiceConfig, allocated_ports: list[i # Pass ports to template workers_config = generate_router_workers_config(conf, allocated_ports) workers_conf_name = f"router-workers.{conf.domain}.conf" - workers_conf_path = self._conf_dir / workers_conf_name - sudo_write(workers_conf_path, workers_config) - self.reload() + self.write_conf(workers_config, workers_conf_name) def generate_router_workers_config(conf: ServiceConfig, allocated_ports: list[int]) -> str: diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 0abc346c9..636d8c38e 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -6,7 +6,7 @@ import dstack._internal.proxy.gateway.schemas.registry as schemas from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType from dstack._internal.proxy.gateway import models as gateway_models from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.services.nginx import ( @@ -57,7 +57,6 @@ async def register_service( client_max_body_size=client_max_body_size, replicas=(), router=router, - model_id=model.name if model is not None else None, ) async with lock: @@ -310,6 +309,15 @@ async def get_nginx_service_config( ) -> ServiceConfig: limit_req_zones: list[LimitReqZoneConfig] = [] locations: list[LocationConfig] = [] + is_sglang = service.router and service.router.type == RouterType.SGLANG + sglang_whitelisted_paths = [ + "/generate", + "/v1/", + "/chat/completions", + ] # Prefix match for paths that end with a slash and exact match for paths that don't + sglang_limits: dict[str, LimitReqConfig] = {} + sglang_prefix_lengths: dict[str, int] = {} # Track prefix lengths for most-specific selection + for i, rate_limit in enumerate(service.rate_limits): zone_name = f"{i}.{service.domain_safe}" if isinstance(rate_limit.key, models.IPAddressPartitioningKey): @@ -321,13 +329,39 @@ async def get_nginx_service_config( limit_req_zones.append( LimitReqZoneConfig(name=zone_name, key=key, rpm=round(rate_limit.rps * 60)) ) - locations.append( - LocationConfig( - prefix=rate_limit.prefix, - limit_req=LimitReqConfig(zone=zone_name, burst=rate_limit.burst), + if is_sglang: + for path in sglang_whitelisted_paths: + if rate_limit.prefix == path or path.startswith(rate_limit.prefix): + # Use the longest prefix if multiple prefixes match the same path + current_prefix_len = len(rate_limit.prefix) + if path not in sglang_limits or current_prefix_len > sglang_prefix_lengths.get( + path, 0 + ): + sglang_limits[path] = LimitReqConfig( + zone=zone_name, burst=rate_limit.burst + ) + sglang_prefix_lengths[path] = current_prefix_len + else: + locations.append( + LocationConfig( + prefix=rate_limit.prefix, + limit_req=LimitReqConfig(zone=zone_name, burst=rate_limit.burst), + ) ) - ) - if not any(location.prefix == "/" for location in locations): + + # Add SGLang whitelisted paths as locations + if is_sglang: + for path in sglang_whitelisted_paths: + # Use prefix match for paths that end with a slash and exact match for paths that don't + if path.endswith("/"): + locations.append(LocationConfig(prefix=path, limit_req=sglang_limits.get(path))) + else: + locations.append( + LocationConfig(prefix=f"= {path}", limit_req=sglang_limits.get(path)) + ) + + # Don't auto-add / location for SGLang routers (catch-all 403 handles it) + if not any(location.prefix == "/" for location in locations) and not is_sglang: locations.append(LocationConfig(prefix="/", limit_req=None)) return ServiceConfig( domain=service.domain_safe, @@ -340,7 +374,6 @@ async def get_nginx_service_config( locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs router=service.router, - model_id=service.model_id, ) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 21890d233..bf37e0b5a 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -59,7 +59,6 @@ class Service(ImmutableModel): strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] router: Optional[AnyRouterConfig] = None - model_id: Optional[str] = None @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 3e375ca2b..e2a48fa27 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -454,18 +454,12 @@ async def _update_gateway(gateway_compute_model: GatewayComputeModel, build: str gateway_compute_model.configuration ) - # Build package spec with extras and wheel URL (similar to compute.py) - wheel = get_dstack_gateway_wheel(build) - if compute_config.router: - gateway_package = f"dstack-gateway[{compute_config.router.type}] @ {wheel}" - else: - gateway_package = f"dstack-gateway @ {wheel}" - + # Build package spec with extras and wheel URL + gateway_package = get_dstack_gateway_wheel(build, compute_config.router) commands = [ # prevent update.sh from overwriting itself during execution "cp dstack/update.sh dstack/_update.sh", - # Pass the full package spec (with extras) to update.sh instead of just the wheel - f"sh dstack/_update.sh '{gateway_package}' {build}", + f'sh dstack/_update.sh "{gateway_package}" {build}', "rm dstack/_update.sh", ] stdout = await connection.tunnel.aexec("/bin/sh -c '" + " && ".join(commands) + "'") From 741c0eada325b37171897d13ce7c4298515a729d Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 21 Nov 2025 10:12:04 +0545 Subject: [PATCH 09/11] Resolve additional comments --- .../gateway/resources/nginx/service.jinja2 | 4 +- .../{ => services}/model_routers/__init__.py | 10 +- .../{ => services}/model_routers/base.py | 10 +- .../{ => services}/model_routers/sglang.py | 110 ++++++++++-------- .../_internal/proxy/gateway/services/nginx.py | 40 +++---- 5 files changed, 89 insertions(+), 85 deletions(-) rename src/dstack/_internal/proxy/gateway/{ => services}/model_routers/__init__.py (61%) rename src/dstack/_internal/proxy/gateway/{ => services}/model_routers/base.py (93%) rename src/dstack/_internal/proxy/gateway/{ => services}/model_routers/sglang.py (74%) diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index 20590e37c..31f987706 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -4,10 +4,8 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { - {% if router is not none %} - {% if router.type == "sglang" and router_port is not none %} + {% if router_port is not none %} server 127.0.0.1:{{ router_port }}; # SGLang router on the gateway - {% endif %} {% else %} {% for replica in replicas %} server unix:{{ replica.socket }}; # replica {{ replica.id }} diff --git a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py similarity index 61% rename from src/dstack/_internal/proxy/gateway/model_routers/__init__.py rename to src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py index 1db0e759a..9678699ac 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py @@ -1,14 +1,12 @@ -from typing import Optional - -from dstack._internal.core.models.routers import AnyRouterConfig -from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter +from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.proxy.gateway.services.model_routers.sglang import SglangRouter from dstack._internal.proxy.lib.errors import ProxyError from .base import Router, RouterContext -def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router: - if router.type == "sglang": +def get_router(router: AnyRouterConfig, context: RouterContext) -> Router: + if router.type == RouterType.SGLANG: return SglangRouter(config=router, context=context) raise ProxyError(f"Router type '{router.type}' is not available") diff --git a/src/dstack/_internal/proxy/gateway/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py similarity index 93% rename from src/dstack/_internal/proxy/gateway/model_routers/base.py rename to src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 65a414ca0..867591ca1 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -14,8 +14,8 @@ class Config: frozen = True host: str = "127.0.0.1" - port: int = 3000 - log_dir: Path = Path("./router_logs") + port: int + log_dir: Path log_level: Literal["debug", "info", "warning", "error"] = "info" @@ -28,16 +28,16 @@ class Router(ABC): def __init__( self, + context: RouterContext, config: Optional[AnyRouterConfig] = None, - context: Optional[RouterContext] = None, ): """Initialize router with context. Args: - config: Optional router configuration (implementation-specific) context: Runtime context for the router (host, port, logging, etc.) + config: Optional router configuration (implementation-specific) """ - self.context = context or RouterContext() + self.context = context @abstractmethod def start(self) -> None: diff --git a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py similarity index 74% rename from src/dstack/_internal/proxy/gateway/model_routers/sglang.py rename to src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index fed4ef480..c3a0dfaae 100644 --- a/src/dstack/_internal/proxy/gateway/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -6,6 +6,7 @@ from typing import List, Optional import httpx +import psutil from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError @@ -21,16 +22,26 @@ class SglangRouter(Router): TYPE = RouterType.SGLANG - def __init__(self, config: SGLangRouterConfig, context: Optional[RouterContext] = None): + def __init__(self, config: SGLangRouterConfig, context: RouterContext): """Initialize SGLang router. Args: config: SGLang router configuration (policy, cache_threshold, etc.) context: Runtime context for the router (host, port, logging, etc.) """ - super().__init__(config=config, context=context) + super().__init__(context=context, config=config) self.config = config + def pid_from_tcp_ipv4_port(self, port: int) -> Optional[int]: + """ + Return PID of the process listening on the given TCP IPv4 port. + If no process is found, return None. + """ + for conn in psutil.net_connections(kind="tcp4"): + if conn.laddr and conn.laddr.port == port and conn.status == psutil.CONN_LISTEN: + return conn.pid + return None + def start(self) -> None: try: logger.info("Starting sglang-router-new on port %s...", self.context.port) @@ -73,51 +84,42 @@ def start(self) -> None: prometheus_port, ) - except Exception as e: - logger.error(f"Failed to start sglang-router-new: {e}") + except Exception: + logger.exception("Failed to start sglang-router") raise def stop(self) -> None: try: - result = subprocess.run( - ["lsof", "-ti", f":{self.context.port}"], capture_output=True, timeout=5 - ) - if result.returncode == 0: - pids = result.stdout.decode().strip().split("\n") - for pid in pids: - if pid: - logger.info( - "Stopping sglang-router-new process (PID: %s) on port %s", - pid, - self.context.port, + pid = self.pid_from_tcp_ipv4_port(self.context.port) + + if pid: + logger.debug( + "Stopping sglang-router process (PID: %s) on port %s", + pid, + self.context.port, + ) + try: + proc = psutil.Process(pid) + proc.terminate() + try: + proc.wait(timeout=5) + except psutil.TimeoutExpired: + logger.warning( + "Process %s did not terminate gracefully, forcing kill", pid ) - subprocess.run(["kill", pid], timeout=5) + proc.kill() + except psutil.NoSuchProcess: + logger.debug("sglang-router process %s already exited before stop()", pid) else: - result = subprocess.run( - ["pgrep", "-f", f"sglang.*--port.*{self.context.port}"], - capture_output=True, - timeout=5, - ) - if result.returncode == 0: - pids = result.stdout.decode().strip().split("\n") - for pid in pids: - if pid: - logger.info("Stopping sglang-router-new process (PID: %s)", pid) - subprocess.run(["kill", pid], timeout=5) - else: - logger.debug( - "No sglang-router-new process found on port %s", self.context.port - ) + logger.debug("No sglang-router process found on port %s", self.context.port) # Clean up router logs if self.context.log_dir.exists(): logger.debug("Cleaning up router logs for port %s...", self.context.port) shutil.rmtree(self.context.log_dir, ignore_errors=True) - else: - logger.debug("No router logs directory found to clean up") - except Exception as e: - logger.error(f"Failed to stop sglang-router-new: {e}") + except Exception: + logger.exception("Failed to stop sglang-router") raise def is_running(self) -> bool: @@ -126,8 +128,12 @@ def is_running(self) -> bool: with httpx.Client(timeout=5.0) as client: response = client.get(f"http://{self.context.host}:{self.context.port}/workers") return response.status_code == 200 - except Exception as e: - logger.error(f"Error checking sglang router status on port {self.context.port}: {e}") + except httpx.RequestError as e: + logger.debug( + "Sglang router not responding on port %s: %s", + self.context.port, + e, + ) return False def remove_replicas(self, replica_urls: List[str]) -> None: @@ -187,8 +193,8 @@ def _get_router_workers(self) -> List[dict]: workers = response_data.get("workers", []) return workers return [] - except Exception as e: - logger.error(f"Error getting sglang router workers: {e}") + except Exception: + logger.exception("Error getting sglang router workers") return [] def _add_worker_to_router(self, worker_url: str) -> bool: @@ -199,17 +205,21 @@ def _add_worker_to_router(self, worker_url: str) -> bool: f"http://{self.context.host}:{self.context.port}/workers", json=payload, ) - if response.status_code == 200: + if response.status_code == 202: response_data = response.json() if response_data.get("status") == "accepted": logger.info( - "Added worker %s to sglang router on port %s", + "Worker %s accepted by sglang router on port %s", worker_url, self.context.port, ) return True else: - logger.error("Failed to add worker %s: %s", worker_url, response_data) + logger.error( + "Sglang router on port %s failed to accept worker: %s", + self.context.port, + response_data, + ) return False else: logger.error( @@ -219,8 +229,8 @@ def _add_worker_to_router(self, worker_url: str) -> bool: response.text, ) return False - except Exception as e: - logger.error(f"Error adding worker {worker_url}: {e}") + except Exception: + logger.exception("Error adding worker %s", worker_url) return False def _remove_worker_from_router(self, worker_url: str) -> bool: @@ -230,7 +240,7 @@ def _remove_worker_from_router(self, worker_url: str) -> bool: response = client.delete( f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" ) - if response.status_code == 200: + if response.status_code == 202: response_data = response.json() if response_data.get("status") == "accepted": logger.info( @@ -240,7 +250,11 @@ def _remove_worker_from_router(self, worker_url: str) -> bool: ) return True else: - logger.error("Failed to remove worker %s: %s", worker_url, response_data) + logger.error( + "Sglang router on port %s failed to remove worker: %s", + self.context.port, + response_data, + ) return False else: logger.error( @@ -250,6 +264,6 @@ def _remove_worker_from_router(self, worker_url: str) -> bool: response.text, ) return False - except Exception as e: - logger.error(f"Error removing worker {worker_url}: {e}") + except Exception: + logger.exception("Error removing worker %s", worker_url) return False diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 6a55d3b9a..bbda92d91 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -12,12 +12,12 @@ from dstack._internal.core.models.routers import AnyRouterConfig, RouterType from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY -from dstack._internal.proxy.gateway.model_routers import ( +from dstack._internal.proxy.gateway.models import ACMESettings +from dstack._internal.proxy.gateway.services.model_routers import ( Router, RouterContext, get_router, ) -from dstack._internal.proxy.gateway.models import ACMESettings from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -135,18 +135,16 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: self._domain_to_router[conf.domain] = router # Start router if not running - if not await run_async(router.is_running): - await run_async(router.start) - - # Discard old worker ports if domain already has allocated ports (required for scaling case) - if conf.domain in self._domain_to_worker_ports: - old_worker_ports = self._domain_to_worker_ports[conf.domain] - for port in old_worker_ports: - self._allocated_worker_ports.discard(port) + try: + if not await run_async(router.is_running): + await run_async(router.start) + except Exception: + # Clean up on failure + del self._router_port_to_domain[router_port] + del self._domain_to_router[conf.domain] + raise allocated_ports = self._allocate_worker_ports(len(conf.replicas)) - self._domain_to_worker_ports[conf.domain] = allocated_ports - replica_urls = [ f"http://{router.context.host}:{port}" for port in allocated_ports ] @@ -155,12 +153,13 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: try: if conf.replicas: await run_async(self.write_router_workers_conf, conf, allocated_ports) + # Discard old worker ports if domain already has allocated ports (required for scaling case) + if conf.domain in self._domain_to_worker_ports: + old_worker_ports = self._domain_to_worker_ports[conf.domain] + for port in old_worker_ports: + self._allocated_worker_ports.discard(port) + self._domain_to_worker_ports[conf.domain] = allocated_ports except Exception as e: - # Discard allocated worker ports on error - for port in allocated_ports: - self._allocated_worker_ports.discard(port) - if conf.domain in self._domain_to_worker_ports: - del self._domain_to_worker_ports[conf.domain] logger.exception( "write_router_workers_conf failed for domain=%s: %s", conf.domain, e ) @@ -170,11 +169,6 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: try: await run_async(router.update_replicas, replica_urls) except Exception as e: - # Free allocated worker ports on error - for port in allocated_ports: - self._allocated_worker_ports.discard(port) - if conf.domain in self._domain_to_worker_ports: - del self._domain_to_worker_ports[conf.domain] logger.exception( "Failed to add replicas to router for domain=%s: %s", conf.domain, e ) @@ -307,7 +301,7 @@ def _is_port_available(port: int) -> bool: # If bind fails (e.g., Address already in use), port is not available return False except Exception: - logger.debug("Error checking port %s availability", port) + logger.warning("Error checking port %s availability", port) return False def _allocate_router_port(self) -> int: From d7b21f95325d9ba8ec9d7e492fefd21d7e0d8394 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 21 Nov 2025 11:13:26 +0545 Subject: [PATCH 10/11] Pinned sglang-router to 0.2.1 --- gateway/pyproject.toml | 4 ++-- src/dstack/_internal/core/backends/base/compute.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 7bade5896..c40a37b7f 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,11 +11,11 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ https://github.com/Bihan/dstack/archive/refs/heads/add_sglang_router_support.tar.gz", + "dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz", ] [project.optional-dependencies] -sglang = ["sglang-router==0.2.2"] +sglang = ["sglang-router==0.2.1"] [tool.setuptools.package-data] "dstack.gateway" = [ diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 5bdc4746c..2181bea76 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1022,8 +1022,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non r.raise_for_status() build = r.text.strip() logger.debug("Found the latest gateway build: %s", build) - # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" - wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" From c7a6c015400629d709e42983aa9e84f69f99bc97 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 21 Nov 2025 11:40:59 +0545 Subject: [PATCH 11/11] Fix linting error --- src/dstack/_internal/core/backends/kubernetes/compute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 406a5a506..71a59ad22 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -4,7 +4,7 @@ import threading import time from enum import Enum -from typing import Optional +from typing import List, Optional from gpuhunt import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, AcceleratorVendor from kubernetes import client @@ -50,9 +50,9 @@ Resources, SSHConnectionParams, ) -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Memory +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error, parse_memory