Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/dstack/_internal/server/background/pipeline_tasks/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class _ProcessResult:
fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap)
instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict)
new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list)
consolidation_limit_reached: bool = False


class _NewInstanceCreate(TypedDict):
Expand Down Expand Up @@ -358,6 +359,8 @@ def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optio
or consolidation_fleet_spec.autocreated
):
return None
if fleet_model.consolidation_attempt >= _MAX_CONSOLIDATION_ATTEMPTS:
return None
if not _is_fleet_ready_for_consolidation(fleet_model):
return None
return consolidation_fleet_spec
Expand Down Expand Up @@ -502,6 +505,16 @@ async def _apply_process_result(
"status_message", context.fleet_model.status_message
),
)
if result.consolidation_limit_reached:
events.emit(
session=session,
message=(
f"Fleet consolidation stopped after {_MAX_CONSOLIDATION_ATTEMPTS} attempts."
" Update the fleet to resume"
),
actor=events.SystemActor(),
targets=[events.Target.from_model(context.fleet_model)],
)


async def _process_fleet(
Expand Down Expand Up @@ -560,7 +573,10 @@ def _consolidate_fleet_state_with_spec(
result.instance_id_to_update_map.update(maintain_nodes_result.instance_id_to_update_map)
result.new_instance_creates = maintain_nodes_result.new_instance_creates
if len(spec_mismatch_updates) > 0 or maintain_nodes_result.changes_required:
result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1
new_attempt = fleet_model.consolidation_attempt + 1
result.fleet_update_map["consolidation_attempt"] = new_attempt
if new_attempt >= _MAX_CONSOLIDATION_ATTEMPTS:
result.consolidation_limit_reached = True
else:
# The fleet is consolidated with respect to spec and nodes min/max.
result.fleet_update_map["consolidation_attempt"] = 0
Expand All @@ -575,6 +591,8 @@ def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool:
return duration_since_last_consolidation >= consolidation_retry_delay


_MAX_CONSOLIDATION_ATTEMPTS = 15

# We use exponentially increasing consolidation retry delays so that
# consolidation does not happen too often. In particular, this prevents
# retrying instance provisioning constantly in case of no offers.
Expand Down
147 changes: 146 additions & 1 deletion src/tests/_internal/server/background/pipeline_tasks/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FleetPipeline,
FleetWorker,
)
from dstack._internal.server.models import FleetModel, InstanceModel
from dstack._internal.server.models import EventModel, EventTargetModel, FleetModel, InstanceModel
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
create_fleet,
Expand Down Expand Up @@ -1409,3 +1409,148 @@ async def test_consolidation_preserves_instances_matching_fleet_spec(
await session.refresh(instance)
assert instance.status == InstanceStatus.IDLE
assert fleet.consolidation_attempt == 0

async def test_consolidation_stops_at_max_attempts(
self, test_db, session: AsyncSession, worker: FleetWorker
):
project = await create_project(session)
spec = get_fleet_spec()
spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2)
fleet = await create_fleet(
session=session,
project=project,
spec=spec,
)
await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.IDLE,
instance_num=0,
)
fleet.consolidation_attempt = fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS
fleet.last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
fleet.lock_token = uuid.uuid4()
fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
await session.commit()

await worker.process(_fleet_to_pipeline_item(fleet))

await session.refresh(fleet)
instances = (
(
await session.execute(
select(InstanceModel).where(
InstanceModel.fleet_id == fleet.id,
InstanceModel.deleted == False,
)
)
)
.scalars()
.all()
)
assert len(instances) == 1
assert fleet.consolidation_attempt == fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS
assert not fleet.deleted

async def test_consolidation_emits_event_on_reaching_limit(
self, test_db, session: AsyncSession, worker: FleetWorker
):
project = await create_project(session)
spec = get_fleet_spec()
spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2)
fleet = await create_fleet(
session=session,
project=project,
spec=spec,
)
await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.IDLE,
instance_num=0,
)
fleet.consolidation_attempt = fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS - 1
fleet.last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
fleet.lock_token = uuid.uuid4()
fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
await session.commit()

await worker.process(_fleet_to_pipeline_item(fleet))

await session.refresh(fleet)
instances = (
(
await session.execute(
select(InstanceModel).where(
InstanceModel.fleet_id == fleet.id,
InstanceModel.deleted == False,
)
)
)
.scalars()
.all()
)
# Last allowed consolidation still creates the missing instance
assert len(instances) == 2
assert fleet.consolidation_attempt == fleets_pipeline._MAX_CONSOLIDATION_ATTEMPTS
# Verify the consolidation-stopped event was emitted
event_models = (
(
await session.execute(
select(EventModel)
.join(EventTargetModel)
.where(EventTargetModel.entity_id == fleet.id)
)
)
.scalars()
.all()
)
consolidation_stopped_events = [
e for e in event_models if "consolidation stopped" in e.message
]
assert len(consolidation_stopped_events) == 1

async def test_consolidation_resumes_after_attempt_reset(
self, test_db, session: AsyncSession, worker: FleetWorker
):
project = await create_project(session)
spec = get_fleet_spec()
spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2)
fleet = await create_fleet(
session=session,
project=project,
spec=spec,
)
await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.IDLE,
instance_num=0,
)
# Simulate in-place update resetting the attempt counter
fleet.consolidation_attempt = 0
fleet.last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
fleet.lock_token = uuid.uuid4()
fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
await session.commit()

await worker.process(_fleet_to_pipeline_item(fleet))

await session.refresh(fleet)
instances = (
(
await session.execute(
select(InstanceModel).where(
InstanceModel.fleet_id == fleet.id,
InstanceModel.deleted == False,
)
)
)
.scalars()
.all()
)
assert len(instances) == 2
assert fleet.consolidation_attempt == 1
Loading