diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index e2a4e6a55..1c4758889 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -245,6 +245,7 @@ class _ProcessResult: fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) + consolidation_limit_reached: bool = False class _NewInstanceCreate(TypedDict): @@ -358,6 +359,8 @@ def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optio or consolidation_fleet_spec.autocreated ): return None + if fleet_model.consolidation_attempt >= _MAX_CONSOLIDATION_ATTEMPTS: + return None if not _is_fleet_ready_for_consolidation(fleet_model): return None return consolidation_fleet_spec @@ -502,6 +505,16 @@ async def _apply_process_result( "status_message", context.fleet_model.status_message ), ) + if result.consolidation_limit_reached: + events.emit( + session=session, + message=( + f"Fleet consolidation stopped after {_MAX_CONSOLIDATION_ATTEMPTS} attempts." + " Update the fleet to resume" + ), + actor=events.SystemActor(), + targets=[events.Target.from_model(context.fleet_model)], + ) async def _process_fleet( @@ -560,7 +573,10 @@ def _consolidate_fleet_state_with_spec( result.instance_id_to_update_map.update(maintain_nodes_result.instance_id_to_update_map) result.new_instance_creates = maintain_nodes_result.new_instance_creates if len(spec_mismatch_updates) > 0 or maintain_nodes_result.changes_required: - result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 + new_attempt = fleet_model.consolidation_attempt + 1 + result.fleet_update_map["consolidation_attempt"] = new_attempt + if new_attempt >= _MAX_CONSOLIDATION_ATTEMPTS: + result.consolidation_limit_reached = True else: # The fleet is consolidated with respect to spec and nodes min/max. result.fleet_update_map["consolidation_attempt"] = 0 @@ -575,6 +591,8 @@ def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool: return duration_since_last_consolidation >= consolidation_retry_delay +_MAX_CONSOLIDATION_ATTEMPTS = 15 + # We use exponentially increasing consolidation retry delays so that # consolidation does not happen too often. In particular, this prevents # retrying instance provisioning constantly in case of no offers. diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 1acceeeec..4726dfb9d 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -23,7 +23,7 @@ FleetPipeline, FleetWorker, ) -from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.models import EventModel, EventTargetModel, FleetModel, InstanceModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( create_fleet, @@ -1409,3 +1409,148 @@ async def test_consolidation_preserves_instances_matching_fleet_spec( await session.refresh(instance) assert instance.status == InstanceStatus.IDLE assert fleet.consolidation_attempt == 0 + + async def test_consolidation_stops_at_max_attempts( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.consolidation_attempt = fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS + fleet.last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 1 + assert fleet.consolidation_attempt == fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS + assert not fleet.deleted + + async def test_consolidation_emits_event_on_reaching_limit( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.consolidation_attempt = fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS - 1 + fleet.last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + # Last allowed consolidation still creates the missing instance + assert len(instances) == 2 + assert fleet.consolidation_attempt == fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS + # Verify the consolidation-stopped event was emitted + event_models = ( + ( + await session.execute( + select(EventModel) + .join(EventTargetModel) + .where(EventTargetModel.entity_id == fleet.id) + ) + ) + .scalars() + .all() + ) + consolidation_stopped_events = [ + e for e in event_models if "consolidation stopped" in e.message + ] + assert len(consolidation_stopped_events) == 1 + + async def test_consolidation_resumes_after_attempt_reset( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + # Simulate in-place update resetting the attempt counter + fleet.consolidation_attempt = 0 + fleet.last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 2 + assert fleet.consolidation_attempt == 1