From 33e2bbe72b82fad3e7f96e8b58e5eb39035ae7bb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 7 Oct 2025 14:09:35 +0500 Subject: [PATCH] Consider multinode replica inactive only if all jobs done --- .../background/tasks/process_instances.py | 4 +- .../server/background/tasks/process_runs.py | 50 +++++++++--------- .../background/tasks/test_process_runs.py | 51 ++++++++++++++++++- 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index b44c9271b..ec7ca8f7e 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -259,9 +259,7 @@ async def _add_remote(instance: InstanceModel) -> None: if instance.status == InstanceStatus.PENDING: instance.status = InstanceStatus.PROVISIONING - retry_duration_deadline = instance.created_at.replace( - tzinfo=datetime.timezone.utc - ) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS) + retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS) if retry_duration_deadline < get_current_datetime(): instance.status = InstanceStatus.TERMINATED instance.termination_reason = "Provisioning timeout expired" diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 03017202b..df1cce72f 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -256,8 +256,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): replica_statuses: Set[RunStatus] = set() replica_needs_retry = False - replica_active = True + jobs_done_num = 0 for job_model in job_models: job = find_job(run.jobs, job_model.replica_num, job_model.job_num) if ( @@ -272,8 +272,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): ): # the job is done or going to be done replica_statuses.add(RunStatus.DONE) - # for some reason the replica is done, it's not active - replica_active = False + jobs_done_num += 1 elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN: # the job was scaled down replica_active = False @@ -313,26 +312,14 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): if not replica_needs_retry or retry_single_job: run_statuses.update(replica_statuses) - if replica_active: - # submitted_at = replica created - replicas_info.append( - autoscalers.ReplicaInfo( - active=True, - timestamp=min(job.submitted_at for job in job_models).replace( - tzinfo=datetime.timezone.utc - ), - ) - ) - else: - # last_processed_at = replica scaled down - replicas_info.append( - autoscalers.ReplicaInfo( - active=False, - timestamp=max(job.last_processed_at for job in job_models).replace( - tzinfo=datetime.timezone.utc - ), - ) - ) + if jobs_done_num == len(job_models): + # Consider replica inactive if all its jobs are done for some reason. + # If only some jobs are done, replica is considered active to avoid + # provisioning new replicas for partially done multi-node tasks. + replica_active = False + + replica_info = _get_replica_info(job_models, replica_active) + replicas_info.append(replica_info) termination_reason: Optional[RunTerminationReason] = None if RunStatus.FAILED in run_statuses: @@ -410,6 +397,23 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): run_model.resubmission_attempt += 1 +def _get_replica_info( + replica_job_models: list[JobModel], + replica_active: bool, +) -> autoscalers.ReplicaInfo: + if replica_active: + # submitted_at = replica created + return autoscalers.ReplicaInfo( + active=True, + timestamp=min(job.submitted_at for job in replica_job_models), + ) + # last_processed_at = replica scaled down + return autoscalers.ReplicaInfo( + active=False, + timestamp=max(job.last_processed_at for job in replica_job_models), + ) + + async def _handle_run_replicas( session: AsyncSession, run_model: RunModel, diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py index 8b2d51878..ca1ac060e 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/tasks/test_process_runs.py @@ -373,7 +373,7 @@ async def test_some_failed_to_terminating( session: AsyncSession, job_status: JobStatus, job_termination_reason: JobTerminationReason, - ) -> None: + ): run = await make_run(session, status=RunStatus.RUNNING, replicas=2) await create_job( session=session, @@ -389,6 +389,55 @@ async def test_some_failed_to_terminating( assert run.status == RunStatus.TERMINATING assert run.termination_reason == RunTerminationReason.JOB_FAILED + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_considers_replicas_inactive_only_when_all_jobs_done( + self, + test_db, + session: AsyncSession, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_name = "test-run" + run_spec = get_run_spec( + repo_id=repo.name, + run_name=run_name, + configuration=TaskConfiguration( + commands=["echo hello"], + nodes=2, + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name=run_name, + run_spec=run_spec, + status=RunStatus.RUNNING, + ) + await create_job( + session=session, + run=run, + status=JobStatus.DONE, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + replica_num=0, + job_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + job_num=1, + ) + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.RUNNING + # Should not create new replica with new jobs + assert len(run.jobs) == 2 + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_pending_to_submitted_adds_replicas(self, test_db, session: AsyncSession):