Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions src/dstack/_internal/server/background/tasks/process_fleets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections import defaultdict
from datetime import timedelta
from typing import List
from uuid import UUID

from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, load_only
from sqlalchemy.orm import joinedload, load_only, selectinload

from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
from dstack._internal.core.models.instances import InstanceStatus
Expand Down Expand Up @@ -37,30 +38,68 @@

@sentry_utils.instrument_background_task
async def process_fleets():
lock, lockset = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__)
fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset(
FleetModel.__tablename__
)
instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
InstanceModel.__tablename__
)
async with get_session_ctx() as session:
async with lock:
async with fleet_lock, instance_lock:
res = await session.execute(
select(FleetModel)
.where(
FleetModel.deleted == False,
FleetModel.id.not_in(lockset),
FleetModel.id.not_in(fleet_lockset),
FleetModel.last_processed_at
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
)
.options(load_only(FleetModel.id))
.options(
load_only(FleetModel.id, FleetModel.name),
selectinload(FleetModel.instances).load_only(InstanceModel.id),
)
.order_by(FleetModel.last_processed_at.asc())
.limit(BATCH_SIZE)
.with_for_update(skip_locked=True, key_share=True)
)
fleet_models = list(res.scalars().all())
fleet_models = list(res.scalars().unique().all())
fleet_ids = [fm.id for fm in fleet_models]
res = await session.execute(
select(InstanceModel)
.where(
InstanceModel.id.not_in(instance_lockset),
InstanceModel.fleet_id.in_(fleet_ids),
)
.options(load_only(InstanceModel.id, InstanceModel.fleet_id))
.order_by(InstanceModel.id)
.with_for_update(skip_locked=True, key_share=True)
)
instance_models = list(res.scalars().all())
fleet_id_to_locked_instances = defaultdict(list)
for instance_model in instance_models:
fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model)
# Process only fleets with all instances locked.
# Other fleets won't be processed but will still be locked to avoid new transaction.
# This should not be problematic as long as process_fleets is quick.
fleet_models_to_process = []
for fleet_model in fleet_models:
if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]):
fleet_models_to_process.append(fleet_model)
else:
logger.debug(
"Fleet %s processing will be skipped: some instance were not locked",
fleet_model.name,
)
for fleet_id in fleet_ids:
lockset.add(fleet_id)
fleet_lockset.add(fleet_id)
instance_ids = [im.id for im in instance_models]
for instance_id in instance_ids:
instance_lockset.add(instance_id)
try:
await _process_fleets(session=session, fleet_models=fleet_models)
await _process_fleets(session=session, fleet_models=fleet_models_to_process)
finally:
lockset.difference_update(fleet_ids)
fleet_lockset.difference_update(fleet_ids)
instance_lockset.difference_update(instance_ids)


async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]):
Expand Down Expand Up @@ -99,8 +138,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet
return
if not _is_fleet_ready_for_consolidation(fleet_model):
return
added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec)
if added_instances:
changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec)
if changed_instances:
fleet_model.consolidation_attempt += 1
else:
# The fleet is already consolidated or consolidation is in progress.
Expand Down Expand Up @@ -138,28 +177,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
return _CONSOLIDATION_RETRY_DELAYS[-1]


def _maintain_fleet_nodes_min(
def _maintain_fleet_nodes_in_min_max_range(
session: AsyncSession,
fleet_model: FleetModel,
fleet_spec: FleetSpec,
) -> bool:
"""
Ensures the fleet has at least `nodes.min` instances.
Returns `True` if retried or added new instances and `False` otherwise.
Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances.
Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise.
"""
assert fleet_spec.configuration.nodes is not None
for instance in fleet_model.instances:
# Delete terminated but not deleted instances since
# they are going to be replaced with new pending instances.
if instance.status == InstanceStatus.TERMINATED and not instance.deleted:
# It's safe to modify instances without instance lock since
# no other task modifies already terminated instances.
instance.deleted = True
instance.deleted_at = get_current_datetime()
active_instances = [i for i in fleet_model.instances if not i.deleted]
active_instances_num = len(active_instances)
if active_instances_num >= fleet_spec.configuration.nodes.min:
return False
if (
fleet_spec.configuration.nodes.max is None
or active_instances_num <= fleet_spec.configuration.nodes.max
):
return False
# Fleet has more instances than allowed by nodes.max.
# This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently)
# or if nodes.max is updated.
nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max
for instance in fleet_model.instances:
if nodes_redundant == 0:
break
if instance.status in [InstanceStatus.IDLE]:
instance.status = InstanceStatus.TERMINATING
instance.termination_reason = "Fleet has too many instances"
nodes_redundant -= 1
logger.info(
"Terminating instance %s: %s",
instance.name,
instance.termination_reason,
)
return True
nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
for i in range(nodes_missing):
instance_model = create_fleet_instance_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):

instance_filters = [
InstanceModel.deleted == False,
InstanceModel.total_blocks > InstanceModel.busy_blocks,
InstanceModel.id.not_in(detaching_instances_ids),
]

Expand Down Expand Up @@ -514,9 +513,6 @@ async def _find_optimal_fleet_with_offers(
)
return run_model.fleet, fleet_instances_with_pool_offers

if len(fleet_models) == 0:
return None, []

nodes_required_num = _get_nodes_required_num_for_run(run_spec)
# The current strategy is first to consider fleets that can accommodate
# the run without additional provisioning and choose the one with the cheapest pool offer.
Expand All @@ -534,31 +530,29 @@ async def _find_optimal_fleet_with_offers(
]
] = []
for candidate_fleet_model in fleet_models:
candidate_fleet = fleet_model_to_fleet(candidate_fleet_model)
fleet_instances_with_pool_offers = _get_fleet_instances_with_pool_offers(
fleet_model=candidate_fleet_model,
run_spec=run_spec,
job=job,
master_job_provisioning_data=master_job_provisioning_data,
volumes=volumes,
)
fleet_has_available_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers)
fleet_has_pool_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers)
fleet_cheapest_pool_offer = math.inf
if len(fleet_instances_with_pool_offers) > 0:
fleet_cheapest_pool_offer = fleet_instances_with_pool_offers[0][1].price

candidate_fleet = fleet_model_to_fleet(candidate_fleet_model)
profile = None
requirements = None
try:
_check_can_create_new_instance_in_fleet(candidate_fleet)
profile, requirements = _get_run_profile_and_requirements_in_fleet(
job=job,
run_spec=run_spec,
fleet=candidate_fleet,
)
except ValueError:
pass
fleet_backend_offers = []
if profile is not None and requirements is not None:
fleet_backend_offers = []
else:
multinode = (
candidate_fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
or job.job_spec.jobs_per_replica > 1
Expand All @@ -579,8 +573,12 @@ async def _find_optimal_fleet_with_offers(
if len(fleet_backend_offers) > 0:
fleet_cheapest_backend_offer = fleet_backend_offers[0][1].price

if not _run_can_fit_into_fleet(run_spec, candidate_fleet):
logger.debug("Skipping fleet %s from consideration: run cannot fit into fleet")
continue

fleet_priority = (
not fleet_has_available_capacity,
not fleet_has_pool_capacity,
fleet_cheapest_pool_offer,
fleet_cheapest_backend_offer,
)
Expand All @@ -593,10 +591,13 @@ async def _find_optimal_fleet_with_offers(
fleet_priority,
)
)
if len(candidate_fleets_with_offers) == 0:
return None, []
if run_spec.merged_profile.fleets is None and all(
t[2] == 0 and t[3] == 0 for t in candidate_fleets_with_offers
):
# If fleets are not specified and no fleets have available pool or backend offers, create a new fleet.
# If fleets are not specified and no fleets have available pool
# or backend offers, create a new fleet.
# This is for compatibility with non-fleet-first UX when runs created new fleets
# if there are no instances to reuse.
return None, []
Expand All @@ -616,6 +617,39 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int:
return nodes_required_num


def _run_can_fit_into_fleet(run_spec: RunSpec, fleet: Fleet) -> bool:
"""
Returns `False` if the run cannot fit into fleet for sure.
This is helpful heuristic to avoid even considering fleets too small for a run.
A run may not fit even if this function returns `True`.
This will lead to some jobs failing due to exceeding `nodes.max`
or more than `nodes.max` instances being provisioned
and eventually removed by the fleet consolidation logic.
"""
# No check for cloud fleets with blocks > 1 since we don't know
# how many jobs such fleets can accommodate.
nodes_required_num = _get_nodes_required_num_for_run(run_spec)
if (
fleet.spec.configuration.nodes is not None
and fleet.spec.configuration.blocks == 1
and fleet.spec.configuration.nodes.max is not None
):
busy_instances = [i for i in fleet.instances if i.busy_blocks > 0]
fleet_available_capacity = fleet.spec.configuration.nodes.max - len(busy_instances)
if fleet_available_capacity < nodes_required_num:
return False
elif fleet.spec.configuration.ssh_config is not None:
# Currently assume that each idle block can run a job.
# TODO: Take resources / eligible offers into account.
total_idle_blocks = 0
for instance in fleet.instances:
total_blocks = instance.total_blocks or 1
total_idle_blocks += total_blocks - instance.busy_blocks
if total_idle_blocks < nodes_required_num:
return False
return True


def _get_fleet_instances_with_pool_offers(
fleet_model: FleetModel,
run_spec: RunSpec,
Expand Down Expand Up @@ -713,6 +747,7 @@ async def _run_job_on_new_instance(
if fleet_model is not None:
fleet = fleet_model_to_fleet(fleet_model)
try:
_check_can_create_new_instance_in_fleet(fleet)
profile, requirements = _get_run_profile_and_requirements_in_fleet(
job=job,
run_spec=run.run_spec,
Expand Down Expand Up @@ -787,8 +822,6 @@ def _get_run_profile_and_requirements_in_fleet(
run_spec: RunSpec,
fleet: Fleet,
) -> tuple[Profile, Requirements]:
if not _check_can_create_new_instance_in_fleet(fleet):
raise ValueError("Cannot fit new instance into fleet")
profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, run_spec.merged_profile)
if profile is None:
raise ValueError("Cannot combine fleet profile")
Expand All @@ -801,13 +834,23 @@ def _get_run_profile_and_requirements_in_fleet(
return profile, requirements


def _check_can_create_new_instance_in_fleet(fleet: Fleet) -> bool:
def _check_can_create_new_instance_in_fleet(fleet: Fleet):
if not _can_create_new_instance_in_fleet(fleet):
raise ValueError("Cannot fit new instance into fleet")


def _can_create_new_instance_in_fleet(fleet: Fleet) -> bool:
if fleet.spec.configuration.ssh_config is not None:
return False
# TODO: Respect nodes.max
# Ensure concurrent provisioning does not violate nodes.max
# E.g. lock fleet and split instance model creation
# and instance provisioning into separate transactions.
active_instances = [i for i in fleet.instances if i.status.is_active()]
# nodes.max is a soft limit that can be exceeded when provisioning concurrently.
# The fleet consolidation logic will remove redundant nodes eventually.
if (
fleet.spec.configuration.nodes is not None
and fleet.spec.configuration.nodes.max is not None
and len(active_instances) >= fleet.spec.configuration.nodes.max
):
return False
return True


Expand Down
33 changes: 33 additions & 0 deletions src/tests/_internal/server/background/tasks/test_process_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,36 @@ async def test_consolidation_creates_missing_instances(self, test_db, session: A
instances = (await session.execute(select(InstanceModel))).scalars().all()
assert len(instances) == 2
assert {i.instance_num for i in instances} == {0, 1} # uses 0 for next instance num

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_consolidation_terminates_redundant_instances(
self, test_db, session: AsyncSession
):
project = await create_project(session)
spec = get_fleet_spec()
spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1)
fleet = await create_fleet(
session=session,
project=project,
spec=spec,
)
instance1 = await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.BUSY,
instance_num=0,
)
instance2 = await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.IDLE,
instance_num=1,
)
await process_fleets()
await session.refresh(instance1)
await session.refresh(instance2)
assert instance1.status == InstanceStatus.BUSY
assert instance2.status == InstanceStatus.TERMINATING
Loading