Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,13 @@ def list_jobs(self, repo_id: str, run_name: str) -> List[Job]:

def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus):
self.predict_build_plan(job) # raises exception on missing build
base_jobs.run_job(self.storage(), self.compute(), job, failed_to_start_job_new_status)
base_jobs.run_job(
self.storage(),
self.compute(),
self.secrets_manager(),
job,
failed_to_start_job_new_status,
)

def restart_job(self, job: Job):
base_jobs.restart_job(self.storage(), self.compute(), job)
Expand Down
57 changes: 40 additions & 17 deletions cli/dstack/_internal/backend/base/gateway.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import subprocess
import time
from typing import List, Optional

import pkg_resources

from dstack._internal.backend.base.compute import Compute
from dstack._internal.backend.base.head import (
delete_head_object,
list_head_objects,
put_head_object,
)
from dstack._internal.backend.base.secrets import SecretsManager
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.error import DstackError
from dstack._internal.core.error import SSHCommandError
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH
from dstack._internal.utils.common import PathLike
from dstack._internal.utils.common import PathLike, removeprefix
from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.random_names import generate_name


Expand All @@ -37,34 +40,54 @@ def delete_gateway(compute: Compute, storage: Storage, instance_name: str):
delete_head_object(storage, head)


def ssh_copy_id(
def resolve_hostname(secrets_manager: SecretsManager, repo_id: str, hostname: str) -> str:
secrets = {}
_, missed = VariablesInterpolator({}).interpolate(hostname, return_missing=True)
for ns_name in missed:
name = removeprefix(ns_name, "secrets.")
value = secrets_manager.get_secret(repo_id, name)
if value is not None:
secrets[name] = value.secret_value
return VariablesInterpolator({"secrets": secrets}).interpolate(hostname)


def publish(
hostname: str,
public_key: bytes,
port: int,
ssh_key: bytes,
user: str = "ubuntu",
id_rsa: Optional[PathLike] = HUB_PRIVATE_KEY_PATH,
):
command = f"echo '{public_key.decode()}' >> ~/.ssh/authorized_keys"
exec_ssh_command(hostname, command, user=user, id_rsa=id_rsa)
) -> str:
command = ["sudo", "python3", "-", hostname, str(port), f'"{ssh_key.decode().strip()}"']
with open(
pkg_resources.resource_filename("dstack._internal", "scripts/gateway_publish.py"), "r"
) as f:
output = exec_ssh_command(
hostname, command=" ".join(command), user=user, id_rsa=id_rsa, stdin=f
)
return output.decode().strip()


def exec_ssh_command(hostname: str, command: str, user: str, id_rsa: Optional[PathLike]) -> bytes:
def exec_ssh_command(
hostname: str, command: str, user: str, id_rsa: Optional[PathLike], stdin=None
) -> bytes:
args = ["ssh"]
if id_rsa is not None:
args += ["-i", id_rsa]
args += [
"-o",
"StrictHostKeyChecking=accept-new",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
f"{user}@{hostname}",
command,
]
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if not hostname: # ssh hangs indefinitely with empty hostname
raise SSHCommandError(
args, "ssh: Could not connect to the gateway, because hostname is empty"
)
proc = subprocess.Popen(args, stdin=stdin, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
if proc.returncode != 0:
raise SSHCommandError(args, stderr.decode())
return stdout


class SSHCommandError(DstackError):
def __init__(self, cmd: List[str], message: str):
super().__init__(message)
self.cmd = cmd
12 changes: 10 additions & 2 deletions cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dstack._internal.backend.base.gateway as gateway
from dstack._internal.backend.base import runners
from dstack._internal.backend.base.compute import Compute, InstanceNotFoundError, NoCapacityError
from dstack._internal.backend.base.secrets import SecretsManager
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.error import BackendError, BackendValueError, NoMatchingInstanceError
from dstack._internal.core.instance import InstanceType
Expand Down Expand Up @@ -119,17 +120,24 @@ def predict_job_instance(
def run_job(
storage: Storage,
compute: Compute,
secrets_manager: SecretsManager,
job: Job,
failed_to_start_job_new_status: JobStatus,
):
if job.status != JobStatus.SUBMITTED:
raise BackendError("Can't create a request for a job which status is not SUBMITTED")
try:
if job.configuration_type == ConfigurationType.SERVICE:
job.gateway.hostname = gateway.resolve_hostname(
secrets_manager, job.repo_ref.repo_id, job.gateway.hostname
)
private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=job.run_name)
gateway.ssh_copy_id(job.gateway.hostname, public_bytes)
job.gateway.sock_path = gateway.publish(
job.gateway.hostname, job.gateway.public_port, public_bytes
)
job.gateway.ssh_key = private_bytes.decode()
update_job(storage, job)

_try_run_job(
storage=storage,
compute=compute,
Expand Down Expand Up @@ -163,7 +171,7 @@ def restart_job(


def stop_job(
storage: Storage, compute: Compute, repo_id: str, job_id: str, terminate: str, abort: str
storage: Storage, compute: Compute, repo_id: str, job_id: str, terminate: bool, abort: bool
):
logger.info("Stopping job [repo_id=%s job_id=%s]", repo_id, job_id)
job_head = get_job_head(storage, repo_id, job_id)
Expand Down
12 changes: 8 additions & 4 deletions cli/dstack/_internal/backend/gcp/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ def create_gateway_firewall_rules(
network: str,
):
firewall_rule = compute_v1.Firewall()
firewall_rule.name = "dstack-gateway-in-" + network.replace("/", "-")
firewall_rule.name = "dstack-gateway-in-all-" + network.replace("/", "-")
firewall_rule.direction = "INGRESS"

allowed_ports = compute_v1.Allowed()
allowed_ports.I_p_protocol = "tcp"
allowed_ports.ports = ["22", "80", "443"]
allowed_ports.ports = ["0-65535"]

firewall_rule.allowed = [allowed_ports]
firewall_rule.source_ranges = ["0.0.0.0/0"]
firewall_rule.network = network
firewall_rule.description = "Allowing TCP traffic on ports 22, 80, and 443 from Internet."
firewall_rule.description = "Allowing TCP traffic on all ports from Internet."

firewall_rule.target_tags = [DSTACK_GATEWAY_TAG]

Expand All @@ -114,4 +114,8 @@ def gateway_disks(zone: str) -> List[compute_v1.AttachedDisk]:
def gateway_user_data_script() -> str:
return f"""#!/bin/sh
sudo apt-get update
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx"""
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx
WWW_UID=$(id -u www-data)
WWW_GID=$(id -g www-data)
install -m 700 -o $WWW_UID -g $WWW_GID -d /var/www/.ssh
install -m 600 -o $WWW_UID -g $WWW_GID /dev/null /var/www/.ssh/authorized_keys"""
8 changes: 6 additions & 2 deletions cli/dstack/_internal/configurators/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ def default_max_duration(self) -> Optional[int]:
return None # infinite

def ports(self) -> Dict[int, ports.PortMapping]:
port = self.conf.gateway.service_port
port = self.conf.port.container_port
return {port: ports.PortMapping(container_port=port)}

def gateway(self) -> Optional[job.Gateway]:
return job.Gateway.parse_obj(self.conf.gateway)
return job.Gateway(
hostname=self.conf.gateway,
service_port=self.conf.port.container_port,
public_port=self.conf.port.local_port,
)

def build_commands(self) -> List[str]:
return self.conf.build
Expand Down
26 changes: 16 additions & 10 deletions cli/dstack/_internal/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ class Artifact(ForbidExtra):
] = False


class Gateway(ForbidExtra):
hostname: Annotated[str, Field(description="IP address or domain name")]
public_port: Annotated[
ValidPort, Field(description="The port that the gateway listens to")
] = 80
service_port: Annotated[ValidPort, Field(description="The port that the service listens to")]


class BaseConfiguration(ForbidExtra):
type: Literal["none"]
image: Annotated[Optional[str], Field(description="The name of the Docker image to run")]
Expand Down Expand Up @@ -119,7 +111,7 @@ def convert_env(cls, v) -> Dict[str, str]:

class BaseConfigurationWithPorts(BaseConfiguration):
ports: Annotated[
List[Union[constr(regex=r"^(?:([0-9]+|\*):)?[0-9]+$"), ValidPort, PortMapping]],
List[Union[ValidPort, constr(regex=r"^(?:[0-9]+|\*):[0-9]+$"), PortMapping]],
Field(description="Port numbers/mapping to expose"),
] = []

Expand Down Expand Up @@ -147,7 +139,21 @@ class TaskConfiguration(BaseConfigurationWithPorts):
class ServiceConfiguration(BaseConfiguration):
type: Literal["service"] = "service"
commands: Annotated[CommandsList, Field(description="The bash commands to run")]
gateway: Annotated[Gateway, Field(description="The gateway to publish the service")]
port: Annotated[
Union[ValidPort, constr(regex=r"^[0-9]+:[0-9]+$"), PortMapping],
Field(description="The port, that application listens to or the mapping"),
]
gateway: Annotated[
str, Field(description="The gateway IP address or domain to publish the service")
]

@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
return PortMapping(local_port=80, container_port=v)
elif isinstance(v, str):
return PortMapping.parse(v)
return v


class DstackConfiguration(BaseModel):
Expand Down
10 changes: 9 additions & 1 deletion cli/dstack/_internal/core/error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional


class DstackError(Exception):
Expand Down Expand Up @@ -34,3 +34,11 @@ def __init__(self, message: Optional[str] = None, project_name: Optional[str] =

class NameNotFoundError(DstackError):
pass


class SSHCommandError(BackendError):
code = "ssh_command_error"

def __init__(self, cmd: List[str], message: str):
super().__init__(message)
self.cmd = cmd
3 changes: 2 additions & 1 deletion cli/dstack/_internal/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

class Gateway(BaseModel):
hostname: str
ssh_key: Optional[str]
service_port: int
public_port: int = 80
ssh_key: Optional[str]
sock_path: Optional[str]


class GpusRequirements(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions cli/dstack/_internal/core/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class Profile(ForbidExtra):
class ProfilesConfig(ForbidExtra):
profiles: List[Profile]

class Config:
schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"}

def default(self) -> Profile:
for p in self.profiles:
if p.default:
Expand Down
7 changes: 6 additions & 1 deletion cli/dstack/_internal/hub/routers/runners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status

from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.error import NoMatchingInstanceError, SSHCommandError
from dstack._internal.core.job import Job, JobStatus
from dstack._internal.hub.models import StopRunners
from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project
Expand Down Expand Up @@ -33,6 +33,11 @@ async def run(project_name: str, job: Job):
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)
except SSHCommandError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)


@router.post("/{project_name}/runners/restart")
Expand Down
Loading