diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index eda62ba08..747aef25d 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -160,6 +160,11 @@ async def lifespan(app: FastAPI): logger.info("Background processing is disabled") PROBES_SCHEDULER.start() dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)" + logger.info( + "Job network mode: %s (%d)", + settings.JOB_NETWORK_MODE.name, + settings.JOB_NETWORK_MODE.value, + ) logger.info(f"The admin token is {admin.token.get_plaintext_or_error()}", {"show_path": False}) logger.info( f"The dstack server {dstack_version} is running at {SERVER_URL}", diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 21b699a0b..f5a087edf 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -87,7 +87,6 @@ ) from dstack._internal.server.utils import sentry_utils from dstack._internal.utils import common as common_utils -from dstack._internal.utils import env as env_utils from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -188,6 +187,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): run_spec = run.run_spec profile = run_spec.merged_profile job = find_job(run.jobs, job_model.replica_num, job_model.job_num) + multinode = job.job_spec.jobs_per_replica > 1 # Master job chooses fleet for the run. # Due to two-step processing, it's saved to job_model.fleet. @@ -310,6 +310,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): session=session, instances_with_offers=fleet_instances_with_offers, job_model=job_model, + multinode=multinode, ) job_model.fleet = fleet_model job_model.instance_assigned = True @@ -385,7 +386,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): offer=offer, instance_num=instance_num, ) - job_model.job_runtime_data = _prepare_job_runtime_data(offer).json() + job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json() # Both this task and process_fleets can add instances to fleets. # TODO: Ensure this does not violate nodes.max when it's enforced. instance.fleet_id = fleet_model.id @@ -614,6 +615,7 @@ async def _assign_job_to_fleet_instance( session: AsyncSession, instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]], job_model: JobModel, + multinode: bool, ) -> Optional[InstanceModel]: if len(instances_with_offers) == 0: return None @@ -643,7 +645,7 @@ async def _assign_job_to_fleet_instance( job_model.instance = instance job_model.used_instance_id = instance.id job_model.job_provisioning_data = instance.job_provisioning_data - job_model.job_runtime_data = _prepare_job_runtime_data(offer).json() + job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json() return instance @@ -839,12 +841,17 @@ def _create_instance_model_for_job( return instance -def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData: +def _prepare_job_runtime_data( + offer: InstanceOfferWithAvailability, multinode: bool +) -> JobRuntimeData: if offer.blocks == offer.total_blocks: - if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"): + if settings.JOB_NETWORK_MODE == settings.JobNetworkMode.FORCED_BRIDGE: network_mode = NetworkMode.BRIDGE - else: + elif settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_WHEN_POSSIBLE: network_mode = NetworkMode.HOST + else: + assert settings.JOB_NETWORK_MODE == settings.JobNetworkMode.HOST_FOR_MULTINODE_ONLY + network_mode = NetworkMode.HOST if multinode else NetworkMode.BRIDGE return JobRuntimeData( network_mode=network_mode, offer=offer, diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index ac4924d9c..cbeb74903 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -4,8 +4,14 @@ import os import warnings +from enum import Enum from pathlib import Path +from dstack._internal.utils.env import environ +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + DSTACK_DIR_PATH = Path("~/.dstack/").expanduser() SERVER_DIR_PATH = Path(os.getenv("DSTACK_SERVER_DIR", DSTACK_DIR_PATH / "server")) @@ -136,3 +142,43 @@ DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE") is not None ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS") is not None + + +class JobNetworkMode(Enum): + # "host" for multinode runs only, "bridge" otherwise. Opt-in new defaut + HOST_FOR_MULTINODE_ONLY = 1 + # "bridge" if the job occupies only a part of the instance, "host" otherswise. Current default + HOST_WHEN_POSSIBLE = 2 + # Always "bridge", even for multinode runs. Same as legacy DSTACK_FORCE_BRIDGE_NETWORK=true + FORCED_BRIDGE = 3 + + +def _get_job_network_mode() -> JobNetworkMode: + # Current default + mode = JobNetworkMode.HOST_WHEN_POSSIBLE + bridge_var = "DSTACK_FORCE_BRIDGE_NETWORK" + force_bridge = environ.get_bool(bridge_var) + mode_var = "DSTACK_SERVER_JOB_NETWORK_MODE" + mode_from_env = environ.get_enum(mode_var, JobNetworkMode, value_type=int) + if mode_from_env is not None: + if force_bridge is not None: + logger.warning( + f"{bridge_var} is deprecated since 0.19.27 and ignored when {mode_var} is set" + ) + return mode_from_env + if force_bridge is not None: + if force_bridge: + mode = JobNetworkMode.FORCED_BRIDGE + logger.warning( + ( + f"{bridge_var} is deprecated since 0.19.27." + f" Set {mode_var} to {mode.value} and remove {bridge_var}" + ) + ) + else: + logger.warning(f"{bridge_var} is deprecated since 0.19.27. Remove {bridge_var}") + return mode + + +JOB_NETWORK_MODE = _get_job_network_mode() +del _get_job_network_mode diff --git a/src/dstack/_internal/utils/env.py b/src/dstack/_internal/utils/env.py index 28152c1e5..53eb4c61d 100644 --- a/src/dstack/_internal/utils/env.py +++ b/src/dstack/_internal/utils/env.py @@ -1,14 +1,88 @@ import os +from collections.abc import Mapping +from enum import Enum +from typing import Optional, TypeVar, Union, overload +_Value = Union[str, int] +_T = TypeVar("_T", bound=Enum) -def get_bool(name: str, default: bool = False) -> bool: - try: - value = os.environ[name] - except KeyError: - return default - value = value.lower() - if value in ["0", "false", "off"]: - return False - if value in ["1", "true", "on"]: - return True - raise ValueError(f"Invalid bool value: {name}={value}") + +class Environ: + def __init__(self, environ: Mapping[str, str]): + self._environ = environ + + @overload + def get_bool(self, name: str, *, default: None = None) -> Optional[bool]: ... + + @overload + def get_bool(self, name: str, *, default: bool) -> bool: ... + + def get_bool(self, name: str, *, default: Optional[bool] = None) -> Optional[bool]: + try: + raw_value = self._environ[name] + except KeyError: + return default + value = raw_value.lower() + if value in ["0", "false", "off"]: + return False + if value in ["1", "true", "on"]: + return True + raise ValueError(f"Invalid bool value: {name}={raw_value}") + + @overload + def get_int(self, name: str, *, default: None = None) -> Optional[int]: ... + + @overload + def get_int(self, name: str, *, default: int) -> int: ... + + def get_int(self, name: str, *, default: Optional[int] = None) -> Optional[int]: + try: + raw_value = self._environ[name] + except KeyError: + return default + try: + return int(raw_value) + except ValueError as e: + raise ValueError(f"Invalid int value: {e}: {name}={raw_value}") from e + + @overload + def get_enum( + self, + name: str, + enum_cls: type[_T], + *, + value_type: Optional[type[_Value]] = None, + default: None = None, + ) -> Optional[_T]: ... + + @overload + def get_enum( + self, + name: str, + enum_cls: type[_T], + *, + value_type: Optional[type[_Value]] = None, + default: _T, + ) -> _T: ... + + def get_enum( + self, + name: str, + enum_cls: type[_T], + *, + value_type: Optional[type[_Value]] = None, + default: Optional[_T] = None, + ) -> Optional[_T]: + try: + raw_value = self._environ[name] + except KeyError: + return default + try: + if value_type is not None: + raw_value = value_type(raw_value) + return enum_cls(raw_value) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid {enum_cls.__name__} value: {e}: {name}={raw_value}") from e + + +environ = Environ(os.environ) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index f3f7df124..109dd4f2e 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import joinedload from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode from dstack._internal.core.models.configurations import TaskConfiguration from dstack._internal.core.models.fleets import FleetNodesSpec from dstack._internal.core.models.health import HealthStatus @@ -25,8 +26,12 @@ VolumeMountPoint, VolumeStatus, ) -from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs +from dstack._internal.server.background.tasks.process_submitted_jobs import ( + _prepare_job_runtime_data, + process_submitted_jobs, +) from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel +from dstack._internal.server.settings import JobNetworkMode from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, @@ -1004,3 +1009,102 @@ async def test_picks_high_priority_jobs_first(self, test_db, session: AsyncSessi await process_submitted_jobs() await session.refresh(job2) assert job2.status == JobStatus.PROVISIONING + + +@pytest.mark.parametrize( + ["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"], + [ + pytest.param( + JobNetworkMode.HOST_FOR_MULTINODE_ONLY, + 2, + False, + NetworkMode.BRIDGE, + True, + id="host-for-multinode-only--half-of-instance", + ), + pytest.param( + JobNetworkMode.HOST_FOR_MULTINODE_ONLY, + 4, + False, + NetworkMode.BRIDGE, + False, + id="host-for-multinode-only--entire-instance", + ), + pytest.param( + JobNetworkMode.HOST_FOR_MULTINODE_ONLY, + 4, + True, + NetworkMode.HOST, + False, + id="host-for-multinode-only--entire-instance--multinode", + ), + pytest.param( + JobNetworkMode.HOST_WHEN_POSSIBLE, + 2, + False, + NetworkMode.BRIDGE, + True, + id="host-when-possible--half-of-instance", + ), + pytest.param( + JobNetworkMode.HOST_WHEN_POSSIBLE, + 4, + False, + NetworkMode.HOST, + False, + id="host-when-possible--entire-instance", + ), + pytest.param( + JobNetworkMode.HOST_WHEN_POSSIBLE, + 4, + True, + NetworkMode.HOST, + False, + id="host-when-possible--entire-instance--multinode", + ), + pytest.param( + JobNetworkMode.FORCED_BRIDGE, + 2, + False, + NetworkMode.BRIDGE, + True, + id="forced-bridge--half-of-instance", + ), + pytest.param( + JobNetworkMode.FORCED_BRIDGE, + 4, + False, + NetworkMode.BRIDGE, + False, + id="forced-bridge--entire-instance", + ), + pytest.param( + JobNetworkMode.FORCED_BRIDGE, + 4, + True, + NetworkMode.BRIDGE, + False, + id="forced-bridge--entire-instance--multinode", + ), + ], +) +def test_prepare_job_runtime_data( + monkeypatch: pytest.MonkeyPatch, + job_network_mode: JobNetworkMode, + blocks: int, + multinode: bool, + network_mode: NetworkMode, + constraints_are_set: bool, +): + monkeypatch.setattr("dstack._internal.server.settings.JOB_NETWORK_MODE", job_network_mode) + offer = get_instance_offer_with_availability(blocks=blocks, total_blocks=4) + jrd = _prepare_job_runtime_data(offer=offer, multinode=multinode) + assert jrd.network_mode == network_mode + if constraints_are_set: + assert jrd.gpu is not None + assert jrd.cpu is not None + assert jrd.memory is not None + else: + assert jrd.gpu is None + assert jrd.cpu is None + assert jrd.memory is None diff --git a/src/tests/_internal/utils/test_env.py b/src/tests/_internal/utils/test_env.py index 1c222a29b..4e9e2f6b7 100644 --- a/src/tests/_internal/utils/test_env.py +++ b/src/tests/_internal/utils/test_env.py @@ -1,38 +1,113 @@ +from enum import Enum +from typing import Union + import pytest -from dstack._internal.utils.env import get_bool +from dstack._internal.utils.env import Environ, _Value + + +class _TestEnviron: + def get_environ(self, **env: str) -> Environ: + return Environ(env) + + +class TestEnvironGetBool(_TestEnviron): + @pytest.mark.parametrize( + ["value", "expected"], + [ + ["0", False], + ["1", True], + ["true", True], + ["True", True], + ["FALSE", False], + ["off", False], + ["ON", True], + ], + ) + def test_is_set(self, value: str, expected: bool): + environ = self.get_environ(VAR=value) + assert environ.get_bool("VAR") is expected + + def test_not_set_default_not_set(self): + environ = self.get_environ() + assert environ.get_bool("VAR") is None + + @pytest.mark.parametrize("default", [False, True]) + def test_not_set_default_is_set(self, default: bool): + environ = self.get_environ() + assert environ.get_bool("VAR", default=default) is default + + @pytest.mark.parametrize("value", ["", "2", "foo"]) + def test_error_bad_value(self, value: str): + environ = self.get_environ(VAR=value) + with pytest.raises(ValueError, match=f"VAR={value}"): + environ.get_bool("VAR") + + +class TestEnvironGetInt(_TestEnviron): + def test_is_set(self): + environ = self.get_environ(VAR="12") + assert environ.get_int("VAR") == 12 + + def test_not_set_default_not_set(self): + environ = self.get_environ() + assert environ.get_int("VAR") is None + + def test_not_set_default_is_set(self): + environ = self.get_environ() + assert environ.get_int("VAR", default=12) == 12 + + @pytest.mark.parametrize("value", ["", "false", "10a"]) + def test_error_bad_value(self, value: str): + environ = self.get_environ(VAR=value) + with pytest.raises(ValueError, match=f"VAR={value}"): + environ.get_int("VAR") + + +class _Enum(Enum): + FOO: Union[str, int] + BAR: Union[str, int] + +class _StrEnum(_Enum): + FOO = "foo" + BAR = "bar" -@pytest.mark.parametrize( - ["value", "expected"], - [ - ["0", False], - ["1", True], - ["true", True], - ["True", True], - ["FALSE", False], - ["off", False], - ["ON", True], - ], -) -def test_get_bool_is_set(monkeypatch: pytest.MonkeyPatch, value: str, expected: bool): - monkeypatch.setenv("VAR", value) - assert get_bool("VAR") is expected +class _IntEnum(_Enum): + FOO = 100 + BAR = 200 -def test_get_bool_not_set_default_not_set(monkeypatch: pytest.MonkeyPatch): - monkeypatch.delenv("VAR", raising=False) - assert get_bool("VAR") is False +class TestEnvironGetEnum(_TestEnviron): + @pytest.mark.parametrize( + ["enum_cls", "value_type", "value"], + [ + pytest.param(_StrEnum, str, "foo", id="str"), + pytest.param(_IntEnum, int, "100", id="int"), + ], + ) + def test_is_set(self, enum_cls: type[_Enum], value_type: type[_Value], value: str): + environ = self.get_environ(VAR=value) + assert environ.get_enum("VAR", enum_cls, value_type=value_type) is enum_cls.FOO -@pytest.mark.parametrize("default", [False, True]) -def test_get_bool_not_set_default_is_set(monkeypatch: pytest.MonkeyPatch, default: bool): - monkeypatch.delenv("VAR", raising=False) - assert get_bool("VAR", default) is default + def test_not_set_default_not_set(self): + environ = self.get_environ() + assert environ.get_enum("VAR", _StrEnum) is None + def test_not_set_default_is_set(self): + environ = self.get_environ() + assert environ.get_enum("VAR", _IntEnum, default=_IntEnum.BAR) is _IntEnum.BAR -@pytest.mark.parametrize("value", ["", "2", "foo"]) -def test_get_bool_error_value(monkeypatch: pytest.MonkeyPatch, value: str): - monkeypatch.setenv("VAR", value) - with pytest.raises(ValueError, match=f"VAR={value}"): - assert get_bool("VAR") + @pytest.mark.parametrize( + ["enum_cls", "value_type", "value"], + [ + pytest.param(_StrEnum, str, "baz", id="str"), + pytest.param(_IntEnum, int, "300", id="int"), + pytest.param(_IntEnum, int, "10a", id="invalid-int"), + ], + ) + def test_error_bad_value(self, enum_cls: type[_Enum], value_type: type[_Value], value: str): + environ = self.get_environ(VAR=value) + with pytest.raises(ValueError, match=f"VAR={value}"): + environ.get_enum("VAR", enum_cls, value_type=value_type)