Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/docs/concepts/fleets.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,47 @@ ssh_config:
- 3.255.177.52
```

#### Head node

If fleet nodes are behind a head node, configure [`proxy_jump`](../reference/dstack.yml/fleet.md#proxy_jump):

<div editor-title="examples/misc/fleets/.dstack.yml">

```yaml
type: fleet
name: my-fleet

ssh_config:
user: ubuntu
identity_file: ~/.ssh/worker_node_key
hosts:
- 3.255.177.51
- 3.255.177.52
proxy_jump:
hostname: 3.255.177.50
user: ubuntu
identity_file: ~/.ssh/head_node_key
```

</div>

To be able to attach to runs, both explicitly with `dstack attach` and implicitly with `dstack apply`, you must either
add a front node key (`~/.ssh/head_node_key`) to an SSH agent or configure a key path in `~/.ssh/config`:
Comment on lines +327 to +328
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Why do we require the user to configure a key path in ~/.ssh/config if user specifies identity_file: ~/.ssh/head_node_key? Can't we just use it to connect to the head node?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be two different users, one creating a fleet has the key, but other users who run workloads may not have the key. With regular setup, the key problem is solved via shim -- shim keeps user's key on the instance while a run is running, but on the head node we don't have any dstack agent to manage authorized keys, thus we require that each user has the head node key on their machine preconfigured. We could download the key to a user's machine though, but I'm not sure if this is a good idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put this comment somewhere in the code.


<div editor-title="~/.ssh/config">

```
Host 3.255.177.50
IdentityFile ~/.ssh/head_node_key
```

</div>

where `Host` must match `ssh_config.proxy_jump.hostname` or `ssh_config.hosts[n].proxy_jump.hostname` if you configure head nodes
on a per-worker basis.

> Currently, [services](services.md) do not work on instances with a head node setup.

!!! info "Reference"
For all SSH fleet configuration options, refer to the [reference](../reference/dstack.yml/fleet.md).

Expand Down
14 changes: 14 additions & 0 deletions docs/docs/reference/dstack.yml/fleet.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@ The `fleet` configuration type allows creating and updating fleets.
show_root_heading: false
item_id_prefix: ssh_config-

#### `ssh_config.proxy_jump` { #ssh_config-proxy_jump data-toc-label="proxy_jump" }

#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams
overrides:
show_root_heading: false
item_id_prefix: proxy_jump-

#### `ssh_config.hosts[n]` { #ssh_config-hosts data-toc-label="hosts" }

#SCHEMA# dstack._internal.core.models.fleets.SSHHostParams
overrides:
show_root_heading: false

##### `ssh_config.hosts[n].proxy_jump` { #proxy_jump data-toc-label="hosts[n].proxy_jump" }

#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams
overrides:
show_root_heading: false
item_id_prefix: hosts-proxy_jump-

### `resources`

#SCHEMA# dstack._internal.core.models.resources.ResourcesSpecSchema
Expand Down
13 changes: 8 additions & 5 deletions src/dstack/_internal/cli/services/configurators/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,16 @@ def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown


def _preprocess_spec(spec: FleetSpec):
if spec.configuration.ssh_config is not None:
spec.configuration.ssh_config.ssh_key = _resolve_ssh_key(
spec.configuration.ssh_config.identity_file
)
for host in spec.configuration.ssh_config.hosts:
ssh_config = spec.configuration.ssh_config
if ssh_config is not None:
ssh_config.ssh_key = _resolve_ssh_key(ssh_config.identity_file)
if ssh_config.proxy_jump is not None:
ssh_config.proxy_jump.ssh_key = _resolve_ssh_key(ssh_config.proxy_jump.identity_file)
for host in ssh_config.hosts:
if not isinstance(host, str):
host.ssh_key = _resolve_ssh_key(host.identity_file)
if host.proxy_jump is not None:
host.proxy_jump.ssh_key = _resolve_ssh_key(host.proxy_jump.identity_file)


def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]:
Expand Down
92 changes: 65 additions & 27 deletions src/dstack/_internal/core/backends/remote/provisioning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import io
import json
import time
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from textwrap import dedent
from typing import Any, Dict, Generator, List
from typing import Any, Dict, Generator, List, Optional

import paramiko
from gpuhunt import AcceleratorVendor, correct_gpu_memory_gib
Expand All @@ -17,6 +17,7 @@
Gpu,
InstanceType,
Resources,
SSHConnectionParams,
)
from dstack._internal.utils.gpu import (
convert_amd_gpu_name,
Expand Down Expand Up @@ -262,35 +263,72 @@ def host_info_to_instance_type(host_info: Dict[str, Any]) -> InstanceType:

@contextmanager
def get_paramiko_connection(
ssh_user: str, host: str, port: int, pkeys: List[paramiko.PKey]
ssh_user: str,
host: str,
port: int,
pkeys: List[paramiko.PKey],
proxy: Optional[SSHConnectionParams] = None,
proxy_pkeys: Optional[list[paramiko.PKey]] = None,
) -> Generator[paramiko.SSHClient, None, None]:
with paramiko.SSHClient() as client:
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
for pkey in pkeys:
conn_url = f"{ssh_user}@{host}:{port}"
if proxy is not None:
if proxy_pkeys is None:
raise ProvisioningError("Missing proxy private keys")
proxy_ctx = get_paramiko_connection(
proxy.username, proxy.hostname, proxy.port, proxy_pkeys
)
else:
proxy_ctx = nullcontext()
conn_url = f"{ssh_user}@{host}:{port}"
with proxy_ctx as proxy_client, paramiko.SSHClient() as client:
proxy_channel: Optional[paramiko.Channel] = None
if proxy_client is not None:
try:
logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint)
client.connect(
username=ssh_user,
hostname=host,
port=port,
pkey=pkey,
look_for_keys=False,
allow_agent=False,
timeout=SSH_CONNECT_TIMEOUT,
proxy_channel = proxy_client.get_transport().open_channel(
"direct-tcpip", (host, port), ("", 0)
)
except paramiko.AuthenticationException:
logger.debug(
f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}'
)
continue # try next key
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError(f"Connect failed: {e}") from e
else:
raise ProvisioningError(f"Proxy channel failed: {e}") from e
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
for pkey in pkeys:
logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint)
connected = _paramiko_connect(client, ssh_user, host, port, pkey, proxy_channel)
if connected:
yield client
return
else:
keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys)
raise ProvisioningError(
f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful"
logger.debug(
f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}'
)
keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys)
raise ProvisioningError(
f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful"
)


def _paramiko_connect(
client: paramiko.SSHClient,
user: str,
host: str,
port: int,
pkey: paramiko.PKey,
channel: Optional[paramiko.Channel] = None,
) -> bool:
"""
Returns `True` if connected, `False` if auth failed, and raises `ProvisioningError`
on other errors.
"""
try:
client.connect(
username=user,
hostname=host,
port=port,
pkey=pkey,
look_for_keys=False,
allow_agent=False,
timeout=SSH_CONNECT_TIMEOUT,
sock=channel,
)
return True
except paramiko.AuthenticationException:
return False
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError(f"Connect failed: {e}") from e
14 changes: 14 additions & 0 deletions src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class InstanceGroupPlacement(str, Enum):
CLUSTER = "cluster"


class SSHProxyParams(CoreModel):
hostname: Annotated[str, Field(description="The IP address or domain of proxy host")]
port: Annotated[Optional[int], Field(description="The SSH port of proxy host")] = None
user: Annotated[str, Field(description="The user to log in with for proxy host")]
identity_file: Annotated[str, Field(description="The private key to use for proxy host")]
ssh_key: Optional[SSHKey] = None


class SSHHostParams(CoreModel):
hostname: Annotated[str, Field(description="The IP address or domain to connect to")]
port: Annotated[
Expand All @@ -50,6 +58,9 @@ class SSHHostParams(CoreModel):
identity_file: Annotated[
Optional[str], Field(description="The private key to use for this host")
] = None
proxy_jump: Annotated[
Optional[SSHProxyParams], Field(description="The SSH proxy configuration for this host")
] = None
internal_ip: Annotated[
Optional[str],
Field(
Expand Down Expand Up @@ -96,6 +107,9 @@ class SSHParams(CoreModel):
Optional[str], Field(description="The private key to use for all hosts")
] = None
ssh_key: Optional[SSHKey] = None
proxy_jump: Annotated[
Optional[SSHProxyParams], Field(description="The SSH proxy configuration for all hosts")
] = None
hosts: Annotated[
List[Union[SSHHostParams, str]],
Field(
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class RemoteConnectionInfo(CoreModel):
port: int
ssh_user: str
ssh_keys: List[SSHKey]
ssh_proxy: Optional[SSHConnectionParams] = None
ssh_proxy_keys: Optional[list[SSHKey]] = None
env: Env = Env()


Expand Down
Loading