Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,7 @@ async def _add_remote(instance: InstanceModel) -> None:
if instance.status == InstanceStatus.PENDING:
instance.status = InstanceStatus.PROVISIONING

retry_duration_deadline = instance.created_at.replace(
tzinfo=datetime.timezone.utc
) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
if retry_duration_deadline < get_current_datetime():
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Provisioning timeout expired"
Expand Down
50 changes: 27 additions & 23 deletions src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
replica_statuses: Set[RunStatus] = set()
replica_needs_retry = False

replica_active = True
jobs_done_num = 0
for job_model in job_models:
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
if (
Expand All @@ -272,8 +272,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
):
# the job is done or going to be done
replica_statuses.add(RunStatus.DONE)
# for some reason the replica is done, it's not active
replica_active = False
jobs_done_num += 1
elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN:
# the job was scaled down
replica_active = False
Expand Down Expand Up @@ -313,26 +312,14 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
if not replica_needs_retry or retry_single_job:
run_statuses.update(replica_statuses)

if replica_active:
# submitted_at = replica created
replicas_info.append(
autoscalers.ReplicaInfo(
active=True,
timestamp=min(job.submitted_at for job in job_models).replace(
tzinfo=datetime.timezone.utc
),
)
)
else:
# last_processed_at = replica scaled down
replicas_info.append(
autoscalers.ReplicaInfo(
active=False,
timestamp=max(job.last_processed_at for job in job_models).replace(
tzinfo=datetime.timezone.utc
),
)
)
if jobs_done_num == len(job_models):
# Consider replica inactive if all its jobs are done for some reason.
# If only some jobs are done, replica is considered active to avoid
# provisioning new replicas for partially done multi-node tasks.
replica_active = False

replica_info = _get_replica_info(job_models, replica_active)
replicas_info.append(replica_info)

termination_reason: Optional[RunTerminationReason] = None
if RunStatus.FAILED in run_statuses:
Expand Down Expand Up @@ -410,6 +397,23 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
run_model.resubmission_attempt += 1


def _get_replica_info(
replica_job_models: list[JobModel],
replica_active: bool,
) -> autoscalers.ReplicaInfo:
if replica_active:
# submitted_at = replica created
return autoscalers.ReplicaInfo(
active=True,
timestamp=min(job.submitted_at for job in replica_job_models),
)
# last_processed_at = replica scaled down
return autoscalers.ReplicaInfo(
active=False,
timestamp=max(job.last_processed_at for job in replica_job_models),
)


async def _handle_run_replicas(
session: AsyncSession,
run_model: RunModel,
Expand Down
51 changes: 50 additions & 1 deletion src/tests/_internal/server/background/tasks/test_process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def test_some_failed_to_terminating(
session: AsyncSession,
job_status: JobStatus,
job_termination_reason: JobTerminationReason,
) -> None:
):
run = await make_run(session, status=RunStatus.RUNNING, replicas=2)
await create_job(
session=session,
Expand All @@ -389,6 +389,55 @@ async def test_some_failed_to_terminating(
assert run.status == RunStatus.TERMINATING
assert run.termination_reason == RunTerminationReason.JOB_FAILED

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_considers_replicas_inactive_only_when_all_jobs_done(
self,
test_db,
session: AsyncSession,
):
project = await create_project(session=session)
user = await create_user(session=session)
repo = await create_repo(session=session, project_id=project.id)
run_name = "test-run"
run_spec = get_run_spec(
repo_id=repo.name,
run_name=run_name,
configuration=TaskConfiguration(
commands=["echo hello"],
nodes=2,
),
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
run_name=run_name,
run_spec=run_spec,
status=RunStatus.RUNNING,
)
await create_job(
session=session,
run=run,
status=JobStatus.DONE,
termination_reason=JobTerminationReason.DONE_BY_RUNNER,
replica_num=0,
job_num=0,
)
await create_job(
session=session,
run=run,
status=JobStatus.RUNNING,
replica_num=0,
job_num=1,
)
await process_runs.process_runs()
await session.refresh(run)
assert run.status == RunStatus.RUNNING
# Should not create new replica with new jobs
assert len(run.jobs) == 2

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_pending_to_submitted_adds_replicas(self, test_db, session: AsyncSession):
Expand Down