Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ dependencies = [
"dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz",
]

[project.optional-dependencies]
sglang = ["sglang-router==0.2.1"]

[tool.setuptools.package-data]
"dstack.gateway" = [
"resources/systemd/*",
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def create_gateway(
image_id=aws_resources.get_gateway_image_id(ec2_client),
instance_type="t3.micro",
iam_instance_profile=None,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
tags=tags,
security_group_id=security_group_id,
spot=False,
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def create_gateway(
image_reference=_get_gateway_image_ref(),
vm_size="Standard_B1ms",
instance_name=instance_name,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
ssh_pub_keys=[configuration.ssh_key_pub],
spot=False,
disk_size=30,
Expand Down
18 changes: 12 additions & 6 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Volume,
Expand Down Expand Up @@ -881,7 +882,7 @@ def get_run_shim_script(
]


def get_gateway_user_data(authorized_key: str) -> str:
def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
return get_cloud_config(
package_update=True,
packages=[
Expand All @@ -897,7 +898,7 @@ def get_gateway_user_data(authorized_key: str) -> str:
"s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/",
"/etc/nginx/nginx.conf",
],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router))],
],
ssh_authorized_keys=[authorized_key],
)
Expand Down Expand Up @@ -1018,24 +1019,29 @@ def get_latest_runner_build() -> Optional[str]:
return None


def get_dstack_gateway_wheel(build: str) -> str:
def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str:
channel = "release" if settings.DSTACK_RELEASE else "stgn"
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
if build == "latest":
r = requests.get(f"{base_url}/latest-version", timeout=5)
r.raise_for_status()
build = r.text.strip()
logger.debug("Found the latest gateway build: %s", build)
return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
# Build package spec with extras if router is specified
if router:
return f"dstack-gateway[{router.type}] @ {wheel}"
return f"dstack-gateway @ {wheel}"


def get_dstack_gateway_commands() -> List[str]:
def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
build = get_dstack_runner_version()
gateway_package = get_dstack_gateway_wheel(build, router)
return [
"mkdir -p /home/ubuntu/dstack",
"python3 -m venv /home/ubuntu/dstack/blue",
"python3 -m venv /home/ubuntu/dstack/green",
f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}",
f"/home/ubuntu/dstack/blue/bin/pip install '{gateway_package}'",
"sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run",
]

Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,9 @@ def create_gateway(
machine_type="e2-medium",
accelerators=[],
spot=False,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
authorized_keys=[configuration.ssh_key_pub],
labels=labels,
tags=[gcp_resources.DSTACK_GATEWAY_TAG],
Expand Down
18 changes: 13 additions & 5 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import shlex
import subprocess
import tempfile
import threading
import time
from enum import Enum
from typing import Optional
from typing import List, Optional

from gpuhunt import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, AcceleratorVendor
from kubernetes import client
Expand Down Expand Up @@ -51,6 +52,7 @@
)
from dstack._internal.core.models.placement import PlacementGroup
from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Memory
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.utils.common import get_or_error, parse_memory
Expand Down Expand Up @@ -371,7 +373,9 @@ def create_gateway(
# Consider deploying an NLB. It seems it requires some extra configuration on the cluster:
# https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html
instance_name = generate_unique_gateway_instance_name(configuration)
commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub])
commands = _get_gateway_commands(
authorized_keys=[configuration.ssh_key_pub], router=configuration.router
)
pod = client.V1Pod(
metadata=client.V1ObjectMeta(
name=instance_name,
Expand Down Expand Up @@ -983,9 +987,13 @@ def _add_authorized_key_to_jump_pod(
)


def _get_gateway_commands(authorized_keys: list[str]) -> list[str]:
def _get_gateway_commands(
authorized_keys: List[str], router: Optional[AnyRouterConfig] = None
) -> List[str]:
authorized_keys_content = "\n".join(authorized_keys).strip()
gateway_commands = " && ".join(get_dstack_gateway_commands())
gateway_commands = " && ".join(get_dstack_gateway_commands(router=router))
quoted_gateway_commands = shlex.quote(gateway_commands)

commands = [
# install packages
"apt-get update && apt-get install -y sudo wget openssh-server nginx python3.10-venv libaugeas0",
Expand Down Expand Up @@ -1013,7 +1021,7 @@ def _get_gateway_commands(authorized_keys: list[str]) -> list[str]:
# start sshd
"/usr/sbin/sshd -p 22 -o PermitUserEnvironment=yes",
# run gateway
f"su ubuntu -c '{gateway_commands}'",
f"su ubuntu -c {quoted_gateway_commands}",
"sleep infinity",
]
return commands
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.utils.tags import tags_validator


Expand Down Expand Up @@ -50,6 +51,10 @@ class GatewayConfiguration(CoreModel):
default: Annotated[bool, Field(description="Make the gateway default")] = False
backend: Annotated[BackendType, Field(description="The gateway backend")]
region: Annotated[str, Field(description="The gateway region")]
router: Annotated[
Optional[AnyRouterConfig],
Field(description="The router configuration"),
] = None
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
] = None
Expand Down Expand Up @@ -113,6 +118,7 @@ class GatewayComputeConfiguration(CoreModel):
ssh_key_pub: str
certificate: Optional[AnyGatewayCertificate] = None
tags: Optional[Dict[str, str]] = None
router: Optional[AnyRouterConfig] = None


class GatewayProvisioningData(CoreModel):
Expand Down
16 changes: 16 additions & 0 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum
from typing import Literal

from dstack._internal.core.models.common import CoreModel


class RouterType(str, Enum):
SGLANG = "sglang"


class SGLangRouterConfig(CoreModel):
type: Literal["sglang"] = "sglang"
policy: Literal["random", "round_robin", "cache_aware", "power_of_two"] = "cache_aware"


AnyRouterConfig = SGLangRouterConfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{% for replica in replicas %}
# Worker {{ loop.index }}
upstream router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream {
server unix:{{ replica.socket }};
}

server {
listen 127.0.0.1:{{ ports[loop.index0] }};
access_log off; # disable access logs for this internal endpoint

proxy_read_timeout 300s;
proxy_send_timeout 300s;

location / {
proxy_pass http://router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Connection "";
proxy_set_header Upgrade $http_upgrade;
}
}
{% endfor %}
11 changes: 11 additions & 0 deletions src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m;

{% if replicas %}
upstream {{ domain }}.upstream {
{% if router_port is not none %}
server 127.0.0.1:{{ router_port }}; # SGLang router on the gateway
{% else %}
{% for replica in replicas %}
server unix:{{ replica.socket }}; # replica {{ replica.id }}
{% endfor %}
{% endif %}
}
{% else %}

Expand All @@ -32,6 +36,13 @@ server {
}
{% endfor %}

{# For SGLang router: block all requests except whitelisted locations added dynamically above #}
{% if router is not none and router.type == "sglang" %}
location / {
return 403;
}
{% endif %}

location @websocket {
set $dstack_replica_hit 1;
{% if replicas %}
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def register_service(
model=body.options.openai.model if body.options.openai is not None else None,
ssh_private_key=body.ssh_private_key,
repo=repo,
router=body.router,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel, Field

from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.proxy.lib.models import RateLimit


Expand Down Expand Up @@ -44,6 +45,7 @@ class RegisterServiceRequest(BaseModel):
options: Options
ssh_private_key: str
rate_limits: tuple[RateLimit, ...] = ()
router: Optional[AnyRouterConfig] = None


class RegisterReplicaRequest(BaseModel):
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I'd suggest to move the model_routers directory to proxy/gateway/services/model_routers.

The subdirectories in proxy/gateway are supposed to represent architectural tiers:

  • proxy/gateway/repo - the data tier
  • proxy/gateway/services - the logic tier
  • proxy/gateway/routers and proxy/gateway/schemas - the presentation tier

The model routers implementation is part of the logic tier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dstack._internal.core.models.routers import AnyRouterConfig, RouterType
from dstack._internal.proxy.gateway.services.model_routers.sglang import SglangRouter
from dstack._internal.proxy.lib.errors import ProxyError

from .base import Router, RouterContext


def get_router(router: AnyRouterConfig, context: RouterContext) -> Router:
if router.type == RouterType.SGLANG:
return SglangRouter(config=router, context=context)
raise ProxyError(f"Router type '{router.type}' is not available")


__all__ = [
"Router",
"RouterContext",
"get_router",
]
91 changes: 91 additions & 0 deletions src/dstack/_internal/proxy/gateway/services/model_routers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Literal, Optional

from pydantic import BaseModel

from dstack._internal.core.models.routers import AnyRouterConfig


class RouterContext(BaseModel):
"""Context for router initialization and configuration."""

class Config:
frozen = True

host: str = "127.0.0.1"
port: int
log_dir: Path
log_level: Literal["debug", "info", "warning", "error"] = "info"


class Router(ABC):
"""Abstract base class for router implementations.
A router manages the lifecycle of worker replicas and handles request routing.
Different router implementations may have different mechanisms for managing
replicas.
"""

def __init__(
self,
context: RouterContext,
config: Optional[AnyRouterConfig] = None,
):
"""Initialize router with context.

Args:
context: Runtime context for the router (host, port, logging, etc.)
config: Optional router configuration (implementation-specific)
"""
self.context = context

@abstractmethod
def start(self) -> None:
"""Start the router process.

Raises:
Exception: If the router fails to start.
"""
...

@abstractmethod
def stop(self) -> None:
"""Stop the router process.

Raises:
Exception: If the router fails to stop.
"""
...

@abstractmethod
def is_running(self) -> bool:
"""Check if the router is currently running and responding.

Returns:
True if the router is running and healthy, False otherwise.
"""
...

@abstractmethod
def remove_replicas(self, replica_urls: List[str]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This method looks redundant, because it is possible to remove replicas using update_replicas.

For example, this can be replaced with run_async(router.update_replicas, [])

"""Unregister replicas from the router (actual API calls to remove workers).

Args:
replica_urls: The list of replica URLs to remove from router.

Raises:
Exception: If removing replicas fails.
"""
...

@abstractmethod
def update_replicas(self, replica_urls: List[str]) -> None:
"""Update replicas for service, replacing the current set.

Args:
replica_urls: The new list of replica URLs for this service.

Raises:
Exception: If updating replicas fails.
"""
...
Loading