Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 50 additions & 36 deletions cli/dstack/_internal/cli/commands/build/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

from dstack._internal.api.runs import list_runs_hub
from dstack._internal.cli.commands import BasicCommand
from dstack._internal.cli.commands.run import _poll_run, _print_run_plan, _read_ssh_key_pub
from dstack._internal.cli.commands.run import (
_poll_run,
_print_run_plan,
_read_ssh_key_pub,
_reserve_ports,
)
from dstack._internal.cli.common import add_project_argument, check_init, console
from dstack._internal.cli.config import config, get_hub_client
from dstack._internal.cli.configuration import load_configuration
from dstack._internal.configurators.ports import PortUsedError
from dstack._internal.core.error import RepoNotInitializedError


Expand All @@ -26,45 +32,53 @@ def _command(self, args: argparse.Namespace):
elif configurator.profile.project:
project_name = configurator.profile.project

hub_client = get_hub_client(project_name=project_name)
if (
hub_client.repo.repo_data.repo_type != "local"
and not hub_client.get_repo_credentials()
):
raise RepoNotInitializedError("No credentials", project_name=project_name)
try:
hub_client = get_hub_client(project_name=project_name)
if (
hub_client.repo.repo_data.repo_type != "local"
and not hub_client.get_repo_credentials()
):
raise RepoNotInitializedError("No credentials", project_name=project_name)

if not config.repo_user_config.ssh_key_path:
ssh_key_pub = None
else:
ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path)
if not config.repo_user_config.ssh_key_path:
ssh_key_pub = None
else:
ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path)

configurator_args, run_args = configurator.get_parser().parse_known_args(
args.args + args.unknown
)
configurator.apply_args(configurator_args)
configurator_args, run_args = configurator.get_parser().parse_known_args(
args.args + args.unknown
)
configurator.apply_args(configurator_args)

run_plan = hub_client.get_run_plan(configurator)
console.print("dstack will execute the following plan:\n")
_print_run_plan(configurator.configuration_path, run_plan)
if not args.yes and not Confirm.ask("Continue?"):
console.print("\nExiting...")
exit(0)
console.print("\nProvisioning...\n")
run_plan = hub_client.get_run_plan(configurator)
console.print("dstack will execute the following plan:\n")
_print_run_plan(configurator.configuration_path, run_plan)
if not args.yes and not Confirm.ask("Continue?"):
console.print("\nExiting...")
exit(0)

run_name, jobs = hub_client.run_configuration(
configurator=configurator,
ssh_key_pub=ssh_key_pub,
run_args=run_args,
)
runs = list_runs_hub(hub_client, run_name=run_name)
run = runs[0]
_poll_run(
hub_client,
run,
jobs,
ssh_key=config.repo_user_config.ssh_key_path,
watcher=None,
)
ports_locks = _reserve_ports(
configurator.app_specs(), hub_client.get_project_backend_type() == "local"
)

console.print("\nProvisioning...\n")
run_name, jobs = hub_client.run_configuration(
configurator=configurator,
ssh_key_pub=ssh_key_pub,
run_args=run_args,
)
runs = list_runs_hub(hub_client, run_name=run_name)
run = runs[0]
_poll_run(
hub_client,
run,
jobs,
ssh_key=config.repo_user_config.ssh_key_path,
watcher=None,
ports_locks=ports_locks,
)
except PortUsedError as e:
exit(f"{type(e).__name__}: {e}")

def __init__(self, parser):
super().__init__(parser)
Expand Down
78 changes: 56 additions & 22 deletions cli/dstack/_internal/cli/commands/run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from argparse import Namespace
from pathlib import Path
from typing import Dict, Iterator, List, Optional
from typing import Dict, Iterator, List, Optional, Tuple

import websocket
from cursor import cursor
Expand All @@ -23,6 +23,8 @@
from dstack._internal.cli.common import add_project_argument, check_init, console, print_runs
from dstack._internal.cli.config import config, get_hub_client
from dstack._internal.cli.configuration import load_configuration
from dstack._internal.configurators.ports import PortUsedError
from dstack._internal.core.app import AppSpec
from dstack._internal.core.error import RepoNotInitializedError
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobErrorCode, JobHead, JobStatus
Expand Down Expand Up @@ -85,14 +87,6 @@ def register(self):
type=str,
dest="profile_name",
)
self._parser.add_argument(
"-t",
"--tag",
metavar="TAG",
help="A tag name. Warning, if the tag exists, " "it will be overridden.",
type=str,
dest="tag_name",
)
self._parser.add_argument(
"args",
metavar="ARGS",
Expand Down Expand Up @@ -145,8 +139,14 @@ def _command(self, args: Namespace):
if not args.yes and not Confirm.ask("Continue?"):
console.print("\nExiting...")
exit(0)
console.print("\nProvisioning...\n")

ports_locks = None
if not args.detach:
ports_locks = _reserve_ports(
configurator.app_specs(), hub_client.get_project_backend_type() == "local"
)

console.print("\nProvisioning...\n")
run_name, jobs = hub_client.run_configuration(
configurator=configurator,
ssh_key_pub=ssh_key_pub,
Expand All @@ -161,7 +161,10 @@ def _command(self, args: Namespace):
jobs,
ssh_key=config.repo_user_config.ssh_key_path,
watcher=watcher,
ports_locks=ports_locks,
)
except PortUsedError as e:
exit(f"{type(e).__name__}: {e}")
finally:
if watcher.is_alive():
watcher.stop()
Expand Down Expand Up @@ -217,6 +220,7 @@ def _poll_run(
job_heads: List[JobHead],
ssh_key: Optional[str],
watcher: Optional[Watcher],
ports_locks: Tuple[PortsLock, PortsLock],
):
print_runs([run])
console.print()
Expand Down Expand Up @@ -276,7 +280,7 @@ def _poll_run(

# attach to the instance, attach to the container in the background
jobs = [hub_client.get_job(job_head.job_id) for job_head in job_heads]
ports = _attach(hub_client, jobs[0], ssh_key)
ports = _attach(hub_client, jobs[0], ssh_key, ports_locks)
console.print()
console.print("[grey58]To stop, press Ctrl+C.[/]")
console.print()
Expand Down Expand Up @@ -338,25 +342,57 @@ def _print_failed_run_message(run: RunHead):
console.print("Provisioning failed\n")


def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int]:
def _reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]:
host_ports = {}
host_ports_lock = PortsLock()
app_ports = {}
app_ports_lock = PortsLock()
openssh_server_port: Optional[int] = None

for app in apps:
app_ports[app.port] = app.map_to_port or 0
if app.app_name == "openssh-server":
openssh_server_port = app.port
if local_backend and openssh_server_port is None:
return host_ports_lock, app_ports_lock

if not local_backend:
if openssh_server_port is None:
host_ports.update(app_ports)
app_ports = {}
host_ports_lock = PortsLock(host_ports).acquire()

if openssh_server_port is not None:
del app_ports[openssh_server_port]
if app_ports:
app_ports_lock = PortsLock(app_ports).acquire()

return host_ports_lock, app_ports_lock


def _attach(
hub_client: HubClient, job: Job, ssh_key_path: str, ports_locks: Tuple[PortsLock, PortsLock]
) -> Dict[int, int]:
"""
:return: (host tunnel ports, container tunnel ports, ports mapping)
"""
backend_type = hub_client.get_project_backend_type()
app_ports = {}
openssh_server_port = 0
openssh_server_port: Optional[int] = None
for app in job.app_specs or []:
app_ports[app.port] = app.map_to_port or 0
if app.app_name == "openssh-server":
openssh_server_port = app.port
if backend_type == "local" and openssh_server_port == 0:
if backend_type == "local" and openssh_server_port is None:
console.print("Provisioning... It may take up to a minute. [green]✓[/]")
return {k: v for k, v in app_ports.items() if v != 0}

console.print("Starting SSH tunnel...")
include_ssh_config(config.ssh_config_path)
ws_port = int(job.env["WS_LOGS_PORT"])
host_ports = {ws_port: ws_port}

host_ports = {}
host_ports_lock, app_ports_lock = ports_locks

if backend_type != "local":
ssh_config_add_host(
Expand All @@ -374,19 +410,18 @@ def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int
"ControlPersist": "10m",
},
)
host_ports[ws_port] = 0 # to map dynamically
if openssh_server_port == 0:
host_ports.update(app_ports)
if openssh_server_port is None:
app_ports = {}
host_ports = PortsLock(host_ports).acquire().release()
host_ports = PortsLock({ws_port: ws_port}).acquire().release()
host_ports.update(host_ports_lock.release())
for i in range(3): # retry
time.sleep(2**i)
if run_ssh_tunnel(f"{job.run_name}-host", host_ports):
break
else:
console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]")

if openssh_server_port != 0:
if openssh_server_port is not None:
options = {
"HostName": "localhost",
"Port": app_ports[openssh_server_port] or openssh_server_port,
Expand All @@ -403,8 +438,7 @@ def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int
ssh_config_add_host(config.ssh_config_path, job.run_name, options)
del app_ports[openssh_server_port]
if app_ports:
app_ports_lock = PortsLock(app_ports).acquire()
app_ports = app_ports_lock.dict()
app_ports.update(app_ports_lock.dict())
# try to attach in the background
threading.Thread(
target=_attach_to_container,
Expand Down
2 changes: 1 addition & 1 deletion cli/dstack/_internal/configurators/dev_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DEFAULT_MAX_DURATION_SECONDS = 6 * 3600

require_sshd = require(["sshd"])
install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"'
install_ipykernel = f'(pip install --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"'


class DevEnvironmentConfigurator(JobConfigurator):
Expand Down
4 changes: 0 additions & 4 deletions cli/dstack/_internal/configurators/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
RESERVED_PORTS_END = 10999


class PortReservedError(DstackError):
pass


class PortUsedError(DstackError):
pass

Expand Down
8 changes: 5 additions & 3 deletions cli/dstack/_internal/core/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
def parse_memory(v: Optional[Union[int, str]]) -> Optional[int]:
"""
Converts human-readable sizes (MB and GB) to megabytes
>>> mem_size("512MB")
>>> parse_memory("512MB")
512
>>> mem_size("1 GB")
>>> parse_memory("1 GB")
1024
"""
if isinstance(v, str):
Expand Down Expand Up @@ -88,7 +88,9 @@ class Profile(ForbidExtra):
retry_policy: ProfileRetryPolicy = ProfileRetryPolicy()
max_duration: Optional[Union[int, str]]
default: bool = False
_validate_limit = validator("max_duration", pre=True, allow_reuse=True)(parse_max_duration)
_validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)(
parse_max_duration
)


class ProfilesConfig(ForbidExtra):
Expand Down