diff --git a/cli/dstack/_internal/cli/commands/build/__init__.py b/cli/dstack/_internal/cli/commands/build/__init__.py index f5785b491..167b92345 100644 --- a/cli/dstack/_internal/cli/commands/build/__init__.py +++ b/cli/dstack/_internal/cli/commands/build/__init__.py @@ -4,10 +4,16 @@ from dstack._internal.api.runs import list_runs_hub from dstack._internal.cli.commands import BasicCommand -from dstack._internal.cli.commands.run import _poll_run, _print_run_plan, _read_ssh_key_pub +from dstack._internal.cli.commands.run import ( + _poll_run, + _print_run_plan, + _read_ssh_key_pub, + _reserve_ports, +) from dstack._internal.cli.common import add_project_argument, check_init, console from dstack._internal.cli.config import config, get_hub_client from dstack._internal.cli.configuration import load_configuration +from dstack._internal.configurators.ports import PortUsedError from dstack._internal.core.error import RepoNotInitializedError @@ -26,45 +32,53 @@ def _command(self, args: argparse.Namespace): elif configurator.profile.project: project_name = configurator.profile.project - hub_client = get_hub_client(project_name=project_name) - if ( - hub_client.repo.repo_data.repo_type != "local" - and not hub_client.get_repo_credentials() - ): - raise RepoNotInitializedError("No credentials", project_name=project_name) + try: + hub_client = get_hub_client(project_name=project_name) + if ( + hub_client.repo.repo_data.repo_type != "local" + and not hub_client.get_repo_credentials() + ): + raise RepoNotInitializedError("No credentials", project_name=project_name) - if not config.repo_user_config.ssh_key_path: - ssh_key_pub = None - else: - ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) + if not config.repo_user_config.ssh_key_path: + ssh_key_pub = None + else: + ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) - configurator_args, run_args = configurator.get_parser().parse_known_args( - args.args + args.unknown - ) - configurator.apply_args(configurator_args) + configurator_args, run_args = configurator.get_parser().parse_known_args( + args.args + args.unknown + ) + configurator.apply_args(configurator_args) - run_plan = hub_client.get_run_plan(configurator) - console.print("dstack will execute the following plan:\n") - _print_run_plan(configurator.configuration_path, run_plan) - if not args.yes and not Confirm.ask("Continue?"): - console.print("\nExiting...") - exit(0) - console.print("\nProvisioning...\n") + run_plan = hub_client.get_run_plan(configurator) + console.print("dstack will execute the following plan:\n") + _print_run_plan(configurator.configuration_path, run_plan) + if not args.yes and not Confirm.ask("Continue?"): + console.print("\nExiting...") + exit(0) - run_name, jobs = hub_client.run_configuration( - configurator=configurator, - ssh_key_pub=ssh_key_pub, - run_args=run_args, - ) - runs = list_runs_hub(hub_client, run_name=run_name) - run = runs[0] - _poll_run( - hub_client, - run, - jobs, - ssh_key=config.repo_user_config.ssh_key_path, - watcher=None, - ) + ports_locks = _reserve_ports( + configurator.app_specs(), hub_client.get_project_backend_type() == "local" + ) + + console.print("\nProvisioning...\n") + run_name, jobs = hub_client.run_configuration( + configurator=configurator, + ssh_key_pub=ssh_key_pub, + run_args=run_args, + ) + runs = list_runs_hub(hub_client, run_name=run_name) + run = runs[0] + _poll_run( + hub_client, + run, + jobs, + ssh_key=config.repo_user_config.ssh_key_path, + watcher=None, + ports_locks=ports_locks, + ) + except PortUsedError as e: + exit(f"{type(e).__name__}: {e}") def __init__(self, parser): super().__init__(parser) diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index fede51beb..a05a690bc 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -6,7 +6,7 @@ import time from argparse import Namespace from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Tuple import websocket from cursor import cursor @@ -23,6 +23,8 @@ from dstack._internal.cli.common import add_project_argument, check_init, console, print_runs from dstack._internal.cli.config import config, get_hub_client from dstack._internal.cli.configuration import load_configuration +from dstack._internal.configurators.ports import PortUsedError +from dstack._internal.core.app import AppSpec from dstack._internal.core.error import RepoNotInitializedError from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import Job, JobErrorCode, JobHead, JobStatus @@ -85,14 +87,6 @@ def register(self): type=str, dest="profile_name", ) - self._parser.add_argument( - "-t", - "--tag", - metavar="TAG", - help="A tag name. Warning, if the tag exists, " "it will be overridden.", - type=str, - dest="tag_name", - ) self._parser.add_argument( "args", metavar="ARGS", @@ -145,8 +139,14 @@ def _command(self, args: Namespace): if not args.yes and not Confirm.ask("Continue?"): console.print("\nExiting...") exit(0) - console.print("\nProvisioning...\n") + ports_locks = None + if not args.detach: + ports_locks = _reserve_ports( + configurator.app_specs(), hub_client.get_project_backend_type() == "local" + ) + + console.print("\nProvisioning...\n") run_name, jobs = hub_client.run_configuration( configurator=configurator, ssh_key_pub=ssh_key_pub, @@ -161,7 +161,10 @@ def _command(self, args: Namespace): jobs, ssh_key=config.repo_user_config.ssh_key_path, watcher=watcher, + ports_locks=ports_locks, ) + except PortUsedError as e: + exit(f"{type(e).__name__}: {e}") finally: if watcher.is_alive(): watcher.stop() @@ -217,6 +220,7 @@ def _poll_run( job_heads: List[JobHead], ssh_key: Optional[str], watcher: Optional[Watcher], + ports_locks: Tuple[PortsLock, PortsLock], ): print_runs([run]) console.print() @@ -276,7 +280,7 @@ def _poll_run( # attach to the instance, attach to the container in the background jobs = [hub_client.get_job(job_head.job_id) for job_head in job_heads] - ports = _attach(hub_client, jobs[0], ssh_key) + ports = _attach(hub_client, jobs[0], ssh_key, ports_locks) console.print() console.print("[grey58]To stop, press Ctrl+C.[/]") console.print() @@ -338,25 +342,57 @@ def _print_failed_run_message(run: RunHead): console.print("Provisioning failed\n") -def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int]: +def _reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]: + host_ports = {} + host_ports_lock = PortsLock() + app_ports = {} + app_ports_lock = PortsLock() + openssh_server_port: Optional[int] = None + + for app in apps: + app_ports[app.port] = app.map_to_port or 0 + if app.app_name == "openssh-server": + openssh_server_port = app.port + if local_backend and openssh_server_port is None: + return host_ports_lock, app_ports_lock + + if not local_backend: + if openssh_server_port is None: + host_ports.update(app_ports) + app_ports = {} + host_ports_lock = PortsLock(host_ports).acquire() + + if openssh_server_port is not None: + del app_ports[openssh_server_port] + if app_ports: + app_ports_lock = PortsLock(app_ports).acquire() + + return host_ports_lock, app_ports_lock + + +def _attach( + hub_client: HubClient, job: Job, ssh_key_path: str, ports_locks: Tuple[PortsLock, PortsLock] +) -> Dict[int, int]: """ :return: (host tunnel ports, container tunnel ports, ports mapping) """ backend_type = hub_client.get_project_backend_type() app_ports = {} - openssh_server_port = 0 + openssh_server_port: Optional[int] = None for app in job.app_specs or []: app_ports[app.port] = app.map_to_port or 0 if app.app_name == "openssh-server": openssh_server_port = app.port - if backend_type == "local" and openssh_server_port == 0: + if backend_type == "local" and openssh_server_port is None: console.print("Provisioning... It may take up to a minute. [green]✓[/]") return {k: v for k, v in app_ports.items() if v != 0} console.print("Starting SSH tunnel...") include_ssh_config(config.ssh_config_path) ws_port = int(job.env["WS_LOGS_PORT"]) - host_ports = {ws_port: ws_port} + + host_ports = {} + host_ports_lock, app_ports_lock = ports_locks if backend_type != "local": ssh_config_add_host( @@ -374,11 +410,10 @@ def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int "ControlPersist": "10m", }, ) - host_ports[ws_port] = 0 # to map dynamically - if openssh_server_port == 0: - host_ports.update(app_ports) + if openssh_server_port is None: app_ports = {} - host_ports = PortsLock(host_ports).acquire().release() + host_ports = PortsLock({ws_port: ws_port}).acquire().release() + host_ports.update(host_ports_lock.release()) for i in range(3): # retry time.sleep(2**i) if run_ssh_tunnel(f"{job.run_name}-host", host_ports): @@ -386,7 +421,7 @@ def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int else: console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]") - if openssh_server_port != 0: + if openssh_server_port is not None: options = { "HostName": "localhost", "Port": app_ports[openssh_server_port] or openssh_server_port, @@ -403,8 +438,7 @@ def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int ssh_config_add_host(config.ssh_config_path, job.run_name, options) del app_ports[openssh_server_port] if app_ports: - app_ports_lock = PortsLock(app_ports).acquire() - app_ports = app_ports_lock.dict() + app_ports.update(app_ports_lock.dict()) # try to attach in the background threading.Thread( target=_attach_to_container, diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index 8f6109834..3192e9011 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -13,7 +13,7 @@ DEFAULT_MAX_DURATION_SECONDS = 6 * 3600 require_sshd = require(["sshd"]) -install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"' +install_ipykernel = f'(pip install --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"' class DevEnvironmentConfigurator(JobConfigurator): diff --git a/cli/dstack/_internal/configurators/ports.py b/cli/dstack/_internal/configurators/ports.py index 467fd1f91..809b3744b 100644 --- a/cli/dstack/_internal/configurators/ports.py +++ b/cli/dstack/_internal/configurators/ports.py @@ -7,10 +7,6 @@ RESERVED_PORTS_END = 10999 -class PortReservedError(DstackError): - pass - - class PortUsedError(DstackError): pass diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py index 66d494e73..b429aee8f 100644 --- a/cli/dstack/_internal/core/profile.py +++ b/cli/dstack/_internal/core/profile.py @@ -14,9 +14,9 @@ def parse_memory(v: Optional[Union[int, str]]) -> Optional[int]: """ Converts human-readable sizes (MB and GB) to megabytes - >>> mem_size("512MB") + >>> parse_memory("512MB") 512 - >>> mem_size("1 GB") + >>> parse_memory("1 GB") 1024 """ if isinstance(v, str): @@ -88,7 +88,9 @@ class Profile(ForbidExtra): retry_policy: ProfileRetryPolicy = ProfileRetryPolicy() max_duration: Optional[Union[int, str]] default: bool = False - _validate_limit = validator("max_duration", pre=True, allow_reuse=True)(parse_max_duration) + _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( + parse_max_duration + ) class ProfilesConfig(ForbidExtra):