diff --git a/docs/docs/concepts/fleets.md b/docs/docs/concepts/fleets.md index 60575f9af..15c9bf78a 100644 --- a/docs/docs/concepts/fleets.md +++ b/docs/docs/concepts/fleets.md @@ -300,6 +300,47 @@ ssh_config: - 3.255.177.52 ``` +#### Head node + +If fleet nodes are behind a head node, configure [`proxy_jump`](../reference/dstack.yml/fleet.md#proxy_jump): + +
+ + ```yaml + type: fleet + name: my-fleet + + ssh_config: + user: ubuntu + identity_file: ~/.ssh/worker_node_key + hosts: + - 3.255.177.51 + - 3.255.177.52 + proxy_jump: + hostname: 3.255.177.50 + user: ubuntu + identity_file: ~/.ssh/head_node_key + ``` + +
+ +To be able to attach to runs, both explicitly with `dstack attach` and implicitly with `dstack apply`, you must either +add a front node key (`~/.ssh/head_node_key`) to an SSH agent or configure a key path in `~/.ssh/config`: + +
+ + ``` + Host 3.255.177.50 + IdentityFile ~/.ssh/head_node_key + ``` + +
+ +where `Host` must match `ssh_config.proxy_jump.hostname` or `ssh_config.hosts[n].proxy_jump.hostname` if you configure head nodes +on a per-worker basis. + +> Currently, [services](services.md) do not work on instances with a head node setup. + !!! info "Reference" For all SSH fleet configuration options, refer to the [reference](../reference/dstack.yml/fleet.md). diff --git a/docs/docs/reference/dstack.yml/fleet.md b/docs/docs/reference/dstack.yml/fleet.md index 537ddb109..7c11967f7 100644 --- a/docs/docs/reference/dstack.yml/fleet.md +++ b/docs/docs/reference/dstack.yml/fleet.md @@ -17,12 +17,26 @@ The `fleet` configuration type allows creating and updating fleets. show_root_heading: false item_id_prefix: ssh_config- +#### `ssh_config.proxy_jump` { #ssh_config-proxy_jump data-toc-label="proxy_jump" } + +#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams + overrides: + show_root_heading: false + item_id_prefix: proxy_jump- + #### `ssh_config.hosts[n]` { #ssh_config-hosts data-toc-label="hosts" } #SCHEMA# dstack._internal.core.models.fleets.SSHHostParams overrides: show_root_heading: false +##### `ssh_config.hosts[n].proxy_jump` { #proxy_jump data-toc-label="hosts[n].proxy_jump" } + +#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams + overrides: + show_root_heading: false + item_id_prefix: hosts-proxy_jump- + ### `resources` #SCHEMA# dstack._internal.core.models.resources.ResourcesSpecSchema diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 7eea9fb70..35699f216 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -201,13 +201,16 @@ def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown def _preprocess_spec(spec: FleetSpec): - if spec.configuration.ssh_config is not None: - spec.configuration.ssh_config.ssh_key = _resolve_ssh_key( - spec.configuration.ssh_config.identity_file - ) - for host in spec.configuration.ssh_config.hosts: + ssh_config = spec.configuration.ssh_config + if ssh_config is not None: + ssh_config.ssh_key = _resolve_ssh_key(ssh_config.identity_file) + if ssh_config.proxy_jump is not None: + ssh_config.proxy_jump.ssh_key = _resolve_ssh_key(ssh_config.proxy_jump.identity_file) + for host in ssh_config.hosts: if not isinstance(host, str): host.ssh_key = _resolve_ssh_key(host.identity_file) + if host.proxy_jump is not None: + host.proxy_jump.ssh_key = _resolve_ssh_key(host.proxy_jump.identity_file) def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]: diff --git a/src/dstack/_internal/core/backends/remote/provisioning.py b/src/dstack/_internal/core/backends/remote/provisioning.py index 3f3769791..7a2398b7c 100644 --- a/src/dstack/_internal/core/backends/remote/provisioning.py +++ b/src/dstack/_internal/core/backends/remote/provisioning.py @@ -1,9 +1,9 @@ import io import json import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from textwrap import dedent -from typing import Any, Dict, Generator, List +from typing import Any, Dict, Generator, List, Optional import paramiko from gpuhunt import AcceleratorVendor, correct_gpu_memory_gib @@ -17,6 +17,7 @@ Gpu, InstanceType, Resources, + SSHConnectionParams, ) from dstack._internal.utils.gpu import ( convert_amd_gpu_name, @@ -262,35 +263,72 @@ def host_info_to_instance_type(host_info: Dict[str, Any]) -> InstanceType: @contextmanager def get_paramiko_connection( - ssh_user: str, host: str, port: int, pkeys: List[paramiko.PKey] + ssh_user: str, + host: str, + port: int, + pkeys: List[paramiko.PKey], + proxy: Optional[SSHConnectionParams] = None, + proxy_pkeys: Optional[list[paramiko.PKey]] = None, ) -> Generator[paramiko.SSHClient, None, None]: - with paramiko.SSHClient() as client: - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - for pkey in pkeys: - conn_url = f"{ssh_user}@{host}:{port}" + if proxy is not None: + if proxy_pkeys is None: + raise ProvisioningError("Missing proxy private keys") + proxy_ctx = get_paramiko_connection( + proxy.username, proxy.hostname, proxy.port, proxy_pkeys + ) + else: + proxy_ctx = nullcontext() + conn_url = f"{ssh_user}@{host}:{port}" + with proxy_ctx as proxy_client, paramiko.SSHClient() as client: + proxy_channel: Optional[paramiko.Channel] = None + if proxy_client is not None: try: - logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint) - client.connect( - username=ssh_user, - hostname=host, - port=port, - pkey=pkey, - look_for_keys=False, - allow_agent=False, - timeout=SSH_CONNECT_TIMEOUT, + proxy_channel = proxy_client.get_transport().open_channel( + "direct-tcpip", (host, port), ("", 0) ) - except paramiko.AuthenticationException: - logger.debug( - f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}' - ) - continue # try next key except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Connect failed: {e}") from e - else: + raise ProvisioningError(f"Proxy channel failed: {e}") from e + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + for pkey in pkeys: + logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint) + connected = _paramiko_connect(client, ssh_user, host, port, pkey, proxy_channel) + if connected: yield client return - else: - keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys) - raise ProvisioningError( - f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful" + logger.debug( + f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}' ) + keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys) + raise ProvisioningError( + f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful" + ) + + +def _paramiko_connect( + client: paramiko.SSHClient, + user: str, + host: str, + port: int, + pkey: paramiko.PKey, + channel: Optional[paramiko.Channel] = None, +) -> bool: + """ + Returns `True` if connected, `False` if auth failed, and raises `ProvisioningError` + on other errors. + """ + try: + client.connect( + username=user, + hostname=host, + port=port, + pkey=pkey, + look_for_keys=False, + allow_agent=False, + timeout=SSH_CONNECT_TIMEOUT, + sock=channel, + ) + return True + except paramiko.AuthenticationException: + return False + except (paramiko.SSHException, OSError) as e: + raise ProvisioningError(f"Connect failed: {e}") from e diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index abc0766c4..c56720b9a 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -39,6 +39,14 @@ class InstanceGroupPlacement(str, Enum): CLUSTER = "cluster" +class SSHProxyParams(CoreModel): + hostname: Annotated[str, Field(description="The IP address or domain of proxy host")] + port: Annotated[Optional[int], Field(description="The SSH port of proxy host")] = None + user: Annotated[str, Field(description="The user to log in with for proxy host")] + identity_file: Annotated[str, Field(description="The private key to use for proxy host")] + ssh_key: Optional[SSHKey] = None + + class SSHHostParams(CoreModel): hostname: Annotated[str, Field(description="The IP address or domain to connect to")] port: Annotated[ @@ -50,6 +58,9 @@ class SSHHostParams(CoreModel): identity_file: Annotated[ Optional[str], Field(description="The private key to use for this host") ] = None + proxy_jump: Annotated[ + Optional[SSHProxyParams], Field(description="The SSH proxy configuration for this host") + ] = None internal_ip: Annotated[ Optional[str], Field( @@ -96,6 +107,9 @@ class SSHParams(CoreModel): Optional[str], Field(description="The private key to use for all hosts") ] = None ssh_key: Optional[SSHKey] = None + proxy_jump: Annotated[ + Optional[SSHProxyParams], Field(description="The SSH proxy configuration for all hosts") + ] = None hosts: Annotated[ List[Union[SSHHostParams, str]], Field( diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 9b5bce0e5..84ae92a9e 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -92,6 +92,8 @@ class RemoteConnectionInfo(CoreModel): port: int ssh_user: str ssh_keys: List[SSHKey] + ssh_proxy: Optional[SSHConnectionParams] = None + ssh_proxy_keys: Optional[list[SSHKey]] = None env: Env = Env() diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index 788d369b2..a095fafb2 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -2,7 +2,7 @@ import re import time from pathlib import Path -from typing import Optional +from typing import Optional, Union import psutil @@ -14,6 +14,8 @@ from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets from dstack._internal.utils.path import FilePath, PathLike from dstack._internal.utils.ssh import ( + default_ssh_config_path, + get_host_config, include_ssh_config, normalize_path, update_ssh_config, @@ -88,28 +90,63 @@ def __init__( }, ) self.ssh_proxy = ssh_proxy - if ssh_proxy is None: - self.host_config = { + + hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {} + self.hosts = hosts + + if local_backend: + hosts[run_name] = { "HostName": hostname, - "Port": ssh_port, - "User": user if dockerized else container_user, - "IdentityFile": self.identity_file, - "IdentitiesOnly": "yes", - "StrictHostKeyChecking": "no", - "UserKnownHostsFile": "/dev/null", - } - else: - self.host_config = { - "HostName": ssh_proxy.hostname, - "Port": ssh_proxy.port, - "User": ssh_proxy.username, + "Port": container_ssh_port, + "User": container_user, "IdentityFile": self.identity_file, "IdentitiesOnly": "yes", "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", } - if dockerized and not local_backend: - self.container_config = { + elif dockerized: + if ssh_proxy is not None: + # SSH instance with jump host + # dstack has no IdentityFile for jump host, it must be either preconfigured + # in the ~/.ssh/config or loaded into ssh-agent + hosts[f"{run_name}-jump-host"] = { + "HostName": ssh_proxy.hostname, + "Port": ssh_proxy.port, + "User": ssh_proxy.username, + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + } + jump_host_config = get_host_config(ssh_proxy.hostname, default_ssh_config_path) + jump_host_identity_files = jump_host_config.get("identityfile") + if jump_host_identity_files: + hosts[f"{run_name}-jump-host"].update( + { + "IdentityFile": jump_host_identity_files[0], + "IdentitiesOnly": "yes", + } + ) + hosts[f"{run_name}-host"] = { + "HostName": hostname, + "Port": ssh_port, + "User": user, + "IdentityFile": self.identity_file, + "IdentitiesOnly": "yes", + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + "ProxyJump": f"{run_name}-jump-host", + } + else: + # Regular SSH instance or VM-based cloud instance + hosts[f"{run_name}-host"] = { + "HostName": hostname, + "Port": ssh_port, + "User": user, + "IdentityFile": self.identity_file, + "IdentitiesOnly": "yes", + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + } + hosts[run_name] = { "HostName": "localhost", "Port": container_ssh_port, "User": container_user, @@ -119,32 +156,41 @@ def __init__( "UserKnownHostsFile": "/dev/null", "ProxyJump": f"{run_name}-host", } - elif ssh_proxy is not None: - self.container_config = { - "HostName": hostname, - "Port": ssh_port, - "User": container_user, - "IdentityFile": self.identity_file, - "IdentitiesOnly": "yes", - "StrictHostKeyChecking": "no", - "UserKnownHostsFile": "/dev/null", - "ProxyJump": f"{run_name}-jump-host", - } else: - self.container_config = None - if local_backend: - self.container_config = None - self.host_config = { - "HostName": hostname, - "Port": container_ssh_port, - "User": container_user, - "IdentityFile": self.identity_file, - "IdentitiesOnly": "yes", - "StrictHostKeyChecking": "no", - "UserKnownHostsFile": "/dev/null", - } - if self.container_config is not None and get_ssh_client_info().supports_multiplexing: - self.container_config.update( + if ssh_proxy is not None: + # Kubernetes + hosts[f"{run_name}-jump-host"] = { + "HostName": ssh_proxy.hostname, + "Port": ssh_proxy.port, + "User": ssh_proxy.username, + "IdentityFile": self.identity_file, + "IdentitiesOnly": "yes", + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + } + hosts[run_name] = { + "HostName": hostname, + "Port": ssh_port, + "User": container_user, + "IdentityFile": self.identity_file, + "IdentitiesOnly": "yes", + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + "ProxyJump": f"{run_name}-jump-host", + } + else: + # Container-based backends + hosts[run_name] = { + "HostName": hostname, + "Port": ssh_port, + "User": container_user, + "IdentityFile": self.identity_file, + "IdentitiesOnly": "yes", + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + } + if get_ssh_client_info().supports_multiplexing: + hosts[run_name].update( { "ControlMaster": "auto", "ControlPath": self.control_sock_path, @@ -153,14 +199,8 @@ def __init__( def attach(self): include_ssh_config(self.ssh_config_path) - if self.container_config is None: - update_ssh_config(self.ssh_config_path, self.run_name, self.host_config) - elif self.ssh_proxy is not None: - update_ssh_config(self.ssh_config_path, f"{self.run_name}-jump-host", self.host_config) - update_ssh_config(self.ssh_config_path, self.run_name, self.container_config) - else: - update_ssh_config(self.ssh_config_path, f"{self.run_name}-host", self.host_config) - update_ssh_config(self.ssh_config_path, self.run_name, self.container_config) + for host, options in self.hosts.items(): + update_ssh_config(self.ssh_config_path, host, options) max_retries = 10 self._ports_lock.release() @@ -178,9 +218,8 @@ def attach(self): def detach(self): self.tunnel.close() - update_ssh_config(self.ssh_config_path, f"{self.run_name}-jump-host", {}) - update_ssh_config(self.ssh_config_path, f"{self.run_name}-host", {}) - update_ssh_config(self.ssh_config_path, self.run_name, {}) + for host in self.hosts: + update_ssh_config(self.ssh_config_path, host, {}) def __enter__(self): self.attach() diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index 98789a618..03c9894b4 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -70,6 +70,7 @@ def __init__( ssh_config_path: Union[PathLike, Literal["none"]] = "none", port: Optional[int] = None, ssh_proxy: Optional[SSHConnectionParams] = None, + ssh_proxy_identity: Optional[FilePathOrContent] = None, ): """ :param forwarded_sockets: Connections to the specified local sockets will be @@ -97,7 +98,15 @@ def __init__( identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w" ) as f: f.write(identity.content) - self.identity_path = normalize_path(identity_path) + self.identity_path = normalize_path(self._get_identity_path(identity, "identity")) + if ssh_proxy_identity is not None: + self.ssh_proxy_identity_path = normalize_path( + self._get_identity_path(ssh_proxy_identity, "proxy_identity") + ) + elif ssh_proxy is not None: + self.ssh_proxy_identity_path = self.identity_path + else: + self.ssh_proxy_identity_path = None self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log")) self.ssh_client_info = get_ssh_client_info() self.ssh_exec_path = str(self.ssh_client_info.path) @@ -166,7 +175,7 @@ def proxy_command(self) -> Optional[List[str]]: return [ self.ssh_exec_path, "-i", - self.identity_path, + self.ssh_proxy_identity_path, "-W", "%h:%p", "-o", @@ -263,6 +272,16 @@ def _remove_log_file(self) -> None: except OSError as e: logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e) + def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike: + if isinstance(identity, FilePath): + return identity.path + identity_path = os.path.join(self.temp_dir.name, tmp_filename) + with open( + identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w" + ) as f: + f.write(identity.content) + return identity_path + def ports_to_forwarded_sockets( ports: Dict[int, int], bind_local: str = "localhost" diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 63fe6f312..fa798bdc4 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -47,6 +47,7 @@ InstanceStatus, InstanceType, RemoteConnectionInfo, + SSHKey, ) from dstack._internal.core.models.placement import ( PlacementGroup, @@ -86,6 +87,7 @@ get_instance_profile, get_instance_provisioning_data, get_instance_requirements, + get_instance_ssh_private_keys, ) from dstack._internal.server.services.runner import client as runner_client from dstack._internal.server.services.runner.client import HealthStatus @@ -232,11 +234,11 @@ async def _add_remote(instance: InstanceModel) -> None: remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info)) # Prepare connection key try: - pkeys = [ - pkey_from_str(sk.private) - for sk in remote_details.ssh_keys - if sk.private is not None - ] + pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) + if remote_details.ssh_proxy_keys is not None: + ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) + else: + ssh_proxy_pkeys = None except (ValueError, PasswordRequiredException): instance.status = InstanceStatus.TERMINATED instance.termination_reason = "Unsupported private SSH key type" @@ -254,7 +256,9 @@ async def _add_remote(instance: InstanceModel) -> None: authorized_keys.append(instance.project.ssh_public_key.strip()) try: - future = run_async(_deploy_instance, remote_details, pkeys, authorized_keys) + future = run_async( + _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys + ) deploy_timeout = 20 * 60 # 20 minutes result = await asyncio.wait_for(future, timeout=deploy_timeout) health, host_info = result @@ -356,7 +360,7 @@ async def _add_remote(instance: InstanceModel) -> None: ssh_port=remote_details.port, dockerized=True, backend_data=None, - ssh_proxy=None, + ssh_proxy=remote_details.ssh_proxy, ) instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING @@ -379,10 +383,16 @@ async def _add_remote(instance: InstanceModel) -> None: def _deploy_instance( remote_details: RemoteConnectionInfo, pkeys: List[PKey], + ssh_proxy_pkeys: Optional[list[PKey]], authorized_keys: List[str], ) -> Tuple[HealthStatus, Dict[str, Any]]: with get_paramiko_connection( - remote_details.ssh_user, remote_details.host, remote_details.port, pkeys + remote_details.ssh_user, + remote_details.host, + remote_details.port, + pkeys, + remote_details.ssh_proxy, + ssh_proxy_pkeys, ) as client: logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}") @@ -638,18 +648,14 @@ async def _check_instance(instance: InstanceModel) -> None: instance.status = InstanceStatus.BUSY return - ssh_private_key = instance.project.ssh_private_key - # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem - # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716 - if instance.remote_connection_info is not None: - remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw( - instance.remote_connection_info - ) - ssh_private_key = remote_conn_info.ssh_keys[0].private + ssh_private_keys = get_instance_ssh_private_keys(instance) # May return False if fails to establish ssh connection health_status_response = await run_async( - _instance_healthcheck, ssh_private_key, job_provisioning_data, None + _instance_healthcheck, + ssh_private_keys, + job_provisioning_data, + None, ) if isinstance(health_status_response, bool) or health_status_response is None: health_status = HealthStatus(healthy=False, reason="SSH or tunnel error") @@ -971,3 +977,7 @@ def _get_instance_timeout_interval( if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"): return timedelta(seconds=3300) return timedelta(seconds=600) + + +def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: + return [pkey_from_str(sk.private) for sk in ssh_keys if sk.private is not None] diff --git a/src/dstack/_internal/server/background/tasks/process_metrics.py b/src/dstack/_internal/server/background/tasks/process_metrics.py index 3ebe4f6e6..49ac45fc1 100644 --- a/src/dstack/_internal/server/background/tasks/process_metrics.py +++ b/src/dstack/_internal/server/background/tasks/process_metrics.py @@ -3,18 +3,19 @@ from typing import Dict, List, Optional from sqlalchemy import delete, select -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import joinedload from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT from dstack._internal.core.models.runs import JobStatus from dstack._internal.server import settings from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import JobMetricsPoint, JobModel +from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel from dstack._internal.server.schemas.runner import MetricsResponse from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data +from dstack._internal.server.services.pools import get_instance_ssh_private_keys from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.utils.common import batched, get_current_datetime, run_async +from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -29,14 +30,12 @@ async def collect_metrics(): async with get_session_ctx() as session: res = await session.execute( select(JobModel) - .where( - JobModel.status.in_([JobStatus.RUNNING]), - ) - .options(selectinload(JobModel.project)) + .where(JobModel.status.in_([JobStatus.RUNNING])) + .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) .order_by(JobModel.last_processed_at.asc()) .limit(MAX_JOBS_FETCHED) ) - job_models = res.scalars().all() + job_models = res.unique().scalars().all() for batch in batched(job_models, BATCH_SIZE): await _collect_jobs_metrics(batch) @@ -87,6 +86,7 @@ def _get_recently_collected_metric_cutoff() -> int: async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]: + ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance)) jpd = get_job_provisioning_data(job_model) jrd = get_job_runtime_data(job_model) if jpd is None: @@ -94,7 +94,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint] try: res = await run_async( _pull_runner_metrics, - job_model.project.ssh_private_key, + ssh_private_keys, jpd, jrd, ) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index a91f82a98..9a303f9cf 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -26,6 +26,7 @@ from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import ( + InstanceModel, JobModel, ProjectModel, RepoModel, @@ -42,6 +43,7 @@ ) from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.pools import get_instance_ssh_private_keys from dstack._internal.server.services.repos import ( get_code_model, get_repo_creds, @@ -101,7 +103,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): res = await session.execute( select(JobModel) .where(JobModel.id == job_model.id) - .options(joinedload(JobModel.instance)) + .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) .execution_options(populate_existing=True) ) job_model = res.unique().scalar_one() @@ -152,18 +154,9 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job_provisioning_data=job_provisioning_data, ) - server_ssh_private_key = project.ssh_private_key - # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem - # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716 - if ( - job_model.instance is not None - and job_model.instance.remote_connection_info is not None - and job_provisioning_data.dockerized - ): - remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw( - job_model.instance.remote_connection_info - ) - server_ssh_private_key = remote_conn_info.ssh_keys[0].private + server_ssh_private_keys = get_instance_ssh_private_keys( + common_utils.get_or_error(job_model.instance) + ) secrets = {} # TODO secrets @@ -203,7 +196,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): user_ssh_key = "" success = await common_utils.run_async( _process_provisioning_with_shim, - server_ssh_private_key, + server_ssh_private_keys, job_provisioning_data, None, run, @@ -229,7 +222,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): ) success = await common_utils.run_async( _submit_job_to_runner, - server_ssh_private_key, + server_ssh_private_keys, job_provisioning_data, None, run, @@ -272,7 +265,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): ) success = await common_utils.run_async( _process_pulling_with_shim, - server_ssh_private_key, + server_ssh_private_keys, job_provisioning_data, None, run, @@ -282,14 +275,14 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): code, secrets, repo_creds, - server_ssh_private_key, + server_ssh_private_keys, job_provisioning_data, ) elif initial_status == JobStatus.RUNNING: logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age) success = await common_utils.run_async( _process_running, - server_ssh_private_key, + server_ssh_private_keys, job_provisioning_data, job_submission.job_runtime_data, run_model, @@ -493,7 +486,7 @@ def _process_pulling_with_shim( code: bytes, secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], - server_ssh_private_key: str, + server_ssh_private_keys: tuple[str, Optional[str]], job_provisioning_data: JobProvisioningData, ) -> bool: """ @@ -558,7 +551,7 @@ def _process_pulling_with_shim( return True return _submit_job_to_runner( - server_ssh_private_key, + server_ssh_private_keys, job_provisioning_data, job_runtime_data, run=run, diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 71afeb700..fc06ad49a 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -31,6 +31,7 @@ InstanceOfferWithAvailability, InstanceStatus, RemoteConnectionInfo, + SSHConnectionParams, SSHKey, ) from dstack._internal.core.models.pools import Instance @@ -428,6 +429,7 @@ async def create_fleet_ssh_instance_model( ssh_user = ssh_params.user ssh_key = ssh_params.ssh_key port = ssh_params.port + proxy_jump = ssh_params.proxy_jump internal_ip = None blocks = 1 else: @@ -435,6 +437,7 @@ async def create_fleet_ssh_instance_model( ssh_user = host.user or ssh_params.user ssh_key = host.ssh_key or ssh_params.ssh_key port = host.port or ssh_params.port + proxy_jump = host.proxy_jump or ssh_params.proxy_jump internal_ip = host.internal_ip blocks = host.blocks @@ -442,6 +445,17 @@ async def create_fleet_ssh_instance_model( # This should not be reachable but checked by fleet spec validation raise ServerClientError("ssh key or user not specified") + if proxy_jump is not None: + ssh_proxy = SSHConnectionParams( + hostname=proxy_jump.hostname, + port=proxy_jump.port or 22, + username=proxy_jump.user, + ) + ssh_proxy_keys = [proxy_jump.ssh_key] + else: + ssh_proxy = None + ssh_proxy_keys = None + instance_model = await pools_services.create_ssh_instance_model( project=project, pool=pool, @@ -451,6 +465,8 @@ async def create_fleet_ssh_instance_model( host=hostname, ssh_user=ssh_user, ssh_keys=[ssh_key], + ssh_proxy=ssh_proxy, + ssh_proxy_keys=ssh_proxy_keys, env=env, internal_ip=internal_ip, instance_network=ssh_params.network, diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 9db598201..b2f18ba04 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -7,6 +7,7 @@ import requests from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload import dstack._internal.server.services.backends as backends_services from dstack._internal.core.backends.base import Backend @@ -20,7 +21,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import is_core_model_instance from dstack._internal.core.models.configurations import RunConfigurationType -from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo +from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import ( Job, JobProvisioningData, @@ -49,6 +50,7 @@ from dstack._internal.server.services.jobs.configurators.service import ServiceJobConfigurator from dstack._internal.server.services.jobs.configurators.task import TaskJobConfigurator from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.pools import get_instance_ssh_private_keys from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.services.volumes import ( @@ -170,29 +172,22 @@ def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator: async def stop_runner(session: AsyncSession, job_model: JobModel): - project = await session.get(ProjectModel, job_model.project_id) - ssh_private_key = project.ssh_private_key - res = await session.execute( - select(InstanceModel).where( + select(InstanceModel) + .where( InstanceModel.project_id == job_model.project_id, InstanceModel.id == job_model.instance_id, ) + .options(joinedload(InstanceModel.project)) ) instance: Optional[InstanceModel] = res.scalar() - # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem - # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716 - if instance and instance.remote_connection_info is not None: - remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw( - instance.remote_connection_info - ) - ssh_private_key = remote_conn_info.ssh_keys[0].private + ssh_private_keys = get_instance_ssh_private_keys(common.get_or_error(instance)) try: jpd = get_job_provisioning_data(job_model) if jpd is not None: jrd = get_job_runtime_data(job_model) - await run_async(_stop_runner, ssh_private_key, jpd, jrd, job_model) + await run_async(_stop_runner, ssh_private_keys, jpd, jrd, job_model) except SSHError: logger.debug("%s: failed to stop runner", fmt(job_model)) @@ -239,16 +234,8 @@ async def process_terminating_job( jpd = get_job_provisioning_data(job_model) if jpd is not None: logger.debug("%s: stopping container", fmt(job_model)) - ssh_private_key = instance_model.project.ssh_private_key - # TODO: Drop this logic and always use project key once it's safe to assume that - # most on-prem fleets are (re)created after this change: - # https://github.com/dstackai/dstack/pull/1716 - if instance_model and instance_model.remote_connection_info is not None: - remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw( - instance_model.remote_connection_info - ) - ssh_private_key = remote_conn_info.ssh_keys[0].private - await stop_container(job_model, jpd, ssh_private_key) + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + await stop_container(job_model, jpd, ssh_private_keys) volume_models: list[VolumeModel] if jrd is not None and jrd.volume_names is not None: volume_models = await list_project_volume_models( @@ -351,14 +338,16 @@ def _set_job_termination_status(job_model: JobModel): async def stop_container( - job_model: JobModel, job_provisioning_data: JobProvisioningData, ssh_private_key: str + job_model: JobModel, + job_provisioning_data: JobProvisioningData, + ssh_private_keys: tuple[str, Optional[str]], ): if job_provisioning_data.dockerized: # send a request to the shim to terminate the docker container # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator await run_async( _shim_submit_stop, - ssh_private_key, + ssh_private_keys, job_provisioning_data, None, job_model, diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 220dae558..d21560188 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -30,6 +30,7 @@ InstanceType, RemoteConnectionInfo, Resources, + SSHConnectionParams, SSHKey, ) from dstack._internal.core.models.pools import Instance, Pool, PoolInstances @@ -293,6 +294,27 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements: return Requirements.__response__.parse_raw(instance_model.requirements) +def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]: + """ + Returns a pair of SSH private keys: host key and optional proxy jump key. + """ + host_private_key = instance_model.project.ssh_private_key + if instance_model.remote_connection_info is None: + # Cloud instance + return host_private_key, None + # SSH instance + rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info) + if rci.ssh_proxy is None: + return host_private_key, None + if rci.ssh_proxy_keys is None: + # Inconsistent RemoteConnectionInfo structure - proxy without keys + raise ValueError("Missing instance SSH proxy private keys") + proxy_private_keys = [key.private for key in rci.ssh_proxy_keys if key.private is not None] + if not proxy_private_keys: + raise ValueError("No instance SSH proxy private key found") + return host_private_key, proxy_private_keys[0] + + async def generate_instance_name( session: AsyncSession, project: ProjectModel, @@ -737,6 +759,8 @@ async def create_ssh_instance_model( port: int, ssh_user: str, ssh_keys: List[SSHKey], + ssh_proxy: Optional[SSHConnectionParams], + ssh_proxy_keys: Optional[list[SSHKey]], env: Env, blocks: Union[Literal["auto"], int], ) -> InstanceModel: @@ -773,6 +797,8 @@ async def create_ssh_instance_model( port=port, ssh_user=ssh_user, ssh_keys=ssh_keys, + ssh_proxy=ssh_proxy, + ssh_proxy_keys=ssh_proxy_keys, env=env, ) im = InstanceModel( diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index b66f81fd1..6d3c11359 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -17,13 +17,18 @@ logger = get_logger(__name__) P = ParamSpec("P") R = TypeVar("R") +# A host private key or pair of (host private key, optional proxy jump private key) +PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]] def runner_ssh_tunnel( ports: List[int], retries: int = 3, retry_interval: float = 1 ) -> Callable[ [Callable[Concatenate[Dict[int, int], P], R]], - Callable[Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R]], + Callable[ + Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], + Union[bool, R], + ], ]: """ A decorator that opens an SSH tunnel to the runner. @@ -36,11 +41,12 @@ def runner_ssh_tunnel( def decorator( func: Callable[Concatenate[Dict[int, int], P], R], ) -> Callable[ - Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R] + Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], + Union[bool, R], ]: @functools.wraps(func) def wrapper( - ssh_private_key: str, + ssh_private_key: PrivateKeyOrPair, job_provisioning_data: JobProvisioningData, job_runtime_data: Optional[JobRuntimeData], *args: P.args, @@ -59,6 +65,16 @@ def wrapper( # without SSH return func(container_ports_map, *args, **kwargs) + if isinstance(ssh_private_key, str): + ssh_proxy_private_key = None + else: + ssh_private_key, ssh_proxy_private_key = ssh_private_key + identity = FileContent(ssh_private_key) + if ssh_proxy_private_key is not None: + proxy_identity = FileContent(ssh_proxy_private_key) + else: + proxy_identity = None + for attempt in range(retries): last = attempt == retries - 1 # remote_host:local mapping @@ -74,8 +90,9 @@ def wrapper( ), port=job_provisioning_data.ssh_port, forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map), - identity=FileContent(ssh_private_key), + identity=identity, ssh_proxy=job_provisioning_data.ssh_proxy, + ssh_proxy_identity=proxy_identity, ): return func(runner_ports_map, *args, **kwargs) except SSHError: diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py index 2199c8b8c..492eafe2a 100644 --- a/src/dstack/_internal/utils/ssh.py +++ b/src/dstack/_internal/utils/ssh.py @@ -159,7 +159,7 @@ def get_ssh_config(path: PathLike, host: str) -> Optional[Dict[str, str]]: return None -def update_ssh_config(path: PathLike, host: str, options: Dict[str, Union[str, FilePath]]): +def update_ssh_config(path: PathLike, host: str, options: Dict[str, Union[str, int, FilePath]]): Path(path).parent.mkdir(parents=True, exist_ok=True) with FileLock(str(path) + ".lock"): copy_mode = True diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py index 2f295f846..822d3f3a2 100644 --- a/src/dstack/api/server/_fleets.py +++ b/src/dstack/api/server/_fleets.py @@ -62,12 +62,20 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[_ExcludeDict]: spec_excludes: _ExcludeDict = {} configuration_excludes: _ExcludeDict = {} profile_excludes: set[str] = set() + ssh_config_excludes: _ExcludeDict = {} ssh_hosts_excludes: set[str] = set() # TODO: Can be removed in 0.19 if fleet_spec.configuration_path is None: spec_excludes["configuration_path"] = True if fleet_spec.configuration.ssh_config is not None: + if fleet_spec.configuration.ssh_config.proxy_jump is None: + ssh_config_excludes["proxy_jump"] = True + if all( + isinstance(h, str) or h.proxy_jump is None + for h in fleet_spec.configuration.ssh_config.hosts + ): + ssh_hosts_excludes.add("proxy_jump") if all( isinstance(h, str) or h.internal_ip is None for h in fleet_spec.configuration.ssh_config.hosts @@ -98,7 +106,9 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[_ExcludeDict]: configuration_excludes["blocks"] = True if ssh_hosts_excludes: - configuration_excludes["ssh_config"] = {"hosts": {"__all__": ssh_hosts_excludes}} + ssh_config_excludes["hosts"] = {"__all__": ssh_hosts_excludes} + if ssh_config_excludes: + configuration_excludes["ssh_config"] = ssh_config_excludes if configuration_excludes: spec_excludes["configuration"] = configuration_excludes if profile_excludes: diff --git a/src/tests/_internal/server/background/tasks/test_process_metrics.py b/src/tests/_internal/server/background/tasks/test_process_metrics.py index 7d93a89eb..422270eac 100644 --- a/src/tests/_internal/server/background/tasks/test_process_metrics.py +++ b/src/tests/_internal/server/background/tasks/test_process_metrics.py @@ -6,6 +6,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server import settings @@ -17,8 +18,10 @@ from dstack._internal.server.schemas.runner import GPUMetrics, MetricsResponse from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_instance, create_job, create_job_metrics_point, + create_pool, create_project, create_repo, create_run, @@ -42,6 +45,13 @@ async def test_collects_metrics(self, test_db, session: AsyncSession): session=session, project_id=project.id, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) run = await create_run( session=session, project=project, @@ -53,6 +63,8 @@ async def test_collects_metrics(self, test_db, session: AsyncSession): run=run, status=JobStatus.RUNNING, job_provisioning_data=get_job_provisioning_data(), + instance_assigned=True, + instance=instance, ) with ( patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 571d4fec3..f986e6d94 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -103,6 +103,13 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( repo=repo, user=user, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job_provisioning_data = get_job_provisioning_data(dockerized=False) job = await create_job( session=session, @@ -110,6 +117,8 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( status=JobStatus.PROVISIONING, submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, ) with ( patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, @@ -144,12 +153,21 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): repo=repo, user=user, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job_provisioning_data = get_job_provisioning_data(dockerized=False) job = await create_job( session=session, run=run, status=JobStatus.PROVISIONING, job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, ) with ( patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, @@ -186,12 +204,21 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat repo=repo, user=user, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job_provisioning_data = get_job_provisioning_data(dockerized=False) job = await create_job( session=session, run=run, status=JobStatus.RUNNING, job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, ) with ( patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, @@ -277,6 +304,13 @@ async def test_provisioning_shim_with_volumes( run_name="test-run", run_spec=run_spec, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job_provisioning_data = get_job_provisioning_data(dockerized=True) with patch( @@ -288,6 +322,8 @@ async def test_provisioning_shim_with_volumes( run=run, status=JobStatus.PROVISIONING, job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, ) await process_running_jobs() @@ -337,12 +373,21 @@ async def test_pulling_shim( repo=repo, user=user, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job = await create_job( session=session, run=run, status=JobStatus.PULLING, job_provisioning_data=get_job_provisioning_data(dockerized=True), job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, ) shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING shim_client_mock.get_task.return_value.ports = [ @@ -385,6 +430,13 @@ async def test_pulling_shim_port_mapping_not_ready( repo=repo, user=user, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job_provisioning_data = get_job_provisioning_data(dockerized=True) job = await create_job( session=session, @@ -392,6 +444,8 @@ async def test_pulling_shim_port_mapping_not_ready( status=JobStatus.PULLING, job_provisioning_data=job_provisioning_data, job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, ) shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING shim_client_mock.get_task.return_value.ports = None @@ -470,12 +524,21 @@ async def test_provisioning_shim_force_stop_if_already_running_api_v1( run_name="test-run", run_spec=run_spec, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job = await create_job( session=session, run=run, status=JobStatus.PROVISIONING, job_provisioning_data=get_job_provisioning_data(dockerized=True), submitted_at=get_current_datetime(), + instance=instance, + instance_assigned=True, ) monkeypatch.setattr( "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock()) @@ -588,11 +651,20 @@ async def test_inactivity_duration( ), ), ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job = await create_job( session=session, run=run, status=JobStatus.RUNNING, job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, ) with ( patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py index 2cccf3088..bdd65a612 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/tasks/test_process_runs.py @@ -8,6 +8,7 @@ import dstack._internal.server.background.tasks.process_runs as process_runs from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( @@ -116,11 +117,20 @@ async def test_running_to_done(self, test_db, session: AsyncSession): async def test_terminate_run_jobs(self, test_db, session: AsyncSession): run = await make_run(session, status=RunStatus.TERMINATING) run.termination_reason = RunTerminationReason.JOB_FAILED + pool = await create_pool(session=session, project=run.project) + instance = await create_instance( + session=session, + project=run.project, + pool=pool, + status=InstanceStatus.BUSY, + ) job = await create_job( session=session, run=run, job_provisioning_data=get_job_provisioning_data(), status=JobStatus.RUNNING, + instance=instance, + instance_assigned=True, ) with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner: diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 81b6f422f..2ad610603 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -450,6 +450,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A "port": None, "identity_file": None, "ssh_key": None, # should not return ssh_key + "proxy_jump": None, "hosts": ["1.1.1.1"], "network": None, }, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 453b60d2a..f15afce57 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -18,6 +18,7 @@ from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, + InstanceStatus, InstanceType, Resources, ) @@ -47,7 +48,9 @@ create_backend, create_gateway, create_gateway_compute, + create_instance, create_job, + create_pool, create_project, create_repo, create_run, @@ -1315,11 +1318,20 @@ async def test_terminates_running_run( user=user, status=RunStatus.RUNNING, ) + pool = await create_pool(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.BUSY, + ) job = await create_job( session=session, run=run, job_provisioning_data=get_job_provisioning_data(), status=JobStatus.RUNNING, + instance=instance, + instance_assigned=True, ) with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner: response = await client.post(