diff --git a/src/dstack/_internal/server/background/tasks/process_fleets.py b/src/dstack/_internal/server/background/tasks/process_fleets.py index 176b56644..7b6cdc5de 100644 --- a/src/dstack/_internal/server/background/tasks/process_fleets.py +++ b/src/dstack/_internal/server/background/tasks/process_fleets.py @@ -1,10 +1,11 @@ +from collections import defaultdict from datetime import timedelta from typing import List from uuid import UUID from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm import joinedload, load_only, selectinload from dstack._internal.core.models.fleets import FleetSpec, FleetStatus from dstack._internal.core.models.instances import InstanceStatus @@ -37,30 +38,68 @@ @sentry_utils.instrument_background_task async def process_fleets(): - lock, lockset = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__) + fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset( + FleetModel.__tablename__ + ) + instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset( + InstanceModel.__tablename__ + ) async with get_session_ctx() as session: - async with lock: + async with fleet_lock, instance_lock: res = await session.execute( select(FleetModel) .where( FleetModel.deleted == False, - FleetModel.id.not_in(lockset), + FleetModel.id.not_in(fleet_lockset), FleetModel.last_processed_at < get_current_datetime() - MIN_PROCESSING_INTERVAL, ) - .options(load_only(FleetModel.id)) + .options( + load_only(FleetModel.id, FleetModel.name), + selectinload(FleetModel.instances).load_only(InstanceModel.id), + ) .order_by(FleetModel.last_processed_at.asc()) .limit(BATCH_SIZE) .with_for_update(skip_locked=True, key_share=True) ) - fleet_models = list(res.scalars().all()) + fleet_models = list(res.scalars().unique().all()) fleet_ids = [fm.id for fm in fleet_models] + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id.not_in(instance_lockset), + InstanceModel.fleet_id.in_(fleet_ids), + ) + .options(load_only(InstanceModel.id, InstanceModel.fleet_id)) + .order_by(InstanceModel.id) + .with_for_update(skip_locked=True, key_share=True) + ) + instance_models = list(res.scalars().all()) + fleet_id_to_locked_instances = defaultdict(list) + for instance_model in instance_models: + fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model) + # Process only fleets with all instances locked. + # Other fleets won't be processed but will still be locked to avoid new transaction. + # This should not be problematic as long as process_fleets is quick. + fleet_models_to_process = [] + for fleet_model in fleet_models: + if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]): + fleet_models_to_process.append(fleet_model) + else: + logger.debug( + "Fleet %s processing will be skipped: some instance were not locked", + fleet_model.name, + ) for fleet_id in fleet_ids: - lockset.add(fleet_id) + fleet_lockset.add(fleet_id) + instance_ids = [im.id for im in instance_models] + for instance_id in instance_ids: + instance_lockset.add(instance_id) try: - await _process_fleets(session=session, fleet_models=fleet_models) + await _process_fleets(session=session, fleet_models=fleet_models_to_process) finally: - lockset.difference_update(fleet_ids) + fleet_lockset.difference_update(fleet_ids) + instance_lockset.difference_update(instance_ids) async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]): @@ -99,8 +138,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet return if not _is_fleet_ready_for_consolidation(fleet_model): return - added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec) - if added_instances: + changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec) + if changed_instances: fleet_model.consolidation_attempt += 1 else: # The fleet is already consolidated or consolidation is in progress. @@ -138,28 +177,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: return _CONSOLIDATION_RETRY_DELAYS[-1] -def _maintain_fleet_nodes_min( +def _maintain_fleet_nodes_in_min_max_range( session: AsyncSession, fleet_model: FleetModel, fleet_spec: FleetSpec, ) -> bool: """ - Ensures the fleet has at least `nodes.min` instances. - Returns `True` if retried or added new instances and `False` otherwise. + Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances. + Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise. """ assert fleet_spec.configuration.nodes is not None for instance in fleet_model.instances: # Delete terminated but not deleted instances since # they are going to be replaced with new pending instances. if instance.status == InstanceStatus.TERMINATED and not instance.deleted: - # It's safe to modify instances without instance lock since - # no other task modifies already terminated instances. instance.deleted = True instance.deleted_at = get_current_datetime() active_instances = [i for i in fleet_model.instances if not i.deleted] active_instances_num = len(active_instances) if active_instances_num >= fleet_spec.configuration.nodes.min: - return False + if ( + fleet_spec.configuration.nodes.max is None + or active_instances_num <= fleet_spec.configuration.nodes.max + ): + return False + # Fleet has more instances than allowed by nodes.max. + # This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently) + # or if nodes.max is updated. + nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max + for instance in fleet_model.instances: + if nodes_redundant == 0: + break + if instance.status in [InstanceStatus.IDLE]: + instance.status = InstanceStatus.TERMINATING + instance.termination_reason = "Fleet has too many instances" + nodes_redundant -= 1 + logger.info( + "Terminating instance %s: %s", + instance.name, + instance.termination_reason, + ) + return True nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num for i in range(nodes_missing): instance_model = create_fleet_instance_model( diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index e8dc7b2f3..2814840b5 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -260,7 +260,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): instance_filters = [ InstanceModel.deleted == False, - InstanceModel.total_blocks > InstanceModel.busy_blocks, InstanceModel.id.not_in(detaching_instances_ids), ] @@ -514,9 +513,6 @@ async def _find_optimal_fleet_with_offers( ) return run_model.fleet, fleet_instances_with_pool_offers - if len(fleet_models) == 0: - return None, [] - nodes_required_num = _get_nodes_required_num_for_run(run_spec) # The current strategy is first to consider fleets that can accommodate # the run without additional provisioning and choose the one with the cheapest pool offer. @@ -534,6 +530,7 @@ async def _find_optimal_fleet_with_offers( ] ] = [] for candidate_fleet_model in fleet_models: + candidate_fleet = fleet_model_to_fleet(candidate_fleet_model) fleet_instances_with_pool_offers = _get_fleet_instances_with_pool_offers( fleet_model=candidate_fleet_model, run_spec=run_spec, @@ -541,24 +538,21 @@ async def _find_optimal_fleet_with_offers( master_job_provisioning_data=master_job_provisioning_data, volumes=volumes, ) - fleet_has_available_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers) + fleet_has_pool_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers) fleet_cheapest_pool_offer = math.inf if len(fleet_instances_with_pool_offers) > 0: fleet_cheapest_pool_offer = fleet_instances_with_pool_offers[0][1].price - candidate_fleet = fleet_model_to_fleet(candidate_fleet_model) - profile = None - requirements = None try: + _check_can_create_new_instance_in_fleet(candidate_fleet) profile, requirements = _get_run_profile_and_requirements_in_fleet( job=job, run_spec=run_spec, fleet=candidate_fleet, ) except ValueError: - pass - fleet_backend_offers = [] - if profile is not None and requirements is not None: + fleet_backend_offers = [] + else: multinode = ( candidate_fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER or job.job_spec.jobs_per_replica > 1 @@ -579,8 +573,12 @@ async def _find_optimal_fleet_with_offers( if len(fleet_backend_offers) > 0: fleet_cheapest_backend_offer = fleet_backend_offers[0][1].price + if not _run_can_fit_into_fleet(run_spec, candidate_fleet): + logger.debug("Skipping fleet %s from consideration: run cannot fit into fleet") + continue + fleet_priority = ( - not fleet_has_available_capacity, + not fleet_has_pool_capacity, fleet_cheapest_pool_offer, fleet_cheapest_backend_offer, ) @@ -593,10 +591,13 @@ async def _find_optimal_fleet_with_offers( fleet_priority, ) ) + if len(candidate_fleets_with_offers) == 0: + return None, [] if run_spec.merged_profile.fleets is None and all( t[2] == 0 and t[3] == 0 for t in candidate_fleets_with_offers ): - # If fleets are not specified and no fleets have available pool or backend offers, create a new fleet. + # If fleets are not specified and no fleets have available pool + # or backend offers, create a new fleet. # This is for compatibility with non-fleet-first UX when runs created new fleets # if there are no instances to reuse. return None, [] @@ -616,6 +617,39 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int: return nodes_required_num +def _run_can_fit_into_fleet(run_spec: RunSpec, fleet: Fleet) -> bool: + """ + Returns `False` if the run cannot fit into fleet for sure. + This is helpful heuristic to avoid even considering fleets too small for a run. + A run may not fit even if this function returns `True`. + This will lead to some jobs failing due to exceeding `nodes.max` + or more than `nodes.max` instances being provisioned + and eventually removed by the fleet consolidation logic. + """ + # No check for cloud fleets with blocks > 1 since we don't know + # how many jobs such fleets can accommodate. + nodes_required_num = _get_nodes_required_num_for_run(run_spec) + if ( + fleet.spec.configuration.nodes is not None + and fleet.spec.configuration.blocks == 1 + and fleet.spec.configuration.nodes.max is not None + ): + busy_instances = [i for i in fleet.instances if i.busy_blocks > 0] + fleet_available_capacity = fleet.spec.configuration.nodes.max - len(busy_instances) + if fleet_available_capacity < nodes_required_num: + return False + elif fleet.spec.configuration.ssh_config is not None: + # Currently assume that each idle block can run a job. + # TODO: Take resources / eligible offers into account. + total_idle_blocks = 0 + for instance in fleet.instances: + total_blocks = instance.total_blocks or 1 + total_idle_blocks += total_blocks - instance.busy_blocks + if total_idle_blocks < nodes_required_num: + return False + return True + + def _get_fleet_instances_with_pool_offers( fleet_model: FleetModel, run_spec: RunSpec, @@ -713,6 +747,7 @@ async def _run_job_on_new_instance( if fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) try: + _check_can_create_new_instance_in_fleet(fleet) profile, requirements = _get_run_profile_and_requirements_in_fleet( job=job, run_spec=run.run_spec, @@ -787,8 +822,6 @@ def _get_run_profile_and_requirements_in_fleet( run_spec: RunSpec, fleet: Fleet, ) -> tuple[Profile, Requirements]: - if not _check_can_create_new_instance_in_fleet(fleet): - raise ValueError("Cannot fit new instance into fleet") profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, run_spec.merged_profile) if profile is None: raise ValueError("Cannot combine fleet profile") @@ -801,13 +834,23 @@ def _get_run_profile_and_requirements_in_fleet( return profile, requirements -def _check_can_create_new_instance_in_fleet(fleet: Fleet) -> bool: +def _check_can_create_new_instance_in_fleet(fleet: Fleet): + if not _can_create_new_instance_in_fleet(fleet): + raise ValueError("Cannot fit new instance into fleet") + + +def _can_create_new_instance_in_fleet(fleet: Fleet) -> bool: if fleet.spec.configuration.ssh_config is not None: return False - # TODO: Respect nodes.max - # Ensure concurrent provisioning does not violate nodes.max - # E.g. lock fleet and split instance model creation - # and instance provisioning into separate transactions. + active_instances = [i for i in fleet.instances if i.status.is_active()] + # nodes.max is a soft limit that can be exceeded when provisioning concurrently. + # The fleet consolidation logic will remove redundant nodes eventually. + if ( + fleet.spec.configuration.nodes is not None + and fleet.spec.configuration.nodes.max is not None + and len(active_instances) >= fleet.spec.configuration.nodes.max + ): + return False return True diff --git a/src/tests/_internal/server/background/tasks/test_process_fleets.py b/src/tests/_internal/server/background/tasks/test_process_fleets.py index 2d47da6a7..4370f77b5 100644 --- a/src/tests/_internal/server/background/tasks/test_process_fleets.py +++ b/src/tests/_internal/server/background/tasks/test_process_fleets.py @@ -126,3 +126,36 @@ async def test_consolidation_creates_missing_instances(self, test_db, session: A instances = (await session.execute(select(InstanceModel))).scalars().all() assert len(instances) == 2 assert {i.instance_num for i in instances} == {0, 1} # uses 0 for next instance num + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_consolidation_terminates_redundant_instances( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=0, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + await process_fleets() + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status == InstanceStatus.BUSY + assert instance2.status == InstanceStatus.TERMINATING diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index cbf387284..b4ebf9fb5 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -494,7 +494,9 @@ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSessi project_id=project.id, ) offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128) - fleet = await create_fleet(session=session, project=project) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.blocks = 4 + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) instance = await create_instance( session=session, project=project, @@ -537,7 +539,9 @@ async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session: project_id=project.id, ) offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128) - fleet = await create_fleet(session=session, project=project) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=None) + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) instance = await create_instance( session=session, project=project, @@ -743,6 +747,55 @@ async def test_assigns_no_fleet_when_all_fleets_occupied(self, test_db, session: assert job.instance_id is None assert job.fleet_id is None + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_assigns_no_fleet_if_run_cannot_fit(self, test_db, session: AsyncSession): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=3) + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + instance_num=0, + status=InstanceStatus.BUSY, + busy_blocks=1, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + instance_num=1, + status=InstanceStatus.IDLE, + busy_blocks=0, + ) + fleet.instances.append(instance1) + fleet.instances.append(instance2) + run_spec = get_run_spec(repo_id=repo.name) + run_spec.configuration = TaskConfiguration(nodes=3, commands=["echo"]) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + await session.commit() + await process_submitted_jobs() + await session.refresh(job) + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance_id is None + assert job.fleet_id is None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_does_not_assign_job_to_elastic_empty_fleet_without_backend_offers_if_fleets_unspecified(