diff --git a/docs/docs/concepts/fleets.md b/docs/docs/concepts/fleets.md
index 76e99a24c..406763d18 100644
--- a/docs/docs/concepts/fleets.md
+++ b/docs/docs/concepts/fleets.md
@@ -59,6 +59,27 @@ Once the status of instances changes to `idle`, they can be used by dev environm
### Configuration options
+#### Nodes { #nodes }
+
+The `nodes` property controls how many instances to provision and maintain in the fleet:
+
+
+
+```yaml
+type: fleet
+
+name: my-fleet
+
+nodes:
+ min: 1 # Always maintain at least 1 instance
+ target: 2 # Provision 2 instances initially
+ max: 3 # Do not allow more than 3 instances
+```
+
+
+
+`dstack` ensures the fleet always has at least `nodes.min` instances, creating new instances in the background if necessary. If you don't need to keep instances in the fleet forever, you can set `nodes.min` to `0`. By default, `dstack apply` also provisions `nodes.min` instances. The `nodes.target` property allows provisioning more instances initially than needs to be maintained.
+
#### Placement { #cloud-placement }
To ensure instances are interconnected (e.g., for
diff --git a/src/dstack/_internal/core/compatibility/fleets.py b/src/dstack/_internal/core/compatibility/fleets.py
index a8c8bbebf..2ac640b27 100644
--- a/src/dstack/_internal/core/compatibility/fleets.py
+++ b/src/dstack/_internal/core/compatibility/fleets.py
@@ -59,6 +59,11 @@ def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[IncludeExcludeDic
profile_excludes.add("stop_criteria")
if profile.schedule is None:
profile_excludes.add("schedule")
+ if (
+ fleet_spec.configuration.nodes
+ and fleet_spec.configuration.nodes.min == fleet_spec.configuration.nodes.target
+ ):
+ configuration_excludes["nodes"] = {"target"}
if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py
index 357f9b5b0..7d7d31aa8 100644
--- a/src/dstack/_internal/core/models/fleets.py
+++ b/src/dstack/_internal/core/models/fleets.py
@@ -19,7 +19,7 @@
TerminationPolicy,
parse_idle_duration,
)
-from dstack._internal.core.models.resources import Range, ResourcesSpec
+from dstack._internal.core.models.resources import ResourcesSpec
from dstack._internal.utils.common import list_enum_values_for_annotation
from dstack._internal.utils.json_schema import add_extra_schema_types
from dstack._internal.utils.tags import tags_validator
@@ -141,6 +141,58 @@ def validate_network(cls, value):
return value
+class FleetNodesSpec(CoreModel):
+ min: Annotated[
+ int, Field(description=("The minimum number of instances to maintain in the fleet"))
+ ]
+ target: Annotated[
+ int,
+ Field(
+ description=(
+ "The number of instances to provision on fleet apply. `min` <= `target` <= `max`"
+ " Defaults to `min`"
+ )
+ ),
+ ]
+ max: Annotated[
+ Optional[int],
+ Field(
+ description=(
+ "The maximum number of instances allowed in the fleet. Unlimited if not specified"
+ )
+ ),
+ ] = None
+
+ @root_validator(pre=True)
+ def set_min_and_target_defaults(cls, values):
+ min_ = values.get("min")
+ target = values.get("target")
+ if min_ is None:
+ values["min"] = 0
+ if target is None:
+ values["target"] = values["min"]
+ return values
+
+ @validator("min")
+ def validate_min(cls, v: int) -> int:
+ if v < 0:
+ raise ValueError("min cannot be negative")
+ return v
+
+ @root_validator(skip_on_failure=True)
+ def _post_validate_ranges(cls, values):
+ min_ = values["min"]
+ target = values["target"]
+ max_ = values.get("max")
+ if target < min_:
+ raise ValueError("target must not be be less than min")
+ if max_ is not None and max_ < min_:
+ raise ValueError("max must not be less than min")
+ if max_ is not None and max_ < target:
+ raise ValueError("max must not be less than target")
+ return values
+
+
class InstanceGroupParams(CoreModel):
env: Annotated[
Env,
@@ -151,7 +203,9 @@ class InstanceGroupParams(CoreModel):
Field(description="The parameters for adding instances via SSH"),
] = None
- nodes: Annotated[Optional[Range[int]], Field(description="The number of instances")] = None
+ nodes: Annotated[
+ Optional[FleetNodesSpec], Field(description="The number of instances in cloud fleet")
+ ] = None
placement: Annotated[
Optional[InstanceGroupPlacement],
Field(description="The placement of instances: `any` or `cluster`"),
@@ -248,6 +302,16 @@ def schema_extra(schema: Dict[str, Any], model: Type):
extra_types=[{"type": "string"}],
)
+ @validator("nodes", pre=True)
+ def parse_nodes(cls, v: Optional[Union[dict, str]]) -> Optional[dict]:
+ if isinstance(v, str) and ".." in v:
+ v = v.replace(" ", "")
+ min, max = v.split("..")
+ return dict(min=min or None, max=max or None)
+ elif isinstance(v, str) or isinstance(v, int):
+ return dict(min=v, max=v)
+ return v
+
_validate_idle_duration = validator("idle_duration", pre=True, allow_reuse=True)(
parse_idle_duration
)
diff --git a/src/dstack/_internal/server/background/tasks/process_fleets.py b/src/dstack/_internal/server/background/tasks/process_fleets.py
index 4ce819e59..2536902a6 100644
--- a/src/dstack/_internal/server/background/tasks/process_fleets.py
+++ b/src/dstack/_internal/server/background/tasks/process_fleets.py
@@ -1,11 +1,13 @@
from datetime import timedelta
from typing import List
+from uuid import UUID
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, load_only
-from dstack._internal.core.models.fleets import FleetStatus
+from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
+from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
FleetModel,
@@ -15,7 +17,9 @@
RunModel,
)
from dstack._internal.server.services.fleets import (
+ create_fleet_instance_model,
get_fleet_spec,
+ get_next_instance_num,
is_fleet_empty,
is_fleet_in_use,
)
@@ -65,31 +69,111 @@ async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel])
res = await session.execute(
select(FleetModel)
.where(FleetModel.id.in_(fleet_ids))
- .options(joinedload(FleetModel.instances).load_only(InstanceModel.deleted))
.options(
- joinedload(FleetModel.instances).joinedload(InstanceModel.jobs).load_only(JobModel.id)
+ joinedload(FleetModel.instances).joinedload(InstanceModel.jobs).load_only(JobModel.id),
+ joinedload(FleetModel.project),
)
.options(joinedload(FleetModel.runs).load_only(RunModel.status))
.execution_options(populate_existing=True)
)
fleet_models = list(res.unique().scalars().all())
+ # TODO: Drop fleets auto-deletion after dropping fleets auto-creation.
deleted_fleets_ids = []
- now = get_current_datetime()
for fleet_model in fleet_models:
+ _consolidate_fleet_state_with_spec(session, fleet_model)
deleted = _autodelete_fleet(fleet_model)
if deleted:
deleted_fleets_ids.append(fleet_model.id)
- fleet_model.last_processed_at = now
+ fleet_model.last_processed_at = get_current_datetime()
+ await _update_deleted_fleets_placement_groups(session, deleted_fleets_ids)
+ await session.commit()
- await session.execute(
- update(PlacementGroupModel)
- .where(
- PlacementGroupModel.fleet_id.in_(deleted_fleets_ids),
+
+def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: FleetModel):
+ if fleet_model.status == FleetStatus.TERMINATING:
+ return
+ fleet_spec = get_fleet_spec(fleet_model)
+ if fleet_spec.configuration.nodes is None or fleet_spec.autocreated:
+ # Only explicitly created cloud fleets are consolidated.
+ return
+ if not _is_fleet_ready_for_consolidation(fleet_model):
+ return
+ added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec)
+ if added_instances:
+ fleet_model.consolidation_attempt += 1
+ else:
+ # The fleet is already consolidated or consolidation is in progress.
+ # We reset consolidation_attempt in both cases for simplicity.
+ # The second case does not need reset but is ok to do since
+ # it means consolidation is longer than delay, so it won't happen too often.
+ # TODO: Reset consolidation_attempt on fleet in-place update.
+ fleet_model.consolidation_attempt = 0
+ fleet_model.last_consolidated_at = get_current_datetime()
+
+
+def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool:
+ consolidation_retry_delay = _get_consolidation_retry_delay(fleet_model.consolidation_attempt)
+ last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.last_processed_at
+ duration_since_last_consolidation = get_current_datetime() - last_consolidated_at
+ return duration_since_last_consolidation >= consolidation_retry_delay
+
+
+# We use exponentially increasing consolidation retry delays so that
+# consolidation does not happen too often. In particular, this prevents
+# retrying instance provisioning constantly in case of no offers.
+# TODO: Adjust delays.
+_CONSOLIDATION_RETRY_DELAYS = [
+ timedelta(seconds=30),
+ timedelta(minutes=1),
+ timedelta(minutes=2),
+ timedelta(minutes=5),
+ timedelta(minutes=10),
+]
+
+
+def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
+ if consolidation_attempt < len(_CONSOLIDATION_RETRY_DELAYS):
+ return _CONSOLIDATION_RETRY_DELAYS[consolidation_attempt]
+ return _CONSOLIDATION_RETRY_DELAYS[-1]
+
+
+def _maintain_fleet_nodes_min(
+ session: AsyncSession,
+ fleet_model: FleetModel,
+ fleet_spec: FleetSpec,
+) -> bool:
+ """
+ Ensures the fleet has at least `nodes.min` instances.
+ Returns `True` if retried or added new instances and `False` otherwise.
+ """
+ assert fleet_spec.configuration.nodes is not None
+ for instance in fleet_model.instances:
+ # Delete terminated but not deleted instances since
+ # they are going to be replaced with new pending instances.
+ if instance.status == InstanceStatus.TERMINATED and not instance.deleted:
+ # It's safe to modify instances without instance lock since
+ # no other task modifies already terminated instances.
+ instance.deleted = True
+ instance.deleted_at = get_current_datetime()
+ active_instances = [i for i in fleet_model.instances if not i.deleted]
+ active_instances_num = len(active_instances)
+ if active_instances_num >= fleet_spec.configuration.nodes.min:
+ return False
+ nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
+ for i in range(nodes_missing):
+ instance_model = create_fleet_instance_model(
+ session=session,
+ project=fleet_model.project,
+ # TODO: Store fleet.user and pass it instead of the project owner.
+ username=fleet_model.project.owner.name,
+ spec=fleet_spec,
+ instance_num=get_next_instance_num({i.instance_num for i in active_instances}),
)
- .values(fleet_deleted=True)
- )
- await session.commit()
+ active_instances.append(instance_model)
+ fleet_model.instances.append(instance_model)
+ logger.info("Added %s instances to fleet %s", nodes_missing, fleet_model.name)
+ return True
def _autodelete_fleet(fleet_model: FleetModel) -> bool:
@@ -100,7 +184,7 @@ def _autodelete_fleet(fleet_model: FleetModel) -> bool:
if (
fleet_model.status != FleetStatus.TERMINATING
and fleet_spec.configuration.nodes is not None
- and (fleet_spec.configuration.nodes.min is None or fleet_spec.configuration.nodes.min == 0)
+ and fleet_spec.configuration.nodes.min == 0
):
# Empty fleets that allow 0 nodes should not be auto-deleted
return False
@@ -110,3 +194,13 @@ def _autodelete_fleet(fleet_model: FleetModel) -> bool:
fleet_model.deleted = True
logger.info("Fleet %s deleted", fleet_model.name)
return True
+
+
+async def _update_deleted_fleets_placement_groups(session: AsyncSession, fleets_ids: list[UUID]):
+ await session.execute(
+ update(PlacementGroupModel)
+ .where(
+ PlacementGroupModel.fleet_id.in_(fleets_ids),
+ )
+ .values(fleet_deleted=True)
+ )
diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py
index 849c7dc54..4694141b3 100644
--- a/src/dstack/_internal/server/background/tasks/process_instances.py
+++ b/src/dstack/_internal/server/background/tasks/process_instances.py
@@ -53,14 +53,12 @@
PlacementStrategy,
)
from dstack._internal.core.models.profiles import (
- RetryEvent,
TerminationPolicy,
)
from dstack._internal.core.models.runs import (
JobProvisioningData,
Retry,
)
-from dstack._internal.core.services.profiles import get_retry
from dstack._internal.server import settings as server_settings
from dstack._internal.server.background.tasks.common import get_provisioning_timeout
from dstack._internal.server.db import get_db, get_session_ctx
@@ -327,7 +325,6 @@ async def _add_remote(instance: InstanceModel) -> None:
e,
)
instance.status = InstanceStatus.PENDING
- instance.last_retry_at = get_current_datetime()
return
instance_type = host_info_to_instance_type(host_info, cpu_arch)
@@ -426,7 +423,6 @@ async def _add_remote(instance: InstanceModel) -> None:
instance.offer = instance_offer.json()
instance.job_provisioning_data = jpd.json()
instance.started_at = get_current_datetime()
- instance.last_retry_at = get_current_datetime()
def _deploy_instance(
@@ -493,29 +489,6 @@ def _deploy_instance(
async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
- if instance.last_retry_at is not None:
- last_retry = instance.last_retry_at
- if get_current_datetime() < last_retry + timedelta(minutes=1):
- return
-
- if (
- instance.profile is None
- or instance.requirements is None
- or instance.instance_configuration is None
- ):
- instance.status = InstanceStatus.TERMINATED
- instance.termination_reason = "Empty profile, requirements or instance_configuration"
- instance.last_retry_at = get_current_datetime()
- logger.warning(
- "Empty profile, requirements or instance_configuration. Terminate instance: %s",
- instance.name,
- extra={
- "instance_name": instance.name,
- "instance_status": InstanceStatus.TERMINATED.value,
- },
- )
- return
-
if _need_to_wait_fleet_provisioning(instance):
logger.debug("Waiting for the first instance in the fleet to be provisioned")
return
@@ -529,7 +502,6 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
instance.termination_reason = (
f"Error to parse profile, requirements or instance_configuration: {e}"
)
- instance.last_retry_at = get_current_datetime()
logger.warning(
"Error to parse profile, requirements or instance_configuration. Terminate instance: %s",
instance.name,
@@ -540,24 +512,6 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
)
return
- retry = get_retry(profile)
- should_retry = retry is not None and RetryEvent.NO_CAPACITY in retry.on_events
-
- if retry is not None:
- retry_duration_deadline = _get_retry_duration_deadline(instance, retry)
- if get_current_datetime() > retry_duration_deadline:
- instance.status = InstanceStatus.TERMINATED
- instance.termination_reason = "Retry duration expired"
- logger.warning(
- "Retry duration expired. Terminating instance %s",
- instance.name,
- extra={
- "instance_name": instance.name,
- "instance_status": InstanceStatus.TERMINATED.value,
- },
- )
- return
-
placement_group_models = []
placement_group_model = None
if instance.fleet_id:
@@ -595,15 +549,6 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
exclude_not_available=True,
)
- if not offers and should_retry:
- instance.last_retry_at = get_current_datetime()
- logger.debug(
- "No offers for instance %s. Next retry",
- instance.name,
- extra={"instance_name": instance.name},
- )
- return
-
# Limit number of offers tried to prevent long-running processing
# in case all offers fail.
for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]:
@@ -681,7 +626,6 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
instance.offer = instance_offer.json()
instance.total_blocks = instance_offer.total_blocks
instance.started_at = get_current_datetime()
- instance.last_retry_at = get_current_datetime()
logger.info(
"Created instance %s",
@@ -702,21 +646,18 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
)
return
- instance.last_retry_at = get_current_datetime()
-
- if not should_retry:
- _mark_terminated(instance, "All offers failed" if offers else "No offers found")
- if (
- instance.fleet
- and _is_fleet_master_instance(instance)
- and _is_cloud_cluster(instance.fleet)
- ):
- # Do not attempt to deploy other instances, as they won't determine the correct cluster
- # backend, region, and placement group without a successfully deployed master instance
- for sibling_instance in instance.fleet.instances:
- if sibling_instance.id == instance.id:
- continue
- _mark_terminated(sibling_instance, "Master instance failed to start")
+ _mark_terminated(instance, "All offers failed" if offers else "No offers found")
+ if (
+ instance.fleet
+ and _is_fleet_master_instance(instance)
+ and _is_cloud_cluster(instance.fleet)
+ ):
+ # Do not attempt to deploy other instances, as they won't determine the correct cluster
+ # backend, region, and placement group without a successfully deployed master instance
+ for sibling_instance in instance.fleet.instances:
+ if sibling_instance.id == instance.id:
+ continue
+ _mark_terminated(sibling_instance, "Master instance failed to start")
def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
index 6d93eb523..21b699a0b 100644
--- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
@@ -5,7 +5,7 @@
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
-from sqlalchemy import and_, func, not_, or_, select
+from sqlalchemy import and_, not_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, load_only, noload, selectinload
@@ -16,6 +16,7 @@
from dstack._internal.core.models.fleets import (
Fleet,
FleetConfiguration,
+ FleetNodesSpec,
FleetSpec,
FleetStatus,
InstanceGroupPlacement,
@@ -26,7 +27,7 @@
CreationPolicy,
TerminationPolicy,
)
-from dstack._internal.core.models.resources import Memory, Range
+from dstack._internal.core.models.resources import Memory
from dstack._internal.core.models.runs import (
Job,
JobProvisioningData,
@@ -54,6 +55,7 @@
from dstack._internal.server.services.fleets import (
fleet_model_to_fleet,
get_fleet_requirements,
+ get_next_instance_num,
)
from dstack._internal.server.services.instances import (
filter_pool_instances,
@@ -384,6 +386,8 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
instance_num=instance_num,
)
job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
+ # Both this task and process_fleets can add instances to fleets.
+ # TODO: Ensure this does not violate nodes.max when it's enforced.
instance.fleet_id = fleet_model.id
logger.info(
"The job %s created the new instance %s",
@@ -755,12 +759,17 @@ def _create_fleet_model_for_job(
placement = InstanceGroupPlacement.ANY
if run.run_spec.configuration.type == "task" and run.run_spec.configuration.nodes > 1:
placement = InstanceGroupPlacement.CLUSTER
+ nodes = _get_nodes_required_num_for_run(run.run_spec)
spec = FleetSpec(
configuration=FleetConfiguration(
name=run.run_spec.run_name,
placement=placement,
reservation=run.run_spec.configuration.reservation,
- nodes=Range(min=_get_nodes_required_num_for_run(run.run_spec), max=None),
+ nodes=FleetNodesSpec(
+ min=nodes,
+ target=nodes,
+ max=None,
+ ),
),
profile=run.run_spec.merged_profile,
autocreated=True,
@@ -778,10 +787,13 @@ def _create_fleet_model_for_job(
async def _get_next_instance_num(session: AsyncSession, fleet_model: FleetModel) -> int:
res = await session.execute(
- select(func.count(InstanceModel.id)).where(InstanceModel.fleet_id == fleet_model.id)
+ select(InstanceModel.instance_num).where(
+ InstanceModel.fleet_id == fleet_model.id,
+ InstanceModel.deleted.is_(False),
+ )
)
- instance_count = res.scalar_one()
- return instance_count
+ taken_instance_nums = set(res.scalars().all())
+ return get_next_instance_num(taken_instance_nums)
def _create_instance_model_for_job(
diff --git a/src/dstack/_internal/server/migrations/versions/2498ab323443_add_fleetmodel_consolidation_attempt_.py b/src/dstack/_internal/server/migrations/versions/2498ab323443_add_fleetmodel_consolidation_attempt_.py
new file mode 100644
index 000000000..534dacaba
--- /dev/null
+++ b/src/dstack/_internal/server/migrations/versions/2498ab323443_add_fleetmodel_consolidation_attempt_.py
@@ -0,0 +1,44 @@
+"""Add FleetModel.consolidation_attempt and FleetModel.last_consolidated_at
+
+Revision ID: 2498ab323443
+Revises: e2d08cd1b8d9
+Create Date: 2025-08-29 16:08:48.686595
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+import dstack._internal.server.models
+
+# revision identifiers, used by Alembic.
+revision = "2498ab323443"
+down_revision = "e2d08cd1b8d9"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("fleets", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column("consolidation_attempt", sa.Integer(), server_default="0", nullable=False)
+ )
+ batch_op.add_column(
+ sa.Column(
+ "last_consolidated_at",
+ dstack._internal.server.models.NaiveDateTime(),
+ nullable=True,
+ )
+ )
+
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("fleets", schema=None) as batch_op:
+ batch_op.drop_column("last_consolidated_at")
+ batch_op.drop_column("consolidation_attempt")
+
+ # ### end Alembic commands ###
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index cd8873e73..1243a4ae1 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -551,6 +551,9 @@ class FleetModel(BaseModel):
jobs: Mapped[List["JobModel"]] = relationship(back_populates="fleet")
instances: Mapped[List["InstanceModel"]] = relationship(back_populates="fleet")
+ consolidation_attempt: Mapped[int] = mapped_column(Integer, server_default="0")
+ last_consolidated_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
+
class InstanceModel(BaseModel):
__tablename__ = "instances"
@@ -605,8 +608,8 @@ class InstanceModel(BaseModel):
Integer, default=DEFAULT_FLEET_TERMINATION_IDLE_TIME
)
- # retry policy
- last_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
+ # Deprecated
+ last_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime, deferred=True)
# instance termination handling
termination_deadline: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py
index 33257d3a9..277ed41b3 100644
--- a/src/dstack/_internal/server/services/fleets.py
+++ b/src/dstack/_internal/server/services/fleets.py
@@ -449,25 +449,24 @@ async def create_fleet(
return await _create_fleet(session=session, project=project, user=user, spec=spec)
-async def create_fleet_instance_model(
+def create_fleet_instance_model(
session: AsyncSession,
project: ProjectModel,
- user: UserModel,
+ username: str,
spec: FleetSpec,
- reservation: Optional[str],
instance_num: int,
) -> InstanceModel:
profile = spec.merged_profile
requirements = get_fleet_requirements(spec)
- instance_model = await instances_services.create_instance_model(
+ instance_model = instances_services.create_instance_model(
session=session,
project=project,
- user=user,
+ username=username,
profile=profile,
requirements=requirements,
instance_name=f"{spec.configuration.name}-{instance_num}",
instance_num=instance_num,
- reservation=reservation,
+ reservation=spec.merged_profile.reservation,
blocks=spec.configuration.blocks,
tags=spec.configuration.tags,
)
@@ -655,6 +654,19 @@ def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
return requirements
+def get_next_instance_num(taken_instance_nums: set[int]) -> int:
+ if not taken_instance_nums:
+ return 0
+ min_instance_num = min(taken_instance_nums)
+ if min_instance_num > 0:
+ return 0
+ instance_num = min_instance_num + 1
+ while True:
+ if instance_num not in taken_instance_nums:
+ return instance_num
+ instance_num += 1
+
+
async def _create_fleet(
session: AsyncSession,
project: ProjectModel,
@@ -705,12 +717,11 @@ async def _create_fleet(
fleet_model.instances.append(instances_model)
else:
for i in range(_get_fleet_nodes_to_provision(spec)):
- instance_model = await create_fleet_instance_model(
+ instance_model = create_fleet_instance_model(
session=session,
project=project,
- user=user,
+ username=user.name,
spec=spec,
- reservation=spec.configuration.reservation,
instance_num=i,
)
fleet_model.instances.append(instance_model)
@@ -778,7 +789,7 @@ async def _update_fleet(
if added_hosts:
await _check_ssh_hosts_not_yet_added(session, spec, fleet.id)
for host in added_hosts.values():
- instance_num = _get_next_instance_num(active_instance_nums)
+ instance_num = get_next_instance_num(active_instance_nums)
instance_model = await create_fleet_ssh_instance_model(
project=project,
spec=spec,
@@ -994,9 +1005,9 @@ def _validate_internal_ips(ssh_config: SSHParams):
def _get_fleet_nodes_to_provision(spec: FleetSpec) -> int:
- if spec.configuration.nodes is None or spec.configuration.nodes.min is None:
+ if spec.configuration.nodes is None:
return 0
- return spec.configuration.nodes.min
+ return spec.configuration.nodes.target
def _terminate_fleet_instances(fleet_model: FleetModel, instance_nums: Optional[List[int]]):
@@ -1013,16 +1024,3 @@ def _terminate_fleet_instances(fleet_model: FleetModel, instance_nums: Optional[
instance.deleted = True
else:
instance.status = InstanceStatus.TERMINATING
-
-
-def _get_next_instance_num(instance_nums: set[int]) -> int:
- if not instance_nums:
- return 0
- min_instance_num = min(instance_nums)
- if min_instance_num > 0:
- return 0
- instance_num = min_instance_num + 1
- while True:
- if instance_num not in instance_nums:
- return instance_num
- instance_num += 1
diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py
index 79fadc2b9..7c679b0cc 100644
--- a/src/dstack/_internal/server/services/instances.py
+++ b/src/dstack/_internal/server/services/instances.py
@@ -513,10 +513,10 @@ async def list_active_remote_instances(
return instance_models
-async def create_instance_model(
+def create_instance_model(
session: AsyncSession,
project: ProjectModel,
- user: UserModel,
+ username: str,
profile: Profile,
requirements: Requirements,
instance_name: str,
@@ -536,7 +536,7 @@ async def create_instance_model(
instance_config = InstanceConfiguration(
project_name=project.name,
instance_name=instance_name,
- user=user.name,
+ user=username,
ssh_keys=[project_ssh_key],
instance_id=str(instance_id),
reservation=reservation,
diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py
index c99733bb3..4a40bb9eb 100644
--- a/src/dstack/_internal/server/testing/common.py
+++ b/src/dstack/_internal/server/testing/common.py
@@ -28,6 +28,7 @@
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.fleets import (
FleetConfiguration,
+ FleetNodesSpec,
FleetSpec,
FleetStatus,
InstanceGroupPlacement,
@@ -60,7 +61,7 @@
)
from dstack._internal.core.models.repos.base import RepoType
from dstack._internal.core.models.repos.local import LocalRunRepoData
-from dstack._internal.core.models.resources import CPUSpec, Memory, Range, ResourcesSpec
+from dstack._internal.core.models.resources import CPUSpec, Memory, ResourcesSpec
from dstack._internal.core.models.runs import (
JobProvisioningData,
JobRuntimeData,
@@ -579,7 +580,7 @@ def get_fleet_spec(conf: Optional[FleetConfiguration] = None) -> FleetSpec:
def get_fleet_configuration(
name: str = "test-fleet",
- nodes: Range[int] = Range(min=1, max=1),
+ nodes: FleetNodesSpec = FleetNodesSpec(min=1, target=1, max=1),
placement: Optional[InstanceGroupPlacement] = None,
) -> FleetConfiguration:
return FleetConfiguration(
diff --git a/src/tests/_internal/core/models/test_fleets.py b/src/tests/_internal/core/models/test_fleets.py
new file mode 100644
index 000000000..a9214f7ec
--- /dev/null
+++ b/src/tests/_internal/core/models/test_fleets.py
@@ -0,0 +1,122 @@
+from typing import Any
+
+import pytest
+from pydantic import ValidationError
+
+from dstack._internal.core.models.fleets import FleetConfiguration, FleetNodesSpec
+
+
+class TestFleetConfiguration:
+ @pytest.mark.parametrize(
+ ["input_nodes", "expected_nodes"],
+ [
+ pytest.param(
+ 1,
+ FleetNodesSpec(
+ min=1,
+ target=1,
+ max=1,
+ ),
+ id="int",
+ ),
+ pytest.param(
+ "1..2",
+ FleetNodesSpec(
+ min=1,
+ target=1,
+ max=2,
+ ),
+ id="closed-range",
+ ),
+ pytest.param(
+ "..2",
+ FleetNodesSpec(
+ min=0,
+ target=0,
+ max=2,
+ ),
+ id="range-without-min",
+ ),
+ pytest.param(
+ "1..",
+ FleetNodesSpec(
+ min=1,
+ target=1,
+ max=None,
+ ),
+ id="range-without-max",
+ ),
+ pytest.param(
+ {
+ "min": 1,
+ "max": 2,
+ },
+ FleetNodesSpec(
+ min=1,
+ target=1,
+ max=2,
+ ),
+ id="dict-without-target",
+ ),
+ pytest.param(
+ {
+ "min": 1,
+ "target": 2,
+ "max": 3,
+ },
+ FleetNodesSpec(
+ min=1,
+ target=2,
+ max=3,
+ ),
+ id="dict-with-all-attributes",
+ ),
+ pytest.param(
+ {
+ "target": 2,
+ "max": 3,
+ },
+ FleetNodesSpec(
+ min=0,
+ target=2,
+ max=3,
+ ),
+ id="dict-without-min",
+ ),
+ pytest.param(
+ {},
+ FleetNodesSpec(
+ min=0,
+ target=0,
+ max=None,
+ ),
+ id="dict-empty",
+ ),
+ ],
+ )
+ def test_parses_nodes(self, input_nodes: Any, expected_nodes: FleetNodesSpec):
+ configuration_input = {
+ "type": "fleet",
+ "nodes": input_nodes,
+ }
+ configuration = FleetConfiguration.parse_obj(configuration_input)
+ assert configuration.nodes == expected_nodes
+
+ @pytest.mark.parametrize(
+ ["input_nodes"],
+ [
+ pytest.param("2..1", id="min-gt-max"),
+ pytest.param({"min": -1}, id="negative-min"),
+ pytest.param({"target": -1}, id="negative-target"),
+ pytest.param({"target": 2, "max": 1}, id="target-gt-max"),
+ pytest.param({"min": 2, "max": 1}, id="min-gt-max"),
+ pytest.param({"min": 2, "target": 1}, id="min-gt-target"),
+ ],
+ )
+ def test_rejects_nodes(self, input_nodes: Any):
+ configuration_input = {
+ "type": "fleet",
+ "nodes": input_nodes,
+ }
+ with pytest.raises(ValidationError):
+ FleetConfiguration.parse_obj(configuration_input)
diff --git a/src/tests/_internal/server/background/tasks/test_process_fleets.py b/src/tests/_internal/server/background/tasks/test_process_fleets.py
index c128cb77b..2d47da6a7 100644
--- a/src/tests/_internal/server/background/tasks/test_process_fleets.py
+++ b/src/tests/_internal/server/background/tasks/test_process_fleets.py
@@ -1,11 +1,13 @@
import pytest
+from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from dstack._internal.core.models.fleets import FleetStatus
+from dstack._internal.core.models.fleets import FleetNodesSpec, FleetStatus
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.runs import RunStatus
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.background.tasks.process_fleets import process_fleets
+from dstack._internal.server.models import InstanceModel
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
create_fleet,
@@ -101,3 +103,26 @@ async def test_does_not_delete_fleet_with_instance(self, test_db, session: Async
await process_fleets()
await session.refresh(fleet)
assert not fleet.deleted
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_consolidation_creates_missing_instances(self, test_db, session: AsyncSession):
+ project = await create_project(session)
+ spec = get_fleet_spec()
+ spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2)
+ fleet = await create_fleet(
+ session=session,
+ project=project,
+ spec=spec,
+ )
+ await create_instance(
+ session=session,
+ project=project,
+ fleet=fleet,
+ status=InstanceStatus.IDLE,
+ instance_num=1,
+ )
+ await process_fleets()
+ instances = (await session.execute(select(InstanceModel))).scalars().all()
+ assert len(instances) == 2
+ assert {i.instance_num for i in instances} == {0, 1} # uses 0 for next instance num
diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py
index b64cf2c56..f3f7df124 100644
--- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py
+++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py
@@ -8,13 +8,13 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.configurations import TaskConfiguration
+from dstack._internal.core.models.fleets import FleetNodesSpec
from dstack._internal.core.models.health import HealthStatus
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceStatus,
)
from dstack._internal.core.models.profiles import Profile
-from dstack._internal.core.models.resources import Range
from dstack._internal.core.models.runs import (
JobStatus,
JobTerminationReason,
@@ -656,7 +656,7 @@ async def test_creates_new_instance_in_existing_non_empty_fleet(
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
- fleet_spec.configuration.nodes = Range(min=1, max=2)
+ fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=2)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec)
instance = await create_instance(
session=session,
@@ -746,7 +746,7 @@ async def test_does_not_assign_job_to_elastic_empty_fleet_if_fleets_unspecified(
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
- fleet_spec.configuration.nodes = Range(min=0, max=1)
+ fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=1)
await create_fleet(session=session, project=project, spec=fleet_spec, name="fleet")
# Need a second non-empty fleet to have two-stage processing
fleet2 = await create_fleet(
@@ -786,7 +786,7 @@ async def test_assigns_job_to_elastic_empty_fleet_if_fleets_specified(
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
- fleet_spec.configuration.nodes = Range(min=0, max=1)
+ fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=1)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec, name="fleet")
run_spec = get_run_spec(repo_id=repo.name)
run_spec.configuration.fleets = [fleet.name]
@@ -817,7 +817,7 @@ async def test_assigns_job_to_elastic_non_empty_busy_fleet_if_fleets_specified(
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
- fleet_spec.configuration.nodes = Range(min=1, max=2)
+ fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=2)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec, name="fleet")
await create_instance(
session=session,
@@ -857,7 +857,7 @@ async def test_creates_new_instance_in_existing_empty_fleet(
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
- fleet_spec.configuration.nodes = Range(min=0, max=1)
+ fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=1)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec)
run = await create_run(
session=session,
diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py
index 33fc73e01..e4b5192cf 100644
--- a/src/tests/_internal/server/routers/test_fleets.py
+++ b/src/tests/_internal/server/routers/test_fleets.py
@@ -336,7 +336,7 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
"spec": {
"configuration_path": spec.configuration_path,
"configuration": {
- "nodes": {"min": 1, "max": 1},
+ "nodes": {"min": 1, "target": 1, "max": 1},
"placement": None,
"env": {},
"ssh_config": None,