Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,29 @@ class SSHHostParams(CoreModel):
identity_file: Annotated[
Optional[str], Field(description="The private key to use for this host")
] = None
internal_ip: Annotated[
Optional[str],
Field(
description=(
"The internal IP of the host used for communication inside the cluster."
" If not specified, `dstack` will use the IP address from `network` or from the first found internal network."
)
),
] = None
ssh_key: Optional[SSHKey] = None

@validator("internal_ip")
def validate_internal_ip(cls, value):
if value is None:
return value
try:
internal_ip = ipaddress.ip_address(value)
except ValueError as e:
raise ValueError("Invalid IP address") from e
if not internal_ip.is_private:
raise ValueError("IP address is not private")
return value


class SSHParams(CoreModel):
user: Annotated[Optional[str], Field(description="The user to log in with on all hosts")] = (
Expand All @@ -70,7 +91,13 @@ class SSHParams(CoreModel):
]
network: Annotated[
Optional[str],
Field(description="The network address for cluster setup in the format `<ip>/<netmask>`"),
Field(
description=(
"The network address for cluster setup in the format `<ip>/<netmask>`."
" `dstack` will use IP addresses from this network for communication between hosts."
" If not specified, `dstack` will use IPs from the first found internal network."
)
),
]

@validator("network")
Expand Down
29 changes: 24 additions & 5 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from dstack._internal.server.utils.common import run_async
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.network import get_ip_from_network
from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
from dstack._internal.utils.ssh import (
pkey_from_str,
)
Expand Down Expand Up @@ -290,16 +290,20 @@ async def _add_remote(instance: InstanceModel) -> None:

instance_type = host_info_to_instance_type(host_info)
instance_network = None
internal_ip = None
try:
default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data)
instance_network = default_jpd.instance_network
internal_ip = default_jpd.internal_ip
except ValidationError:
pass

internal_ip = get_ip_from_network(
network=instance_network,
addresses=host_info.get("addresses", []),
)
host_network_addresses = host_info.get("addresses", [])
if internal_ip is None:
internal_ip = get_ip_from_network(
network=instance_network,
addresses=host_network_addresses,
)
if instance_network is not None and internal_ip is None:
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Failed to locate internal IP address on the given network"
Expand All @@ -312,6 +316,21 @@ async def _add_remote(instance: InstanceModel) -> None:
},
)
return
if internal_ip is not None:
if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses):
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = (
"Specified internal IP not found among instance interfaces"
)
logger.warning(
"Failed to add instance %s: specified internal IP not found among instance interfaces",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

region = instance.region
jpd = JobProvisioningData(
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,13 @@ async def create_fleet_ssh_instance_model(
ssh_user = ssh_params.user
ssh_key = ssh_params.ssh_key
port = ssh_params.port
internal_ip = None
else:
hostname = host.hostname
ssh_user = host.user or ssh_params.user
ssh_key = host.ssh_key or ssh_params.ssh_key
port = host.port or ssh_params.port
internal_ip = host.internal_ip

if ssh_user is None or ssh_key is None:
# This should not be reachable but checked by fleet spec validation
Expand All @@ -422,6 +424,7 @@ async def create_fleet_ssh_instance_model(
ssh_user=ssh_user,
ssh_keys=[ssh_key],
env=env,
internal_ip=internal_ip,
instance_network=ssh_params.network,
port=port or 22,
)
Expand Down Expand Up @@ -678,6 +681,7 @@ def _validate_fleet_spec(spec: FleetSpec):
for host in spec.configuration.ssh_config.hosts:
if is_core_model_instance(host, SSHHostParams) and host.ssh_key is not None:
_validate_ssh_key(host.ssh_key)
_validate_internal_ips(spec.configuration.ssh_config)


def _validate_all_ssh_params_specified(ssh_config: SSHParams):
Expand Down Expand Up @@ -706,6 +710,17 @@ def _validate_ssh_key(ssh_key: SSHKey):
)


def _validate_internal_ips(ssh_config: SSHParams):
internal_ips_num = 0
for host in ssh_config.hosts:
if not isinstance(host, str) and host.internal_ip is not None:
internal_ips_num += 1
if internal_ips_num != 0 and internal_ips_num != len(ssh_config.hosts):
raise ServerClientError("internal_ip must be specified for all hosts")
if internal_ips_num > 0 and ssh_config.network is not None:
raise ServerClientError("internal_ip is mutually exclusive with network")


def _get_fleet_nodes_to_provision(spec: FleetSpec) -> int:
if spec.configuration.nodes is None or spec.configuration.nodes.min is None:
return 0
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ async def create_ssh_instance_model(
pool: PoolModel,
instance_name: str,
instance_num: int,
internal_ip: Optional[str],
instance_network: Optional[str],
region: Optional[str],
host: str,
Expand All @@ -676,7 +677,7 @@ async def create_ssh_instance_model(
instance_id=instance_name,
hostname=host,
region=host_region,
internal_ip=None,
internal_ip=internal_ip,
instance_network=instance_network,
price=0,
username=ssh_user,
Expand Down
18 changes: 17 additions & 1 deletion src/dstack/_internal/utils/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ipaddress
from typing import Optional, Sequence
from typing import List, Optional, Sequence


def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Optional[str]:
Expand Down Expand Up @@ -32,3 +32,19 @@ def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Opt
# return any ipv4
internal_ip = str(ip_addresses[0]) if ip_addresses else None
return internal_ip


def is_ip_among_addresses(ip_address: str, addresses: Sequence[str]) -> bool:
ip_addresses = get_ips_from_addresses(addresses)
return ip_address in ip_addresses


def get_ips_from_addresses(addresses: Sequence[str]) -> List[str]:
ip_addresses = []
for address in addresses:
try:
interface = ipaddress.IPv4Interface(address)
ip_addresses.append(interface.ip)
except ipaddress.AddressValueError:
continue
return [str(ip) for ip in ip_addresses]
30 changes: 19 additions & 11 deletions src/dstack/api/server/_fleets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

from pydantic import parse_obj_as

Expand Down Expand Up @@ -29,11 +29,7 @@ def get_plan(
spec: FleetSpec,
) -> FleetPlan:
body = GetFleetPlanRequest(spec=spec)
body_json = body.json()
if spec.configuration_path is None:
# Handle old server versions that do not accept configuration_path
# TODO: Can be removed in 0.19
body_json = body.json(exclude={"spec": {"configuration_path"}})
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json)
return parse_obj_as(FleetPlan.__response__, resp.json())

Expand All @@ -43,11 +39,7 @@ def create(
spec: FleetSpec,
) -> Fleet:
body = CreateFleetRequest(spec=spec)
body_json = body.json()
if spec.configuration_path is None:
# Handle old server versions that do not accept configuration_path
# TODO: Can be removed in 0.19
body_json = body.json(exclude={"spec": {"configuration_path"}})
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json)
return parse_obj_as(Fleet.__response__, resp.json())

Expand All @@ -58,3 +50,19 @@ def delete(self, project_name: str, names: List[str]) -> None:
def delete_instances(self, project_name: str, name: str, instance_nums: List[int]) -> None:
body = DeleteFleetInstancesRequest(name=name, instance_nums=instance_nums)
self._request(f"/api/project/{project_name}/fleets/delete_instances", body=body.json())


def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[dict]:
exclude = {}
# TODO: Can be removed in 0.19
if fleet_spec.configuration_path is None:
exclude["spec"] = {"configuration_path"}
if fleet_spec.configuration.ssh_config is not None:
if all(
isinstance(h, str) or h.internal_ip is None
for h in fleet_spec.configuration.ssh_config.hosts
):
exclude["spec"] = {
"configuration": {"ssh_config": {"hosts": {"__all__": {"internal_ip"}}}}
}
return exclude or None