Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions src/dstack/_internal/server/background/tasks/process_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Optional, Union
from uuid import UUID

import requests
from pydantic import parse_raw_as
from sqlalchemy import select
from sqlalchemy.orm import joinedload
Expand All @@ -26,13 +27,18 @@

PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)

TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20)

# Terminate instance if the instance has not started within 10 minutes
STARTING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds


@dataclass
class HealthStatus:
healthy: bool
reason: str

def __str__(self):
def __str__(self) -> str:
return self.reason


Expand Down Expand Up @@ -99,8 +105,8 @@ async def check_shim(instance_id: UUID) -> None:

if health.healthy:
logger.debug("check instance %s status: shim health is OK", instance.name)
instance.fail_count = 0
instance.fail_reason = None
instance.termination_deadline = None
instance.health_status = None

if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING):
instance.status = (
Expand All @@ -110,24 +116,32 @@ async def check_shim(instance_id: UUID) -> None:
else:
logger.debug("check instance %s status: shim health: %s", instance.name, health)

instance.fail_count += 1
instance.fail_reason = health.reason
if instance.termination_deadline is None:
instance.termination_deadline = (
get_current_datetime() + TERMINATION_DEADLINE_OFFSET
)
instance.health_status = health.reason

if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY):
logger.warning(
"instance %s: shim has become unavailable, marked as failed", instance.name
)
FAIL_THRESHOLD = 10 * 6 * 20 # instance_healthcheck fails 20 minutes constantly
if instance.fail_count > FAIL_THRESHOLD:
logger.warning("instance %s shim is not available", instance.name)
deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc)
if get_current_datetime() > deadline:
instance.status = InstanceStatus.TERMINATING
instance.termination_reason = "Termination deadline"
logger.warning("mark instance %s as TERMINATED", instance.name)

if instance.status == InstanceStatus.STARTING and instance.started_at is not None:
STARTING_TIMEOUT = 10 * 60 # 10 minutes
starting_time_threshold = instance.started_at + timedelta(seconds=STARTING_TIMEOUT)
starting_time_threshold = instance.started_at.replace(
tzinfo=datetime.timezone.utc
) + timedelta(seconds=STARTING_TIMEOUT_SECONDS)
expire_starting = starting_time_threshold < get_current_datetime()
if expire_starting:
instance.status = InstanceStatus.TERMINATING
logger.warning(
"The Instance %s can't start in %s seconds. Marked as TERMINATED",
instance.name,
STARTING_TIMEOUT_SECONDS,
)

await session.commit()

Expand All @@ -148,8 +162,13 @@ def instance_healthcheck(*, ports: Dict[int, int]) -> HealthStatus:
healthy=False,
reason=f"Service name is {resp.service}, service version: {resp.version}",
)
except requests.RequestException as e:
return HealthStatus(healthy=False, reason=f"Can't request shim: {e}")
except Exception as e:
return HealthStatus(healthy=False, reason=f"Exception ({e.__class__.__name__}): {e}")
logger.exception("Unknown exception from shim.healthcheck: %s", e)
return HealthStatus(
healthy=False, reason=f"Unknown exception ({e.__class__.__name__}): {e}"
)


async def terminate(instance_id: UUID) -> None:
Expand All @@ -163,11 +182,10 @@ async def terminate(instance_id: UUID) -> None:
).one()

jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data)
BACKEND_TYPE = jpd.backend
backends = await backends_services.get_project_backends(project=instance.project)
backend = next((b for b in backends if b.TYPE == BACKEND_TYPE), None)
backend = next((b for b in backends if b.TYPE == jpd.backend), None)
if backend is None:
raise ValueError(f"there is no backend {BACKEND_TYPE}")
raise ValueError(f"there is no backend {jpd.backend}")

await run_async(
backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data
Expand Down Expand Up @@ -217,6 +235,7 @@ async def terminate_idle_instance() -> None:
instance.deleted_at = get_current_datetime()
instance.finished_at = get_current_datetime()
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Idle timeout"

idle_time = current_time - last_time
logger.info(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Rework termination handling

Revision ID: 1a48dfe44a40
Revises: 9eea6af28e10
Create Date: 2024-02-21 10:11:32.350099

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "1a48dfe44a40"
down_revision = "9eea6af28e10"
branch_labels = None
depends_on = None


def upgrade() -> None:
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("fail_reason")
batch_op.drop_column("fail_count")

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(sa.Column("termination_deadline", sa.DateTime(), nullable=True))
batch_op.add_column(
sa.Column("termination_reason", sa.VARCHAR(length=4000), nullable=True)
)
batch_op.add_column(sa.Column("health_status", sa.VARCHAR(length=4000), nullable=True))


def downgrade() -> None:
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(
sa.Column("fail_count", sa.Integer(), server_default=sa.text("0"), nullable=False)
)
batch_op.add_column(sa.Column("fail_reason", sa.String(length=4000), nullable=True))

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("termination_deadline")
batch_op.drop_column("termination_reason")
batch_op.drop_column("health_status")
10 changes: 5 additions & 5 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
String,
Text,
UniqueConstraint,
text,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.sql import false
Expand Down Expand Up @@ -287,17 +286,18 @@ class InstanceModel(BaseModel):

# VM
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)

# temination policy
termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50))
termination_idle_time: Mapped[int] = mapped_column(
Integer, default=DEFAULT_POOL_TERMINATION_IDLE_TIME
)

# connection fail handling
fail_count: Mapped[int] = mapped_column(Integer, server_default=text("0"))
fail_reason: Mapped[Optional[str]] = mapped_column(String(4000))
# instance termination handling
termination_deadline: Mapped[Optional[datetime]] = mapped_column(DateTime)
termination_reason: Mapped[Optional[str]] = mapped_column(String(4000))
health_status: Mapped[Optional[str]] = mapped_column(String(4000))

# backend
backend: Mapped[BackendType] = mapped_column(Enum(BackendType))
Expand Down
12 changes: 6 additions & 6 deletions src/dstack/_internal/server/routers/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest
from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember
from dstack._internal.server.security.permissions import ProjectMember
from dstack._internal.server.services.runs import (
abort_runs_of_pool,
list_project_runs,
Expand All @@ -33,7 +33,7 @@ async def list_pool(
async def remove_instance(
body: schemas.RemoveInstanceRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> None:
_, project_model = user_project
await pools.remove_instance(
Expand All @@ -45,7 +45,7 @@ async def remove_instance(
async def set_default_pool(
body: schemas.SetDefaultPoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> bool:
_, project_model = user_project
return await pools.set_default_pool(session, project_model, body.pool_name)
Expand All @@ -55,7 +55,7 @@ async def set_default_pool(
async def delete_pool(
body: schemas.DeletePoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> None:
pool_name = body.name
_, project_model = user_project
Expand Down Expand Up @@ -87,7 +87,7 @@ async def delete_pool(
async def create_pool(
body: schemas.CreatePoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> None:
_, project = user_project
await pools.create_pool_model(session=session, project=project, name=body.name)
Expand All @@ -97,7 +97,7 @@ async def create_pool(
async def show_pool(
body: schemas.ShowPoolRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> models.PoolInstances:
_, project = user_project
instances = await pools.show_pool(session, project, pool_name=body.name)
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def add_remote(
name="instance",
resources=instance_resource,
),
region="",
region="", # TODO: add region
price=0.0,
availability=InstanceAvailability.AVAILABLE,
)
Expand All @@ -361,11 +361,14 @@ async def add_remote(
name=instance_name,
project=project,
pool=pool_model,
backend=BackendType.REMOTE,
created_at=common_utils.get_current_datetime(),
started_at=common_utils.get_current_datetime(),
status=InstanceStatus.PENDING,
job_provisioning_data=local.json(),
offer=offer.json(),
region=offer.region,
price=offer.price,
termination_policy=profile.termination_policy,
termination_idle_time=profile.termination_idle_time,
)
Expand Down
Loading