Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 79 additions & 69 deletions cli/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,6 @@ def print_run_plan(configuration_file: str, run_plan: RunPlan):
console.print()


def reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]:
host_ports = {}
host_ports_lock = PortsLock()
app_ports = {}
app_ports_lock = PortsLock()
openssh_server_port: Optional[int] = None

for app in apps:
app_ports[app.port] = app.map_to_port or 0
if app.app_name == "openssh-server":
openssh_server_port = app.port
if local_backend and openssh_server_port is None:
return host_ports_lock, app_ports_lock

if not local_backend:
if openssh_server_port is None:
host_ports.update(app_ports)
app_ports = {}
host_ports_lock = PortsLock(host_ports).acquire()

if openssh_server_port is not None:
del app_ports[openssh_server_port]
if app_ports:
app_ports_lock = PortsLock(app_ports).acquire()

return host_ports_lock, app_ports_lock


def poll_run(
hub_client: HubClient,
run: RunHead,
Expand Down Expand Up @@ -241,61 +213,61 @@ def _print_failed_run_message(run: RunHead):
console.print("Provisioning failed\n")


def reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]:
"""
:return: host_ports_lock, app_ports_lock
"""
app_ports = {app.port: app.map_to_port or 0 for app in apps}
ssh_server_port = get_ssh_server_port(apps)

if not local_backend and ssh_server_port is None:
# cloud backand without ssh in the container: use a host ssh tunnel
return PortsLock(app_ports).acquire(), PortsLock()

if ssh_server_port is not None:
# any backend with ssh in the container: use a container ssh tunnel
del app_ports[ssh_server_port]
# for cloud backend: using ProxyJump to access ssh in the container, no host port forwarding needed
# for local backend: the same host, no port forwarding needed
return PortsLock(), PortsLock(app_ports).acquire()

# local backend without ssh in the container: all ports mapped by runner
return PortsLock(), PortsLock()


def _attach(
hub_client: HubClient, job: Job, ssh_key_path: str, ports_locks: Tuple[PortsLock, PortsLock]
) -> Dict[int, int]:
"""
:return: (host tunnel ports, container tunnel ports, ports mapping)
:return: ports_mapping
"""
backend_type = hub_client.get_project_backend_type()
app_ports = {}
openssh_server_port: Optional[int] = None
for app in job.app_specs or []:
app_ports[app.port] = app.map_to_port or 0
if app.app_name == "openssh-server":
openssh_server_port = app.port
if backend_type == "local" and openssh_server_port is None:
app_ports = {app.port: app.map_to_port or 0 for app in job.app_specs or []}
host_ports = {}
ssh_server_port = get_ssh_server_port(job.app_specs or [])

if backend_type == "local" and ssh_server_port is None:
console.print("Provisioning... It may take up to a minute. [green]✓[/]")
# local backend without ssh in container: all ports mapped by runner
return {k: v for k, v in app_ports.items() if v != 0}

console.print("Starting SSH tunnel...")
include_ssh_config(config.ssh_config_path)
ws_port = int(job.env["WS_LOGS_PORT"])

host_ports = {}
host_ports_lock, app_ports_lock = ports_locks

if backend_type != "local" and not ENABLE_LOCAL_CLOUD:
ssh_config_add_host(
config.ssh_config_path,
f"{job.run_name}-host",
{
"HostName": job.host_name,
# TODO: use non-root for all backends
"User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root",
"IdentityFile": ssh_key_path,
"StrictHostKeyChecking": "no",
"UserKnownHostsFile": "/dev/null",
"ControlPath": config.ssh_control_path(f"{job.run_name}-host"),
"ControlMaster": "auto",
"ControlPersist": "yes",
},
)
if openssh_server_port is None:
# cloud backend, need to forward logs websocket
console.print("Starting SSH tunnel...")
if ssh_server_port is None:
# ssh in the container: no need to forward app ports
app_ports = {}
host_ports = PortsLock({ws_port: 0}).acquire().release()
host_ports.update(host_ports_lock.release())
for i in range(3): # retry
time.sleep(2**i)
if run_ssh_tunnel(f"{job.run_name}-host", host_ports):
break
else:
console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]")
host_ports = _run_host_ssh_tunnel(job, ssh_key_path, host_ports_lock, backend_type)

if openssh_server_port is not None:
if ssh_server_port is not None:
# ssh in the container: update ssh config, run tunnel if any apps
options = {
"HostName": "localhost",
"Port": app_ports[openssh_server_port] or openssh_server_port,
"Port": app_ports[ssh_server_port] or ssh_server_port,
"User": "root",
"IdentityFile": ssh_key_path,
"StrictHostKeyChecking": "no",
Expand All @@ -307,20 +279,51 @@ def _attach(
if backend_type != "local" and not ENABLE_LOCAL_CLOUD:
options["ProxyJump"] = f"{job.run_name}-host"
ssh_config_add_host(config.ssh_config_path, job.run_name, options)
del app_ports[openssh_server_port]
del app_ports[ssh_server_port]
if app_ports:
# save mapping, but don't release ports yet
app_ports.update(app_ports_lock.dict())
# try to attach in the background
threading.Thread(
target=_attach_to_container,
target=_run_container_ssh_tunnel,
args=(hub_client, job.run_name, app_ports_lock),
daemon=True,
).start()

return {**host_ports, **app_ports}


def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: PortsLock):
def _run_host_ssh_tunnel(
job: Job, ssh_key_path: str, ports_lock: PortsLock, backend_type: str
) -> Dict[int, int]:
ssh_config_add_host(
config.ssh_config_path,
f"{job.run_name}-host",
{
"HostName": job.host_name,
# TODO: use non-root for all backends
"User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root",
"IdentityFile": ssh_key_path,
"StrictHostKeyChecking": "no",
"UserKnownHostsFile": "/dev/null",
"ControlPath": config.ssh_control_path(f"{job.run_name}-host"),
"ControlMaster": "auto",
"ControlPersist": "yes",
},
)
# get free port for logs
host_ports = PortsLock({int(job.env["WS_LOGS_PORT"]): 0}).acquire().release()
host_ports.update(ports_lock.release())
for i in range(3): # retry
time.sleep(2**i)
if run_ssh_tunnel(f"{job.run_name}-host", host_ports):
break
else:
console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]")
return host_ports


def _run_container_ssh_tunnel(hub_client: HubClient, run_name: str, ports_lock: PortsLock):
# idle BUILDING
for run in _poll_run_head(hub_client, run_name, loop_statuses=[JobStatus.BUILDING]):
pass
Expand All @@ -337,7 +340,7 @@ def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: Ports
"[red]ERROR[/] Can't establish SSH tunnel with the container\n"
"[grey58]Aborting...[/]"
)
hub_client.stop_jobs(run_name, abort=True)
hub_client.stop_jobs(run_name, terminate=True, abort=True)
exit(1)


Expand Down Expand Up @@ -421,3 +424,10 @@ def _ask_on_interrupt(hub_client: HubClient, run_name: str):
ssh_config_remove_host(config.ssh_config_path, f"{run_name}-host")
ssh_config_remove_host(config.ssh_config_path, run_name)
exit(0)


def get_ssh_server_port(apps: List[AppSpec]) -> Optional[int]:
for app in apps:
if app.app_name == "openssh-server":
return app.port
return None