Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ async def lifespan(app: FastAPI):
logger.info("Background processing is disabled")
PROBES_SCHEDULER.start()
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
logger.info(
"Job network mode: %s (%d)",
settings.JOB_NETWORK_MODE.name,
settings.JOB_NETWORK_MODE.value,
)
logger.info(f"The admin token is {admin.token.get_plaintext_or_error()}", {"show_path": False})
logger.info(
f"The dstack server {dstack_version} is running at {SERVER_URL}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
)
from dstack._internal.server.utils import sentry_utils
from dstack._internal.utils import common as common_utils
from dstack._internal.utils import env as env_utils
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -188,6 +187,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
run_spec = run.run_spec
profile = run_spec.merged_profile
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
multinode = job.job_spec.jobs_per_replica > 1

# Master job chooses fleet for the run.
# Due to two-step processing, it's saved to job_model.fleet.
Expand Down Expand Up @@ -310,6 +310,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
session=session,
instances_with_offers=fleet_instances_with_offers,
job_model=job_model,
multinode=multinode,
)
job_model.fleet = fleet_model
job_model.instance_assigned = True
Expand Down Expand Up @@ -385,7 +386,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
offer=offer,
instance_num=instance_num,
)
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json()
# Both this task and process_fleets can add instances to fleets.
# TODO: Ensure this does not violate nodes.max when it's enforced.
instance.fleet_id = fleet_model.id
Expand Down Expand Up @@ -614,6 +615,7 @@ async def _assign_job_to_fleet_instance(
session: AsyncSession,
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]],
job_model: JobModel,
multinode: bool,
) -> Optional[InstanceModel]:
if len(instances_with_offers) == 0:
return None
Expand Down Expand Up @@ -643,7 +645,7 @@ async def _assign_job_to_fleet_instance(
job_model.instance = instance
job_model.used_instance_id = instance.id
job_model.job_provisioning_data = instance.job_provisioning_data
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json()
return instance


Expand Down Expand Up @@ -839,12 +841,17 @@ def _create_instance_model_for_job(
return instance


def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
def _prepare_job_runtime_data(
offer: InstanceOfferWithAvailability, multinode: bool
) -> JobRuntimeData:
if offer.blocks == offer.total_blocks:
if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
if settings.JOB_NETWORK_MODE == settings.JobNetworkMode.FORCED_BRIDGE:
network_mode = NetworkMode.BRIDGE
else:
elif settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_WHEN_POSSIBLE:
network_mode = NetworkMode.HOST
else:
assert settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_FOR_MULTINODE_ONLY
network_mode = NetworkMode.HOST if multinode else NetworkMode.BRIDGE
return JobRuntimeData(
network_mode=network_mode,
offer=offer,
Expand Down
46 changes: 46 additions & 0 deletions src/dstack/_internal/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

import os
import warnings
from enum import Enum
from pathlib import Path

from dstack._internal.utils.env import environ
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)

DSTACK_DIR_PATH = Path("~/.dstack/").expanduser()

SERVER_DIR_PATH = Path(os.getenv("DSTACK_SERVER_DIR", DSTACK_DIR_PATH / "server"))
Expand Down Expand Up @@ -136,3 +142,43 @@
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE") is not None
ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS") is not None


class JobNetworkMode(Enum):
# "host" for multinode runs only, "bridge" otherwise. Opt-in new defaut
HOST_FOR_MULTINODE_ONLY = 1
# "bridge" if the job occupies only a part of the instance, "host" otherswise. Current default
HOST_WHEN_POSSIBLE = 2
# Always "bridge", even for multinode runs. Same as legacy DSTACK_FORCE_BRIDGE_NETWORK=true
FORCED_BRIDGE = 3


def _get_job_network_mode() -> JobNetworkMode:
# Current default
mode = JobNetworkMode.HOST_WHEN_POSSIBLE
bridge_var = "DSTACK_FORCE_BRIDGE_NETWORK"
force_bridge = environ.get_bool(bridge_var)
mode_var = "DSTACK_SERVER_JOB_NETWORK_MODE"
mode_from_env = environ.get_enum(mode_var, JobNetworkMode, value_type=int)
if mode_from_env is not None:
if force_bridge is not None:
logger.warning(
f"{bridge_var} is deprecated since 0.19.27 and ignored when {mode_var} is set"
)
return mode_from_env
if force_bridge is not None:
if force_bridge:
mode = JobNetworkMode.FORCED_BRIDGE
logger.warning(
(
f"{bridge_var} is deprecated since 0.19.27."
f" Set {mode_var} to {mode.value} and remove {bridge_var}"
)
)
else:
logger.warning(f"{bridge_var} is deprecated since 0.19.27. Remove {bridge_var}")
return mode


JOB_NETWORK_MODE = _get_job_network_mode()
del _get_job_network_mode
96 changes: 85 additions & 11 deletions src/dstack/_internal/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,88 @@
import os
from collections.abc import Mapping
from enum import Enum
from typing import Optional, TypeVar, Union, overload

_Value = Union[str, int]
_T = TypeVar("_T", bound=Enum)

def get_bool(name: str, default: bool = False) -> bool:
try:
value = os.environ[name]
except KeyError:
return default
value = value.lower()
if value in ["0", "false", "off"]:
return False
if value in ["1", "true", "on"]:
return True
raise ValueError(f"Invalid bool value: {name}={value}")

class Environ:
def __init__(self, environ: Mapping[str, str]):
self._environ = environ

@overload
def get_bool(self, name: str, *, default: None = None) -> Optional[bool]: ...

@overload
def get_bool(self, name: str, *, default: bool) -> bool: ...

def get_bool(self, name: str, *, default: Optional[bool] = None) -> Optional[bool]:
try:
raw_value = self._environ[name]
except KeyError:
return default
value = raw_value.lower()
if value in ["0", "false", "off"]:
return False
if value in ["1", "true", "on"]:
return True
raise ValueError(f"Invalid bool value: {name}={raw_value}")

@overload
def get_int(self, name: str, *, default: None = None) -> Optional[int]: ...

@overload
def get_int(self, name: str, *, default: int) -> int: ...

def get_int(self, name: str, *, default: Optional[int] = None) -> Optional[int]:
try:
raw_value = self._environ[name]
except KeyError:
return default
try:
return int(raw_value)
except ValueError as e:
raise ValueError(f"Invalid int value: {e}: {name}={raw_value}") from e

@overload
def get_enum(
self,
name: str,
enum_cls: type[_T],
*,
value_type: Optional[type[_Value]] = None,
default: None = None,
) -> Optional[_T]: ...

@overload
def get_enum(
self,
name: str,
enum_cls: type[_T],
*,
value_type: Optional[type[_Value]] = None,
default: _T,
) -> _T: ...

def get_enum(
self,
name: str,
enum_cls: type[_T],
*,
value_type: Optional[type[_Value]] = None,
default: Optional[_T] = None,
) -> Optional[_T]:
try:
raw_value = self._environ[name]
except KeyError:
return default
try:
if value_type is not None:
raw_value = value_type(raw_value)
return enum_cls(raw_value)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid {enum_cls.__name__} value: {e}: {name}={raw_value}") from e


environ = Environ(os.environ)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import joinedload

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import NetworkMode
from dstack._internal.core.models.configurations import TaskConfiguration
from dstack._internal.core.models.fleets import FleetNodesSpec
from dstack._internal.core.models.health import HealthStatus
Expand All @@ -25,8 +26,12 @@
VolumeMountPoint,
VolumeStatus,
)
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
from dstack._internal.server.background.tasks.process_submitted_jobs import (
_prepare_job_runtime_data,
process_submitted_jobs,
)
from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel
from dstack._internal.server.settings import JobNetworkMode
from dstack._internal.server.testing.common import (
ComputeMockSpec,
create_fleet,
Expand Down Expand Up @@ -1004,3 +1009,102 @@ async def test_picks_high_priority_jobs_first(self, test_db, session: AsyncSessi
await process_submitted_jobs()
await session.refresh(job2)
assert job2.status == JobStatus.PROVISIONING


@pytest.mark.parametrize(
["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"],
[
pytest.param(
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
2,
False,
NetworkMode.BRIDGE,
True,
id="host-for-multinode-only--half-of-instance",
),
pytest.param(
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
4,
False,
NetworkMode.BRIDGE,
False,
id="host-for-multinode-only--entire-instance",
),
pytest.param(
JobNetworkMode.HOST_FOR_MULTINODE_ONLY,
4,
True,
NetworkMode.HOST,
False,
id="host-for-multinode-only--entire-instance--multinode",
),
pytest.param(
JobNetworkMode.HOST_WHEN_POSSIBLE,
2,
False,
NetworkMode.BRIDGE,
True,
id="host-when-possible--half-of-instance",
),
pytest.param(
JobNetworkMode.HOST_WHEN_POSSIBLE,
4,
False,
NetworkMode.HOST,
False,
id="host-when-possible--entire-instance",
),
pytest.param(
JobNetworkMode.HOST_WHEN_POSSIBLE,
4,
True,
NetworkMode.HOST,
False,
id="host-when-possible--entire-instance--multinode",
),
pytest.param(
JobNetworkMode.FORCED_BRIDGE,
2,
False,
NetworkMode.BRIDGE,
True,
id="forced-bridge--half-of-instance",
),
pytest.param(
JobNetworkMode.FORCED_BRIDGE,
4,
False,
NetworkMode.BRIDGE,
False,
id="forced-bridge--entire-instance",
),
pytest.param(
JobNetworkMode.FORCED_BRIDGE,
4,
True,
NetworkMode.BRIDGE,
False,
id="forced-bridge--entire-instance--multinode",
),
],
)
def test_prepare_job_runtime_data(
monkeypatch: pytest.MonkeyPatch,
job_network_mode: JobNetworkMode,
blocks: int,
multinode: bool,
network_mode: NetworkMode,
constraints_are_set: bool,
):
monkeypatch.setattr("dstack._internal.server.settings.JOB_NETWORK_MODE", job_network_mode)
offer = get_instance_offer_with_availability(blocks=blocks, total_blocks=4)
jrd = _prepare_job_runtime_data(offer=offer, multinode=multinode)
assert jrd.network_mode == network_mode
if constraints_are_set:
assert jrd.gpu is not None
assert jrd.cpu is not None
assert jrd.memory is not None
else:
assert jrd.gpu is None
assert jrd.cpu is None
assert jrd.memory is None
Loading
Loading