Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/dstack/_internal/cli/commands/pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse
import datetime
import time
from pathlib import Path
from typing import Sequence

from rich.console import Group
from rich.live import Live
from rich.table import Table

from dstack._internal.cli.commands import APIBaseCommand
Expand Down Expand Up @@ -32,6 +35,9 @@
from dstack.api._public.resources import Resources
from dstack.api.utils import load_profile

REFRESH_RATE_PER_SEC = 5
LIVE_PROVISION_INTERVAL_SECS = 10

logger = get_logger(__name__)


Expand Down Expand Up @@ -77,18 +83,24 @@ def _register(self) -> None:
delete_parser.set_defaults(subfunc=self._delete)

# show pool instances
show_parser = subparsers.add_parser(
"show",
ps_parser = subparsers.add_parser(
"ps",
help="Show pool instances",
description="Show instances in the pool",
formatter_class=self._parser.formatter_class,
)
show_parser.add_argument(
ps_parser.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
show_parser.set_defaults(subfunc=self._show)
ps_parser.add_argument(
"-w",
"--watch",
help="Watch instances in realtime",
action="store_true",
)
ps_parser.set_defaults(subfunc=self._ps)

# add instance
add_parser = subparsers.add_parser(
Expand Down Expand Up @@ -196,10 +208,26 @@ def _set_default(self, args: argparse.Namespace) -> None:
if not result:
console.print(f"Failed to set default pool {args.pool_name!r}", style="error")

def _show(self, args: argparse.Namespace) -> None:
resp = self.api.client.pool.show(self.api.project, args.pool_name)
console.print(f" [bold]Pool name[/] {resp.name}\n")
print_instance_table(resp.instances)
def _ps(self, args: argparse.Namespace) -> None:
pool_name_template = " [bold]Pool name[/] {}\n"
if not args.watch:
resp = self.api.client.pool.show(self.api.project, args.pool_name)
console.print(pool_name_template.format(resp.name))
console.print(print_instance_table(resp.instances))
console.print()
return

try:
with Live(console=console, refresh_per_second=REFRESH_RATE_PER_SEC) as live:
while True:
resp = self.api.client.pool.show(self.api.project, args.pool_name)
group = Group(
pool_name_template.format(resp.name), print_instance_table(resp.instances)
)
live.update(group)
time.sleep(LIVE_PROVISION_INTERVAL_SECS)
except KeyboardInterrupt:
pass

def _add(self, args: argparse.Namespace) -> None:
super()._command(args)
Expand Down Expand Up @@ -287,12 +315,12 @@ def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None:
console.print()


def print_instance_table(instances: Sequence[Instance]) -> None:
def print_instance_table(instances: Sequence[Instance]) -> Table:
table = Table(box=None)
table.add_column("INSTANCE NAME")
table.add_column("INSTANCE")
table.add_column("BACKEND")
table.add_column("REGION")
table.add_column("INSTANCE TYPE")
table.add_column("RESOURCES")
table.add_column("SPOT")
table.add_column("STATUS")
table.add_column("PRICE")
Expand All @@ -317,8 +345,7 @@ def print_instance_table(instances: Sequence[Instance]) -> None:
]
table.add_row(*row)

console.print(table)
console.print()
return table


def print_offers_table(
Expand Down
2 changes: 0 additions & 2 deletions src/dstack/_internal/core/backends/base/offers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def get_catalog_offers(
offers = []

catalog = catalog if catalog is not None else gpuhunt.default_catalog()
locs = []
for item in catalog.query(**asdict(q)):
locs.append(item.location)
if locations is not None and item.location not in locations:
continue
offer = catalog_item_to_offer(backend, item, requirements)
Expand Down
9 changes: 0 additions & 9 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,6 @@
from dstack._internal.utils.common import pretty_resources


class InstanceState(str, Enum):
NOT_FOUND = "not_found"
PROVISIONING = "provisioning"
RUNNING = "running"
STOPPED = "stopped"
STOPPING = "stopping"
TERMINATED = "terminated"


class Gpu(BaseModel):
name: str
memory_mib: int
Expand Down
72 changes: 57 additions & 15 deletions src/dstack/_internal/server/background/tasks/process_pools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict
from typing import Dict, Optional, Union
from uuid import UUID

from pydantic import parse_raw_as
Expand All @@ -25,6 +26,16 @@

PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)


@dataclass
class HealthStatus:
healthy: bool
reason: str

def __str__(self):
return self.reason


logger = get_logger(__name__)


Expand Down Expand Up @@ -78,32 +89,67 @@ async def check_shim(instance_id: UUID) -> None:
ssh_private_key = instance.project.ssh_private_key
job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data)

instance_health = instance_healthcheck(ssh_private_key, job_provisioning_data)
instance_health: Union[Optional[HealthStatus], bool] = instance_healthcheck(
ssh_private_key, job_provisioning_data
)
if isinstance(instance_health, bool) or instance_health is None:
health = HealthStatus(healthy=False, reason="SSH or tunnel error")
else:
health = instance_health

logger.debug("check instance %s status: shim health is %s", instance.name, instance_health)
if health.healthy:
logger.debug("check instance %s status: shim health is OK", instance.name)
instance.fail_count = 0
instance.fail_reason = None

if instance_health:
if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING):
instance.status = (
InstanceStatus.READY if instance.job_id is None else InstanceStatus.BUSY
)
await session.commit()
else:
logger.debug("check instance %s status: shim health: %s", instance.name, health)

instance.fail_count += 1
instance.fail_reason = health.reason

if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY):
logger.warning(
"instance %s shim is not available, marked as failed", instance.name
)
instance.status = InstanceStatus.FAILED
await session.commit()
FAIL_THRESHOLD = 10 * 6 * 20 # instance_healthcheck fails 20 minutes constantly
if instance.fail_count > FAIL_THRESHOLD:
instance.status = InstanceStatus.TERMINATING
logger.warning("mark instance %s as TERMINATED", instance.name)

if instance.status == InstanceStatus.STARTING and instance.started_at is not None:
STARTING_TIMEOUT = 10 * 60 # 10 minutes
starting_time_threshold = instance.started_at + timedelta(seconds=STARTING_TIMEOUT)
expire_starting = starting_time_threshold < get_current_datetime()
if expire_starting:
instance.status = InstanceStatus.TERMINATING

await session.commit()


@runner_ssh_tunnel(ports=[client.REMOTE_SHIM_PORT], retries=1)
def instance_healthcheck(*, ports: Dict[int, int]) -> bool:
def instance_healthcheck(*, ports: Dict[int, int]) -> HealthStatus:
shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT])
resp = shim_client.healthcheck()
if resp is None:
return False # shim is not available yet
return resp.service == "dstack-shim"
try:
resp = shim_client.healthcheck(unmask_exeptions=True)

if resp is None:
return HealthStatus(healthy=False, reason="Unknown reason")

if resp.service == "dstack-shim":
return HealthStatus(healthy=True, reason="Service is OK")
else:
return HealthStatus(
healthy=False,
reason=f"Service name is {resp.service}, service version: {resp.version}",
)
except Exception as e:
return HealthStatus(healthy=False, reason=f"Exception ({e.__class__.__name__}): {e}")


async def terminate(instance_id: UUID) -> None:
Expand All @@ -116,8 +162,6 @@ async def terminate(instance_id: UUID) -> None:
)
).one()

# TODO: need lock

jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data)
BACKEND_TYPE = jpd.backend
backends = await backends_services.get_project_backends(project=instance.project)
Expand Down Expand Up @@ -152,8 +196,6 @@ async def terminate_idle_instance() -> None:
)
instances = res.scalars().all()

# TODO: need lock

for instance in instances:
last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)
if instance.last_job_processed_at is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
InstanceOfferWithAvailability,
LaunchedInstanceInfo,
)
from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy
from dstack._internal.core.models.profiles import CreationPolicy
from dstack._internal.core.models.runs import (
InstanceStatus,
Job,
Expand All @@ -23,17 +23,18 @@
RunSpec,
)
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import InstanceModel, JobModel, PoolModel, RunModel
from dstack._internal.server.models import InstanceModel, JobModel, RunModel
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services.jobs import (
PROCESSING_POOL_LOCK,
SUBMITTED_PROCESSING_JOBS_IDS,
SUBMITTED_PROCESSING_JOBS_LOCK,
)
from dstack._internal.server.services.logging import job_log
from dstack._internal.server.services.pools import (
filter_pool_instances,
get_or_create_default_pool_by_name,
get_pool_instances,
list_project_pool_models,
)
from dstack._internal.server.services.runs import run_model_to_run
from dstack._internal.server.utils.common import run_async
Expand Down Expand Up @@ -87,57 +88,38 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
run_model = res.scalar_one()
project_model = run_model.project

# check default pool
pool = project_model.default_pool
if pool is None:
# TODO: get_or_create_default_pool...
pools = await list_project_pool_models(session, job_model.project)
for pool_item in pools:
if pool_item.id == job_model.project.default_pool_id:
pool = pool_item
if pool_item.name == DEFAULT_POOL_NAME:
pool = pool_item
if pool is None:
pool = PoolModel(
name=DEFAULT_POOL_NAME,
project=project_model,
)
session.add(pool)
await session.commit()
await session.refresh(pool)

if pool.id is not None:
project_model.default_pool_id = pool.id

run_spec = parse_raw_as(RunSpec, run_model.run_spec)
profile = run_spec.profile
run_pool = profile.pool_name
if run_pool is None:
run_pool = pool.name

# pool capacity

pool_instances = await get_pool_instances(session, project_model, run_pool)
relevant_instances = filter_pool_instances(
pool_instances, profile, run_spec.configuration.resources, status=InstanceStatus.READY
)
# check default pool
pool = project_model.default_pool
if pool is None:
pool = await get_or_create_default_pool_by_name(
session, project_model, pool_name=profile.pool_name
)
project_model.default_pool = pool

if relevant_instances:
sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name)
instance = sorted_instances[0]
async with PROCESSING_POOL_LOCK:
# pool capacity
pool_instances = await get_pool_instances(session, project_model, pool.name)
relevant_instances = filter_pool_instances(
pool_instances, profile, run_spec.configuration.resources, status=InstanceStatus.READY
)

# need lock
instance.status = InstanceStatus.BUSY
instance.job = job_model
if relevant_instances:
sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name)
instance = sorted_instances[0]
instance.status = InstanceStatus.BUSY
instance.job = job_model

logger.info(*job_log("now is provisioning", job_model))
job_model.job_provisioning_data = instance.job_provisioning_data
job_model.status = JobStatus.PROVISIONING
job_model.last_processed_at = common_utils.get_current_datetime()
logger.info(*job_log("now is provisioning", job_model))
job_model.job_provisioning_data = instance.job_provisioning_data
job_model.status = JobStatus.PROVISIONING
job_model.last_processed_at = common_utils.get_current_datetime()

await session.commit()
await session.commit()

return
return

run = run_model_to_run(run_model)
job = run.jobs[job_model.job_num]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Added fail reason for InstanceModel

Revision ID: ea4cd670dba5
Revises: 29c08c6a8cb3
Create Date: 2024-02-16 18:10:38.805380

"""


# revision identifiers, used by Alembic.
revision = "ea4cd670dba5"
down_revision = "29c08c6a8cb3"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
Loading