diff --git a/docs/docs/concepts/fleets.md b/docs/docs/concepts/fleets.md
index 60575f9af..15c9bf78a 100644
--- a/docs/docs/concepts/fleets.md
+++ b/docs/docs/concepts/fleets.md
@@ -300,6 +300,47 @@ ssh_config:
- 3.255.177.52
```
+#### Head node
+
+If fleet nodes are behind a head node, configure [`proxy_jump`](../reference/dstack.yml/fleet.md#proxy_jump):
+
+
+
+ ```yaml
+ type: fleet
+ name: my-fleet
+
+ ssh_config:
+ user: ubuntu
+ identity_file: ~/.ssh/worker_node_key
+ hosts:
+ - 3.255.177.51
+ - 3.255.177.52
+ proxy_jump:
+ hostname: 3.255.177.50
+ user: ubuntu
+ identity_file: ~/.ssh/head_node_key
+ ```
+
+
+
+To be able to attach to runs, both explicitly with `dstack attach` and implicitly with `dstack apply`, you must either
+add a front node key (`~/.ssh/head_node_key`) to an SSH agent or configure a key path in `~/.ssh/config`:
+
+
+
+ ```
+ Host 3.255.177.50
+ IdentityFile ~/.ssh/head_node_key
+ ```
+
+
+
+where `Host` must match `ssh_config.proxy_jump.hostname` or `ssh_config.hosts[n].proxy_jump.hostname` if you configure head nodes
+on a per-worker basis.
+
+> Currently, [services](services.md) do not work on instances with a head node setup.
+
!!! info "Reference"
For all SSH fleet configuration options, refer to the [reference](../reference/dstack.yml/fleet.md).
diff --git a/docs/docs/reference/dstack.yml/fleet.md b/docs/docs/reference/dstack.yml/fleet.md
index 537ddb109..7c11967f7 100644
--- a/docs/docs/reference/dstack.yml/fleet.md
+++ b/docs/docs/reference/dstack.yml/fleet.md
@@ -17,12 +17,26 @@ The `fleet` configuration type allows creating and updating fleets.
show_root_heading: false
item_id_prefix: ssh_config-
+#### `ssh_config.proxy_jump` { #ssh_config-proxy_jump data-toc-label="proxy_jump" }
+
+#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams
+ overrides:
+ show_root_heading: false
+ item_id_prefix: proxy_jump-
+
#### `ssh_config.hosts[n]` { #ssh_config-hosts data-toc-label="hosts" }
#SCHEMA# dstack._internal.core.models.fleets.SSHHostParams
overrides:
show_root_heading: false
+##### `ssh_config.hosts[n].proxy_jump` { #proxy_jump data-toc-label="hosts[n].proxy_jump" }
+
+#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams
+ overrides:
+ show_root_heading: false
+ item_id_prefix: hosts-proxy_jump-
+
### `resources`
#SCHEMA# dstack._internal.core.models.resources.ResourcesSpecSchema
diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py
index 7eea9fb70..35699f216 100644
--- a/src/dstack/_internal/cli/services/configurators/fleet.py
+++ b/src/dstack/_internal/cli/services/configurators/fleet.py
@@ -201,13 +201,16 @@ def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown
def _preprocess_spec(spec: FleetSpec):
- if spec.configuration.ssh_config is not None:
- spec.configuration.ssh_config.ssh_key = _resolve_ssh_key(
- spec.configuration.ssh_config.identity_file
- )
- for host in spec.configuration.ssh_config.hosts:
+ ssh_config = spec.configuration.ssh_config
+ if ssh_config is not None:
+ ssh_config.ssh_key = _resolve_ssh_key(ssh_config.identity_file)
+ if ssh_config.proxy_jump is not None:
+ ssh_config.proxy_jump.ssh_key = _resolve_ssh_key(ssh_config.proxy_jump.identity_file)
+ for host in ssh_config.hosts:
if not isinstance(host, str):
host.ssh_key = _resolve_ssh_key(host.identity_file)
+ if host.proxy_jump is not None:
+ host.proxy_jump.ssh_key = _resolve_ssh_key(host.proxy_jump.identity_file)
def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]:
diff --git a/src/dstack/_internal/core/backends/remote/provisioning.py b/src/dstack/_internal/core/backends/remote/provisioning.py
index 3f3769791..7a2398b7c 100644
--- a/src/dstack/_internal/core/backends/remote/provisioning.py
+++ b/src/dstack/_internal/core/backends/remote/provisioning.py
@@ -1,9 +1,9 @@
import io
import json
import time
-from contextlib import contextmanager
+from contextlib import contextmanager, nullcontext
from textwrap import dedent
-from typing import Any, Dict, Generator, List
+from typing import Any, Dict, Generator, List, Optional
import paramiko
from gpuhunt import AcceleratorVendor, correct_gpu_memory_gib
@@ -17,6 +17,7 @@
Gpu,
InstanceType,
Resources,
+ SSHConnectionParams,
)
from dstack._internal.utils.gpu import (
convert_amd_gpu_name,
@@ -262,35 +263,72 @@ def host_info_to_instance_type(host_info: Dict[str, Any]) -> InstanceType:
@contextmanager
def get_paramiko_connection(
- ssh_user: str, host: str, port: int, pkeys: List[paramiko.PKey]
+ ssh_user: str,
+ host: str,
+ port: int,
+ pkeys: List[paramiko.PKey],
+ proxy: Optional[SSHConnectionParams] = None,
+ proxy_pkeys: Optional[list[paramiko.PKey]] = None,
) -> Generator[paramiko.SSHClient, None, None]:
- with paramiko.SSHClient() as client:
- client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- for pkey in pkeys:
- conn_url = f"{ssh_user}@{host}:{port}"
+ if proxy is not None:
+ if proxy_pkeys is None:
+ raise ProvisioningError("Missing proxy private keys")
+ proxy_ctx = get_paramiko_connection(
+ proxy.username, proxy.hostname, proxy.port, proxy_pkeys
+ )
+ else:
+ proxy_ctx = nullcontext()
+ conn_url = f"{ssh_user}@{host}:{port}"
+ with proxy_ctx as proxy_client, paramiko.SSHClient() as client:
+ proxy_channel: Optional[paramiko.Channel] = None
+ if proxy_client is not None:
try:
- logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint)
- client.connect(
- username=ssh_user,
- hostname=host,
- port=port,
- pkey=pkey,
- look_for_keys=False,
- allow_agent=False,
- timeout=SSH_CONNECT_TIMEOUT,
+ proxy_channel = proxy_client.get_transport().open_channel(
+ "direct-tcpip", (host, port), ("", 0)
)
- except paramiko.AuthenticationException:
- logger.debug(
- f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}'
- )
- continue # try next key
except (paramiko.SSHException, OSError) as e:
- raise ProvisioningError(f"Connect failed: {e}") from e
- else:
+ raise ProvisioningError(f"Proxy channel failed: {e}") from e
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ for pkey in pkeys:
+ logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint)
+ connected = _paramiko_connect(client, ssh_user, host, port, pkey, proxy_channel)
+ if connected:
yield client
return
- else:
- keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys)
- raise ProvisioningError(
- f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful"
+ logger.debug(
+ f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}'
)
+ keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys)
+ raise ProvisioningError(
+ f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful"
+ )
+
+
+def _paramiko_connect(
+ client: paramiko.SSHClient,
+ user: str,
+ host: str,
+ port: int,
+ pkey: paramiko.PKey,
+ channel: Optional[paramiko.Channel] = None,
+) -> bool:
+ """
+ Returns `True` if connected, `False` if auth failed, and raises `ProvisioningError`
+ on other errors.
+ """
+ try:
+ client.connect(
+ username=user,
+ hostname=host,
+ port=port,
+ pkey=pkey,
+ look_for_keys=False,
+ allow_agent=False,
+ timeout=SSH_CONNECT_TIMEOUT,
+ sock=channel,
+ )
+ return True
+ except paramiko.AuthenticationException:
+ return False
+ except (paramiko.SSHException, OSError) as e:
+ raise ProvisioningError(f"Connect failed: {e}") from e
diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py
index abc0766c4..c56720b9a 100644
--- a/src/dstack/_internal/core/models/fleets.py
+++ b/src/dstack/_internal/core/models/fleets.py
@@ -39,6 +39,14 @@ class InstanceGroupPlacement(str, Enum):
CLUSTER = "cluster"
+class SSHProxyParams(CoreModel):
+ hostname: Annotated[str, Field(description="The IP address or domain of proxy host")]
+ port: Annotated[Optional[int], Field(description="The SSH port of proxy host")] = None
+ user: Annotated[str, Field(description="The user to log in with for proxy host")]
+ identity_file: Annotated[str, Field(description="The private key to use for proxy host")]
+ ssh_key: Optional[SSHKey] = None
+
+
class SSHHostParams(CoreModel):
hostname: Annotated[str, Field(description="The IP address or domain to connect to")]
port: Annotated[
@@ -50,6 +58,9 @@ class SSHHostParams(CoreModel):
identity_file: Annotated[
Optional[str], Field(description="The private key to use for this host")
] = None
+ proxy_jump: Annotated[
+ Optional[SSHProxyParams], Field(description="The SSH proxy configuration for this host")
+ ] = None
internal_ip: Annotated[
Optional[str],
Field(
@@ -96,6 +107,9 @@ class SSHParams(CoreModel):
Optional[str], Field(description="The private key to use for all hosts")
] = None
ssh_key: Optional[SSHKey] = None
+ proxy_jump: Annotated[
+ Optional[SSHProxyParams], Field(description="The SSH proxy configuration for all hosts")
+ ] = None
hosts: Annotated[
List[Union[SSHHostParams, str]],
Field(
diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py
index 9b5bce0e5..84ae92a9e 100644
--- a/src/dstack/_internal/core/models/instances.py
+++ b/src/dstack/_internal/core/models/instances.py
@@ -92,6 +92,8 @@ class RemoteConnectionInfo(CoreModel):
port: int
ssh_user: str
ssh_keys: List[SSHKey]
+ ssh_proxy: Optional[SSHConnectionParams] = None
+ ssh_proxy_keys: Optional[list[SSHKey]] = None
env: Env = Env()
diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py
index 788d369b2..a095fafb2 100644
--- a/src/dstack/_internal/core/services/ssh/attach.py
+++ b/src/dstack/_internal/core/services/ssh/attach.py
@@ -2,7 +2,7 @@
import re
import time
from pathlib import Path
-from typing import Optional
+from typing import Optional, Union
import psutil
@@ -14,6 +14,8 @@
from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets
from dstack._internal.utils.path import FilePath, PathLike
from dstack._internal.utils.ssh import (
+ default_ssh_config_path,
+ get_host_config,
include_ssh_config,
normalize_path,
update_ssh_config,
@@ -88,28 +90,63 @@ def __init__(
},
)
self.ssh_proxy = ssh_proxy
- if ssh_proxy is None:
- self.host_config = {
+
+ hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}
+ self.hosts = hosts
+
+ if local_backend:
+ hosts[run_name] = {
"HostName": hostname,
- "Port": ssh_port,
- "User": user if dockerized else container_user,
- "IdentityFile": self.identity_file,
- "IdentitiesOnly": "yes",
- "StrictHostKeyChecking": "no",
- "UserKnownHostsFile": "/dev/null",
- }
- else:
- self.host_config = {
- "HostName": ssh_proxy.hostname,
- "Port": ssh_proxy.port,
- "User": ssh_proxy.username,
+ "Port": container_ssh_port,
+ "User": container_user,
"IdentityFile": self.identity_file,
"IdentitiesOnly": "yes",
"StrictHostKeyChecking": "no",
"UserKnownHostsFile": "/dev/null",
}
- if dockerized and not local_backend:
- self.container_config = {
+ elif dockerized:
+ if ssh_proxy is not None:
+ # SSH instance with jump host
+ # dstack has no IdentityFile for jump host, it must be either preconfigured
+ # in the ~/.ssh/config or loaded into ssh-agent
+ hosts[f"{run_name}-jump-host"] = {
+ "HostName": ssh_proxy.hostname,
+ "Port": ssh_proxy.port,
+ "User": ssh_proxy.username,
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null",
+ }
+ jump_host_config = get_host_config(ssh_proxy.hostname, default_ssh_config_path)
+ jump_host_identity_files = jump_host_config.get("identityfile")
+ if jump_host_identity_files:
+ hosts[f"{run_name}-jump-host"].update(
+ {
+ "IdentityFile": jump_host_identity_files[0],
+ "IdentitiesOnly": "yes",
+ }
+ )
+ hosts[f"{run_name}-host"] = {
+ "HostName": hostname,
+ "Port": ssh_port,
+ "User": user,
+ "IdentityFile": self.identity_file,
+ "IdentitiesOnly": "yes",
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null",
+ "ProxyJump": f"{run_name}-jump-host",
+ }
+ else:
+ # Regular SSH instance or VM-based cloud instance
+ hosts[f"{run_name}-host"] = {
+ "HostName": hostname,
+ "Port": ssh_port,
+ "User": user,
+ "IdentityFile": self.identity_file,
+ "IdentitiesOnly": "yes",
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null",
+ }
+ hosts[run_name] = {
"HostName": "localhost",
"Port": container_ssh_port,
"User": container_user,
@@ -119,32 +156,41 @@ def __init__(
"UserKnownHostsFile": "/dev/null",
"ProxyJump": f"{run_name}-host",
}
- elif ssh_proxy is not None:
- self.container_config = {
- "HostName": hostname,
- "Port": ssh_port,
- "User": container_user,
- "IdentityFile": self.identity_file,
- "IdentitiesOnly": "yes",
- "StrictHostKeyChecking": "no",
- "UserKnownHostsFile": "/dev/null",
- "ProxyJump": f"{run_name}-jump-host",
- }
else:
- self.container_config = None
- if local_backend:
- self.container_config = None
- self.host_config = {
- "HostName": hostname,
- "Port": container_ssh_port,
- "User": container_user,
- "IdentityFile": self.identity_file,
- "IdentitiesOnly": "yes",
- "StrictHostKeyChecking": "no",
- "UserKnownHostsFile": "/dev/null",
- }
- if self.container_config is not None and get_ssh_client_info().supports_multiplexing:
- self.container_config.update(
+ if ssh_proxy is not None:
+ # Kubernetes
+ hosts[f"{run_name}-jump-host"] = {
+ "HostName": ssh_proxy.hostname,
+ "Port": ssh_proxy.port,
+ "User": ssh_proxy.username,
+ "IdentityFile": self.identity_file,
+ "IdentitiesOnly": "yes",
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null",
+ }
+ hosts[run_name] = {
+ "HostName": hostname,
+ "Port": ssh_port,
+ "User": container_user,
+ "IdentityFile": self.identity_file,
+ "IdentitiesOnly": "yes",
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null",
+ "ProxyJump": f"{run_name}-jump-host",
+ }
+ else:
+ # Container-based backends
+ hosts[run_name] = {
+ "HostName": hostname,
+ "Port": ssh_port,
+ "User": container_user,
+ "IdentityFile": self.identity_file,
+ "IdentitiesOnly": "yes",
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null",
+ }
+ if get_ssh_client_info().supports_multiplexing:
+ hosts[run_name].update(
{
"ControlMaster": "auto",
"ControlPath": self.control_sock_path,
@@ -153,14 +199,8 @@ def __init__(
def attach(self):
include_ssh_config(self.ssh_config_path)
- if self.container_config is None:
- update_ssh_config(self.ssh_config_path, self.run_name, self.host_config)
- elif self.ssh_proxy is not None:
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-jump-host", self.host_config)
- update_ssh_config(self.ssh_config_path, self.run_name, self.container_config)
- else:
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-host", self.host_config)
- update_ssh_config(self.ssh_config_path, self.run_name, self.container_config)
+ for host, options in self.hosts.items():
+ update_ssh_config(self.ssh_config_path, host, options)
max_retries = 10
self._ports_lock.release()
@@ -178,9 +218,8 @@ def attach(self):
def detach(self):
self.tunnel.close()
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-jump-host", {})
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-host", {})
- update_ssh_config(self.ssh_config_path, self.run_name, {})
+ for host in self.hosts:
+ update_ssh_config(self.ssh_config_path, host, {})
def __enter__(self):
self.attach()
diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py
index 98789a618..03c9894b4 100644
--- a/src/dstack/_internal/core/services/ssh/tunnel.py
+++ b/src/dstack/_internal/core/services/ssh/tunnel.py
@@ -70,6 +70,7 @@ def __init__(
ssh_config_path: Union[PathLike, Literal["none"]] = "none",
port: Optional[int] = None,
ssh_proxy: Optional[SSHConnectionParams] = None,
+ ssh_proxy_identity: Optional[FilePathOrContent] = None,
):
"""
:param forwarded_sockets: Connections to the specified local sockets will be
@@ -97,7 +98,15 @@ def __init__(
identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
) as f:
f.write(identity.content)
- self.identity_path = normalize_path(identity_path)
+ self.identity_path = normalize_path(self._get_identity_path(identity, "identity"))
+ if ssh_proxy_identity is not None:
+ self.ssh_proxy_identity_path = normalize_path(
+ self._get_identity_path(ssh_proxy_identity, "proxy_identity")
+ )
+ elif ssh_proxy is not None:
+ self.ssh_proxy_identity_path = self.identity_path
+ else:
+ self.ssh_proxy_identity_path = None
self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
self.ssh_client_info = get_ssh_client_info()
self.ssh_exec_path = str(self.ssh_client_info.path)
@@ -166,7 +175,7 @@ def proxy_command(self) -> Optional[List[str]]:
return [
self.ssh_exec_path,
"-i",
- self.identity_path,
+ self.ssh_proxy_identity_path,
"-W",
"%h:%p",
"-o",
@@ -263,6 +272,16 @@ def _remove_log_file(self) -> None:
except OSError as e:
logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e)
+ def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike:
+ if isinstance(identity, FilePath):
+ return identity.path
+ identity_path = os.path.join(self.temp_dir.name, tmp_filename)
+ with open(
+ identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
+ ) as f:
+ f.write(identity.content)
+ return identity_path
+
def ports_to_forwarded_sockets(
ports: Dict[int, int], bind_local: str = "localhost"
diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py
index 63fe6f312..fa798bdc4 100644
--- a/src/dstack/_internal/server/background/tasks/process_instances.py
+++ b/src/dstack/_internal/server/background/tasks/process_instances.py
@@ -47,6 +47,7 @@
InstanceStatus,
InstanceType,
RemoteConnectionInfo,
+ SSHKey,
)
from dstack._internal.core.models.placement import (
PlacementGroup,
@@ -86,6 +87,7 @@
get_instance_profile,
get_instance_provisioning_data,
get_instance_requirements,
+ get_instance_ssh_private_keys,
)
from dstack._internal.server.services.runner import client as runner_client
from dstack._internal.server.services.runner.client import HealthStatus
@@ -232,11 +234,11 @@ async def _add_remote(instance: InstanceModel) -> None:
remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
# Prepare connection key
try:
- pkeys = [
- pkey_from_str(sk.private)
- for sk in remote_details.ssh_keys
- if sk.private is not None
- ]
+ pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)
+ if remote_details.ssh_proxy_keys is not None:
+ ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys)
+ else:
+ ssh_proxy_pkeys = None
except (ValueError, PasswordRequiredException):
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Unsupported private SSH key type"
@@ -254,7 +256,9 @@ async def _add_remote(instance: InstanceModel) -> None:
authorized_keys.append(instance.project.ssh_public_key.strip())
try:
- future = run_async(_deploy_instance, remote_details, pkeys, authorized_keys)
+ future = run_async(
+ _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys
+ )
deploy_timeout = 20 * 60 # 20 minutes
result = await asyncio.wait_for(future, timeout=deploy_timeout)
health, host_info = result
@@ -356,7 +360,7 @@ async def _add_remote(instance: InstanceModel) -> None:
ssh_port=remote_details.port,
dockerized=True,
backend_data=None,
- ssh_proxy=None,
+ ssh_proxy=remote_details.ssh_proxy,
)
instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING
@@ -379,10 +383,16 @@ async def _add_remote(instance: InstanceModel) -> None:
def _deploy_instance(
remote_details: RemoteConnectionInfo,
pkeys: List[PKey],
+ ssh_proxy_pkeys: Optional[list[PKey]],
authorized_keys: List[str],
) -> Tuple[HealthStatus, Dict[str, Any]]:
with get_paramiko_connection(
- remote_details.ssh_user, remote_details.host, remote_details.port, pkeys
+ remote_details.ssh_user,
+ remote_details.host,
+ remote_details.port,
+ pkeys,
+ remote_details.ssh_proxy,
+ ssh_proxy_pkeys,
) as client:
logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}")
@@ -638,18 +648,14 @@ async def _check_instance(instance: InstanceModel) -> None:
instance.status = InstanceStatus.BUSY
return
- ssh_private_key = instance.project.ssh_private_key
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
- if instance.remote_connection_info is not None:
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
- instance.remote_connection_info
- )
- ssh_private_key = remote_conn_info.ssh_keys[0].private
+ ssh_private_keys = get_instance_ssh_private_keys(instance)
# May return False if fails to establish ssh connection
health_status_response = await run_async(
- _instance_healthcheck, ssh_private_key, job_provisioning_data, None
+ _instance_healthcheck,
+ ssh_private_keys,
+ job_provisioning_data,
+ None,
)
if isinstance(health_status_response, bool) or health_status_response is None:
health_status = HealthStatus(healthy=False, reason="SSH or tunnel error")
@@ -971,3 +977,7 @@ def _get_instance_timeout_interval(
if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
return timedelta(seconds=3300)
return timedelta(seconds=600)
+
+
+def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
+ return [pkey_from_str(sk.private) for sk in ssh_keys if sk.private is not None]
diff --git a/src/dstack/_internal/server/background/tasks/process_metrics.py b/src/dstack/_internal/server/background/tasks/process_metrics.py
index 3ebe4f6e6..49ac45fc1 100644
--- a/src/dstack/_internal/server/background/tasks/process_metrics.py
+++ b/src/dstack/_internal/server/background/tasks/process_metrics.py
@@ -3,18 +3,19 @@
from typing import Dict, List, Optional
from sqlalchemy import delete, select
-from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import joinedload
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
from dstack._internal.core.models.runs import JobStatus
from dstack._internal.server import settings
from dstack._internal.server.db import get_session_ctx
-from dstack._internal.server.models import JobMetricsPoint, JobModel
+from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel
from dstack._internal.server.schemas.runner import MetricsResponse
from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
+from dstack._internal.server.services.pools import get_instance_ssh_private_keys
from dstack._internal.server.services.runner import client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
-from dstack._internal.utils.common import batched, get_current_datetime, run_async
+from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
from dstack._internal.utils.logging import get_logger
logger = get_logger(__name__)
@@ -29,14 +30,12 @@ async def collect_metrics():
async with get_session_ctx() as session:
res = await session.execute(
select(JobModel)
- .where(
- JobModel.status.in_([JobStatus.RUNNING]),
- )
- .options(selectinload(JobModel.project))
+ .where(JobModel.status.in_([JobStatus.RUNNING]))
+ .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
.order_by(JobModel.last_processed_at.asc())
.limit(MAX_JOBS_FETCHED)
)
- job_models = res.scalars().all()
+ job_models = res.unique().scalars().all()
for batch in batched(job_models, BATCH_SIZE):
await _collect_jobs_metrics(batch)
@@ -87,6 +86,7 @@ def _get_recently_collected_metric_cutoff() -> int:
async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]:
+ ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance))
jpd = get_job_provisioning_data(job_model)
jrd = get_job_runtime_data(job_model)
if jpd is None:
@@ -94,7 +94,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]
try:
res = await run_async(
_pull_runner_metrics,
- job_model.project.ssh_private_key,
+ ssh_private_keys,
jpd,
jrd,
)
diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py
index a91f82a98..9a303f9cf 100644
--- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py
@@ -26,6 +26,7 @@
from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import (
+ InstanceModel,
JobModel,
ProjectModel,
RepoModel,
@@ -42,6 +43,7 @@
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
+from dstack._internal.server.services.pools import get_instance_ssh_private_keys
from dstack._internal.server.services.repos import (
get_code_model,
get_repo_creds,
@@ -101,7 +103,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
res = await session.execute(
select(JobModel)
.where(JobModel.id == job_model.id)
- .options(joinedload(JobModel.instance))
+ .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
.execution_options(populate_existing=True)
)
job_model = res.unique().scalar_one()
@@ -152,18 +154,9 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
job_provisioning_data=job_provisioning_data,
)
- server_ssh_private_key = project.ssh_private_key
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
- if (
- job_model.instance is not None
- and job_model.instance.remote_connection_info is not None
- and job_provisioning_data.dockerized
- ):
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
- job_model.instance.remote_connection_info
- )
- server_ssh_private_key = remote_conn_info.ssh_keys[0].private
+ server_ssh_private_keys = get_instance_ssh_private_keys(
+ common_utils.get_or_error(job_model.instance)
+ )
secrets = {} # TODO secrets
@@ -203,7 +196,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
user_ssh_key = ""
success = await common_utils.run_async(
_process_provisioning_with_shim,
- server_ssh_private_key,
+ server_ssh_private_keys,
job_provisioning_data,
None,
run,
@@ -229,7 +222,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
)
success = await common_utils.run_async(
_submit_job_to_runner,
- server_ssh_private_key,
+ server_ssh_private_keys,
job_provisioning_data,
None,
run,
@@ -272,7 +265,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
)
success = await common_utils.run_async(
_process_pulling_with_shim,
- server_ssh_private_key,
+ server_ssh_private_keys,
job_provisioning_data,
None,
run,
@@ -282,14 +275,14 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
code,
secrets,
repo_creds,
- server_ssh_private_key,
+ server_ssh_private_keys,
job_provisioning_data,
)
elif initial_status == JobStatus.RUNNING:
logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age)
success = await common_utils.run_async(
_process_running,
- server_ssh_private_key,
+ server_ssh_private_keys,
job_provisioning_data,
job_submission.job_runtime_data,
run_model,
@@ -493,7 +486,7 @@ def _process_pulling_with_shim(
code: bytes,
secrets: Dict[str, str],
repo_credentials: Optional[RemoteRepoCreds],
- server_ssh_private_key: str,
+ server_ssh_private_keys: tuple[str, Optional[str]],
job_provisioning_data: JobProvisioningData,
) -> bool:
"""
@@ -558,7 +551,7 @@ def _process_pulling_with_shim(
return True
return _submit_job_to_runner(
- server_ssh_private_key,
+ server_ssh_private_keys,
job_provisioning_data,
job_runtime_data,
run=run,
diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py
index 71afeb700..fc06ad49a 100644
--- a/src/dstack/_internal/server/services/fleets.py
+++ b/src/dstack/_internal/server/services/fleets.py
@@ -31,6 +31,7 @@
InstanceOfferWithAvailability,
InstanceStatus,
RemoteConnectionInfo,
+ SSHConnectionParams,
SSHKey,
)
from dstack._internal.core.models.pools import Instance
@@ -428,6 +429,7 @@ async def create_fleet_ssh_instance_model(
ssh_user = ssh_params.user
ssh_key = ssh_params.ssh_key
port = ssh_params.port
+ proxy_jump = ssh_params.proxy_jump
internal_ip = None
blocks = 1
else:
@@ -435,6 +437,7 @@ async def create_fleet_ssh_instance_model(
ssh_user = host.user or ssh_params.user
ssh_key = host.ssh_key or ssh_params.ssh_key
port = host.port or ssh_params.port
+ proxy_jump = host.proxy_jump or ssh_params.proxy_jump
internal_ip = host.internal_ip
blocks = host.blocks
@@ -442,6 +445,17 @@ async def create_fleet_ssh_instance_model(
# This should not be reachable but checked by fleet spec validation
raise ServerClientError("ssh key or user not specified")
+ if proxy_jump is not None:
+ ssh_proxy = SSHConnectionParams(
+ hostname=proxy_jump.hostname,
+ port=proxy_jump.port or 22,
+ username=proxy_jump.user,
+ )
+ ssh_proxy_keys = [proxy_jump.ssh_key]
+ else:
+ ssh_proxy = None
+ ssh_proxy_keys = None
+
instance_model = await pools_services.create_ssh_instance_model(
project=project,
pool=pool,
@@ -451,6 +465,8 @@ async def create_fleet_ssh_instance_model(
host=hostname,
ssh_user=ssh_user,
ssh_keys=[ssh_key],
+ ssh_proxy=ssh_proxy,
+ ssh_proxy_keys=ssh_proxy_keys,
env=env,
internal_ip=internal_ip,
instance_network=ssh_params.network,
diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py
index 9db598201..b2f18ba04 100644
--- a/src/dstack/_internal/server/services/jobs/__init__.py
+++ b/src/dstack/_internal/server/services/jobs/__init__.py
@@ -7,6 +7,7 @@
import requests
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import joinedload
import dstack._internal.server.services.backends as backends_services
from dstack._internal.core.backends.base import Backend
@@ -20,7 +21,7 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import RunConfigurationType
-from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo
+from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.runs import (
Job,
JobProvisioningData,
@@ -49,6 +50,7 @@
from dstack._internal.server.services.jobs.configurators.service import ServiceJobConfigurator
from dstack._internal.server.services.jobs.configurators.task import TaskJobConfigurator
from dstack._internal.server.services.logging import fmt
+from dstack._internal.server.services.pools import get_instance_ssh_private_keys
from dstack._internal.server.services.runner import client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
from dstack._internal.server.services.volumes import (
@@ -170,29 +172,22 @@ def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator:
async def stop_runner(session: AsyncSession, job_model: JobModel):
- project = await session.get(ProjectModel, job_model.project_id)
- ssh_private_key = project.ssh_private_key
-
res = await session.execute(
- select(InstanceModel).where(
+ select(InstanceModel)
+ .where(
InstanceModel.project_id == job_model.project_id,
InstanceModel.id == job_model.instance_id,
)
+ .options(joinedload(InstanceModel.project))
)
instance: Optional[InstanceModel] = res.scalar()
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
- if instance and instance.remote_connection_info is not None:
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
- instance.remote_connection_info
- )
- ssh_private_key = remote_conn_info.ssh_keys[0].private
+ ssh_private_keys = get_instance_ssh_private_keys(common.get_or_error(instance))
try:
jpd = get_job_provisioning_data(job_model)
if jpd is not None:
jrd = get_job_runtime_data(job_model)
- await run_async(_stop_runner, ssh_private_key, jpd, jrd, job_model)
+ await run_async(_stop_runner, ssh_private_keys, jpd, jrd, job_model)
except SSHError:
logger.debug("%s: failed to stop runner", fmt(job_model))
@@ -239,16 +234,8 @@ async def process_terminating_job(
jpd = get_job_provisioning_data(job_model)
if jpd is not None:
logger.debug("%s: stopping container", fmt(job_model))
- ssh_private_key = instance_model.project.ssh_private_key
- # TODO: Drop this logic and always use project key once it's safe to assume that
- # most on-prem fleets are (re)created after this change:
- # https://github.com/dstackai/dstack/pull/1716
- if instance_model and instance_model.remote_connection_info is not None:
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
- instance_model.remote_connection_info
- )
- ssh_private_key = remote_conn_info.ssh_keys[0].private
- await stop_container(job_model, jpd, ssh_private_key)
+ ssh_private_keys = get_instance_ssh_private_keys(instance_model)
+ await stop_container(job_model, jpd, ssh_private_keys)
volume_models: list[VolumeModel]
if jrd is not None and jrd.volume_names is not None:
volume_models = await list_project_volume_models(
@@ -351,14 +338,16 @@ def _set_job_termination_status(job_model: JobModel):
async def stop_container(
- job_model: JobModel, job_provisioning_data: JobProvisioningData, ssh_private_key: str
+ job_model: JobModel,
+ job_provisioning_data: JobProvisioningData,
+ ssh_private_keys: tuple[str, Optional[str]],
):
if job_provisioning_data.dockerized:
# send a request to the shim to terminate the docker container
# SSHError and RequestException are caught in the `runner_ssh_tunner` decorator
await run_async(
_shim_submit_stop,
- ssh_private_key,
+ ssh_private_keys,
job_provisioning_data,
None,
job_model,
diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py
index 220dae558..d21560188 100644
--- a/src/dstack/_internal/server/services/pools.py
+++ b/src/dstack/_internal/server/services/pools.py
@@ -30,6 +30,7 @@
InstanceType,
RemoteConnectionInfo,
Resources,
+ SSHConnectionParams,
SSHKey,
)
from dstack._internal.core.models.pools import Instance, Pool, PoolInstances
@@ -293,6 +294,27 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements:
return Requirements.__response__.parse_raw(instance_model.requirements)
+def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]:
+ """
+ Returns a pair of SSH private keys: host key and optional proxy jump key.
+ """
+ host_private_key = instance_model.project.ssh_private_key
+ if instance_model.remote_connection_info is None:
+ # Cloud instance
+ return host_private_key, None
+ # SSH instance
+ rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info)
+ if rci.ssh_proxy is None:
+ return host_private_key, None
+ if rci.ssh_proxy_keys is None:
+ # Inconsistent RemoteConnectionInfo structure - proxy without keys
+ raise ValueError("Missing instance SSH proxy private keys")
+ proxy_private_keys = [key.private for key in rci.ssh_proxy_keys if key.private is not None]
+ if not proxy_private_keys:
+ raise ValueError("No instance SSH proxy private key found")
+ return host_private_key, proxy_private_keys[0]
+
+
async def generate_instance_name(
session: AsyncSession,
project: ProjectModel,
@@ -737,6 +759,8 @@ async def create_ssh_instance_model(
port: int,
ssh_user: str,
ssh_keys: List[SSHKey],
+ ssh_proxy: Optional[SSHConnectionParams],
+ ssh_proxy_keys: Optional[list[SSHKey]],
env: Env,
blocks: Union[Literal["auto"], int],
) -> InstanceModel:
@@ -773,6 +797,8 @@ async def create_ssh_instance_model(
port=port,
ssh_user=ssh_user,
ssh_keys=ssh_keys,
+ ssh_proxy=ssh_proxy,
+ ssh_proxy_keys=ssh_proxy_keys,
env=env,
)
im = InstanceModel(
diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py
index b66f81fd1..6d3c11359 100644
--- a/src/dstack/_internal/server/services/runner/ssh.py
+++ b/src/dstack/_internal/server/services/runner/ssh.py
@@ -17,13 +17,18 @@
logger = get_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
+# A host private key or pair of (host private key, optional proxy jump private key)
+PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]]
def runner_ssh_tunnel(
ports: List[int], retries: int = 3, retry_interval: float = 1
) -> Callable[
[Callable[Concatenate[Dict[int, int], P], R]],
- Callable[Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R]],
+ Callable[
+ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
+ Union[bool, R],
+ ],
]:
"""
A decorator that opens an SSH tunnel to the runner.
@@ -36,11 +41,12 @@ def runner_ssh_tunnel(
def decorator(
func: Callable[Concatenate[Dict[int, int], P], R],
) -> Callable[
- Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R]
+ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
+ Union[bool, R],
]:
@functools.wraps(func)
def wrapper(
- ssh_private_key: str,
+ ssh_private_key: PrivateKeyOrPair,
job_provisioning_data: JobProvisioningData,
job_runtime_data: Optional[JobRuntimeData],
*args: P.args,
@@ -59,6 +65,16 @@ def wrapper(
# without SSH
return func(container_ports_map, *args, **kwargs)
+ if isinstance(ssh_private_key, str):
+ ssh_proxy_private_key = None
+ else:
+ ssh_private_key, ssh_proxy_private_key = ssh_private_key
+ identity = FileContent(ssh_private_key)
+ if ssh_proxy_private_key is not None:
+ proxy_identity = FileContent(ssh_proxy_private_key)
+ else:
+ proxy_identity = None
+
for attempt in range(retries):
last = attempt == retries - 1
# remote_host:local mapping
@@ -74,8 +90,9 @@ def wrapper(
),
port=job_provisioning_data.ssh_port,
forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map),
- identity=FileContent(ssh_private_key),
+ identity=identity,
ssh_proxy=job_provisioning_data.ssh_proxy,
+ ssh_proxy_identity=proxy_identity,
):
return func(runner_ports_map, *args, **kwargs)
except SSHError:
diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py
index 2199c8b8c..492eafe2a 100644
--- a/src/dstack/_internal/utils/ssh.py
+++ b/src/dstack/_internal/utils/ssh.py
@@ -159,7 +159,7 @@ def get_ssh_config(path: PathLike, host: str) -> Optional[Dict[str, str]]:
return None
-def update_ssh_config(path: PathLike, host: str, options: Dict[str, Union[str, FilePath]]):
+def update_ssh_config(path: PathLike, host: str, options: Dict[str, Union[str, int, FilePath]]):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with FileLock(str(path) + ".lock"):
copy_mode = True
diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py
index 2f295f846..822d3f3a2 100644
--- a/src/dstack/api/server/_fleets.py
+++ b/src/dstack/api/server/_fleets.py
@@ -62,12 +62,20 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[_ExcludeDict]:
spec_excludes: _ExcludeDict = {}
configuration_excludes: _ExcludeDict = {}
profile_excludes: set[str] = set()
+ ssh_config_excludes: _ExcludeDict = {}
ssh_hosts_excludes: set[str] = set()
# TODO: Can be removed in 0.19
if fleet_spec.configuration_path is None:
spec_excludes["configuration_path"] = True
if fleet_spec.configuration.ssh_config is not None:
+ if fleet_spec.configuration.ssh_config.proxy_jump is None:
+ ssh_config_excludes["proxy_jump"] = True
+ if all(
+ isinstance(h, str) or h.proxy_jump is None
+ for h in fleet_spec.configuration.ssh_config.hosts
+ ):
+ ssh_hosts_excludes.add("proxy_jump")
if all(
isinstance(h, str) or h.internal_ip is None
for h in fleet_spec.configuration.ssh_config.hosts
@@ -98,7 +106,9 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[_ExcludeDict]:
configuration_excludes["blocks"] = True
if ssh_hosts_excludes:
- configuration_excludes["ssh_config"] = {"hosts": {"__all__": ssh_hosts_excludes}}
+ ssh_config_excludes["hosts"] = {"__all__": ssh_hosts_excludes}
+ if ssh_config_excludes:
+ configuration_excludes["ssh_config"] = ssh_config_excludes
if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
diff --git a/src/tests/_internal/server/background/tasks/test_process_metrics.py b/src/tests/_internal/server/background/tasks/test_process_metrics.py
index 7d93a89eb..422270eac 100644
--- a/src/tests/_internal/server/background/tasks/test_process_metrics.py
+++ b/src/tests/_internal/server/background/tasks/test_process_metrics.py
@@ -6,6 +6,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
+from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.runs import JobStatus
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server import settings
@@ -17,8 +18,10 @@
from dstack._internal.server.schemas.runner import GPUMetrics, MetricsResponse
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
+ create_instance,
create_job,
create_job_metrics_point,
+ create_pool,
create_project,
create_repo,
create_run,
@@ -42,6 +45,13 @@ async def test_collects_metrics(self, test_db, session: AsyncSession):
session=session,
project_id=project.id,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
run = await create_run(
session=session,
project=project,
@@ -53,6 +63,8 @@ async def test_collects_metrics(self, test_db, session: AsyncSession):
run=run,
status=JobStatus.RUNNING,
job_provisioning_data=get_job_provisioning_data(),
+ instance_assigned=True,
+ instance=instance,
)
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py
index 571d4fec3..f986e6d94 100644
--- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py
+++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py
@@ -103,6 +103,13 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive(
repo=repo,
user=user,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job_provisioning_data = get_job_provisioning_data(dockerized=False)
job = await create_job(
session=session,
@@ -110,6 +117,8 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive(
status=JobStatus.PROVISIONING,
submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc),
job_provisioning_data=job_provisioning_data,
+ instance=instance,
+ instance_assigned=True,
)
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -144,12 +153,21 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession):
repo=repo,
user=user,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job_provisioning_data = get_job_provisioning_data(dockerized=False)
job = await create_job(
session=session,
run=run,
status=JobStatus.PROVISIONING,
job_provisioning_data=job_provisioning_data,
+ instance=instance,
+ instance_assigned=True,
)
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -186,12 +204,21 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat
repo=repo,
user=user,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job_provisioning_data = get_job_provisioning_data(dockerized=False)
job = await create_job(
session=session,
run=run,
status=JobStatus.RUNNING,
job_provisioning_data=job_provisioning_data,
+ instance=instance,
+ instance_assigned=True,
)
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -277,6 +304,13 @@ async def test_provisioning_shim_with_volumes(
run_name="test-run",
run_spec=run_spec,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job_provisioning_data = get_job_provisioning_data(dockerized=True)
with patch(
@@ -288,6 +322,8 @@ async def test_provisioning_shim_with_volumes(
run=run,
status=JobStatus.PROVISIONING,
job_provisioning_data=job_provisioning_data,
+ instance=instance,
+ instance_assigned=True,
)
await process_running_jobs()
@@ -337,12 +373,21 @@ async def test_pulling_shim(
repo=repo,
user=user,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job = await create_job(
session=session,
run=run,
status=JobStatus.PULLING,
job_provisioning_data=get_job_provisioning_data(dockerized=True),
job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
+ instance=instance,
+ instance_assigned=True,
)
shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
shim_client_mock.get_task.return_value.ports = [
@@ -385,6 +430,13 @@ async def test_pulling_shim_port_mapping_not_ready(
repo=repo,
user=user,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job_provisioning_data = get_job_provisioning_data(dockerized=True)
job = await create_job(
session=session,
@@ -392,6 +444,8 @@ async def test_pulling_shim_port_mapping_not_ready(
status=JobStatus.PULLING,
job_provisioning_data=job_provisioning_data,
job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
+ instance=instance,
+ instance_assigned=True,
)
shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
shim_client_mock.get_task.return_value.ports = None
@@ -470,12 +524,21 @@ async def test_provisioning_shim_force_stop_if_already_running_api_v1(
run_name="test-run",
run_spec=run_spec,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job = await create_job(
session=session,
run=run,
status=JobStatus.PROVISIONING,
job_provisioning_data=get_job_provisioning_data(dockerized=True),
submitted_at=get_current_datetime(),
+ instance=instance,
+ instance_assigned=True,
)
monkeypatch.setattr(
"dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock())
@@ -588,11 +651,20 @@ async def test_inactivity_duration(
),
),
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job = await create_job(
session=session,
run=run,
status=JobStatus.RUNNING,
job_provisioning_data=get_job_provisioning_data(),
+ instance=instance,
+ instance_assigned=True,
)
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py
index 2cccf3088..bdd65a612 100644
--- a/src/tests/_internal/server/background/tasks/test_process_runs.py
+++ b/src/tests/_internal/server/background/tasks/test_process_runs.py
@@ -8,6 +8,7 @@
import dstack._internal.server.background.tasks.process_runs as process_runs
from dstack._internal.core.models.configurations import ServiceConfiguration
+from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.profiles import Profile
from dstack._internal.core.models.resources import Range
from dstack._internal.core.models.runs import (
@@ -116,11 +117,20 @@ async def test_running_to_done(self, test_db, session: AsyncSession):
async def test_terminate_run_jobs(self, test_db, session: AsyncSession):
run = await make_run(session, status=RunStatus.TERMINATING)
run.termination_reason = RunTerminationReason.JOB_FAILED
+ pool = await create_pool(session=session, project=run.project)
+ instance = await create_instance(
+ session=session,
+ project=run.project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job = await create_job(
session=session,
run=run,
job_provisioning_data=get_job_provisioning_data(),
status=JobStatus.RUNNING,
+ instance=instance,
+ instance_assigned=True,
)
with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner:
diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py
index 81b6f422f..2ad610603 100644
--- a/src/tests/_internal/server/routers/test_fleets.py
+++ b/src/tests/_internal/server/routers/test_fleets.py
@@ -450,6 +450,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
"port": None,
"identity_file": None,
"ssh_key": None, # should not return ssh_key
+ "proxy_jump": None,
"hosts": ["1.1.1.1"],
"network": None,
},
diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py
index 453b60d2a..f15afce57 100644
--- a/src/tests/_internal/server/routers/test_runs.py
+++ b/src/tests/_internal/server/routers/test_runs.py
@@ -18,6 +18,7 @@
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
+ InstanceStatus,
InstanceType,
Resources,
)
@@ -47,7 +48,9 @@
create_backend,
create_gateway,
create_gateway_compute,
+ create_instance,
create_job,
+ create_pool,
create_project,
create_repo,
create_run,
@@ -1315,11 +1318,20 @@ async def test_terminates_running_run(
user=user,
status=RunStatus.RUNNING,
)
+ pool = await create_pool(session=session, project=project)
+ instance = await create_instance(
+ session=session,
+ project=project,
+ pool=pool,
+ status=InstanceStatus.BUSY,
+ )
job = await create_job(
session=session,
run=run,
job_provisioning_data=get_job_provisioning_data(),
status=JobStatus.RUNNING,
+ instance=instance,
+ instance_assigned=True,
)
with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner:
response = await client.post(