From 15951a669918dfbe16bb86eee1aee82c961098d0 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 8 Dec 2023 11:29:17 +0300 Subject: [PATCH 01/47] pool initial --- .pre-commit-config.yaml | 2 +- runner/internal/runner/api/http.go | 4 +- src/dstack/_internal/cli/commands/pool.py | 81 +++++++++++++++++++ src/dstack/_internal/cli/commands/run.py | 9 +++ src/dstack/_internal/cli/main.py | 2 + src/dstack/_internal/core/models/instances.py | 6 +- src/dstack/_internal/core/models/pool.py | 21 +++++ src/dstack/_internal/core/models/profiles.py | 5 ++ src/dstack/_internal/core/models/runs.py | 1 + src/dstack/_internal/server/app.py | 2 + .../tasks/process_submitted_jobs.py | 41 ++++++++-- .../versions/2943402e3b56_add_pools.py | 49 +++++++++++ src/dstack/_internal/server/models.py | 40 +++++++++ src/dstack/_internal/server/routers/pool.py | 52 ++++++++++++ src/dstack/_internal/server/routers/runs.py | 3 +- src/dstack/_internal/server/schemas/pool.py | 13 +++ .../services/jobs/configurators/base.py | 4 + src/dstack/_internal/server/services/pool.py | 81 +++++++++++++++++++ src/dstack/_internal/server/services/runs.py | 13 ++- src/dstack/api/_public/__init__.py | 6 ++ src/dstack/api/_public/pool.py | 40 +++++++++ src/dstack/api/_public/runs.py | 8 +- src/dstack/api/server/__init__.py | 12 ++- src/dstack/api/server/_pool.py | 25 ++++++ src/dstack/api/server/_runs.py | 3 +- .../tasks/test_process_submitted_jobs.py | 47 +++++++---- .../_internal/server/routers/test_runs.py | 5 ++ 27 files changed, 540 insertions(+), 35 deletions(-) create mode 100644 src/dstack/_internal/cli/commands/pool.py create mode 100644 src/dstack/_internal/core/models/pool.py create mode 100644 src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py create mode 100644 src/dstack/_internal/server/routers/pool.py create mode 100644 src/dstack/_internal/server/schemas/pool.py create mode 100644 src/dstack/_internal/server/services/pool.py create mode 100644 src/dstack/api/_public/pool.py create mode 100644 src/dstack/api/server/_pool.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0593575f2..860da2c55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: 22.12.0 hooks: - id: black - language_version: python3.11 + language_version: python3.10 args: ['--config', 'pyconfig.toml'] - repo: https://github.com/pycqa/isort rev: 5.12.0 diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 1b0cef495..936085983 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -26,7 +26,9 @@ func (s *Server) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) ( func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() - if s.executor.GetRunnerState() != executor.WaitSubmit { + state := s.executor.GetRunnerState() + if state != executor.WaitSubmit { + log.Warning(r.Context(), "Executor doesn't wait submit", "current_state", state) return nil, &api.Error{Status: http.StatusConflict} } diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py new file mode 100644 index 000000000..b21265e56 --- /dev/null +++ b/src/dstack/_internal/cli/commands/pool.py @@ -0,0 +1,81 @@ +import argparse +from typing import List + +from rich.table import Table + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.utils.common import console +from dstack._internal.utils.common import pretty_date + + +def print_pool_table(pools: List, verbose): + table = Table(box=None) + table.add_column("NAME") + table.add_column("DEFAULT") + if verbose: + table.add_column("CREATED") + + for pool in pools: + row = [pool.name, "default" if pool.default else ""] + if verbose: + row.append(pretty_date(pool.created_at)) + table.add_row(*row) + + console.print(table) + console.print() + + +class PoolCommand(APIBaseCommand): + NAME = "pool" + DESCRIPTION = "Pool management" + + def _register(self): + super()._register() + self._parser.set_defaults(subfunc=self._list) + subparsers = self._parser.add_subparsers(dest="action") + + list_parser = subparsers.add_parser( + "list", help="List pools", formatter_class=self._parser.formatter_class + ) + list_parser.add_argument("-v", "--verbose", help="Show more information") + list_parser.set_defaults(subfunc=self._list) + + create_parser = subparsers.add_parser( + "create", help="Create pool", formatter_class=self._parser.formatter_class + ) + create_parser.add_argument("-n", "--name", dest="pool_name", help="The name of the pool") + create_parser.set_defaults(subfunc=self._create) + + delete_parser = subparsers.add_parser( + "delete", help="Delete pool", formatter_class=self._parser.formatter_class + ) + delete_parser.add_argument( + "-n", "--name", dest="pool_name", help="The name of the pool", required=True + ) + delete_parser.set_defaults(subfunc=self._delete) + + show_parser = subparsers.add_parser( + "show", help="Show pool's instances", formatter_class=self._parser.formatter_class + ) + show_parser.add_argument( + "-n", "--name", dest="pool_name", help="The name of the pool", required=True + ) + show_parser.set_defaults(subfunc=self._show) + + def _list(self, args: argparse.Namespace): + pools = self.api.client.pool.list(self.api.project) + print_pool_table(pools, verbose=getattr(args, "verbose", False)) + + def _create(self, args: argparse.Namespace): + self.api.client.pool.create(self.api.project, args.pool_name) + + def _delete(self, args: argparse.Namespace): + self.api.client.pool.delete(self.api.project, args.pool_name) + + def _show(self, args: argparse.Namespace): + self.api.client.pool.show(self.api.project, args.pool_name) + + def _command(self, args: argparse.Namespace): + super()._command(args) + # TODO handle 404 and other errors + args.subfunc(args) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index f39296da4..a27c0e7d9 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -17,6 +17,7 @@ from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.logging import get_logger @@ -78,6 +79,11 @@ def _register(self): type=int, default=3, ) + self._parser.add_argument( + "--pool", + dest="pool_name", + help="The name of the pool", + ) register_profile_args(self._parser) def _command(self, args: argparse.Namespace): @@ -109,6 +115,8 @@ def _command(self, args: argparse.Namespace): known, unknown = parser.parse_known_args(args.unknown) configurator.apply(known, unknown, conf) + pool_name = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name + with console.status("Getting run plan..."): run_plan = self.api.runs.get_plan( configuration=conf, @@ -121,6 +129,7 @@ def _command(self, args: argparse.Namespace): max_price=profile.max_price, working_dir=args.working_dir, run_name=args.run_name, + pool_name=pool_name, ) except ConfigurationError as e: raise CLIError(str(e)) diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index a714d3a75..295bc2136 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -6,6 +6,7 @@ from dstack._internal.cli.commands.gateway import GatewayCommand from dstack._internal.cli.commands.init import InitCommand from dstack._internal.cli.commands.logs import LogsCommand +from dstack._internal.cli.commands.pool import PoolCommand from dstack._internal.cli.commands.ps import PsCommand from dstack._internal.cli.commands.run import RunCommand from dstack._internal.cli.commands.server import ServerCommand @@ -50,6 +51,7 @@ def main(): subparsers = parser.add_subparsers(metavar="COMMAND") ConfigCommand.register(subparsers) GatewayCommand.register(subparsers) + PoolCommand.register(subparsers) InitCommand.register(subparsers) LogsCommand.register(subparsers) PsCommand.register(subparsers) diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index fde057b75..cd45ded11 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from dstack._internal.core.models.backends.base import BackendType from dstack._internal.utils.common import pretty_resources @@ -67,8 +67,8 @@ class LaunchedInstanceInfo(BaseModel): username: str ssh_port: int # could be different from 22 for some backends dockerized: bool # True if backend starts shim - ssh_proxy: Optional[SSHConnectionParams] - backend_data: Optional[str] # backend-specific data in json + ssh_proxy: Optional[SSHConnectionParams] = Field(default=None) + backend_data: Optional[str] = Field(default=None) # backend-specific data in json class InstanceAvailability(Enum): diff --git a/src/dstack/_internal/core/models/pool.py b/src/dstack/_internal/core/models/pool.py new file mode 100644 index 000000000..315ca0171 --- /dev/null +++ b/src/dstack/_internal/core/models/pool.py @@ -0,0 +1,21 @@ +import datetime +from typing import List, Optional + +from pydantic import BaseModel + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceType + + +class Pool(BaseModel): + name: str + default: bool + created_at: datetime.datetime + + +class Instance(BaseModel): + backend: BackendType + instance_type: InstanceType + instance_id: str + hostname: str + price: float diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 35a9717ff..d4c6fbf2e 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -9,6 +9,7 @@ from dstack._internal.core.models.common import ForbidExtra DEFAULT_RETRY_LIMIT = 3600 +DEFAULT_POOL_NAME = "default-pool" class SpotPolicy(str, Enum): @@ -94,6 +95,10 @@ class Profile(ForbidExtra): default: Annotated[ bool, Field(description="If set to true, `dstack run` will use this profile by default.") ] = False + pool_name: Annotated[ + Optional[str], + Field(description="The name of the pool. If not set, dstack will use the default name."), + ] = DEFAULT_POOL_NAME _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 6d22471d7..d1cd10052 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -120,6 +120,7 @@ class JobSpec(BaseModel): requirements: Requirements retry_policy: RetryPolicy working_dir: str + pool_name: Optional[str] class JobProvisioningData(BaseModel): diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index e1d6e0152..c589b1bd6 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -17,6 +17,7 @@ backends, gateways, logs, + pool, projects, repos, runs, @@ -130,6 +131,7 @@ def add_no_api_version_check_routes(paths: List[str]): def register_routes(app: FastAPI): app.include_router(users.router) app.include_router(projects.router) + app.include_router(pool.router) app.include_router(backends.root_router) app.include_router(backends.project_router) app.include_router(repos.router) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 07b623d96..8fd711944 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from uuid import UUID from sqlalchemy import select @@ -7,7 +7,11 @@ from dstack._internal.core.backends.base import Backend from dstack._internal.core.errors import BackendError -from dstack._internal.core.models.instances import LaunchedInstanceInfo +from dstack._internal.core.models.instances import ( + InstanceOfferWithAvailability, + LaunchedInstanceInfo, +) +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME from dstack._internal.core.models.runs import ( Job, JobErrorCode, @@ -16,7 +20,7 @@ Run, ) from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.models import InstanceModel, JobModel, PoolModel, RunModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.jobs import ( SUBMITTED_PROCESSING_JOBS_IDS, @@ -74,10 +78,22 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): ) run_model = res.scalar_one() project_model = run_model.project + + pool = project_model.default_pool + if pool is None: + pool = PoolModel( + name=DEFAULT_POOL_NAME, + project=project_model, + ) + session.add(pool) + await session.commit() + if pool.id is not None: + project_model.default_pool_id = pool.id + run = run_model_to_run(run_model) job = run.jobs[job_model.job_num] backends = await backends_services.get_project_backends(project=run_model.project) - job_provisioning_data = await _run_job( + job_provisioning_data, offer = await _run_job( job_model=job_model, run=run, job=job, @@ -89,6 +105,15 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): logger.info(*job_log("now is provisioning", job_model)) job_model.job_provisioning_data = job_provisioning_data.json() job_model.status = JobStatus.PROVISIONING + + im = InstanceModel( + project=project_model, + pool=pool, + job_provisioning_data=job_provisioning_data.json(), + offer=offer.json(), + ) + session.add(im) + else: logger.debug(*job_log("provisioning failed", job_model)) if job.is_retry_active(): @@ -108,7 +133,7 @@ async def _run_job( backends: List[Backend], project_ssh_public_key: str, project_ssh_private_key: str, -) -> Optional[JobProvisioningData]: +) -> Tuple[Optional[JobProvisioningData], Optional[InstanceOfferWithAvailability]]: if run.run_spec.profile.backends is not None: backends = [b for b in backends if b.TYPE in run.run_spec.profile.backends] try: @@ -151,7 +176,7 @@ async def _run_job( ) continue else: - return JobProvisioningData( + job_provisioning_data = JobProvisioningData( backend=backend.TYPE, instance_type=offer.instance, instance_id=launched_instance_info.instance_id, @@ -164,4 +189,6 @@ async def _run_job( ssh_proxy=launched_instance_info.ssh_proxy, backend_data=launched_instance_info.backend_data, ) - return None + + return (job_provisioning_data, offer) + return (None, None) diff --git a/src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py b/src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py new file mode 100644 index 000000000..554829e55 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py @@ -0,0 +1,49 @@ +"""add pools + +Revision ID: 2943402e3b56 +Revises: e6391ca6a264 +Create Date: 2023-12-13 14:02:25.106604 + +""" +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "2943402e3b56" +down_revision = "e6391ca6a264" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "default_pool_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_projects_default_pool_id_pools"), + "pools", + ["default_pool_id"], + ["id"], + ondelete="SET NULL", + use_alter=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_projects_default_pool_id_pools"), type_="foreignkey" + ) + batch_op.drop_column("default_pool_id") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index dc62443f3..a0be7ca22 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -79,6 +79,13 @@ class ProjectModel(BaseModel): foreign_keys=[default_gateway_id], lazy="selectin" ) + default_pool_id: Mapped[Optional[UUIDType]] = mapped_column( + ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True + ) + default_pool: Mapped["PoolModel"] = relationship( + foreign_keys=[default_pool_id], lazy="selectin" + ) + class MemberModel(BaseModel): __tablename__ = "members" @@ -230,3 +237,36 @@ class GatewayComputeModel(BaseModel): ssh_public_key: Mapped[str] = mapped_column(Text) deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) + + +class PoolModel(BaseModel): + __tablename__ = "pools" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(String(50), unique=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) + + instances: Mapped[List["InstanceModel"]] = relationship(back_populates="pool", lazy="selectin") + + +class InstanceModel(BaseModel): + __tablename__ = "instances" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) + + pool_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("pools.id")) + pool: Mapped["PoolModel"] = relationship(back_populates="instances") + + job_provisioning_data: Mapped[str] = mapped_column(String(4000)) + offer: Mapped[str] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/routers/pool.py b/src/dstack/_internal/server/routers/pool.py new file mode 100644 index 000000000..6808d7e71 --- /dev/null +++ b/src/dstack/_internal/server/routers/pool.py @@ -0,0 +1,52 @@ +from typing import List, Tuple + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +import dstack._internal.core.models.pool as models +import dstack._internal.server.schemas.pool as schemas +import dstack._internal.server.services.pool as pool +from dstack._internal.server.db import get_session +from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember + +router = APIRouter(prefix="/api/project/{project_name}/pool", tags=["pool"]) + + +@router.post("/list") +async def list_pool( + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> List[models.Pool]: + _, project = user_project + return await pool.list_project_pool(session=session, project=project) + + +@router.post("/delete") +async def delete_pool( + body: schemas.DeletePoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +): + _, project = user_project + await pool.delete_pool(session=session, project=project, pool_name=body.name) + + +@router.post("/create") +async def create_pool( + body: schemas.CreatePoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +): + _, project = user_project + await pool.create_pool_model(name=body.name, session=session, project=project) + + +@router.post("/show") +async def how_pool( + body: schemas.CreatePoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +): + _, project = user_project + return await pool.show_pool(pool_name=body.name, session=session, project=project) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 1a02b7f23..7bab036c4 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -8,6 +8,7 @@ from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( DeleteRunsRequest, + GetRunPlanRequest, GetRunRequest, ListRunsRequest, StopRunsRequest, @@ -60,7 +61,7 @@ async def get_run( @project_router.post("/get_plan") async def get_run_plan( - body: SubmitRunRequest, + body: GetRunPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> RunPlan: diff --git a/src/dstack/_internal/server/schemas/pool.py b/src/dstack/_internal/server/schemas/pool.py new file mode 100644 index 000000000..ade9c1e88 --- /dev/null +++ b/src/dstack/_internal/server/schemas/pool.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class DeletePoolRequest(BaseModel): + name: str + + +class CreatePoolRequest(BaseModel): + name: str + + +class ShowPoolRequest(BaseModel): + name: str diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index c62d55474..e5dd34e01 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -43,6 +43,7 @@ def get_job_specs(self) -> List[JobSpec]: requirements=self._requirements(), retry_policy=self._retry_policy(), working_dir=self._working_dir(), + pool_name=self._pool_name(), ) return [job_spec] @@ -143,6 +144,9 @@ def _python(self) -> str: version_info = sys.version_info return PythonVersion(f"{version_info.major}.{version_info.minor}").value + def _pool_name(self): + return self.run_spec.profile.pool_name + def _join_shell_commands(commands: List[str], env: Optional[Dict[str, str]] = None) -> str: if env is None: diff --git a/src/dstack/_internal/server/services/pool.py b/src/dstack/_internal/server/services/pool.py new file mode 100644 index 000000000..88299e0e2 --- /dev/null +++ b/src/dstack/_internal/server/services/pool.py @@ -0,0 +1,81 @@ +from datetime import timezone +from typing import List, Optional, Sequence + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.pool import Pool +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME +from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def list_project_pool(session: AsyncSession, project: ProjectModel) -> List[Pool]: + pools = list(await list_project_pool_models(session=session, project=project)) + if not pools: + pool = await create_pool_model(DEFAULT_POOL_NAME, session, project) + pools.append(pool) + return [pool_model_to_pool(p) for p in pools] + + +def pool_model_to_pool(pool_model: PoolModel) -> Pool: + return Pool( + name=pool_model.name, + default=pool_model.project.default_pool_id == pool_model.id, + created_at=pool_model.created_at.replace(tzinfo=timezone.utc), + ) + + +async def create_pool_model(name: str, session: AsyncSession, project: ProjectModel) -> PoolModel: + pool = PoolModel( + name=name, + project_id=project.id, + ) + session.add(pool) + await session.commit() + project.default_pool = pool + await session.commit() + return pool + + +async def list_project_pool_models( + session: AsyncSession, project: ProjectModel +) -> Sequence[PoolModel]: + pools = await session.execute(select(PoolModel).where(PoolModel.project_id == project.id)) + return pools.scalars().all() + + +async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str): + """delete the pool and set the default pool to project""" + + default_pool: Optional[PoolModel] = None + default_pool_removed = False + + for pool in await list_project_pool_models(session=session, project=project): + if pool.name == DEFAULT_POOL_NAME: + default_pool = pool + + if pool_name == pool.name: + if project.default_pool_id == pool.id: + default_pool_removed = True + await session.delete(pool) + + if default_pool_removed: + if default_pool is not None: + project.default_pool = default_pool + else: + await create_pool_model(DEFAULT_POOL_NAME, session, project) + + await session.commit() + + +async def show_pool( + pool_name: str, session: AsyncSession, project: ProjectModel +) -> Sequence[InstanceModel]: + pools_result = await session.execute(select(PoolModel).where(PoolModel.name == pool_name)) + pools = pools_result.scalars().all() + + instances = pools[0].instances + return instances diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 301218747..c96240208 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -25,7 +25,7 @@ ServiceModelInfo, ) from dstack._internal.core.models.users import GlobalRole -from dstack._internal.server.models import JobModel, ProjectModel, RunModel, UserModel +from dstack._internal.server.models import JobModel, PoolModel, ProjectModel, RunModel, UserModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import repos as repos_services from dstack._internal.server.services.jobs import ( @@ -33,6 +33,7 @@ job_model_to_job_submission, stop_job, ) +from dstack._internal.server.services.pool import create_pool_model from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.utils.logging import get_logger from dstack._internal.utils.random_names import generate_name @@ -132,6 +133,7 @@ async def get_run_plan( jobs = get_jobs_from_run_spec(run_spec) job_plans = [] for job in jobs: + # TODO: use the job.pool_name to select an offer offers = await backends_services.get_instance_offers( backends=backends, job=job, @@ -177,6 +179,13 @@ async def submit_run( ) else: await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) + + pool_name = run_spec.profile.pool_name + pools_result = await session.execute(select(PoolModel).where(PoolModel.name == pool_name)) + pools = pools_result.scalars().all() + if not pools: + await create_pool_model(name=pool_name, session=session, project=project) + run_model = RunModel( id=uuid.uuid4(), project_id=project.id, @@ -346,7 +355,7 @@ async def _generate_run_name( def _get_run_cost(run: Run) -> float: - run_cost = sum( + run_cost = math.fsum( _get_job_submission_cost(submission) for job in run.jobs for submission in job.job_submissions diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index ffcc6da95..2c16fc6b0 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -6,6 +6,7 @@ from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack.api._public.backends import BackendCollection +from dstack.api._public.pool import PoolCollection from dstack.api._public.repos import RepoCollection, get_ssh_keypair from dstack.api._public.runs import RunCollection from dstack.api.server import APIClient @@ -40,6 +41,7 @@ def __init__( self._repos = RepoCollection(api_client, project_name) self._backends = BackendCollection(api_client, project_name) self._runs = RunCollection(api_client, project_name, self) + self._pool = PoolCollection(api_client, project_name) if ssh_identity_file: self.ssh_identity_file = str(ssh_identity_file) else: @@ -95,3 +97,7 @@ def client(self) -> APIClient: @property def project(self) -> str: return self._project + + @property + def pool(self) -> PoolCollection: + return self._pool diff --git a/src/dstack/api/_public/pool.py b/src/dstack/api/_public/pool.py new file mode 100644 index 000000000..f1dea3cbf --- /dev/null +++ b/src/dstack/api/_public/pool.py @@ -0,0 +1,40 @@ +from typing import List + +from dstack.api.server import APIClient + + +class Instance: + def __init__(self, api_client: APIClient, instance): + self._api_client = api_client + self._instance = instance + + @property + def name(self) -> str: + return self._instance.name + + def __str__(self) -> str: + return f"" + + def __repr__(self) -> str: + return f"" + + +class PoolCollection: + """ + Operations with pools + """ + + def __init__(self, api_client: APIClient, project: str): + self._api_client = api_client + self._project = project + + def list(self) -> List[Instance]: + """ + List available pool in the project + + Returns: + pools + """ + list_raw_instances = self._api_client.pool.list(project_name=self._project) + list_instances = [Instance(self._api_client, instance) for instance in list_raw_instances] + return list_instances diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 89558743a..6adc7680a 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -370,6 +370,7 @@ def get_plan( max_price: Optional[float] = None, working_dir: Optional[str] = None, run_name: Optional[str] = None, + pool_name: Optional[str] = None, ) -> RunPlan: # """ # Get run plan. Same arguments as `submit` @@ -380,10 +381,10 @@ def get_plan( if working_dir is None: working_dir = "." elif repo.repo_dir is not None: - working_dir = Path(repo.repo_dir) / working_dir - if not path_in_dir(working_dir, repo.repo_dir): + working_dir_path = Path(repo.repo_dir) / working_dir + if not path_in_dir(working_dir_path, repo.repo_dir): raise ConfigurationError("Working directory is outside of the repo") - working_dir = working_dir.relative_to(repo.repo_dir).as_posix() + working_dir = working_dir_path.relative_to(repo.repo_dir).as_posix() if configuration_path is None: configuration_path = "(python)" @@ -399,6 +400,7 @@ def get_plan( retry_policy=retry_policy, max_duration=max_duration, max_price=max_price, + pool_name=pool_name, ) run_spec = RunSpec( run_name=run_name, diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index c5c9749cf..49b501070 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -10,6 +10,7 @@ from dstack.api.server._backends import BackendsAPIClient from dstack.api.server._gateways import GatewaysAPIClient from dstack.api.server._logs import LogsAPIClient +from dstack.api.server._pool import PoolAPIClient from dstack.api.server._projects import ProjectsAPIClient from dstack.api.server._repos import ReposAPIClient from dstack.api.server._runs import RunsAPIClient @@ -34,6 +35,7 @@ class APIClient: runs: operations with runs logs: operations with logs gateways: operations with gateways + pools: operations with pools """ def __init__(self, base_url: str, token: str): @@ -82,8 +84,16 @@ def secrets(self) -> SecretsAPIClient: def gateways(self) -> GatewaysAPIClient: return GatewaysAPIClient(self._request) + @property + def pool(self) -> PoolAPIClient: + return PoolAPIClient(self._request) + def _request( - self, path: str, body: Optional[str] = None, raise_for_status: bool = True, **kwargs + self, + path: str, + body: Optional[str] = None, + raise_for_status: bool = True, + **kwargs, ) -> requests.Response: path = path.lstrip("/") if body is not None: diff --git a/src/dstack/api/server/_pool.py b/src/dstack/api/server/_pool.py new file mode 100644 index 000000000..ea1b4ccf0 --- /dev/null +++ b/src/dstack/api/server/_pool.py @@ -0,0 +1,25 @@ +from typing import List + +from pydantic import parse_obj_as + +import dstack._internal.server.schemas.pool as pool_schemas +from dstack._internal.core.models.pool import Pool +from dstack.api.server._group import APIClientGroup + + +class PoolAPIClient(APIClientGroup): + def list(self, project_name: str) -> List[Pool]: + resp = self._request(f"/api/project/{project_name}/pool/list") + return parse_obj_as(List[Pool], resp.json()) + + def delete(self, project_name: str, pool_name: str) -> None: + body = pool_schemas.DeletePoolRequest(name=pool_name) + self._request(f"/api/project/{project_name}/pool/delete", body=body.json()) + + def create(self, project_name: str, pool_name: str) -> None: + body = pool_schemas.CreatePoolRequest(name=pool_name) + self._request(f"/api/project/{project_name}/pool/create", body=body.json()) + + def show(self, project_name: str, pool_name: str) -> None: + body = pool_schemas.ShowPoolRequest(name=pool_name) + self._request(f"/api/project/{project_name}/pool/show", body=body.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 0f1bbe1ca..107685a83 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -9,6 +9,7 @@ GetRunRequest, ListRunsRequest, StopRunsRequest, + SubmitRunRequest, ) from dstack.api.server._group import APIClientGroup @@ -30,7 +31,7 @@ def get_plan(self, project_name: str, run_spec: RunSpec) -> RunPlan: return parse_obj_as(RunPlan, resp.json()) def submit(self, project_name: str, run_spec: RunSpec) -> Run: - body = GetRunPlanRequest(run_spec=run_spec) + body = SubmitRunRequest(run_spec=run_spec) resp = self._request(f"/api/project/{project_name}/runs/submit", body=body.json()) return parse_obj_as(Run, resp.json()) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 2bb8a6b2f..87d37c92e 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -1,7 +1,5 @@ -import json from datetime import datetime, timezone from unittest.mock import Mock, patch -from uuid import UUID import pytest from sqlalchemy.ext.asyncio import AsyncSession @@ -14,7 +12,7 @@ LaunchedInstanceInfo, Resources, ) -from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, ProfileRetryPolicy from dstack._internal.core.models.runs import JobStatus from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs from dstack._internal.server.testing.common import ( @@ -69,22 +67,21 @@ async def test_provisiones_job(self, test_db, session: AsyncSession): session=session, run=run, ) + offer = InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + ) with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value.get_offers.return_value = [ - InstanceOfferWithAvailability( - backend=BackendType.AWS, - instance=InstanceType( - name="instance", - resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), - ), - region="us", - price=1.0, - availability=InstanceAvailability.AVAILABLE, - ) - ] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = LaunchedInstanceInfo( instance_id="instance_id", region="us", @@ -97,10 +94,22 @@ async def test_provisiones_job(self, test_db, session: AsyncSession): m.assert_called_once() backend_mock.compute.return_value.get_offers.assert_called_once() backend_mock.compute.return_value.run_job.assert_called_once() + await session.refresh(job) assert job is not None assert job.status == JobStatus.PROVISIONING + await session.refresh(project) + assert project.default_pool.name == DEFAULT_POOL_NAME + + instance_offer = InstanceOfferWithAvailability.parse_raw( + project.default_pool.instances[0].offer + ) + assert offer == instance_offer + + pool_job_provisioning_data = project.default_pool.instances[0].job_provisioning_data + assert pool_job_provisioning_data == job.job_provisioning_data + @pytest.mark.asyncio async def test_transitions_job_with_retry_to_pending_on_no_capacity( self, test_db, session: AsyncSession @@ -134,10 +143,14 @@ async def test_transitions_job_with_retry_to_pending_on_no_capacity( with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 3, 30, 0, tzinfo=timezone.utc) await process_submitted_jobs() + await session.refresh(job) assert job is not None assert job.status == JobStatus.PENDING + await session.refresh(project) + assert not project.default_pool.instances + @pytest.mark.asyncio async def test_transitions_job_with_outdated_retry_to_failed_on_no_capacity( self, test_db, session: AsyncSession @@ -171,6 +184,10 @@ async def test_transitions_job_with_outdated_retry_to_failed_on_no_capacity( with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 5, 0, 0, tzinfo=timezone.utc) await process_submitted_jobs() + await session.refresh(job) assert job is not None assert job.status == JobStatus.FAILED + + await session.refresh(project) + assert not project.default_pool.instances diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 814319013..1f4a4816c 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -16,6 +16,7 @@ InstanceType, Resources, ) +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME from dstack._internal.core.models.runs import JobSpec, JobStatus, RunSpec from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.main import app @@ -74,6 +75,7 @@ def get_dev_env_run_plan_dict( "max_duration": "off", "max_price": None, "name": "string", + "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", }, @@ -112,6 +114,7 @@ def get_dev_env_run_plan_dict( "job_name": f"{run_name}-0", "job_num": 0, "max_duration": None, + "pool_name": DEFAULT_POOL_NAME, "registry_auth": None, "requirements": { "resources": { @@ -180,6 +183,7 @@ def get_dev_env_run_dict( "max_duration": "off", "max_price": None, "name": "string", + "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", }, @@ -218,6 +222,7 @@ def get_dev_env_run_dict( "job_name": f"{run_name}-0", "job_num": 0, "max_duration": None, + "pool_name": DEFAULT_POOL_NAME, "registry_auth": None, "requirements": { "resources": { From 27b169eaf62fb4f6ac5c97eafc42283371e095a1 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 20 Dec 2023 11:27:52 +0300 Subject: [PATCH 02/47] pool add --- docs/docs/reference/pool/index.md | 42 ++++ runner/cmd/shim/main.go | 1 - runner/internal/shim/api/http.go | 14 +- runner/internal/shim/api/schemas.go | 16 +- runner/internal/shim/api/server.go | 9 +- runner/internal/shim/docker.go | 19 +- runner/internal/shim/docker_test.go | 6 +- runner/internal/shim/models.go | 28 ++- src/dstack/_internal/cli/commands/pool.py | 183 +++++++++++++++++- src/dstack/_internal/cli/utils/run.py | 2 +- .../_internal/core/backends/aws/compute.py | 66 +++++++ .../_internal/core/backends/base/compute.py | 34 ++++ .../core/backends/datacrunch/compute.py | 67 ++++++- .../_internal/core/backends/gcp/compute.py | 74 +++++++ .../_internal/core/services/ssh/tunnel.py | 8 +- .../background/tasks/process_running_jobs.py | 19 +- .../tasks/process_submitted_jobs.py | 6 +- .../versions/2943402e3b56_add_pools.py | 49 ----- .../versions/beceb9d2895d_add_pool.py | 89 +++++++++ src/dstack/_internal/server/models.py | 1 + src/dstack/_internal/server/routers/pool.py | 4 +- src/dstack/_internal/server/routers/runs.py | 41 +++- src/dstack/_internal/server/schemas/runner.py | 1 + src/dstack/_internal/server/schemas/runs.py | 10 + .../server/services/backends/__init__.py | 8 +- .../_internal/server/services/docker.py | 11 +- .../services/jobs/configurators/base.py | 16 +- src/dstack/_internal/server/services/pool.py | 70 ++++++- .../server/services/runner/client.py | 6 +- .../_internal/server/services/runner/ssh.py | 6 + src/dstack/_internal/server/services/runs.py | 130 ++++++++++++- src/dstack/api/_public/runs.py | 12 +- src/dstack/api/server/_backends.py | 8 +- src/dstack/api/server/_pool.py | 7 +- src/dstack/api/server/_runs.py | 19 +- .../tasks/test_process_running_jobs.py | 4 +- .../_internal/server/services/test_pool.py | 181 +++++++++++++++++ 37 files changed, 1131 insertions(+), 136 deletions(-) create mode 100644 docs/docs/reference/pool/index.md delete mode 100644 src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py create mode 100644 src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py create mode 100644 src/tests/_internal/server/services/test_pool.py diff --git a/docs/docs/reference/pool/index.md b/docs/docs/reference/pool/index.md new file mode 100644 index 000000000..c2ac6e41f --- /dev/null +++ b/docs/docs/reference/pool/index.md @@ -0,0 +1,42 @@ +# dstack pool + +## What is `dstack pool` + +The primary element that enables you to precisely control how compute instances are used is the `dstack pool`. + +- Sometimes the desired instance for the task might not be available. The `dstack pool` will wait for compute instances to become available and, when possible, allocate instances before running tasks on these instances. + +- You need reserved compute instances to work on a constant load. The dstack will pre-allocate ondemand instances and allow you to run tasks on them when they are available. + +- I want to speed up tasks start. Searching for instances and provisioning the runner will take time. When using dstack pool, tasks will be distributed to already running instances. + +- You have your own compute instances. You can connect them to a dstack pool and use them with cloud instances. + +## How to use + +Any task that runs without setted the argument `--pool` by default uses a pool named `default`. + +When you specify a pool name for a task, for example `dstack run --pool mypool` there are two ways the task will be executed: + +- if `mypool` exists, the task will be run on a available instance with the suitable configuration +- if `mypool` does not exist, this pool will be created and the compute instances required for the pool are created and connected to that pool. + +### CLI + +- `dstack pool list` +- `dstack pool create` +- `dstack pool show ` +- `dstack pool add ` +- `dstack pool delete` + +### Instance lifecycle + +- idle time +- reservation policy (instance termination) +- task retry policy + +### Add your own compute instance + +When connecting your own instance, it must have public ip-address for the dstack server to connect. + +To connect you need to pass the ip-addres and ssh credentials to the command `dstack poll add --host HOST --port PORT --ssh-private-key-fileSSH_PRIVATE_KEY_FILE`. diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 1b6f7ff7b..582af89f2 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -93,7 +93,6 @@ func main() { &cli.StringFlag{ Name: "image", Usage: "Docker image name", - Required: true, Destination: &args.Docker.ImageName, EnvVars: []string{"DSTACK_IMAGE_NAME"}, }, diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index 160bdf2fa..019ea1c93 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -1,12 +1,9 @@ package api import ( - "encoding/base64" - "encoding/json" "log" "net/http" - "github.com/docker/docker/api/types/registry" "github.com/dstackai/dstack/runner/internal/api" "github.com/dstackai/dstack/runner/internal/shim" ) @@ -33,16 +30,7 @@ func (s *ShimServer) registryAuthPostHandler(w http.ResponseWriter, r *http.Requ return nil, err } - authConfig := registry.AuthConfig{ - Username: body.Username, - Password: body.Password, - } - encodedConfig, err := json.Marshal(authConfig) - if err != nil { - log.Println("Failed to encode auth config", "err", err) - return nil, err - } - s.registryAuth <- base64.URLEncoding.EncodeToString(encodedConfig) + s.registryAuth = body.MakeConfig() return nil, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 069b1aa66..8098cdfc1 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -1,8 +1,11 @@ package api +import "github.com/dstackai/dstack/runner/internal/shim" + type RegistryAuthBody struct { - Username string `json:"username"` - Password string `json:"password"` + Username string `json:"username"` + Password string `json:"password"` + ImageName string `json:"image_name"` } type HealthcheckResponse struct { @@ -12,3 +15,12 @@ type HealthcheckResponse struct { type PullResponse struct { State string `json:"state"` } + +func (ra RegistryAuthBody) MakeConfig() shim.ImagePullConfig { + res := shim.ImagePullConfig{ + ImageName: ra.ImageName, + Username: ra.Username, + Password: ra.Password, + } + return res +} diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 6ef8a3a13..a57a04536 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -16,7 +16,7 @@ type ShimServer struct { srv *http.Server mu sync.RWMutex - registryAuth chan string + registryAuth shim.ImagePullConfig state string } @@ -28,13 +28,10 @@ func NewShimServer(address string, registryAuthRequired bool) *ShimServer { Handler: mux, }, - registryAuth: make(chan string, 1), - state: shim.WaitRegistryAuth, + state: shim.WaitRegistryAuth, } if registryAuthRequired { mux.HandleFunc("/api/registry_auth", api.JSONResponseHandler("POST", s.registryAuthPostHandler)) - } else { - close(s.registryAuth) // no credentials ever would be sent } mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.healthcheckGetHandler)) mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.pullGetHandler)) @@ -55,7 +52,7 @@ func (s *ShimServer) RunDocker(ctx context.Context, params shim.DockerParameters return gerrors.Wrap(shim.RunDocker(ctx, params, s)) } -func (s *ShimServer) GetRegistryAuth() <-chan string { +func (s *ShimServer) GetRegistryAuth() shim.ImagePullConfig { return s.registryAuth } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index f3b6a4822..19232c8ca 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -29,13 +29,18 @@ func RunDocker(ctx context.Context, params DockerParameters, serverAPI APIAdapte } log.Println("Waiting for registry auth") - registryAuth := <-serverAPI.GetRegistryAuth() + registryAuth := serverAPI.GetRegistryAuth() serverAPI.SetState(Pulling) log.Println("Pulling image") - if err = pullImage(ctx, client, params.DockerImageName(), registryAuth); err != nil { + imageName := params.DockerImageName() + if imageName == "" { + imageName = registryAuth.ImageName + } + if err = pullImage(ctx, client, imageName, registryAuth); err != nil { return gerrors.Wrap(err) } + log.Println("Creating container") containerID, err := createContainer(ctx, client, params) if err != nil { @@ -57,7 +62,7 @@ func RunDocker(ctx context.Context, params DockerParameters, serverAPI APIAdapte return nil } -func pullImage(ctx context.Context, client docker.APIClient, imageName string, registryAuth string) error { +func pullImage(ctx context.Context, client docker.APIClient, imageName string, imagePullConfig ImagePullConfig) error { if !strings.Contains(imageName, ":") { imageName += ":latest" } @@ -71,7 +76,13 @@ func pullImage(ctx context.Context, client docker.APIClient, imageName string, r return nil } - reader, err := client.ImagePull(ctx, imageName, types.ImagePullOptions{RegistryAuth: registryAuth}) // todo test registry auth + opts := types.ImagePullOptions{} + regAuth, _ := imagePullConfig.EncodeRegistryAuth() + if regAuth != "" { + opts.RegistryAuth = regAuth + } + + reader, err := client.ImagePull(ctx, imageName, opts) // todo test registry auth if err != nil { return gerrors.Wrap(err) } diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index bf3cee691..b321224da 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -116,10 +116,8 @@ func (c *dockerParametersMock) DockerMounts() ([]mount.Mount, error) { type apiAdapterMock struct{} -func (s *apiAdapterMock) GetRegistryAuth() <-chan string { - ch := make(chan string) - close(ch) - return ch +func (s *apiAdapterMock) GetRegistryAuth() ImagePullConfig { + return ImagePullConfig{} } func (s *apiAdapterMock) SetState(string) {} diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 0a88847c7..b03d4963a 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -1,11 +1,16 @@ package shim import ( + "encoding/base64" + "encoding/json" + "log" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/registry" ) type APIAdapter interface { - GetRegistryAuth() <-chan string + GetRegistryAuth() ImagePullConfig SetState(string) } @@ -42,3 +47,24 @@ type CLIArgs struct { PublicSSHKey string } } + +type ImagePullConfig struct { + Username string + Password string + ImageName string +} + +func (ra ImagePullConfig) EncodeRegistryAuth() (string, error) { + authConfig := registry.AuthConfig{ + Username: ra.Username, + Password: ra.Password, + } + + encodedConfig, err := json.Marshal(authConfig) + if err != nil { + log.Println("Failed to encode auth config", "err", err) + return "", err + } + + return base64.URLEncoding.EncodeToString(encodedConfig), nil +} diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index b21265e56..856e95c9c 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -1,21 +1,41 @@ import argparse -from typing import List +from collections.abc import Sequence +from pathlib import Path from rich.table import Table from dstack._internal.cli.commands import APIBaseCommand -from dstack._internal.cli.utils.common import console +from dstack._internal.cli.services.configurators.profile import ( + apply_profile_args, + register_profile_args, +) +from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.core.errors import CLIError, ServerClientError +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, +) +from dstack._internal.core.models.pool import Instance, Pool +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.runs import Requirements +from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import pretty_date +from dstack._internal.utils.logging import get_logger +from dstack.api.utils import load_profile +logger = get_logger(__name__) +NOTSET = object() -def print_pool_table(pools: List, verbose): + +def print_pool_table(pools: Sequence[Pool], verbose): table = Table(box=None) table.add_column("NAME") table.add_column("DEFAULT") if verbose: table.add_column("CREATED") - for pool in pools: + sorted_pools = sorted(pools, key=lambda r: r.name) + for pool in sorted_pools: row = [pool.name, "default" if pool.default else ""] if verbose: row.append(pretty_date(pool.created_at)) @@ -25,6 +45,112 @@ def print_pool_table(pools: List, verbose): console.print() +def print_instance_table(instances: Sequence[Instance]): + table = Table(box=None) + table.add_column("INSTANCE ID") + table.add_column("BACKEND") + table.add_column("INSTANCE TYPE") + table.add_column("PRICE") + + for instance in instances: + row = [ + instance.instance_id, + instance.backend, + instance.instance_type.resources.pretty_format(), + f"{instance.price:.02f}", + ] + table.add_row(*row) + + console.print(table) + console.print() + + +def print_offers_table( + pool_name: str, + profile: Profile, + requirements: Requirements, + instance_offers: Sequence[InstanceOfferWithAvailability], + offers_limit: int = 3, +): + + pretty_req = requirements.pretty_format(resources_only=True) + max_price = f"${requirements.max_price:g}" if requirements.max_price else "-" + max_duration = ( + f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-" + ) + + # TODO: improve retry policy + # retry_policy = profile.retry_policy + # retry_policy = ( + # (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes") + # if retry_policy.retry + # else "no" + # ) + + # TODO: improve spot policy + if requirements.spot is None: + spot_policy = "auto" + elif requirements.spot: + spot_policy = "spot" + else: + spot_policy = "on-demand" + + def th(s: str) -> str: + return f"[bold]{s}[/bold]" + + props = Table(box=None, show_header=False) + props.add_column(no_wrap=True) # key + props.add_column() # value + + props.add_row(th("Pool name"), pool_name) + props.add_row(th("Min resources"), pretty_req) + props.add_row(th("Max price"), max_price) + props.add_row(th("Max duration"), max_duration) + props.add_row(th("Spot policy"), spot_policy) + # props.add_row(th("Retry policy"), retry_policy) + + offers_table = Table(box=None) + offers_table.add_column("#") + offers_table.add_column("BACKEND") + offers_table.add_column("REGION") + offers_table.add_column("INSTANCE") + offers_table.add_column("RESOURCES") + offers_table.add_column("SPOT") + offers_table.add_column("PRICE") + offers_table.add_column() + + print_offers = instance_offers[:offers_limit] + + for i, offer in enumerate(print_offers, start=1): + r = offer.instance.resources + + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + }: + availability = offer.availability.value.replace("_", " ").title() + offers_table.add_row( + f"{i}", + offer.backend, + offer.region, + offer.instance.name, + r.pretty_format(), + "yes" if r.spot else "no", + f"${offer.price:g}", + availability, + style=None if i == 1 else "grey58", + ) + if len(print_offers) > offers_limit: + offers_table.add_row("", "...", style="grey58") + + console.print(props) + console.print() + if len(print_offers) > 0: + console.print(offers_table) + console.print() + + class PoolCommand(APIBaseCommand): NAME = "pool" DESCRIPTION = "Pool management" @@ -34,18 +160,21 @@ def _register(self): self._parser.set_defaults(subfunc=self._list) subparsers = self._parser.add_subparsers(dest="action") + # list list_parser = subparsers.add_parser( "list", help="List pools", formatter_class=self._parser.formatter_class ) list_parser.add_argument("-v", "--verbose", help="Show more information") list_parser.set_defaults(subfunc=self._list) + # create create_parser = subparsers.add_parser( "create", help="Create pool", formatter_class=self._parser.formatter_class ) create_parser.add_argument("-n", "--name", dest="pool_name", help="The name of the pool") create_parser.set_defaults(subfunc=self._create) + # delete delete_parser = subparsers.add_parser( "delete", help="Delete pool", formatter_class=self._parser.formatter_class ) @@ -54,14 +183,28 @@ def _register(self): ) delete_parser.set_defaults(subfunc=self._delete) + # show show_parser = subparsers.add_parser( - "show", help="Show pool's instances", formatter_class=self._parser.formatter_class + "show", help="Show pool instances", formatter_class=self._parser.formatter_class ) show_parser.add_argument( "-n", "--name", dest="pool_name", help="The name of the pool", required=True ) show_parser.set_defaults(subfunc=self._show) + # add + add_parser = subparsers.add_parser( + "add", help="Add instance to pool", formatter_class=self._parser.formatter_class + ) + add_parser.add_argument( + "--pool", dest="pool_name", help="The name of the pool", required=True + ) + add_parser.add_argument( + "-y", "--yes", help="Don't ask for confirmation", action="store_true" + ) + add_parser.set_defaults(subfunc=self._add) + register_profile_args(add_parser) + def _list(self, args: argparse.Namespace): pools = self.api.client.pool.list(self.api.project) print_pool_table(pools, verbose=getattr(args, "verbose", False)) @@ -73,7 +216,35 @@ def _delete(self, args: argparse.Namespace): self.api.client.pool.delete(self.api.project, args.pool_name) def _show(self, args: argparse.Namespace): - self.api.client.pool.show(self.api.project, args.pool_name) + instances = self.api.client.pool.show(self.api.project, args.pool_name) + print_instance_table(instances) + + def _add(self, args: argparse.Namespace): + super()._command(args) + + repo = self.api.repos.load(Path.cwd()) + self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path + + profile = load_profile(Path.cwd(), args.profile) + apply_profile_args(args, profile) + + pool_name: str = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name + profile.pool_name = pool_name + + with console.status("Getting run plan..."): + requirements, offers = self.api.runs.get_offers(profile) + + print(pool_name, profile, requirements, offers) + print_offers_table(pool_name, profile, requirements, offers) + if not args.yes and not confirm_ask("Continue?"): + console.print("\nExiting...") + return + + try: + with console.status("Submitting run..."): + self.api.runs.create_instance(pool_name, profile) + except ServerClientError as e: + raise CLIError(e.msg) def _command(self, args: argparse.Namespace): super()._command(args) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 06bc08242..c2509be70 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -3,7 +3,7 @@ from rich.table import Table from dstack._internal.cli.utils.common import console -from dstack._internal.core.models.instances import InstanceAvailability, InstanceType, Resources +from dstack._internal.core.models.instances import InstanceAvailability, InstanceType from dstack._internal.core.models.runs import RunPlan from dstack._internal.utils.common import pretty_date from dstack.api import Run diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 4e2839710..88af9a5a5 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -11,6 +11,7 @@ from dstack._internal.core.backends.aws.config import AWSConfig from dstack._internal.core.backends.base.compute import ( Compute, + InstanceConfiguration, get_gateway_user_data, get_instance_name, get_user_data, @@ -27,6 +28,7 @@ LaunchedInstanceInfo, ) from dstack._internal.core.models.runs import Job, Requirements, Run +from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -95,6 +97,70 @@ def terminate_instance( else: raise e + def create_instance( + self, + project: ProjectModel, + user: UserModel, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> LaunchedInstanceInfo: + project_id = project.name + ec2 = self.session.resource("ec2", region_name=instance_offer.region) + ec2_client = self.session.client("ec2", region_name=instance_offer.region) + iam_client = self.session.client("iam", region_name=instance_offer.region) + + tags = [ + {"Key": "Name", "Value": run.run_spec.run_name}, + {"Key": "owner", "Value": "dstack"}, + {"Key": "dstack_project", "Value": project_id}, + {"Key": "dstack_user", "Value": run.user}, + ] + try: + disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + response = ec2.create_instances( + **aws_resources.create_instances_struct( + disk_size=disk_size, + image_id=aws_resources.get_image_id( + ec2_client, len(instance_offer.instance.resources.gpus) > 0 + ), + instance_type=instance_offer.instance.name, + iam_instance_profile_arn=aws_resources.create_iam_instance_profile( + iam_client, project_id + ), + user_data=get_user_data( + backend=BackendType.AWS, + image_name=job.job_spec.image_name, + authorized_keys=[ + run.run_spec.ssh_key_pub.strip(), + project_ssh_public_key.strip(), + ], + registry_auth_required=job.job_spec.registry_auth is not None, + ), + tags=tags, + security_group_id=aws_resources.create_security_group(ec2_client, project_id), + spot=instance_offer.instance.resources.spot, + ) + ) + instance = response[0] + instance.wait_until_running() + instance.reload() # populate instance.public_ip_address + + if instance_offer.instance.resources.spot: # it will not terminate the instance + ec2_client.cancel_spot_instance_requests( + SpotInstanceRequestIds=[instance.spot_instance_request_id] + ) + return LaunchedInstanceInfo( + instance_id=instance.instance_id, + ip_address=instance.public_ip_address, + region=instance_offer.region, + username="ubuntu", + ssh_port=22, + dockerized=True, # because `dstack-shim docker` is used + ) + except botocore.exceptions.ClientError as e: + logger.warning("Got botocore.exceptions.ClientError: %s", e) + raise NoCapacityError() + def run_job( self, run: Run, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index e51fd0999..c8e8c1b19 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -6,10 +6,12 @@ import git import requests import yaml +from pydantic import BaseModel from dstack import version from dstack._internal import settings from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import RegistryAuth from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedGatewayInfo, @@ -19,6 +21,28 @@ from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) +from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.services.docker import DockerImage + + +class SSHKeys(BaseModel): + public: str + private: Optional[str] + + +class DockerConfig(BaseModel): + registry_auth: Optional[RegistryAuth] + image: Optional[DockerImage] + + +class InstanceConfiguration(BaseModel): + pool_name: str + instance_name: str # unique in pool + ssh_keys: List[SSHKeys] + job_docker_config: Optional[DockerConfig] + + def get_public_keys(self): + return [ssh_key.public.strip() for ssh_key in self.ssh_keys] class Compute(ABC): @@ -39,6 +63,16 @@ def run_job( ) -> LaunchedInstanceInfo: pass + @abstractmethod + def create_instance( + self, + project: ProjectModel, + user: UserModel, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> LaunchedInstanceInfo: + pass + @abstractmethod def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 19b60fb96..0e0c3acc8 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional from dstack._internal.core.backends.base import Compute -from dstack._internal.core.backends.base.compute import get_shim_commands +from dstack._internal.core.backends.base.compute import InstanceConfiguration, get_shim_commands from dstack._internal.core.backends.base.offers import get_catalog_offers from dstack._internal.core.backends.datacrunch.api_client import DataCrunchAPIClient from dstack._internal.core.backends.datacrunch.config import DataCrunchConfig @@ -14,6 +14,7 @@ LaunchedInstanceInfo, ) from dstack._internal.core.models.runs import Job, Requirements, Run +from dstack._internal.server.models import ProjectModel, UserModel class DataCrunchCompute(Compute): @@ -59,6 +60,70 @@ def _get_offers_with_availability( return availability_offers + def create_instance( + self, + project: ProjectModel, + user: UserModel, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> LaunchedInstanceInfo: + public_keys = instance_config.get_public_keys() + ssh_ids = [] + for ssh_public_key in public_keys: + ssh_ids.append( + # datacrunch allows you to use the same name + self.api_client.get_or_create_ssh_key( + name=f"dstack-{instance_config.instance_name}.key", + public_key=ssh_public_key, + ) + ) + + registry_auth_required = instance_config.job_docker_config.registry_auth is not None + commands = get_shim_commands( + backend=BackendType.DATACRUNCH, + image_name=instance_config.job_docker_config.image.image, + authorized_keys=public_keys, + registry_auth_required=registry_auth_required, + ) + + startup_script = " ".join([" && ".join(commands)]) + script_name = f"dstack-{instance_config.instance_name}.sh" + startup_script_ids = self.api_client.get_or_create_startup_scrpit( + name=script_name, script=startup_script + ) + + # Id of image "Ubuntu 22.04 + CUDA 12.0 + Docker" + # from API https://datacrunch.stoplight.io/docs/datacrunch-public/c46ab45dbc508-get-all-image-types + image_name = "2088da25-bb0d-41cc-a191-dccae45d96fd" + + disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + instance = self.api_client.deploy_instance( + instance_type=instance_offer.instance.name, + ssh_key_ids=ssh_ids, + startup_script_id=startup_script_ids, + hostname=instance_config.instance_name, + description=instance_config.instance_name, + image=image_name, + disk_size=disk_size, + location=instance_offer.region, + ) + + running_instance = self.api_client.wait_for_instance(instance.id) + if running_instance is None: + raise BackendError(f"Wait instance {instance.id!r} timeout") + + launched_instance = LaunchedInstanceInfo( + instance_id=running_instance.id, + ip_address=running_instance.ip, + region=running_instance.location, + ssh_port=22, + username="root", + dockerized=True, + backend_data=None, + ) + + return launched_instance + def run_job( self, run: Run, diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 8942f44e1..15b0e6765 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -8,6 +8,7 @@ import dstack._internal.core.backends.gcp.resources as gcp_resources from dstack._internal.core.backends.base.compute import ( Compute, + InstanceConfiguration, get_gateway_user_data, get_instance_name, get_user_data, @@ -26,6 +27,8 @@ Resources, ) from dstack._internal.core.models.runs import Job, Requirements, Run +from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.utils.random_names import generate_name class GCPCompute(Compute): @@ -78,6 +81,77 @@ def terminate_instance( except google.api_core.exceptions.NotFound: pass + def create_instance( + self, + project: ProjectModel, + user: UserModel, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> LaunchedInstanceInfo: + + project_id = project.name + instance_name = instance_config.instance_name + + gcp_resources.create_runner_firewall_rules( + firewalls_client=self.firewalls_client, + project_id=self.config.project_id, + ) + disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + + registry_auth_required = instance_config.job_docker_config.registry_auth is not None + for zone in _get_instance_zones(instance_offer): + request = compute_v1.InsertInstanceRequest() + request.zone = zone + request.project = self.config.project_id + request.instance_resource = gcp_resources.create_instance_struct( + disk_size=disk_size, + image_id=gcp_resources.get_image_id( + len(instance_offer.instance.resources.gpus) > 0, + ), + machine_type=instance_offer.instance.name, + accelerators=gcp_resources.get_accelerators( + project_id=self.config.project_id, + zone=zone, + gpus=instance_offer.instance.resources.gpus, + ), + spot=instance_offer.instance.resources.spot, + user_data=get_user_data( + backend=BackendType.GCP, + image_name=instance_config.job_docker_config.image.image, + authorized_keys=instance_config.get_public_keys(), + registry_auth_required=registry_auth_required, + ), + labels={ + "owner": "dstack", + "dstack_project": project_id, + "dstack_user": user.name, + }, + tags=[gcp_resources.DSTACK_INSTANCE_TAG], + instance_name=instance_name, + zone=zone, + service_account=self.config.service_account_email, + ) + try: + operation = self.instances_client.insert(request=request) + gcp_resources.wait_for_extended_operation(operation, "instance creation") + except ( + google.api_core.exceptions.ServiceUnavailable, + google.api_core.exceptions.NotFound, + ): + continue + instance = self.instances_client.get( + project=self.config.project_id, zone=zone, instance=instance_name + ) + return LaunchedInstanceInfo( + instance_id=instance_name, + region=zone, + ip_address=instance.network_interfaces[0].access_configs[0].nat_i_p, + username="ubuntu", + ssh_port=22, + dockerized=True, + ) + raise NoCapacityError() + def run_job( self, run: Run, diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index d9a8ded72..68f750836 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -141,13 +141,13 @@ def __init__( id_rsa_path: PathLike, control_sock_path: Optional[str] = None, ): - self.temp_dir = tempfile.TemporaryDirectory() if not control_sock_path else None + if control_sock_path is None: + self.temp_dir = tempfile.TemporaryDirectory() + control_sock_path = os.path.join(self.temp_dir.name, "control.sock") super().__init__( host=host, id_rsa_path=id_rsa_path, ports=ports, - control_sock_path=os.path.join(self.temp_dir.name, "control.sock") - if not control_sock_path - else control_sock_path, + control_sock_path=control_sock_path, options={}, ) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 8f007fa61..4ca273c3c 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -3,6 +3,8 @@ from uuid import UUID import httpx +import requests +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -11,7 +13,14 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RegistryAuth from dstack._internal.core.models.repos import RemoteRepoCreds -from dstack._internal.core.models.runs import Job, JobErrorCode, JobStatus, Run +from dstack._internal.core.models.runs import ( + Job, + JobErrorCode, + JobProvisioningData, + JobSpec, + JobStatus, + Run, +) from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import ( GatewayModel, @@ -308,6 +317,7 @@ def _process_provisioning_with_shim( Returns: is successful """ + job_spec = parse_raw_as(JobSpec, job_model.job_spec_data) shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT]) resp = shim_client.healthcheck() if resp is None: @@ -319,6 +329,13 @@ def _process_provisioning_with_shim( shim_client.registry_auth( username=interpolate(registry_auth.username), password=interpolate(registry_auth.password), + image_name=job_spec.image_name, + ) + else: + shim_client.registry_auth( + username="", + password="", + image_name=job_spec.image_name, ) job_model.status = JobStatus.PULLING logger.info(*job_log("now is pulling", job_model)) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 8fd711944..af79ada11 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -107,6 +107,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): job_model.status = JobStatus.PROVISIONING im = InstanceModel( + name=job.job_spec.job_name, project=project_model, pool=pool, job_provisioning_data=job_provisioning_data.json(), @@ -137,12 +138,13 @@ async def _run_job( if run.run_spec.profile.backends is not None: backends = [b for b in backends if b.TYPE in run.run_spec.profile.backends] try: + requirements = job.job_spec.requirements offers = await backends_services.get_instance_offers( - backends, job, exclude_not_available=True + backends, requirements, exclude_not_available=True ) except BackendError as e: logger.warning(*job_log("failed to get instance offers: %s", job_model, repr(e))) - return None + return None # or (None, None)? for backend, offer in offers: logger.debug( *job_log( diff --git a/src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py b/src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py deleted file mode 100644 index 554829e55..000000000 --- a/src/dstack/_internal/server/migrations/versions/2943402e3b56_add_pools.py +++ /dev/null @@ -1,49 +0,0 @@ -"""add pools - -Revision ID: 2943402e3b56 -Revises: e6391ca6a264 -Create Date: 2023-12-13 14:02:25.106604 - -""" -import sqlalchemy as sa -import sqlalchemy_utils -from alembic import op - -# revision identifiers, used by Alembic. -revision = "2943402e3b56" -down_revision = "e6391ca6a264" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("projects", schema=None) as batch_op: - batch_op.add_column( - sa.Column( - "default_pool_id", - sqlalchemy_utils.types.uuid.UUIDType(binary=False), - nullable=True, - ) - ) - batch_op.create_foreign_key( - batch_op.f("fk_projects_default_pool_id_pools"), - "pools", - ["default_pool_id"], - ["id"], - ondelete="SET NULL", - use_alter=True, - ) - - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("projects", schema=None) as batch_op: - batch_op.drop_constraint( - batch_op.f("fk_projects_default_pool_id_pools"), type_="foreignkey" - ) - batch_op.drop_column("default_pool_id") - - # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py b/src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py new file mode 100644 index 000000000..713041d63 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py @@ -0,0 +1,89 @@ +"""add pool + +Revision ID: beceb9d2895d +Revises: 48ad3ecbaea2 +Create Date: 2023-12-25 07:11:15.778338 + +""" +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "beceb9d2895d" +down_revision = "48ad3ecbaea2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "pools", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=50), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_pools_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_pools")), + sa.UniqueConstraint("name", name=op.f("uq_pools_name")), + ) + op.create_table( + "instances", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("pool_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), + sa.Column("offer", sa.String(length=4000), nullable=False), + sa.ForeignKeyConstraint( + ["pool_id"], ["pools.id"], name=op.f("fk_instances_pool_id_pools") + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_instances_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_instances")), + ) + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "default_pool_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_projects_default_pool_id_pools"), + "pools", + ["default_pool_id"], + ["id"], + ondelete="SET NULL", + use_alter=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_projects_default_pool_id_pools"), type_="foreignkey" + ) + batch_op.drop_column("default_pool_id") + + op.drop_table("instances") + op.drop_table("pools") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index a0be7ca22..eb405b96a 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -261,6 +261,7 @@ class InstanceModel(BaseModel): UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + name: Mapped[str] = mapped_column(String(50)) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) diff --git a/src/dstack/_internal/server/routers/pool.py b/src/dstack/_internal/server/routers/pool.py index 6808d7e71..0548a1d91 100644 --- a/src/dstack/_internal/server/routers/pool.py +++ b/src/dstack/_internal/server/routers/pool.py @@ -39,7 +39,7 @@ async def create_pool( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): _, project = user_project - await pool.create_pool_model(name=body.name, session=session, project=project) + await pool.create_pool_model(session=session, project=project, name=body.name) @router.post("/show") @@ -49,4 +49,4 @@ async def how_pool( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): _, project = user_project - return await pool.show_pool(pool_name=body.name, session=session, project=project) + return await pool.show_pool(session, project, pool_name=body.name) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 7bab036c4..a50b90be1 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -1,13 +1,16 @@ -from typing import List, Optional, Tuple +from typing import List, Tuple from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import Run, RunPlan +from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.runs import Requirements, Run, RunPlan from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( + CreateInstanceRequest, DeleteRunsRequest, + GetOffersRequest, GetRunPlanRequest, GetRunRequest, ListRunsRequest, @@ -16,6 +19,7 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.services import runs +from dstack._internal.server.services.pool import generate_instance_name from dstack._internal.server.utils.routers import error_not_found root_router = APIRouter( @@ -59,6 +63,39 @@ async def get_run( return run +@project_router.post("/get_offers") +async def get_offers( + body: GetOffersRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> Tuple[Requirements, List[InstanceOfferWithAvailability]]: + _, project = user_project + reqs, offers = await runs.get_run_plan_by_requirements( + project=project, + profile=body.profile, + ) + return (reqs, [instance for _, instance in offers]) + + +@project_router.post("/create_instance") +async def create_instance( + body: CreateInstanceRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +): + user, project = user_project + instance_name = await generate_instance_name( + session=session, project=project, pool_name=body.pool_name + ) + await runs.create_instance( + project=project, + user=user, + pool_name=body.pool_name, + instance_name=instance_name, + profile=body.profile, + ) + + @project_router.post("/get_plan") async def get_run_plan( body: GetRunPlanRequest, diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 58b87c053..1c457e1ea 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -68,3 +68,4 @@ class HealthcheckResponse(BaseModel): class RegistryAuthBody(BaseModel): username: str password: str + image_name: str diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 38064a1db..8c03d7bda 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import RunSpec @@ -18,6 +19,15 @@ class GetRunPlanRequest(BaseModel): run_spec: RunSpec +class GetOffersRequest(BaseModel): + profile: Profile + + +class CreateInstanceRequest(BaseModel): + pool_name: str + profile: Profile + + class SubmitRunRequest(BaseModel): run_spec: RunSpec diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 0d1fd0e26..8d6ebde39 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -24,7 +24,7 @@ InstanceAvailability, InstanceOfferWithAvailability, ) -from dstack._internal.core.models.runs import Job +from dstack._internal.core.models.runs import Job, Requirements from dstack._internal.server.models import BackendModel, ProjectModel from dstack._internal.server.services.backends.configurators.base import Configurator from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED @@ -297,14 +297,12 @@ async def get_project_backend_model_by_type( async def get_instance_offers( - backends: List[Backend], job: Job, exclude_not_available: bool = False + backends: List[Backend], requirements: Requirements, exclude_not_available: bool = False ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: """ Returns list of instances satisfying minimal resource requirements sorted by price """ - tasks = [ - run_async(backend.compute().get_offers, job.job_spec.requirements) for backend in backends - ] + tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends] offers_by_backend = [ [ (backend, offer) diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py index db2687391..e16349df2 100644 --- a/src/dstack/_internal/server/services/docker.py +++ b/src/dstack/_internal/server/services/docker.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass from enum import Enum from typing import Optional import requests +from pydantic import BaseModel manifests_media_types = [ "application/vnd.oci.image.index.v1+json", @@ -12,8 +12,11 @@ ] -@dataclass(frozen=True) -class DockerImage: +class DockerImage(BaseModel): + class Config: + frozen = True + + image: str registry: Optional[str] repo: str tag: str @@ -115,7 +118,7 @@ def parse_image_name(image: str) -> DockerImage: registry = components[0] repo = "/".join(components[1:]) - return DockerImage(registry, repo, tag, digest) + return DockerImage(image=image, registry=registry, repo=repo, tag=tag, digest=digest) def is_host(s: str) -> bool: diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index e5dd34e01..beb5f9d5c 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -22,6 +22,16 @@ from dstack._internal.core.services.ssh.ports import filter_reserved_ports +def get_default_python_verison() -> str: + version_info = sys.version_info + return PythonVersion(f"{version_info.major}.{version_info.minor}").value + + +def get_default_image(python_version: str) -> str: + # TODO: non-cuda image + return f"dstackai/base:py{python_version}-{version.base_image}-cuda-12.1" + + class JobConfigurator(ABC): TYPE: ConfigurationType @@ -114,8 +124,7 @@ def _home_dir(self) -> Optional[str]: def _image_name(self) -> str: if self.run_spec.configuration.image is not None: return self.run_spec.configuration.image - # TODO: non-cuda image - return f"dstackai/base:py{self._python()}-{version.base_image}-cuda-12.1" + return get_default_image(self._python()) def _max_duration(self) -> Optional[int]: if self.run_spec.profile.max_duration is None: @@ -141,8 +150,7 @@ def _working_dir(self) -> str: def _python(self) -> str: if self.run_spec.configuration.python is not None: return self.run_spec.configuration.python.value - version_info = sys.version_info - return PythonVersion(f"{version_info.major}.{version_info.minor}").value + return get_default_python_verison() def _pool_name(self): return self.run_spec.profile.pool_name diff --git a/src/dstack/_internal/server/services/pool.py b/src/dstack/_internal/server/services/pool.py index 88299e0e2..85bb22a40 100644 --- a/src/dstack/_internal/server/services/pool.py +++ b/src/dstack/_internal/server/services/pool.py @@ -1,12 +1,18 @@ +import asyncio from datetime import timezone from typing import List, Optional, Sequence +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload -from dstack._internal.core.models.pool import Pool +from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.pool import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel +from dstack._internal.utils import random_names from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -15,7 +21,7 @@ async def list_project_pool(session: AsyncSession, project: ProjectModel) -> List[Pool]: pools = list(await list_project_pool_models(session=session, project=project)) if not pools: - pool = await create_pool_model(DEFAULT_POOL_NAME, session, project) + pool = await create_pool_model(session, project, DEFAULT_POOL_NAME) pools.append(pool) return [pool_model_to_pool(p) for p in pools] @@ -28,7 +34,7 @@ def pool_model_to_pool(pool_model: PoolModel) -> Pool: ) -async def create_pool_model(name: str, session: AsyncSession, project: ProjectModel) -> PoolModel: +async def create_pool_model(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel: pool = PoolModel( name=name, project_id=project.id, @@ -44,7 +50,7 @@ async def list_project_pool_models( session: AsyncSession, project: ProjectModel ) -> Sequence[PoolModel]: pools = await session.execute(select(PoolModel).where(PoolModel.project_id == project.id)) - return pools.scalars().all() + return pools.unique().scalars().all() async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str): @@ -66,16 +72,64 @@ async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: s if default_pool is not None: project.default_pool = default_pool else: - await create_pool_model(DEFAULT_POOL_NAME, session, project) + await create_pool_model(session, project, DEFAULT_POOL_NAME) await session.commit() +def instance_model_to_instance(instance_model: InstanceModel) -> Instance: + offer: InstanceOfferWithAvailability = parse_raw_as( + InstanceOfferWithAvailability, instance_model.offer + ) + jpd: JobProvisioningData = parse_raw_as( + JobProvisioningData, instance_model.job_provisioning_data + ) + + instance = Instance( + backend=offer.backend, + instance_id=jpd.instance_id, + instance_type=jpd.instance_type, + hostname=jpd.hostname, + price=offer.price, + ) + return instance + + async def show_pool( - pool_name: str, session: AsyncSession, project: ProjectModel -) -> Sequence[InstanceModel]: + session: AsyncSession, project: ProjectModel, pool_name: str +) -> Sequence[Instance]: pools_result = await session.execute(select(PoolModel).where(PoolModel.name == pool_name)) pools = pools_result.scalars().all() - instances = pools[0].instances + instances = [instance_model_to_instance(i) for i in pools[0].instances] return instances + + +async def get_pool_instances(session: AsyncSession, pool_name: str) -> List[InstanceModel]: + res = await session.execute( + select(PoolModel) + .where(PoolModel.name == pool_name) + .options(joinedload(PoolModel.instances)) + ) + result = res.unique().scalars().one_or_none() + if result is None: + return [] + return result.instances + + +_GENERATE_POOL_NAME_LOCK = {} + + +async def generate_instance_name( + session: AsyncSession, + project: ProjectModel, + pool_name: str, +) -> str: + lock = _GENERATE_POOL_NAME_LOCK.setdefault(project.name, asyncio.Lock()) + async with lock: + pool_instances: List[InstanceModel] = await get_pool_instances(session, pool_name) + names = {g.name for g in pool_instances} + while True: + name = f"{random_names.generate_name()}" + if name not in names: + return name diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 7520c67e9..992f65a3f 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -99,10 +99,12 @@ def healthcheck(self) -> Optional[HealthcheckResponse]: except requests.exceptions.RequestException: return None - def registry_auth(self, username: str, password: str): + def registry_auth(self, username: str, password: str, image_name: str): resp = requests.post( self._url("/api/registry_auth"), - json=RegistryAuthBody(username=username, password=password).dict(), + json=RegistryAuthBody( + username=username, password=password, image_name=image_name + ).dict(), ) resp.raise_for_status() diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index ba7a486ce..62d3a879d 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -10,6 +10,7 @@ from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.core.services.ssh.tunnel import RunnerTunnel from dstack._internal.server.services.jobs import get_runner_ports +from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -33,6 +34,11 @@ def wrapper( Returns: is successful """ + + if LOCAL_BACKEND_ENABLED: + port_map = {p: p for p in ports} + return func(*args, ports=port_map, **kwargs) + func_kwargs_names = [ p.name for p in inspect.signature(func).parameters.values() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index c96240208..bbd44abea 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -2,7 +2,7 @@ import itertools import uuid from datetime import timezone -from typing import List, Optional +from typing import List, Optional, Tuple import pydantic from sqlalchemy import select, update @@ -11,13 +11,27 @@ import dstack._internal.server.services.gateways as gateways import dstack._internal.utils.common as common_utils -from dstack._internal.core.errors import RepoDoesNotExistError, ServerClientError +from dstack._internal.core.backends.base import Backend +from dstack._internal.core.backends.base.compute import ( + DockerConfig, + InstanceConfiguration, + SSHKeys, +) +from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError +from dstack._internal.core.models.instances import ( + InstanceOfferWithAvailability, + LaunchedInstanceInfo, +) +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy from dstack._internal.core.models.runs import ( + GpusRequirements, Job, JobPlan, + JobProvisioningData, JobSpec, JobStatus, JobSubmission, + Requirements, Run, RunPlan, RunSpec, @@ -28,13 +42,19 @@ from dstack._internal.server.models import JobModel, PoolModel, ProjectModel, RunModel, UserModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import repos as repos_services +from dstack._internal.server.services.docker import parse_image_name from dstack._internal.server.services.jobs import ( get_jobs_from_run_spec, job_model_to_job_submission, stop_job, ) +from dstack._internal.server.services.jobs.configurators.base import ( + get_default_image, + get_default_python_verison, +) from dstack._internal.server.services.pool import create_pool_model from dstack._internal.server.services.projects import list_project_models, list_user_project_models +from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger from dstack._internal.utils.random_names import generate_name @@ -119,6 +139,103 @@ async def get_run( return run_model_to_run(run_model) +async def get_run_plan_by_requirements( + project: ProjectModel, profile: Profile +) -> Tuple[Requirements, List[Tuple[Backend, InstanceOfferWithAvailability]]]: + backends = await backends_services.get_project_backends(project=project) + if profile.backends is not None: + backends = [b for b in backends if b.TYPE in profile.backends] + + spot_policy = profile.spot_policy or SpotPolicy.AUTO # TODO: improve + requirements = Requirements( + cpus=profile.resources.cpu, + memory_mib=profile.resources.memory, + gpus=None, + shm_size_mib=profile.resources.shm_size, + max_price=profile.max_price, + spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), + ) + if profile.resources.gpu: + requirements.gpus = GpusRequirements( + count=profile.resources.gpu.count, + memory_mib=profile.resources.gpu.memory, + name=profile.resources.gpu.name, + total_memory_mib=profile.resources.gpu.total_memory, + compute_capability=profile.resources.gpu.compute_capability, + ) + + offers = await backends_services.get_instance_offers( + backends=backends, + requirements=requirements, + exclude_not_available=False, + ) + + return requirements, offers + + +async def create_instance( + project: ProjectModel, user: UserModel, pool_name: str, instance_name: str, profile: Profile +): + _, offers = await get_run_plan_by_requirements(project, profile) + + ssh_key = SSHKeys( + public=project.ssh_public_key.strip(), + private=project.ssh_private_key.strip(), + ) + instance_config = InstanceConfiguration( + instance_name=instance_name, + pool_name=pool_name, + ssh_keys=[ssh_key], + job_docker_config=DockerConfig( + image=parse_image_name(get_default_image(get_default_python_verison())), + registry_auth=None, + ), + ) + + for backend, instance_offer in offers: + + logger.debug( + "trying %s in %s/%s for $%0.4f per hour", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + instance_offer.price, + ) + try: + launched_instance_info: LaunchedInstanceInfo = await run_async( + backend.compute().create_instance, + project, + user, + instance_offer, + instance_config, + ) + except BackendError as e: + logger.warning( + "%s launch in %s/%s failed: %s", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + repr(e), + ) + continue + else: + job_provisioning_data = JobProvisioningData( + backend=backend.TYPE, + instance_type=instance_offer.instance, + instance_id=launched_instance_info.instance_id, + hostname=launched_instance_info.ip_address, + region=launched_instance_info.region, + price=instance_offer.price, + username=launched_instance_info.username, + ssh_port=launched_instance_info.ssh_port, + dockerized=launched_instance_info.dockerized, + backend_data=launched_instance_info.backend_data, + ) + + return (job_provisioning_data, instance_offer) + return (None, None) + + async def get_run_plan( session: AsyncSession, project: ProjectModel, @@ -134,9 +251,10 @@ async def get_run_plan( job_plans = [] for job in jobs: # TODO: use the job.pool_name to select an offer + requirements = job.job_spec.requirements offers = await backends_services.get_instance_offers( backends=backends, - job=job, + requirements=requirements, exclude_not_available=False, ) for backend, offer in offers: @@ -180,11 +298,13 @@ async def submit_run( else: await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) - pool_name = run_spec.profile.pool_name + pool_name = ( + DEFAULT_POOL_NAME if run_spec.profile.pool_name is None else run_spec.profile.pool_name + ) pools_result = await session.execute(select(PoolModel).where(PoolModel.name == pool_name)) pools = pools_result.scalars().all() if not pools: - await create_pool_model(name=pool_name, session=session, project=project) + await create_pool_model(session=session, project=project, name=pool_name) run_model = RunModel( id=uuid.uuid4(), diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 6adc7680a..322966087 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -7,7 +7,7 @@ from copy import copy from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import requests from websocket import WebSocketApp @@ -16,11 +16,13 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration +from dstack._internal.core.models.instances import InstanceOfferWithAvailability from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy, SpotPolicy from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import JobSpec from dstack._internal.core.models.runs import JobStatus as RunStatus +from dstack._internal.core.models.runs import Requirements from dstack._internal.core.models.runs import Run as RunModel from dstack._internal.core.models.runs import RunPlan, RunSpec from dstack._internal.core.services.logs import URLReplacer @@ -357,6 +359,14 @@ def submit( ) return self.exec_plan(run_plan, repo, reserve_ports=reserve_ports) + def get_offers( + self, profile: Profile + ) -> Tuple[Requirements, List[InstanceOfferWithAvailability]]: + return self._api_client.runs.get_offers(self._project, profile) + + def create_instance(self, pool_name: str, profile: Profile): + self._api_client.runs.create_instance(self._project, pool_name, profile) + def get_plan( self, configuration: AnyRunConfiguration, diff --git a/src/dstack/api/server/_backends.py b/src/dstack/api/server/_backends.py index 85de8e7d9..2d2136abb 100644 --- a/src/dstack/api/server/_backends.py +++ b/src/dstack/api/server/_backends.py @@ -21,11 +21,15 @@ def config_values(self, config: AnyConfigInfoWithCredsPartial) -> AnyConfigValue resp = self._request("/api/backends/config_values", body=config.json()) return parse_obj_as(AnyConfigValues, resp.json()) - def create(self, project_name: str, config: AnyConfigInfoWithCreds) -> AnyConfigInfoWithCreds: + def create( + self, project_name: str, config: AnyConfigInfoWithCreds + ) -> AnyConfigInfoWithCredsPartial: resp = self._request(f"/api/project/{project_name}/backends/create", body=config.json()) return parse_obj_as(AnyConfigInfoWithCredsPartial, resp.json()) - def update(self, project_name: str, config: AnyConfigInfoWithCreds) -> AnyConfigInfoWithCreds: + def update( + self, project_name: str, config: AnyConfigInfoWithCreds + ) -> AnyConfigInfoWithCredsPartial: resp = self._request(f"/api/project/{project_name}/backends/update", body=config.json()) return parse_obj_as(AnyConfigInfoWithCredsPartial, resp.json()) diff --git a/src/dstack/api/server/_pool.py b/src/dstack/api/server/_pool.py index ea1b4ccf0..4b3ae9b03 100644 --- a/src/dstack/api/server/_pool.py +++ b/src/dstack/api/server/_pool.py @@ -3,7 +3,7 @@ from pydantic import parse_obj_as import dstack._internal.server.schemas.pool as pool_schemas -from dstack._internal.core.models.pool import Pool +from dstack._internal.core.models.pool import Instance, Pool from dstack.api.server._group import APIClientGroup @@ -20,6 +20,7 @@ def create(self, project_name: str, pool_name: str) -> None: body = pool_schemas.CreatePoolRequest(name=pool_name) self._request(f"/api/project/{project_name}/pool/create", body=body.json()) - def show(self, project_name: str, pool_name: str) -> None: + def show(self, project_name: str, pool_name: str) -> List[Instance]: body = pool_schemas.ShowPoolRequest(name=pool_name) - self._request(f"/api/project/{project_name}/pool/show", body=body.json()) + resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) + return parse_obj_as(List[Instance], resp.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 107685a83..202fc09b0 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -1,10 +1,14 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from pydantic import parse_obj_as -from dstack._internal.core.models.runs import Run, RunPlan, RunSpec +from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( + CreateInstanceRequest, DeleteRunsRequest, + GetOffersRequest, GetRunPlanRequest, GetRunRequest, ListRunsRequest, @@ -25,6 +29,17 @@ def get(self, project_name: str, run_name: str) -> Run: resp = self._request(f"/api/project/{project_name}/runs/get", body=body.json()) return parse_obj_as(Run, resp.json()) + def get_offers( + self, project_name: str, profile: Profile + ) -> Tuple[Requirements, List[InstanceOfferWithAvailability]]: + body = GetOffersRequest(profile=profile) + resp = self._request(f"/api/project/{project_name}/runs/get_offers", body=body.json()) + return parse_obj_as(Tuple[Requirements, List[InstanceOfferWithAvailability]], resp.json()) + + def create_instance(self, project_name: str, pool_name: str, profile: Profile): + body = CreateInstanceRequest(pool_name=pool_name, profile=profile) + self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) + def get_plan(self, project_name: str, run_spec: RunSpec) -> RunPlan: body = GetRunPlanRequest(run_spec=run_spec) resp = self._request(f"/api/project/{project_name}/runs/get_plan", body=body.json()) diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 29bd6a272..af825e7b7 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -214,7 +214,9 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): await process_running_jobs() RunnerTunnelMock.assert_called_once() ShimClientMock.return_value.healthcheck.assert_called_once() - ShimClientMock.return_value.registry_auth.assert_not_called() + ShimClientMock.return_value.registry_auth.assert_called_once_with( + username="", password="", image_name="dstackai/base:py3.11-0.4rc4-cuda-12.1" + ) await session.refresh(job) assert job is not None assert job.status == JobStatus.PULLING diff --git a/src/tests/_internal/server/services/test_pool.py b/src/tests/_internal/server/services/test_pool.py new file mode 100644 index 000000000..5a23c85ee --- /dev/null +++ b/src/tests/_internal/server/services/test_pool.py @@ -0,0 +1,181 @@ +import datetime as dt +import uuid + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceType, Resources +from dstack._internal.core.models.pool import Instance, Pool +from dstack._internal.server.models import InstanceModel, PoolModel +from dstack._internal.server.services import pool as services_pool +from dstack._internal.server.testing.common import create_project, create_user + + +@pytest.mark.asyncio +async def test_pool(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pool.create_pool_model( + session=session, project=project, name="test_pool" + ) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + job_provisioning_data="", + offer="", + ) + session.add(im) + await session.commit() + + core_model_pool = services_pool.pool_model_to_pool(pool) + assert core_model_pool == Pool(name="test_pool", default=True, created_at=pool.created_at) + + list_pools = await services_pool.list_project_pool(session=session, project=project) + assert list_pools == [services_pool.pool_model_to_pool(pool)] + + list_pool_models = await services_pool.list_project_pool_models( + session=session, project=project + ) + assert len(list_pool_models) == 1 + + pool_intances = await services_pool.get_pool_instances(session=session, pool_name="test_pool") + assert pool_intances == [im] + + +def test_convert_instance(): + expected_instance = Instance( + backend=BackendType.LOCAL, + instance_type=InstanceType( + name="instance", resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]) + ), + instance_id="test_instance", + hostname="hostname_test", + price=1.0, + ) + + im = InstanceModel( + id=str(uuid.uuid4()), + created_at=dt.datetime.now(), + name="test_instance", + project_id=str(uuid.uuid4()), + pool=None, + job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + ) + + instance = services_pool.instance_model_to_instance(im) + assert instance == expected_instance + + +@pytest.mark.asyncio +async def test_delete_pool(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pool.create_pool_model( + session=session, project=project, name="test_pool" + ) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + job_provisioning_data="", + offer="", + ) + session.add(im) + await session.commit() + + await services_pool.delete_pool(session=session, project=project, pool_name="test_pool") + + +# async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str): +# """delete the pool and set the default pool to project""" + +# default_pool: Optional[PoolModel] = None +# default_pool_removed = False + +# for pool in await list_project_pool_models(session=session, project=project): +# if pool.name == DEFAULT_POOL_NAME: +# default_pool = pool + +# if pool_name == pool.name: +# if project.default_pool_id == pool.id: +# default_pool_removed = True +# await session.delete(pool) + +# if default_pool_removed: +# if default_pool is not None: +# project.default_pool = default_pool +# else: +# await create_pool_model(session, project, DEFAULT_POOL_NAME) + +# await session.commit() + + +@pytest.mark.asyncio +async def test_show_pool(session: AsyncSession, test_db): + POOL_NAME = "test_pool" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pool.create_pool_model(session=session, project=project, name=POOL_NAME) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + ) + session.add(im) + await session.commit() + + instances = await services_pool.show_pool(session, project, POOL_NAME) + assert len(instances) == 1 + + +@pytest.mark.asyncio +async def test_get_pool_instances(session: AsyncSession, test_db): + POOL_NAME = "test_pool" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pool.create_pool_model(session=session, project=project, name=POOL_NAME) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + ) + session.add(im) + await session.commit() + + instances = await services_pool.get_pool_instances(session, POOL_NAME) + assert len(instances) == 1 + + empty_instances = await services_pool.get_pool_instances(session, f"{POOL_NAME}-0") + assert len(empty_instances) == 0 + + +@pytest.mark.asyncio +async def test_generate_instance_name(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pool.create_pool_model( + session=session, project=project, name="test_pool" + ) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + job_provisioning_data="", + offer="", + ) + session.add(im) + await session.commit() + + name = await services_pool.generate_instance_name( + session=session, project=project, pool_name="test_pool" + ) + car, _, cdr = name.partition("-") + assert len(car) > 0 + assert len(cdr) > 0 From 74c2c22ccc3737b2e1990c5e34d77d65c5980567 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 27 Dec 2023 11:31:17 +0300 Subject: [PATCH 03/47] merge pool logic --- .pre-commit-config.yaml | 1 - src/dstack/_internal/cli/commands/pool.py | 17 +++- .../core/models/{pool.py => pools.py} | 2 + src/dstack/_internal/core/models/runs.py | 23 ++++- src/dstack/_internal/server/app.py | 4 +- .../tasks/process_submitted_jobs.py | 6 +- ..._add_pool.py => a859c41ae3b7_add_pools.py} | 33 ++++++- src/dstack/_internal/server/models.py | 13 ++- .../server/routers/{pool.py => pools.py} | 43 +++++++-- src/dstack/_internal/server/routers/runs.py | 2 +- .../server/schemas/{pool.py => pools.py} | 1 + .../server/services/{pool.py => pools.py} | 37 ++++++-- .../_internal/server/services/projects.py | 10 +- src/dstack/_internal/server/services/runs.py | 30 ++++-- src/dstack/api/_public/__init__.py | 2 +- src/dstack/api/_public/{pool.py => pools.py} | 0 src/dstack/api/server/__init__.py | 2 +- src/dstack/api/server/{_pool.py => _pools.py} | 12 +-- .../services/{test_pool.py => test_pools.py} | 91 ++++++++----------- 19 files changed, 224 insertions(+), 105 deletions(-) rename src/dstack/_internal/core/models/{pool.py => pools.py} (83%) rename src/dstack/_internal/server/migrations/versions/{beceb9d2895d_add_pool.py => a859c41ae3b7_add_pools.py} (73%) rename src/dstack/_internal/server/routers/{pool.py => pools.py} (51%) rename src/dstack/_internal/server/schemas/{pool.py => pools.py} (92%) rename src/dstack/_internal/server/services/{pool.py => pools.py} (79%) rename src/dstack/api/_public/{pool.py => pools.py} (100%) rename src/dstack/api/server/{_pool.py => _pools.py} (66%) rename src/tests/_internal/server/services/{test_pool.py => test_pools.py} (66%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 860da2c55..d0fb5afc5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,6 @@ repos: rev: 22.12.0 hooks: - id: black - language_version: python3.10 args: ['--config', 'pyconfig.toml'] - repo: https://github.com/pycqa/isort rev: 5.12.0 diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 856e95c9c..76fd30221 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -15,7 +15,7 @@ InstanceAvailability, InstanceOfferWithAvailability, ) -from dstack._internal.core.models.pool import Instance, Pool +from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager @@ -162,7 +162,10 @@ def _register(self): # list list_parser = subparsers.add_parser( - "list", help="List pools", formatter_class=self._parser.formatter_class + "list", + help="List pools", + description="List available pools", + formatter_class=self._parser.formatter_class, ) list_parser.add_argument("-v", "--verbose", help="Show more information") list_parser.set_defaults(subfunc=self._list) @@ -181,11 +184,17 @@ def _register(self): delete_parser.add_argument( "-n", "--name", dest="pool_name", help="The name of the pool", required=True ) + delete_parser.add_argument( + "-f", "--force", dest="force", help="Force remove", type=bool, default=False + ) delete_parser.set_defaults(subfunc=self._delete) # show show_parser = subparsers.add_parser( - "show", help="Show pool instances", formatter_class=self._parser.formatter_class + "show", + help="Show pool instances", + description="Show instances in the pool", + formatter_class=self._parser.formatter_class, ) show_parser.add_argument( "-n", "--name", dest="pool_name", help="The name of the pool", required=True @@ -213,7 +222,7 @@ def _create(self, args: argparse.Namespace): self.api.client.pool.create(self.api.project, args.pool_name) def _delete(self, args: argparse.Namespace): - self.api.client.pool.delete(self.api.project, args.pool_name) + self.api.client.pool.delete(self.api.project, args.pool_name, args.force) def _show(self, args: argparse.Namespace): instances = self.api.client.pool.show(self.api.project, args.pool_name) diff --git a/src/dstack/_internal/core/models/pool.py b/src/dstack/_internal/core/models/pools.py similarity index 83% rename from src/dstack/_internal/core/models/pool.py rename to src/dstack/_internal/core/models/pools.py index 315ca0171..80ecd3961 100644 --- a/src/dstack/_internal/core/models/pool.py +++ b/src/dstack/_internal/core/models/pools.py @@ -5,6 +5,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType +from dstack._internal.core.models.runs import InstanceStatus class Pool(BaseModel): @@ -18,4 +19,5 @@ class Instance(BaseModel): instance_type: InstanceType instance_id: str hostname: str + status: InstanceStatus price: float diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index d1cd10052..d5edc3ef4 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from enum import Enum -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple from pydantic import UUID4, BaseModel, Field from typing_extensions import Annotated, Literal @@ -217,3 +217,24 @@ class RunPlan(BaseModel): user: str run_spec: RunSpec job_plans: List[JobPlan] + + +class InstanceStatus(str, Enum): + PENDING = "pending" + CREATING = "creating" + STARTING = "starting" + READY = "ready" + BUSY = "busy" + TERMINATING = "terminating" + TERMINATED = "terminated" + FAILED = "failed" + + @property + def finished_statuses(cls) -> Sequence["InstanceStatus"]: + return (cls.TERMINATED, cls.FAILED) + + def is_finished(self): + return self in self.finished_statuses + + def is_started(self): + return not self.is_finished() diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index c589b1bd6..da741130d 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -17,7 +17,7 @@ backends, gateways, logs, - pool, + pools, projects, repos, runs, @@ -131,7 +131,7 @@ def add_no_api_version_check_routes(paths: List[str]): def register_routes(app: FastAPI): app.include_router(users.router) app.include_router(projects.router) - app.include_router(pool.router) + app.include_router(pools.router) app.include_router(backends.root_router) app.include_router(backends.project_router) app.include_router(repos.router) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index af79ada11..ce51fa6c7 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -13,6 +13,7 @@ ) from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME from dstack._internal.core.models.runs import ( + InstanceStatus, Job, JobErrorCode, JobProvisioningData, @@ -101,7 +102,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): project_ssh_public_key=project_model.ssh_public_key, project_ssh_private_key=project_model.ssh_private_key, ) - if job_provisioning_data is not None: + if job_provisioning_data is not None and offer is not None: logger.info(*job_log("now is provisioning", job_model)) job_model.job_provisioning_data = job_provisioning_data.json() job_model.status = JobStatus.PROVISIONING @@ -110,6 +111,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): name=job.job_spec.job_name, project=project_model, pool=pool, + status=InstanceStatus.PENDING, job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), ) @@ -144,7 +146,7 @@ async def _run_job( ) except BackendError as e: logger.warning(*job_log("failed to get instance offers: %s", job_model, repr(e))) - return None # or (None, None)? + return (None, None) for backend, offer in offers: logger.debug( *job_log( diff --git a/src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py b/src/dstack/_internal/server/migrations/versions/a859c41ae3b7_add_pools.py similarity index 73% rename from src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py rename to src/dstack/_internal/server/migrations/versions/a859c41ae3b7_add_pools.py index 713041d63..c8c86e30d 100644 --- a/src/dstack/_internal/server/migrations/versions/beceb9d2895d_add_pool.py +++ b/src/dstack/_internal/server/migrations/versions/a859c41ae3b7_add_pools.py @@ -1,8 +1,8 @@ -"""add pool +"""add pools -Revision ID: beceb9d2895d +Revision ID: a859c41ae3b7 Revises: 48ad3ecbaea2 -Create Date: 2023-12-25 07:11:15.778338 +Create Date: 2023-12-28 10:13:48.608439 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "beceb9d2895d" +revision = "a859c41ae3b7" down_revision = "48ad3ecbaea2" branch_labels = None depends_on = None @@ -23,6 +23,8 @@ def upgrade() -> None: sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), sa.Column("name", sa.String(length=50), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), sa.Column( "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False ), @@ -38,11 +40,34 @@ def upgrade() -> None: op.create_table( "instances", sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=50), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), sa.Column( "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False ), sa.Column("pool_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.Column( + "status", + sa.Enum( + "PENDING", + "SUBMITTED", + "PROVISIONING", + "PULLING", + "RUNNING", + "TERMINATING", + "TERMINATED", + "ABORTED", + "FAILED", + "DONE", + name="jobstatus", + ), + nullable=False, + ), + sa.Column("status_message", sa.String(length=50), nullable=True), + sa.Column("started_at", sa.DateTime(), nullable=True), + sa.Column("finished_at", sa.DateTime(), nullable=True), sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), sa.Column("offer", sa.String(length=4000), nullable=False), sa.ForeignKeyConstraint( diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index eb405b96a..c25da2f1c 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -20,7 +20,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.repos.base import RepoType -from dstack._internal.core.models.runs import JobErrorCode, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobErrorCode, JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.utils.common import get_current_datetime @@ -247,6 +247,8 @@ class PoolModel(BaseModel): ) name: Mapped[str] = mapped_column(String(50), unique=True) created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) @@ -260,8 +262,10 @@ class InstanceModel(BaseModel): id: Mapped[uuid.UUID] = mapped_column( UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) - created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) name: Mapped[str] = mapped_column(String(50)) + created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) @@ -269,5 +273,10 @@ class InstanceModel(BaseModel): pool_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("pools.id")) pool: Mapped["PoolModel"] = relationship(back_populates="instances") + status: Mapped[InstanceStatus] = mapped_column(Enum(JobStatus)) + status_message: Mapped[Optional[str]] = mapped_column(String(50)) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) + job_provisioning_data: Mapped[str] = mapped_column(String(4000)) offer: Mapped[str] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/routers/pool.py b/src/dstack/_internal/server/routers/pools.py similarity index 51% rename from src/dstack/_internal/server/routers/pool.py rename to src/dstack/_internal/server/routers/pools.py index 0548a1d91..fe17ce386 100644 --- a/src/dstack/_internal/server/routers/pool.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -3,12 +3,17 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -import dstack._internal.core.models.pool as models -import dstack._internal.server.schemas.pool as schemas -import dstack._internal.server.services.pool as pool +import dstack._internal.core.models.pools as models +import dstack._internal.server.schemas.pools as schemas +import dstack._internal.server.services.pools as pools from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember +from dstack._internal.server.services.runs import ( + abort_runs_of_pool, + list_project_runs, + run_model_to_run, +) router = APIRouter(prefix="/api/project/{project_name}/pool", tags=["pool"]) @@ -19,7 +24,7 @@ async def list_pool( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> List[models.Pool]: _, project = user_project - return await pool.list_project_pool(session=session, project=project) + return await pools.list_project_pool(session=session, project=project) @router.post("/delete") @@ -28,8 +33,30 @@ async def delete_pool( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project - await pool.delete_pool(session=session, project=project, pool_name=body.name) + pool_name = body.name + _, project_model = user_project + + if body.force: + await abort_runs_of_pool(session, project_model, pool_name) + await pools.delete_pool(session, project_model, pool_name) + return + + # check active runs + runs = await list_project_runs(session, project_model, repo_id=None) + active_runs = [] + for run_model in runs: + if run_model.status.is_finished(): + continue + run = run_model_to_run(run_model) + run_pool_name = run.run_spec.profile.pool_name + if run_pool_name == pool_name: + active_runs.append(run) + if active_runs: + return + + # TODO: check active instances + + await pools.delete_pool(session, project_model, pool_name) @router.post("/create") @@ -39,7 +66,7 @@ async def create_pool( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): _, project = user_project - await pool.create_pool_model(session=session, project=project, name=body.name) + await pools.create_pool_model(session=session, project=project, name=body.name) @router.post("/show") @@ -49,4 +76,4 @@ async def how_pool( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): _, project = user_project - return await pool.show_pool(session, project, pool_name=body.name) + return await pools.show_pool(session, project, pool_name=body.name) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index a50b90be1..d80af86be 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -19,7 +19,7 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.services import runs -from dstack._internal.server.services.pool import generate_instance_name +from dstack._internal.server.services.pools import generate_instance_name from dstack._internal.server.utils.routers import error_not_found root_router = APIRouter( diff --git a/src/dstack/_internal/server/schemas/pool.py b/src/dstack/_internal/server/schemas/pools.py similarity index 92% rename from src/dstack/_internal/server/schemas/pool.py rename to src/dstack/_internal/server/schemas/pools.py index ade9c1e88..902750c54 100644 --- a/src/dstack/_internal/server/schemas/pool.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -3,6 +3,7 @@ class DeletePoolRequest(BaseModel): name: str + force: bool class CreatePoolRequest(BaseModel): diff --git a/src/dstack/_internal/server/services/pool.py b/src/dstack/_internal/server/services/pools.py similarity index 79% rename from src/dstack/_internal/server/services/pool.py rename to src/dstack/_internal/server/services/pools.py index 85bb22a40..c3ff15a03 100644 --- a/src/dstack/_internal/server/services/pool.py +++ b/src/dstack/_internal/server/services/pools.py @@ -8,11 +8,12 @@ from sqlalchemy.orm import joinedload from dstack._internal.core.models.instances import InstanceOfferWithAvailability -from dstack._internal.core.models.pool import Instance, Pool +from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel from dstack._internal.utils import random_names +from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -41,7 +42,7 @@ async def create_pool_model(session: AsyncSession, project: ProjectModel, name: ) session.add(pool) await session.commit() - project.default_pool = pool + project.default_pool = pool # TODO: add CLI flag --set-default await session.commit() return pool @@ -49,8 +50,10 @@ async def create_pool_model(session: AsyncSession, project: ProjectModel, name: async def list_project_pool_models( session: AsyncSession, project: ProjectModel ) -> Sequence[PoolModel]: - pools = await session.execute(select(PoolModel).where(PoolModel.project_id == project.id)) - return pools.unique().scalars().all() + pools = await session.scalars( + select(PoolModel).where(PoolModel.project_id == project.id, PoolModel.deleted == False) + ) + return pools.all() async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str): @@ -59,14 +62,15 @@ async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: s default_pool: Optional[PoolModel] = None default_pool_removed = False - for pool in await list_project_pool_models(session=session, project=project): + for pool in await list_project_pool_models(session, project): if pool.name == DEFAULT_POOL_NAME: default_pool = pool if pool_name == pool.name: if project.default_pool_id == pool.id: default_pool_removed = True - await session.delete(pool) + pool.deleted = True + pool.deleted_at = get_current_datetime() if default_pool_removed: if default_pool is not None: @@ -77,6 +81,17 @@ async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: s await session.commit() +async def list_deleted_pools( + session: AsyncSession, project_model: ProjectModel +) -> Sequence[PoolModel]: + pools = await session.scalars( + select(PoolModel).where( + PoolModel.project_id == project_model.id, PoolModel.deleted == True + ) + ) + return pools.all() + + def instance_model_to_instance(instance_model: InstanceModel) -> Instance: offer: InstanceOfferWithAvailability = parse_raw_as( InstanceOfferWithAvailability, instance_model.offer @@ -90,6 +105,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: instance_id=jpd.instance_id, instance_type=jpd.instance_type, hostname=jpd.hostname, + status=instance_model.status, price=offer.price, ) return instance @@ -98,8 +114,13 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: async def show_pool( session: AsyncSession, project: ProjectModel, pool_name: str ) -> Sequence[Instance]: - pools_result = await session.execute(select(PoolModel).where(PoolModel.name == pool_name)) - pools = pools_result.scalars().all() + pools = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, PoolModel.project_id == project.id + ) + ) + ).all() instances = [instance_model_to_instance(i) for i in pools[0].instances] return instances diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 386a84449..05f4e8330 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -78,11 +78,13 @@ async def create_project(session: AsyncSession, user: UserModel, project_name: s user=user, project_role=ProjectRole.ADMIN, ) - project = await get_project_model_by_name_or_error(session=session, project_name=project_name) + project_model = await get_project_model_by_name_or_error( + session=session, project_name=project_name + ) for hook in _CREATE_PROJECT_HOOKS: - await hook(session, project) - await session.refresh(project) # a hook may change project - return project_model_to_project(project) + await hook(session, project_model) + await session.refresh(project_model) # a hook may change project + return project_model_to_project(project_model) async def delete_projects( diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index bbd44abea..fde727146 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -1,5 +1,6 @@ import asyncio import itertools +import math import uuid from datetime import timezone from typing import List, Optional, Tuple @@ -52,7 +53,7 @@ get_default_image, get_default_python_verison, ) -from dstack._internal.server.services.pool import create_pool_model +from dstack._internal.server.services.pools import create_pool_model, list_project_pool, show_pool from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -80,7 +81,7 @@ async def list_user_runs( project=project, repo_id=repo_id, ) - runs.extend(project_runs) + runs.extend(map(run_model_to_run, project_runs)) return sorted(runs, key=lambda r: r.submitted_at, reverse=True) @@ -88,7 +89,7 @@ async def list_project_runs( session: AsyncSession, project: ProjectModel, repo_id: Optional[str], -) -> List[Run]: +) -> List[RunModel]: filters = [ RunModel.project_id == project.id, RunModel.deleted == False, @@ -301,10 +302,11 @@ async def submit_run( pool_name = ( DEFAULT_POOL_NAME if run_spec.profile.pool_name is None else run_spec.profile.pool_name ) - pools_result = await session.execute(select(PoolModel).where(PoolModel.name == pool_name)) - pools = pools_result.scalars().all() + + # create pool + pools = (await session.scalars(select(PoolModel).where(PoolModel.name == pool_name))).all() if not pools: - await create_pool_model(session=session, project=project, name=pool_name) + await create_pool_model(session, project, pool_name) run_model = RunModel( id=uuid.uuid4(), @@ -321,6 +323,7 @@ async def submit_run( if run_spec.configuration.type == "service": await gateways.register_service_jobs(session, project, run_spec.run_name, jobs) for job in jobs: + job.job_spec.pool_name = pool_name job_model = create_job_model_for_new_submission( run_model=run_model, job=job, @@ -516,3 +519,18 @@ def _get_run_service(run: Run) -> Optional[ServiceInfo]: ), model=model, ) + + +async def abort_runs_of_pool(session: AsyncSession, project_model: ProjectModel, pool_name: str): + runs = await list_project_runs(session, project_model, repo_id=None) + active_run_names = [] + for run_model in runs: + if run_model.status.is_finished(): + continue + + run = run_model_to_run(run_model) + run_pool_name = run.run_spec.profile.pool_name + if run_pool_name == pool_name: + active_run_names.append(run.run_spec.run_name) + + await stop_runs(session, project_model, active_run_names, abort=True) diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index 2c16fc6b0..502084168 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -6,7 +6,7 @@ from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack.api._public.backends import BackendCollection -from dstack.api._public.pool import PoolCollection +from dstack.api._public.pools import PoolCollection from dstack.api._public.repos import RepoCollection, get_ssh_keypair from dstack.api._public.runs import RunCollection from dstack.api.server import APIClient diff --git a/src/dstack/api/_public/pool.py b/src/dstack/api/_public/pools.py similarity index 100% rename from src/dstack/api/_public/pool.py rename to src/dstack/api/_public/pools.py diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 49b501070..955bf956e 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -10,7 +10,7 @@ from dstack.api.server._backends import BackendsAPIClient from dstack.api.server._gateways import GatewaysAPIClient from dstack.api.server._logs import LogsAPIClient -from dstack.api.server._pool import PoolAPIClient +from dstack.api.server._pools import PoolAPIClient from dstack.api.server._projects import ProjectsAPIClient from dstack.api.server._repos import ReposAPIClient from dstack.api.server._runs import RunsAPIClient diff --git a/src/dstack/api/server/_pool.py b/src/dstack/api/server/_pools.py similarity index 66% rename from src/dstack/api/server/_pool.py rename to src/dstack/api/server/_pools.py index 4b3ae9b03..d936a7508 100644 --- a/src/dstack/api/server/_pool.py +++ b/src/dstack/api/server/_pools.py @@ -2,8 +2,8 @@ from pydantic import parse_obj_as -import dstack._internal.server.schemas.pool as pool_schemas -from dstack._internal.core.models.pool import Instance, Pool +import dstack._internal.server.schemas.pools as schemas_pools +from dstack._internal.core.models.pools import Instance, Pool from dstack.api.server._group import APIClientGroup @@ -12,15 +12,15 @@ def list(self, project_name: str) -> List[Pool]: resp = self._request(f"/api/project/{project_name}/pool/list") return parse_obj_as(List[Pool], resp.json()) - def delete(self, project_name: str, pool_name: str) -> None: - body = pool_schemas.DeletePoolRequest(name=pool_name) + def delete(self, project_name: str, pool_name: str, force: bool) -> None: + body = schemas_pools.DeletePoolRequest(name=pool_name, force=force) self._request(f"/api/project/{project_name}/pool/delete", body=body.json()) def create(self, project_name: str, pool_name: str) -> None: - body = pool_schemas.CreatePoolRequest(name=pool_name) + body = schemas_pools.CreatePoolRequest(name=pool_name) self._request(f"/api/project/{project_name}/pool/create", body=body.json()) def show(self, project_name: str, pool_name: str) -> List[Instance]: - body = pool_schemas.ShowPoolRequest(name=pool_name) + body = schemas_pools.ShowPoolRequest(name=pool_name) resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) return parse_obj_as(List[Instance], resp.json()) diff --git a/src/tests/_internal/server/services/test_pool.py b/src/tests/_internal/server/services/test_pools.py similarity index 66% rename from src/tests/_internal/server/services/test_pool.py rename to src/tests/_internal/server/services/test_pools.py index 5a23c85ee..31dac9fa3 100644 --- a/src/tests/_internal/server/services/test_pool.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -4,11 +4,15 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession +import dstack._internal.server.services.pools as services_pools +import dstack._internal.server.services.projects as services_projects +import dstack._internal.server.services.users as services_users from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.pool import Instance, Pool -from dstack._internal.server.models import InstanceModel, PoolModel -from dstack._internal.server.services import pool as services_pool +from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.runs import InstanceStatus +from dstack._internal.core.models.users import GlobalRole +from dstack._internal.server.models import InstanceModel from dstack._internal.server.testing.common import create_project, create_user @@ -16,31 +20,32 @@ async def test_pool(session: AsyncSession, test_db): user = await create_user(session=session) project = await create_project(session=session, owner=user) - pool = await services_pool.create_pool_model( + pool = await services_pools.create_pool_model( session=session, project=project, name="test_pool" ) im = InstanceModel( name="test_instnce", project=project, pool=pool, + status=InstanceStatus.PENDING, job_provisioning_data="", offer="", ) session.add(im) await session.commit() - core_model_pool = services_pool.pool_model_to_pool(pool) + core_model_pool = services_pools.pool_model_to_pool(pool) assert core_model_pool == Pool(name="test_pool", default=True, created_at=pool.created_at) - list_pools = await services_pool.list_project_pool(session=session, project=project) - assert list_pools == [services_pool.pool_model_to_pool(pool)] + list_pools = await services_pools.list_project_pool(session=session, project=project) + assert list_pools == [services_pools.pool_model_to_pool(pool)] - list_pool_models = await services_pool.list_project_pool_models( + list_pool_models = await services_pools.list_project_pool_models( session=session, project=project ) assert len(list_pool_models) == 1 - pool_intances = await services_pool.get_pool_instances(session=session, pool_name="test_pool") + pool_intances = await services_pools.get_pool_instances(session=session, pool_name="test_pool") assert pool_intances == [im] @@ -52,6 +57,7 @@ def test_convert_instance(): ), instance_id="test_instance", hostname="hostname_test", + status=InstanceStatus.PENDING, price=1.0, ) @@ -59,58 +65,32 @@ def test_convert_instance(): id=str(uuid.uuid4()), created_at=dt.datetime.now(), name="test_instance", + status=InstanceStatus.PENDING, project_id=str(uuid.uuid4()), pool=None, job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', ) - instance = services_pool.instance_model_to_instance(im) + instance = services_pools.instance_model_to_instance(im) assert instance == expected_instance @pytest.mark.asyncio async def test_delete_pool(session: AsyncSession, test_db): - user = await create_user(session=session) - project = await create_project(session=session, owner=user) - pool = await services_pool.create_pool_model( - session=session, project=project, name="test_pool" - ) - im = InstanceModel( - name="test_instnce", - project=project, - pool=pool, - job_provisioning_data="", - offer="", + POOL_NAME = "test_pool" + user = await services_users.create_user(session, "test_user", global_role=GlobalRole.ADMIN) + project = await services_projects.create_project(session, user, "test_project") + project_model = await services_projects.get_project_model_by_name_or_error( + session, project.project_name ) - session.add(im) - await session.commit() - - await services_pool.delete_pool(session=session, project=project, pool_name="test_pool") - - -# async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str): -# """delete the pool and set the default pool to project""" - -# default_pool: Optional[PoolModel] = None -# default_pool_removed = False - -# for pool in await list_project_pool_models(session=session, project=project): -# if pool.name == DEFAULT_POOL_NAME: -# default_pool = pool - -# if pool_name == pool.name: -# if project.default_pool_id == pool.id: -# default_pool_removed = True -# await session.delete(pool) + pool = await services_pools.create_pool_model(session, project_model, POOL_NAME) -# if default_pool_removed: -# if default_pool is not None: -# project.default_pool = default_pool -# else: -# await create_pool_model(session, project, DEFAULT_POOL_NAME) + await services_pools.delete_pool(session, project_model, POOL_NAME) -# await session.commit() + deleted_pools = await services_pools.list_deleted_pools(session, project_model) + assert len(deleted_pools) == 1 + assert pool.name == deleted_pools[0].name @pytest.mark.asyncio @@ -118,18 +98,19 @@ async def test_show_pool(session: AsyncSession, test_db): POOL_NAME = "test_pool" user = await create_user(session=session) project = await create_project(session=session, owner=user) - pool = await services_pool.create_pool_model(session=session, project=project, name=POOL_NAME) + pool = await services_pools.create_pool_model(session=session, project=project, name=POOL_NAME) im = InstanceModel( name="test_instnce", project=project, pool=pool, + status=InstanceStatus.PENDING, job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', ) session.add(im) await session.commit() - instances = await services_pool.show_pool(session, project, POOL_NAME) + instances = await services_pools.show_pool(session, project, POOL_NAME) assert len(instances) == 1 @@ -138,21 +119,22 @@ async def test_get_pool_instances(session: AsyncSession, test_db): POOL_NAME = "test_pool" user = await create_user(session=session) project = await create_project(session=session, owner=user) - pool = await services_pool.create_pool_model(session=session, project=project, name=POOL_NAME) + pool = await services_pools.create_pool_model(session=session, project=project, name=POOL_NAME) im = InstanceModel( name="test_instnce", project=project, pool=pool, + status=InstanceStatus.PENDING, job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', ) session.add(im) await session.commit() - instances = await services_pool.get_pool_instances(session, POOL_NAME) + instances = await services_pools.get_pool_instances(session, POOL_NAME) assert len(instances) == 1 - empty_instances = await services_pool.get_pool_instances(session, f"{POOL_NAME}-0") + empty_instances = await services_pools.get_pool_instances(session, f"{POOL_NAME}-0") assert len(empty_instances) == 0 @@ -160,20 +142,21 @@ async def test_get_pool_instances(session: AsyncSession, test_db): async def test_generate_instance_name(session: AsyncSession, test_db): user = await create_user(session=session) project = await create_project(session=session, owner=user) - pool = await services_pool.create_pool_model( + pool = await services_pools.create_pool_model( session=session, project=project, name="test_pool" ) im = InstanceModel( name="test_instnce", project=project, pool=pool, + status=InstanceStatus.PENDING, job_provisioning_data="", offer="", ) session.add(im) await session.commit() - name = await services_pool.generate_instance_name( + name = await services_pools.generate_instance_name( session=session, project=project, pool_name="test_pool" ) car, _, cdr = name.partition("-") From 4c946e0589274ffda959da0bdaac9cc38dc93742 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 10 Jan 2024 06:13:40 +0300 Subject: [PATCH 04/47] add local instance to the pool --- src/dstack/_internal/cli/commands/pool.py | 23 +++- src/dstack/_internal/cli/commands/run.py | 11 +- .../_internal/core/backends/local/compute.py | 14 +++ .../_internal/core/models/backends/base.py | 1 + src/dstack/_internal/core/models/profiles.py | 23 ++++ .../_internal/server/background/__init__.py | 2 + .../server/background/tasks/process_pools.py | 43 ++++++++ .../background/tasks/process_running_jobs.py | 15 +++ .../tasks/process_submitted_jobs.py | 79 +++++++++++++- ...add_pools.py => ec4dbadbab3c_add_pools.py} | 6 +- src/dstack/_internal/server/models.py | 6 +- src/dstack/_internal/server/routers/pools.py | 18 ++++ src/dstack/_internal/server/routers/runs.py | 1 + src/dstack/_internal/server/schemas/runs.py | 8 ++ .../server/services/jobs/__init__.py | 4 + src/dstack/_internal/server/services/pools.py | 101 ++++++++++++++++-- src/dstack/api/server/_pools.py | 11 +- src/dstack/api/server/_runs.py | 1 + .../_internal/server/routers/test_runs.py | 6 ++ .../_internal/server/services/test_pools.py | 6 +- 20 files changed, 354 insertions(+), 25 deletions(-) create mode 100644 src/dstack/_internal/server/background/tasks/process_pools.py rename src/dstack/_internal/server/migrations/versions/{a859c41ae3b7_add_pools.py => ec4dbadbab3c_add_pools.py} (97%) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 76fd30221..677f6f05b 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -197,7 +197,7 @@ def _register(self): formatter_class=self._parser.formatter_class, ) show_parser.add_argument( - "-n", "--name", dest="pool_name", help="The name of the pool", required=True + "--name", "-n", dest="pool_name", help="The name of the pool", required=True ) show_parser.set_defaults(subfunc=self._show) @@ -211,6 +211,18 @@ def _register(self): add_parser.add_argument( "-y", "--yes", help="Don't ask for confirmation", action="store_true" ) + add_parser.add_argument( + "--remote", + help="Add remote runner as an instance", + dest="remote", + action="store_true", + default=False, + ) + add_parser.add_argument("--remote-host", help="Remote runner host", dest="remote_host") + add_parser.add_argument( + "--remote-port", help="Remote runner port", dest="remote_port", default=10999 + ) + add_parser.add_argument("--name", dest="instance_name", help="The name of the instance") add_parser.set_defaults(subfunc=self._add) register_profile_args(add_parser) @@ -231,13 +243,20 @@ def _show(self, args: argparse.Namespace): def _add(self, args: argparse.Namespace): super()._command(args) + pool_name: str = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name + + if args.remote: + self.api.client.pool.add( + self.api.project, pool_name, args.instance_name, args.remote_host, args.remote_port + ) + return + repo = self.api.repos.load(Path.cwd()) self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path profile = load_profile(Path.cwd(), args.profile) apply_profile_args(args, profile) - pool_name: str = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name profile.pool_name = pool_name with console.status("Getting run plan..."): diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index a27c0e7d9..7aa22fb12 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -17,7 +17,7 @@ from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.logging import get_logger @@ -84,6 +84,13 @@ def _register(self): dest="pool_name", help="The name of the pool", ) + self._parser.add_argument( + "--reuse", + dest="creation_policy", + action="store_const", + const=CreationPolicy.REUSE, + help="Reuse instance", + ) register_profile_args(self._parser) def _command(self, args: argparse.Namespace): @@ -148,6 +155,8 @@ def _command(self, args: argparse.Namespace): console.print("\nExiting...") return + run_plan.run_spec.profile.creation_policy = args.creation_policy + try: with console.status("Submitting run..."): run = self.api.runs.exec_plan(run_plan, repo, reserve_ports=not args.detach) diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 8bfa3077d..802020b0d 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -38,6 +38,20 @@ def terminate_instance( ): pass + def create_instance( + self, project, user, instance_offer, instance_config + ) -> LaunchedInstanceInfo: + launched_instance = LaunchedInstanceInfo( + instance_id="local", + ip_address="127.0.0.1", + region="", + username="root", + ssh_port=10022, + dockerized=False, + backend_data=None, + ) + return launched_instance + def run_job( self, run: Run, diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py index 56fcdb585..b3c9394cd 100644 --- a/src/dstack/_internal/core/models/backends/base.py +++ b/src/dstack/_internal/core/models/backends/base.py @@ -26,6 +26,7 @@ class BackendType(str, enum.Enum): KUBERNETES = "kubernetes" LAMBDA = "lambda" LOCAL = "local" + # REMOTE= "remote" # replace for LOCAL NEBIUS = "nebius" TENSORDOCK = "tensordock" VASTAI = "vastai" diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index d4c6fbf2e..6957d8998 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -18,6 +18,16 @@ class SpotPolicy(str, Enum): AUTO = "auto" +class CreationPolicy(str, Enum): + REUSE = "reuse" + REUSE_OR_CREATE = "reuse-or-create" + + +class TerminationPolicy(str, Enum): + DONT_DESTROY = "dont-destroy" + DESTROY_AFTER_IDLE = "destroy-after-idle" + + def parse_duration(v: Optional[Union[int, str]]) -> Optional[int]: if v is None: return None @@ -99,10 +109,23 @@ class Profile(ForbidExtra): Optional[str], Field(description="The name of the pool. If not set, dstack will use the default name."), ] = DEFAULT_POOL_NAME + creation_policy: Annotated[ + Optional[CreationPolicy], Field(description="The policy for using instances from the pool") + ] + termination_policy: Annotated[ + Optional[TerminationPolicy], Field(description="The policy for termination instances") + ] + termination_idle_time: Annotated[ + Optional[Union[Literal["off"], str, int]], + Field(description=""), + ] _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration ) + _validate_termination_idle_time = validator( + "termination_idle_time", pre=True, allow_reuse=True + )(parse_max_duration) class ProfilesConfig(ForbidExtra): diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index b24876dd4..540f8df20 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -3,6 +3,7 @@ from dstack._internal.server.background.tasks.process_finished_jobs import process_finished_jobs from dstack._internal.server.background.tasks.process_pending_jobs import process_pending_jobs +from dstack._internal.server.background.tasks.process_pools import process_pools from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs @@ -18,5 +19,6 @@ def start_background_tasks() -> AsyncIOScheduler: _scheduler.add_job(process_running_jobs, IntervalTrigger(seconds=2)) _scheduler.add_job(process_finished_jobs, IntervalTrigger(seconds=2)) _scheduler.add_job(process_pending_jobs, IntervalTrigger(seconds=10)) + _scheduler.add_job(process_pools, IntervalTrigger(seconds=10)) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py new file mode 100644 index 000000000..2b8536c28 --- /dev/null +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -0,0 +1,43 @@ +from datetime import timedelta + +from pydantic import parse_raw_as +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.models.runs import InstanceStatus, JobStatus +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceModel, JobModel +from dstack._internal.server.services.jobs import PROCESSING_POOL_IDS, PROCESSING_POOL_LOCK +from dstack._internal.utils.logging import get_logger + +PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) + +logger = get_logger(__name__) + + +async def process_pools(): + + async with get_session_ctx() as session: + async with PROCESSING_POOL_LOCK: + res = await session.scalars( + select(InstanceModel).where( + InstanceModel.status.in_([InstanceStatus.READY, InstanceStatus.FAILED]), + InstanceModel.id.not_in(PROCESSING_POOL_IDS), + ) + ) + instances = res.all() + if not instances: + return + + PROCESSING_POOL_IDS.update(i.id for i in instances) + + try: + for inst in instances: + await _terminate_instance(inst) + finally: + PROCESSING_POOL_IDS.difference_update(i.id for i in instances) + + +async def _terminate_instance(instance: InstanceModel): + pass diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 4ca273c3c..97472c763 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -1,3 +1,4 @@ +from asyncio.proactor_events import _ProactorBasePipeTransport from datetime import timedelta from typing import Dict, Optional from uuid import UUID @@ -12,8 +13,10 @@ from dstack._internal.core.errors import GatewayError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RegistryAuth +from dstack._internal.core.models.instances import InstanceState from dstack._internal.core.models.repos import RemoteRepoCreds from dstack._internal.core.models.runs import ( + InstanceStatus, Job, JobErrorCode, JobProvisioningData, @@ -38,6 +41,7 @@ job_model_to_job_submission, ) from dstack._internal.server.services.logging import job_log +from dstack._internal.server.services.pools import get_pool_instances from dstack._internal.server.services.repos import get_code_model, repo_model_to_repo_head from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel @@ -98,6 +102,7 @@ async def _process_job(job_id: UUID): job = run.jobs[job_model.job_num] job_submission = job_model_to_job_submission(job_model) job_provisioning_data = job_submission.job_provisioning_data + server_ssh_private_key = project.ssh_private_key secrets = {} # TODO secrets repo_creds = repo_model_to_repo_head(repo_model, include_creds=True).repo_creds @@ -145,6 +150,15 @@ async def _process_job(job_id: UUID): secrets, repo_creds, ) + + if success: + instance_name: str = job_provisioning_data.instance_id + pool_name = str(job.job_spec.pool_name) + instances = await get_pool_instances(session, project, pool_name) + for inst in instances: + if inst.name == instance_name: + inst.status = InstanceStatus.BUSY + if not success: # check timeout if job_submission.age > _get_runner_timeout_interval( job_provisioning_data.backend @@ -285,6 +299,7 @@ def _process_provisioning_no_shim( Returns: is successful """ + runner_client = client.RunnerClient(port=ports[client.REMOTE_RUNNER_PORT]) resp = runner_client.healthcheck() if resp is None: diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index ce51fa6c7..d76c0ff25 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -1,17 +1,19 @@ from typing import List, Optional, Tuple from uuid import UUID +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from dstack._internal.core.backends.base import Backend from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile from dstack._internal.core.models.runs import ( InstanceStatus, Job, @@ -19,6 +21,7 @@ JobProvisioningData, JobStatus, Run, + RunSpec, ) from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import InstanceModel, JobModel, PoolModel, RunModel @@ -28,7 +31,13 @@ SUBMITTED_PROCESSING_JOBS_LOCK, ) from dstack._internal.server.services.logging import job_log +from dstack._internal.server.services.pools import ( + get_pool_instances, + instance_model_to_instance, + show_pool, +) from dstack._internal.server.services.runs import run_model_to_run +from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED from dstack._internal.server.utils.common import run_async from dstack._internal.utils import common as common_utils from dstack._internal.utils.logging import get_logger @@ -69,6 +78,34 @@ async def _process_job(job_id: UUID): ) +def check_relevance(profile: Profile, instance_model: InstanceModel) -> bool: + + jpd: JobProvisioningData = parse_raw_as( + JobProvisioningData, instance_model.job_provisioning_data + ) + + if LOCAL_BACKEND_ENABLED and jpd.backend == BackendType.LOCAL: + return True + + instance = instance_model_to_instance(instance_model) + + if profile.backends is not None and instance.backend not in profile.backends: + return False + + instance_resources = jpd.instance_type.resources + + if profile.resources.cpu is not None and profile.resources.cpu < instance_resources.cpus: + return False + + # TODO: full check + if isinstance(profile.resources.gpu, int): + if profile.resources.gpu < len(instance_resources.gpus): + return False + + return True + # TODO: memory, shm_size, disk + + async def _process_submitted_job(session: AsyncSession, job_model: JobModel): logger.debug(*job_log("provisioning", job_model)) res = await session.execute( @@ -80,6 +117,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): run_model = res.scalar_one() project_model = run_model.project + # check default pool pool = project_model.default_pool if pool is None: pool = PoolModel( @@ -91,9 +129,48 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): if pool.id is not None: project_model.default_pool_id = pool.id + profile = parse_raw_as(RunSpec, run_model.run_spec).profile + run_pool = profile.pool_name + if run_pool is None: + run_pool = pool.name + + # pool capacity + pool_instances = await get_pool_instances(session, project_model, run_pool) + available_instanses = (p for p in pool_instances if p.status == InstanceStatus.PENDING) + relevant_instances: List[InstanceModel] = [] + for instance in available_instanses: + if check_relevance(profile, instance): + relevant_instances.append(instance) + + if relevant_instances: + + sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name) + instance = sorted_instances[0] + + instance.status = InstanceStatus.BUSY + + logger.info(*job_log("now is provisioning", job_model)) + job_model.job_provisioning_data = instance.job_provisioning_data + job_model.status = JobStatus.PROVISIONING + job_model.last_processed_at = common_utils.get_current_datetime() + + await session.commit() + + return + + if profile.creation_policy == CreationPolicy.REUSE: + job_model.status = JobStatus.FAILED + job_model.error_code = JobErrorCode.FAILED_TO_START_DUE_TO_NO_CAPACITY + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + return + + # create a new cloud instance run = run_model_to_run(run_model) job = run.jobs[job_model.job_num] backends = await backends_services.get_project_backends(project=run_model.project) + + # TODO: create VM (backend.compute().create_instance) job_provisioning_data, offer = await _run_job( job_model=job_model, run=run, diff --git a/src/dstack/_internal/server/migrations/versions/a859c41ae3b7_add_pools.py b/src/dstack/_internal/server/migrations/versions/ec4dbadbab3c_add_pools.py similarity index 97% rename from src/dstack/_internal/server/migrations/versions/a859c41ae3b7_add_pools.py rename to src/dstack/_internal/server/migrations/versions/ec4dbadbab3c_add_pools.py index c8c86e30d..ab6bc30b3 100644 --- a/src/dstack/_internal/server/migrations/versions/a859c41ae3b7_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/ec4dbadbab3c_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: a859c41ae3b7 +Revision ID: ec4dbadbab3c Revises: 48ad3ecbaea2 -Create Date: 2023-12-28 10:13:48.608439 +Create Date: 2024-01-10 07:56:08.754541 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "a859c41ae3b7" +revision = "ec4dbadbab3c" down_revision = "48ad3ecbaea2" branch_labels = None depends_on = None diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index c25da2f1c..be5d3d253 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -268,12 +268,12 @@ class InstanceModel(BaseModel): deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) - project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id], single_parent=True) pool_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("pools.id")) - pool: Mapped["PoolModel"] = relationship(back_populates="instances") + pool: Mapped["PoolModel"] = relationship(back_populates="instances", single_parent=True) - status: Mapped[InstanceStatus] = mapped_column(Enum(JobStatus)) + status: Mapped[InstanceStatus] = mapped_column(Enum(InstanceStatus)) status_message: Mapped[Optional[str]] = mapped_column(String(50)) started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index fe17ce386..5f8ad1d94 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -8,6 +8,7 @@ import dstack._internal.server.services.pools as pools from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.schemas.runs import AddInstanceRequest from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember from dstack._internal.server.services.runs import ( abort_runs_of_pool, @@ -77,3 +78,20 @@ async def how_pool( ): _, project = user_project return await pools.show_pool(session, project, pool_name=body.name) + + +@router.post("/add") +async def add_instance( + body: AddInstanceRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +): + _, project = user_project + await pools.add( + session, + project, + pool_name=body.pool_name, + instance_name=body.instance_name, + host=body.host, + port=body.port, + ) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index d80af86be..3106e0767 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -8,6 +8,7 @@ from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( + AddInstanceRequest, CreateInstanceRequest, DeleteRunsRequest, GetOffersRequest, diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 8c03d7bda..4483f62ae 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -28,6 +28,14 @@ class CreateInstanceRequest(BaseModel): profile: Profile +class AddInstanceRequest(BaseModel): + pool_name: str + instance_name: Optional[str] + host: str + port: str + # TODO: define runner spec (gpu, cpu, etc) + + class SubmitRunRequest(BaseModel): run_spec: RunSpec diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index b92a75fd2..9bc10068b 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -40,6 +40,10 @@ RUNNING_PROCESSING_JOBS_LOCK = asyncio.Lock() RUNNING_PROCESSING_JOBS_IDS = set() +PROCESSING_POOL_LOCK = asyncio.Lock() +PROCESSING_POOL_IDS = set() + + TERMINATING_PROCESSING_JOBS_LOCK = asyncio.Lock() TERMINATING_PROCESSING_JOBS_IDS = set() diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index c3ff15a03..eef3d6d86 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -7,10 +7,16 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + InstanceType, + Resources, +) from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME -from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel from dstack._internal.utils import random_names from dstack._internal.utils.common import get_current_datetime @@ -114,22 +120,32 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: async def show_pool( session: AsyncSession, project: ProjectModel, pool_name: str ) -> Sequence[Instance]: - pools = ( + pool = ( await session.scalars( select(PoolModel).where( - PoolModel.name == pool_name, PoolModel.project_id == project.id + PoolModel.name == pool_name, + PoolModel.project_id == project.id, + PoolModel.deleted == False, ) ) - ).all() - - instances = [instance_model_to_instance(i) for i in pools[0].instances] - return instances + ).one_or_none() + if pool is not None: + instances = [instance_model_to_instance(i) for i in pool.instances] + return instances + else: + return [] -async def get_pool_instances(session: AsyncSession, pool_name: str) -> List[InstanceModel]: +async def get_pool_instances( + session: AsyncSession, project: ProjectModel, pool_name: str +) -> List[InstanceModel]: res = await session.execute( select(PoolModel) - .where(PoolModel.name == pool_name) + .where( + PoolModel.name == pool_name, + PoolModel.project_id == project.id, + PoolModel.deleted == False, + ) .options(joinedload(PoolModel.instances)) ) result = res.unique().scalars().one_or_none() @@ -148,9 +164,72 @@ async def generate_instance_name( ) -> str: lock = _GENERATE_POOL_NAME_LOCK.setdefault(project.name, asyncio.Lock()) async with lock: - pool_instances: List[InstanceModel] = await get_pool_instances(session, pool_name) + pool_instances: List[InstanceModel] = await get_pool_instances(session, project, pool_name) names = {g.name for g in pool_instances} while True: name = f"{random_names.generate_name()}" if name not in names: return name + + +async def add( + session: AsyncSession, + project: ProjectModel, + pool_name: str, + instance_name: Optional[str], + host: str, + port: str, +): + + instance_name = instance_name + if instance_name is None: + instance_name = await generate_instance_name(session, project, pool_name) + + pool = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project_id == project.id, + PoolModel.deleted == False, + ) + ) + ).one_or_none() + + if pool is None: + pool = await create_pool_model(session, project, pool_name) + + local = JobProvisioningData( + backend=BackendType.LOCAL, + instance_type=InstanceType( + name="local", resources=Resources(cpus=0, memory_mib=0, gpus=[], spot=False) + ), + instance_id=instance_name, + hostname=host, + region="", + price=0, + username="", + ssh_port=22, + dockerized=False, + backend_data="", + ) + offer = InstanceOfferWithAvailability( + backend=BackendType.LOCAL, + instance=InstanceType( + name="instance", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="", + price=0.0, + availability=InstanceAvailability.AVAILABLE, + ) + + im = InstanceModel( + name=instance_name, + project=project, + pool=pool, + status=InstanceStatus.PENDING, + job_provisioning_data=local.json(), + offer=offer.json(), + ) + session.add(im) + await session.commit() diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index d936a7508..751379f46 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -1,9 +1,10 @@ -from typing import List +from typing import List, Optional from pydantic import parse_obj_as import dstack._internal.server.schemas.pools as schemas_pools from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.server.schemas.runs import AddInstanceRequest from dstack.api.server._group import APIClientGroup @@ -24,3 +25,11 @@ def show(self, project_name: str, pool_name: str) -> List[Instance]: body = schemas_pools.ShowPoolRequest(name=pool_name) resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) return parse_obj_as(List[Instance], resp.json()) + + def add( + self, project_name: str, pool_name: str, instance_name: Optional[str], host: str, port: str + ): + body = AddInstanceRequest( + pool_name=pool_name, instance_name=instance_name, host=host, port=port + ) + self._request(f"/api/project/{project_name}/pool/add", body=body.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 202fc09b0..d5d7260b6 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -6,6 +6,7 @@ from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( + AddInstanceRequest, CreateInstanceRequest, DeleteRunsRequest, GetOffersRequest, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 1f4a4816c..57f826c27 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -71,6 +71,7 @@ def get_dev_env_run_plan_dict( "configuration_path": "dstack.yaml", "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda"], + "creation_policy": None, "default": False, "max_duration": "off", "max_price": None, @@ -78,6 +79,8 @@ def get_dev_env_run_plan_dict( "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", + "termination_idle_time": None, + "termination_policy": None, }, "repo_code_hash": None, "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, @@ -179,6 +182,7 @@ def get_dev_env_run_dict( "configuration_path": "dstack.yaml", "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda"], + "creation_policy": None, "default": False, "max_duration": "off", "max_price": None, @@ -186,6 +190,8 @@ def get_dev_env_run_dict( "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", + "termination_idle_time": None, + "termination_policy": None, }, "repo_code_hash": None, "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index 31dac9fa3..f720f2e20 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -45,7 +45,7 @@ async def test_pool(session: AsyncSession, test_db): ) assert len(list_pool_models) == 1 - pool_intances = await services_pools.get_pool_instances(session=session, pool_name="test_pool") + pool_intances = await services_pools.get_pool_instances(session, project, "test_pool") assert pool_intances == [im] @@ -131,10 +131,10 @@ async def test_get_pool_instances(session: AsyncSession, test_db): session.add(im) await session.commit() - instances = await services_pools.get_pool_instances(session, POOL_NAME) + instances = await services_pools.get_pool_instances(session, project, POOL_NAME) assert len(instances) == 1 - empty_instances = await services_pools.get_pool_instances(session, f"{POOL_NAME}-0") + empty_instances = await services_pools.get_pool_instances(session, project, f"{POOL_NAME}-0") assert len(empty_instances) == 0 From a314da735b61daedd19ea5c9e402fbd3e0260a20 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 12 Jan 2024 10:26:09 +0300 Subject: [PATCH 05/47] multitask --- runner/cmd/shim/main.go | 71 +++++------ runner/internal/api/common.go | 5 +- runner/internal/runner/api/http_test.go | 79 +++++++++++- runner/internal/shim/api/http.go | 17 +-- runner/internal/shim/api/schemas.go | 6 +- runner/internal/shim/api/server.go | 53 +++----- runner/internal/shim/backends/aws.go | 74 ----------- runner/internal/shim/backends/azure.go | 84 ------------- runner/internal/shim/backends/azure_test.go | 33 ----- runner/internal/shim/backends/backends.go | 32 ----- runner/internal/shim/backends/gcp.go | 71 ----------- runner/internal/shim/backends/lambda.go | 82 ------------ runner/internal/shim/docker.go | 118 +++++++++++------- runner/internal/shim/docker_test.go | 19 +-- runner/internal/shim/models.go | 22 ++-- runner/internal/shim/states.go | 9 +- .../_internal/core/backends/aws/compute.py | 2 - .../_internal/core/backends/base/compute.py | 20 +-- .../core/backends/datacrunch/compute.py | 7 -- .../_internal/core/backends/gcp/compute.py | 3 - .../core/backends/lambdalabs/compute.py | 1 - .../_internal/core/backends/nebius/compute.py | 1 - .../core/backends/tensordock/compute.py | 1 - src/dstack/_internal/server/schemas/runner.py | 2 +- .../server/services/runner/client.py | 8 +- 25 files changed, 240 insertions(+), 580 deletions(-) delete mode 100644 runner/internal/shim/backends/aws.go delete mode 100644 runner/internal/shim/backends/azure.go delete mode 100644 runner/internal/shim/backends/azure_test.go delete mode 100644 runner/internal/shim/backends/backends.go delete mode 100644 runner/internal/shim/backends/gcp.go delete mode 100644 runner/internal/shim/backends/lambda.go diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 582af89f2..6a5e331fd 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -2,20 +2,22 @@ package main import ( "context" + "errors" "fmt" "log" + "net/http" "os" "path/filepath" + "time" + "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/api" - "github.com/dstackai/dstack/runner/internal/shim/backends" "github.com/urfave/cli/v2" ) func main() { - var backendName string var args shim.CLIArgs args.Docker.SSHPort = 10022 @@ -24,13 +26,6 @@ func main() { Usage: "Starts dstack-runner or docker container. Kills the VM on exit.", Version: Version, Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "backend", - Usage: "Cloud backend provider", - Required: true, - Destination: &backendName, - EnvVars: []string{"DSTACK_BACKEND"}, - }, /* Shim Parameters */ &cli.PathFlag{ Name: "home", @@ -85,17 +80,6 @@ func main() { Usage: "Starts docker container and modifies entrypoint", Flags: []cli.Flag{ /* Docker Parameters */ - &cli.BoolFlag{ - Name: "with-auth", - Usage: "Waits for registry credentials", - Destination: &args.Docker.RegistryAuthRequired, - }, - &cli.StringFlag{ - Name: "image", - Usage: "Docker image name", - Destination: &args.Docker.ImageName, - EnvVars: []string{"DSTACK_IMAGE_NAME"}, - }, &cli.BoolFlag{ Name: "keep-container", Usage: "Do not delete container on exit", @@ -117,42 +101,44 @@ func main() { defer func() { _ = os.Remove(args.Runner.BinaryPath) }() } - log.Printf("Backend: %s\n", backendName) args.Runner.TempDir = "/tmp/runner" args.Runner.HomeDir = "/root" args.Runner.WorkingDir = "/workflow" var err error + + // set dstack home path args.Shim.HomeDir, err = getDstackHome(args.Shim.HomeDir) if err != nil { - return gerrors.Wrap(err) + return err } log.Printf("Docker: %+v\n", args) - server := api.NewShimServer(fmt.Sprintf(":%d", args.Shim.HTTPPort), args.Docker.RegistryAuthRequired) - return gerrors.Wrap(server.RunDocker(context.TODO(), &args)) - }, - }, - { - Name: "subprocess", - Usage: "Docker-less mode", - Action: func(c *cli.Context) error { - return gerrors.New("not implemented") + dockerRunner, err := shim.NewDockerRunner(args) + if err != nil { + return err + } + + address := fmt.Sprintf(":%d", args.Shim.HTTPPort) + shimServer := api.NewShimServer(address, dockerRunner) + + defer func() { + shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelShutdown() + _ = shimServer.HttpServer.Shutdown(shutdownCtx) + }() + + if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + panic(err) + } + + return nil + }, }, }, } - defer func() { - backend, err := backends.NewBackend(context.TODO(), backendName) - if err != nil { - log.Fatal(err) - } - if err = backend.Terminate(context.TODO()); err != nil { - log.Fatal(err) - } - }() - if err := app.Run(os.Args); err != nil { log.Fatal(err) } @@ -162,9 +148,10 @@ func getDstackHome(flag string) (string, error) { if flag != "" { return flag, nil } + home, err := os.UserHomeDir() if err != nil { return "", gerrors.Wrap(err) } - return filepath.Join(home, ".dstack"), nil + return filepath.Join(home, consts.DSTACK_DIR_PATH), nil } diff --git a/runner/internal/api/common.go b/runner/internal/api/common.go index 39bbe1b8c..7c4ceba8a 100644 --- a/runner/internal/api/common.go +++ b/runner/internal/api/common.go @@ -4,11 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/golang/gddo/httputil/header" "io" "net/http" "strings" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/golang/gddo/httputil/header" ) type Error struct { diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index c0b877068..dfbd64b32 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -1,4 +1,79 @@ package api -// todo test 409 on wrong requests order -// todo test submit wait timeout +import ( + "context" + "net/http/httptest" + "strings" + "testing" + + common "github.com/dstackai/dstack/runner/internal/api" + "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/api" +) + +type DummyRunner struct { + State shim.RunnerStatus +} + +func (ds DummyRunner) GetState() shim.RunnerStatus { + return ds.State +} + +func (ds DummyRunner) Run(context.Context, shim.DockerTaskConfig) error { + return nil +} + +func TestHealthcheck(t *testing.T) { + + request := httptest.NewRequest("GET", "/api/healthcheck", nil) + responseRecorder := httptest.NewRecorder() + + server := api.NewShimServer(":12345", DummyRunner{}) + + f := common.JSONResponseHandler("GET", server.HealthcheckGetHandler) + f(responseRecorder, request) + + if responseRecorder.Code != 200 { + t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) + } + + expected := "{\"service\":\"dstack-shim\"}" + + if strings.TrimSpace(responseRecorder.Body.String()) != expected { + t.Errorf("Want '%s', got '%s'", expected, responseRecorder.Body.String()) + } +} + +func TestSubmit(t *testing.T) { + + request := httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) + responseRecorder := httptest.NewRecorder() + + dummyRunner := DummyRunner{} + dummyRunner.State = shim.Pending + + server := api.NewShimServer(":12340", &dummyRunner) + + firstSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + firstSubmitPost(responseRecorder, request) + + if responseRecorder.Code != 200 { + t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) + } + + t.Logf("%v", responseRecorder.Result()) + + dummyRunner.State = shim.Pulling + + request = httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) + responseRecorder = httptest.NewRecorder() + + secondSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + secondSubmitPost(responseRecorder, request) + + t.Logf("%v", responseRecorder.Result()) + + if responseRecorder.Code != 409 { + t.Errorf("Want status '%d', got '%d'", 409, responseRecorder.Code) + } +} diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index 019ea1c93..f1915f223 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -1,6 +1,7 @@ package api import ( + "context" "log" "net/http" @@ -8,7 +9,7 @@ import ( "github.com/dstackai/dstack/runner/internal/shim" ) -func (s *ShimServer) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) HealthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -17,29 +18,31 @@ func (s *ShimServer) healthcheckGetHandler(w http.ResponseWriter, r *http.Reques }, nil } -func (s *ShimServer) registryAuthPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() - if s.state != shim.WaitRegistryAuth { + if s.runner.GetState() != shim.Pending { return nil, &api.Error{Status: http.StatusConflict} } - var body RegistryAuthBody + var body DockerTaskBody if err := api.DecodeJSONBody(w, r, &body, true); err != nil { log.Println("Failed to decode submit body", "err", err) return nil, err } - s.registryAuth = body.MakeConfig() + go func(taskParams shim.DockerImageConfig) { + s.runner.Run(context.TODO(), taskParams) + }(body.TaskParams()) return nil, nil } -func (s *ShimServer) pullGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) PullGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() return &PullResponse{ - State: s.state, + State: string(s.runner.GetState()), }, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 8098cdfc1..274109461 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -2,7 +2,7 @@ package api import "github.com/dstackai/dstack/runner/internal/shim" -type RegistryAuthBody struct { +type DockerTaskBody struct { Username string `json:"username"` Password string `json:"password"` ImageName string `json:"image_name"` @@ -16,8 +16,8 @@ type PullResponse struct { State string `json:"state"` } -func (ra RegistryAuthBody) MakeConfig() shim.ImagePullConfig { - res := shim.ImagePullConfig{ +func (ra DockerTaskBody) TaskParams() shim.DockerTaskConfig { + res := shim.DockerTaskConfig{ ImageName: ra.ImageName, Username: ra.Username, Password: ra.Password, diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index a57a04536..f6fb77036 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -2,62 +2,37 @@ package api import ( "context" - "errors" "net/http" "sync" - "time" "github.com/dstackai/dstack/runner/internal/api" - "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/shim" ) +type TaskRunner interface { + Run(context.Context, shim.DockerImageConfig) error + GetState() shim.RunnerStatus +} + type ShimServer struct { - srv *http.Server + HttpServer *http.Server + mu sync.RWMutex - mu sync.RWMutex - registryAuth shim.ImagePullConfig - state string + runner TaskRunner } -func NewShimServer(address string, registryAuthRequired bool) *ShimServer { +func NewShimServer(address string, runner TaskRunner) *ShimServer { mux := http.NewServeMux() s := &ShimServer{ - srv: &http.Server{ + HttpServer: &http.Server{ Addr: address, Handler: mux, }, - state: shim.WaitRegistryAuth, - } - if registryAuthRequired { - mux.HandleFunc("/api/registry_auth", api.JSONResponseHandler("POST", s.registryAuthPostHandler)) + runner: runner, } - mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.healthcheckGetHandler)) - mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.pullGetHandler)) + mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.SubmitPostHandler)) + mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.HealthcheckGetHandler)) + mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.PullGetHandler)) return s } - -func (s *ShimServer) RunDocker(ctx context.Context, params shim.DockerParameters) error { - go func() { - if err := s.srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - panic(err) - } - }() - defer func() { - shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) - defer cancelShutdown() - _ = s.srv.Shutdown(shutdownCtx) - }() - return gerrors.Wrap(shim.RunDocker(ctx, params, s)) -} - -func (s *ShimServer) GetRegistryAuth() shim.ImagePullConfig { - return s.registryAuth -} - -func (s *ShimServer) SetState(state string) { - s.mu.Lock() - defer s.mu.Unlock() - s.state = state -} diff --git a/runner/internal/shim/backends/aws.go b/runner/internal/shim/backends/aws.go deleted file mode 100644 index 2b802e275..000000000 --- a/runner/internal/shim/backends/aws.go +++ /dev/null @@ -1,74 +0,0 @@ -package backends - -import ( - "bytes" - "context" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/dstackai/dstack/runner/internal/gerrors" - "io" -) - -type AWSBackend struct { - region string - instanceId string - spot bool -} - -func init() { - register("aws", NewAWSBackend) -} - -func NewAWSBackend(ctx context.Context) (Backend, error) { - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - return nil, gerrors.Wrap(err) - } - - client := imds.NewFromConfig(cfg) - region, err := client.GetRegion(ctx, &imds.GetRegionInput{}) - if err != nil { - return nil, gerrors.Wrap(err) - } - lifecycle, err := getAWSMetadata(ctx, client, "instance-life-cycle") - if err != nil { - return nil, gerrors.Wrap(err) - } - instanceId, err := getAWSMetadata(ctx, client, "instance-id") - if err != nil { - return nil, gerrors.Wrap(err) - } - - return &AWSBackend{ - region: region.Region, - instanceId: instanceId, - spot: lifecycle == "spot", - }, nil -} - -func (b *AWSBackend) Terminate(ctx context.Context) error { - cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(b.region)) - if err != nil { - return gerrors.Wrap(err) - } - client := ec2.NewFromConfig(cfg) - _, err = client.TerminateInstances(ctx, &ec2.TerminateInstancesInput{ - InstanceIds: []string{b.instanceId}, - }) - return gerrors.Wrap(err) -} - -func getAWSMetadata(ctx context.Context, client *imds.Client, path string) (string, error) { - resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ - Path: path, - }) - if err != nil { - return "", gerrors.Wrap(err) - } - var b bytes.Buffer - if _, err = io.Copy(&b, resp.Content); err != nil { - return "", err - } - return b.String(), nil -} diff --git a/runner/internal/shim/backends/azure.go b/runner/internal/shim/backends/azure.go deleted file mode 100644 index a06515e86..000000000 --- a/runner/internal/shim/backends/azure.go +++ /dev/null @@ -1,84 +0,0 @@ -package backends - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4" - "github.com/dstackai/dstack/runner/internal/gerrors" -) - -type AzureBackend struct { - subscriptionId string - resourceGroup string - vmName string -} - -func init() { - register("azure", NewAzureBackend) -} - -func NewAzureBackend(ctx context.Context) (Backend, error) { - metadata, err := getAzureMetadata(ctx, nil) - if err != nil { - return nil, gerrors.Wrap(err) - } - return &AzureBackend{ - subscriptionId: metadata.SubscriptionId, - resourceGroup: metadata.ResourceGroupName, - vmName: metadata.Name, - }, nil -} - -func (b *AzureBackend) Terminate(ctx context.Context) error { - credential, err := azidentity.NewManagedIdentityCredential(nil) - if err != nil { - return gerrors.Wrap(err) - } - computeClient, err := armcompute.NewVirtualMachinesClient(b.subscriptionId, credential, nil) - if err != nil { - return gerrors.Wrap(err) - } - _, err = computeClient.BeginDelete(ctx, b.resourceGroup, b.vmName, nil) - return gerrors.Wrap(err) -} - -type AzureComputeInstanceMetadata struct { - SubscriptionId string `json:"subscriptionId"` - ResourceGroupName string `json:"resourceGroupName"` - Name string `json:"name"` -} - -type AzureInstanceMetadata struct { - Compute AzureComputeInstanceMetadata `json:"compute"` -} - -func getAzureMetadata(ctx context.Context, url *string) (*AzureComputeInstanceMetadata, error) { - baseURL := "http://169.254.169.254" - if url != nil { - baseURL = *url - } - req, err := http.NewRequestWithContext( - ctx, - http.MethodGet, - fmt.Sprintf("%s/metadata/instance?api-version=2021-02-01", baseURL), - nil, - ) - if err != nil { - return nil, gerrors.Wrap(err) - } - req.Header.Add("Metadata", "true") - res, err := http.DefaultClient.Do(req) - if err != nil { - return nil, gerrors.Wrap(err) - } - decoder := json.NewDecoder(res.Body) - var metadata AzureInstanceMetadata - if err = decoder.Decode(&metadata); err != nil { - return nil, gerrors.Wrap(err) - } - return &metadata.Compute, nil -} diff --git a/runner/internal/shim/backends/azure_test.go b/runner/internal/shim/backends/azure_test.go deleted file mode 100644 index 8ea63d20d..000000000 --- a/runner/internal/shim/backends/azure_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package backends - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGetsAzureMetadata(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte( - `{"compute": - { - "subscriptionId":"test_subscription", - "resourceGroupName":"test_group", - "name":"test_vm" - } - }`, - )) - })) - defer server.Close() - metadata, err := getAzureMetadata(context.TODO(), &server.URL) - assert.Equal(t, nil, err) - assert.Equal(t, AzureComputeInstanceMetadata{ - SubscriptionId: "test_subscription", - ResourceGroupName: "test_group", - Name: "test_vm", - }, *metadata) -} diff --git a/runner/internal/shim/backends/backends.go b/runner/internal/shim/backends/backends.go deleted file mode 100644 index 64e7d3728..000000000 --- a/runner/internal/shim/backends/backends.go +++ /dev/null @@ -1,32 +0,0 @@ -package backends - -import ( - "context" - "github.com/dstackai/dstack/runner/internal/gerrors" - "sync" -) - -type Backend interface { - Terminate(context.Context) error -} - -type BackendFactory func(ctx context.Context) (Backend, error) - -var backends = make(map[string]BackendFactory) -var mu = sync.Mutex{} - -func NewBackend(ctx context.Context, name string) (Backend, error) { - mu.Lock() - defer mu.Unlock() - factory, ok := backends[name] - if !ok { - return nil, gerrors.Newf("unknown backend %s", name) - } - return factory(ctx) -} - -func register(name string, factory BackendFactory) { - mu.Lock() - defer mu.Unlock() - backends[name] = factory -} diff --git a/runner/internal/shim/backends/gcp.go b/runner/internal/shim/backends/gcp.go deleted file mode 100644 index 57f668b52..000000000 --- a/runner/internal/shim/backends/gcp.go +++ /dev/null @@ -1,71 +0,0 @@ -package backends - -import ( - compute "cloud.google.com/go/compute/apiv1" - "cloud.google.com/go/compute/apiv1/computepb" - "context" - "fmt" - "github.com/dstackai/dstack/runner/internal/gerrors" - "io" - "net/http" - "strings" -) - -type GCPBackend struct { - instanceName string - project string - zone string -} - -func init() { - register("gcp", NewGCPBackend) -} - -func NewGCPBackend(ctx context.Context) (Backend, error) { - instanceName, err := getGCPMetadata(ctx, "/instance/name") - if err != nil { - return nil, gerrors.Wrap(err) - } - projectZone, err := getGCPMetadata(ctx, "/instance/zone") - if err != nil { - return nil, gerrors.Wrap(err) - } - // Parse `projects//zones/` - parts := strings.Split(projectZone, "/") - return &GCPBackend{ - instanceName: instanceName, - project: parts[1], - zone: parts[3], - }, nil -} - -func (b *GCPBackend) Terminate(ctx context.Context) error { - instancesClient, err := compute.NewInstancesRESTClient(ctx) - if err != nil { - return nil - } - req := &computepb.DeleteInstanceRequest{ - Instance: b.instanceName, - Project: b.project, - Zone: b.zone, - } - _, err = instancesClient.Delete(ctx, req) - return gerrors.Wrap(err) -} - -func getGCPMetadata(ctx context.Context, path string) (string, error) { - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://metadata.google.internal/computeMetadata/v1%s", path), nil) - if err != nil { - return "", gerrors.Wrap(err) - } - req.Header.Add("Metadata-Flavor", "Google") - res, err := http.DefaultClient.Do(req.WithContext(ctx)) - if err != nil { - return "", gerrors.Wrap(err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - return "", gerrors.Wrap(err) - } - return string(body), nil -} diff --git a/runner/internal/shim/backends/lambda.go b/runner/internal/shim/backends/lambda.go deleted file mode 100644 index d2a6325e4..000000000 --- a/runner/internal/shim/backends/lambda.go +++ /dev/null @@ -1,82 +0,0 @@ -package backends - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "os" - - "github.com/dstackai/dstack/runner/internal/gerrors" -) - -const LAMBDA_API_URL = "https://cloud.lambdalabs.com/api/v1" - -type LambdaAPIClient struct { - apiKey string -} - -type TerminateInstanceRequest struct { - InstanceIDs []string `json:"instance_ids"` -} - -func NewLambdaAPIClient(apiKey string) *LambdaAPIClient { - return &LambdaAPIClient{apiKey: apiKey} -} - -func (client *LambdaAPIClient) TerminateInstance(ctx context.Context, instanceIDs []string) error { - body, err := json.Marshal(TerminateInstanceRequest{InstanceIDs: instanceIDs}) - if err != nil { - return gerrors.Wrap(err) - } - req, err := http.NewRequest("POST", LAMBDA_API_URL+"/instance-operations/terminate", bytes.NewReader(body)) - if err != nil { - return gerrors.Wrap(err) - } - req.Header.Add("Authorization", "Bearer "+client.apiKey) - httpClient := http.Client{} - resp, err := httpClient.Do(req) - if err != nil { - return gerrors.Wrap(err) - } - if resp.StatusCode == 200 { - return nil - } - return gerrors.Newf("/instance-operations/terminate returned non-200 status code: %s", resp.Status) -} - -const LAMBDA_CONFIG_PATH = "/home/ubuntu/.dstack/config.json" - -type LambdaConfig struct { - InstanceID string `json:"instance_id"` - ApiKey string `json:"api_key"` -} - -type LambdaBackend struct { - apiClient *LambdaAPIClient - config LambdaConfig -} - -func init() { - register("lambda", NewLambdaBackend) -} - -func NewLambdaBackend(ctx context.Context) (Backend, error) { - config := LambdaConfig{} - fileContent, err := os.ReadFile(LAMBDA_CONFIG_PATH) - if err != nil { - return nil, gerrors.Wrap(err) - } - err = json.Unmarshal(fileContent, &config) - if err != nil { - return nil, gerrors.Wrap(err) - } - return &LambdaBackend{ - apiClient: NewLambdaAPIClient(config.ApiKey), - config: config, - }, nil -} - -func (b *LambdaBackend) Terminate(ctx context.Context) error { - return gerrors.Wrap(b.apiClient.TerminateInstance(ctx, []string{b.config.InstanceID})) -} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 19232c8ca..3fe325ef5 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -22,96 +22,134 @@ import ( "github.com/dstackai/dstack/runner/internal/gerrors" ) -func RunDocker(ctx context.Context, params DockerParameters, serverAPI APIAdapter) error { +type DockerRunner struct { + client *docker.Client + dockerParams DockerParameters + state RunnerStatus +} + +func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { client, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) if err != nil { - return err + return nil, err } - log.Println("Waiting for registry auth") - registryAuth := serverAPI.GetRegistryAuth() - serverAPI.SetState(Pulling) + runner := &DockerRunner{ + client: client, + dockerParams: dockerParams, + state: Pending, + } + return runner, nil +} + +func (d *DockerRunner) Run(ctx context.Context, cfg DockerImageConfig) error { + var err error log.Println("Pulling image") - imageName := params.DockerImageName() - if imageName == "" { - imageName = registryAuth.ImageName - } - if err = pullImage(ctx, client, imageName, registryAuth); err != nil { - return gerrors.Wrap(err) + d.state = Pulling + if err = pullImage(ctx, d.client, cfg); err != nil { + d.state = Pending + fmt.Printf("pullImage error: %s\n", err.Error()) + return err } log.Println("Creating container") - containerID, err := createContainer(ctx, client, params) + d.state = Creating + containerID, err := createContainer(ctx, d.client, d.dockerParams, cfg) if err != nil { - return gerrors.Wrap(err) + d.state = Pending + fmt.Printf("createContainer error: %s\n", err.Error()) + return err } - if !params.DockerKeepContainer() { + + if !d.dockerParams.DockerKeepContainer() { defer func() { log.Println("Deleting container") - _ = client.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{Force: true}) + err := d.client.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{Force: true}) + if err != nil { + log.Printf("ContainerRemove error: %s\n", err.Error()) + } }() } - serverAPI.SetState(Running) log.Printf("Running container, id=%s\n", containerID) - if err = runContainer(ctx, client, containerID); err != nil { - return gerrors.Wrap(err) + d.state = Running + if err = runContainer(ctx, d.client, containerID); err != nil { + d.state = Pending + fmt.Printf("runContainer error: %s\n", err.Error()) + return err } - log.Println("Container finished successfully") + + log.Printf("Container finished successfully, id=%s\n", containerID) + + d.state = Pending return nil } -func pullImage(ctx context.Context, client docker.APIClient, imageName string, imagePullConfig ImagePullConfig) error { - if !strings.Contains(imageName, ":") { - imageName += ":latest" +func (d DockerRunner) GetState() RunnerStatus { + return d.state +} + +func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerImageConfig) error { + if !strings.Contains(taskParams.ImageName, ":") { + taskParams.ImageName += ":latest" } images, err := client.ImageList(ctx, types.ImageListOptions{ - Filters: filters.NewArgs(filters.Arg("reference", imageName)), + Filters: filters.NewArgs(filters.Arg("reference", taskParams.ImageName)), }) if err != nil { return gerrors.Wrap(err) } - if len(images) > 0 { + + // TODO: force pull latset + if len(images) > 0 && !strings.Contains(taskParams.ImageName, ":latest") { return nil } opts := types.ImagePullOptions{} - regAuth, _ := imagePullConfig.EncodeRegistryAuth() + regAuth, _ := taskParams.EncodeRegistryAuth() if regAuth != "" { opts.RegistryAuth = regAuth } - reader, err := client.ImagePull(ctx, imageName, opts) // todo test registry auth + reader, err := client.ImagePull(ctx, taskParams.ImageName, opts) // todo test registry auth if err != nil { return gerrors.Wrap(err) } defer func() { _ = reader.Close() }() - _, err = io.ReadAll(reader) - return gerrors.Wrap(err) + _, err = io.Copy(io.Discard, reader) + if err != nil { + return gerrors.Wrap(err) + } + + // {"status":"Pulling from clickhouse/clickhouse-server","id":"latest"} + // {"status":"Digest: sha256:2ff5796c67e8d588273a5f3f84184b9cdaa39a324bcf74abd3652d818d755f8c"} + // {"status":"Status: Downloaded newer image for clickhouse/clickhouse-server:latest"} + + return nil } -func createContainer(ctx context.Context, client docker.APIClient, params DockerParameters) (string, error) { +func createContainer(ctx context.Context, client docker.APIClient, dockerParams DockerParameters, taskParams DockerImageConfig) (string, error) { runtime, err := getRuntime(ctx, client) if err != nil { return "", gerrors.Wrap(err) } - mounts, err := params.DockerMounts() + mounts, err := dockerParams.DockerMounts() if err != nil { return "", gerrors.Wrap(err) } containerConfig := &container.Config{ - Image: params.DockerImageName(), - Cmd: []string{strings.Join(params.DockerShellCommands(), " && ")}, + Image: taskParams.ImageName, + Cmd: []string{strings.Join(dockerParams.DockerShellCommands(), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, - ExposedPorts: exposePorts(params.DockerPorts()...), + ExposedPorts: exposePorts(dockerParams.DockerPorts()...), } hostConfig := &container.HostConfig{ NetworkMode: getNetworkMode(), - PortBindings: bindPorts(params.DockerPorts()...), + PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, Sysctls: map[string]string{}, Runtime: runtime, @@ -204,21 +242,17 @@ func getRuntime(ctx context.Context, client docker.APIClient) (string, error) { /* DockerParameters interface implementation for CLIArgs */ -func (c *CLIArgs) DockerImageName() string { - return c.Docker.ImageName -} - -func (c *CLIArgs) DockerKeepContainer() bool { +func (c CLIArgs) DockerKeepContainer() bool { return c.Docker.KeepContainer } -func (c *CLIArgs) DockerShellCommands() []string { +func (c CLIArgs) DockerShellCommands() []string { commands := getSSHShellCommands(c.Docker.SSHPort, c.Docker.PublicSSHKey) commands = append(commands, fmt.Sprintf("%s %s", DstackRunnerBinaryName, strings.Join(c.getRunnerArgs(), " "))) return commands } -func (c *CLIArgs) DockerMounts() ([]mount.Mount, error) { +func (c CLIArgs) DockerMounts() ([]mount.Mount, error) { runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", time.Now().Format("20060102-150405")) if err := os.MkdirAll(runnerTemp, 0755); err != nil { return nil, gerrors.Wrap(err) @@ -238,6 +272,6 @@ func (c *CLIArgs) DockerMounts() ([]mount.Mount, error) { }, nil } -func (c *CLIArgs) DockerPorts() []int { +func (c CLIArgs) DockerPorts() []int { return []int{c.Runner.HTTPPort, c.Docker.SSHPort} } diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index b321224da..374e57455 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -31,7 +31,8 @@ func TestDocker_SSHServer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) defer cancel() - assert.NoError(t, RunDocker(ctx, params, &apiAdapterMock{})) + dockerRunner, _ := NewDockerRunner(params) + assert.NoError(t, dockerRunner.Run(ctx, DockerTaskConfig{ImageName: "ubuntu"})) } // TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH @@ -56,11 +57,13 @@ func TestDocker_SSHServerConnect(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) defer cancel() + dockerRunner, _ := NewDockerRunner(params) + var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - assert.NoError(t, RunDocker(ctx, params, &apiAdapterMock{})) + assert.NoError(t, dockerRunner.Run(ctx, DockerTaskConfig{ImageName: "ubuntu"})) }() for i := 0; i < timeout; i++ { @@ -89,10 +92,6 @@ type dockerParametersMock struct { publicSSHKey string } -func (c *dockerParametersMock) DockerImageName() string { - return "ubuntu" -} - func (c *dockerParametersMock) DockerKeepContainer() bool { return false } @@ -114,14 +113,6 @@ func (c *dockerParametersMock) DockerMounts() ([]mount.Mount, error) { return nil, nil } -type apiAdapterMock struct{} - -func (s *apiAdapterMock) GetRegistryAuth() ImagePullConfig { - return ImagePullConfig{} -} - -func (s *apiAdapterMock) SetState(string) {} - /* Utilities */ var portNumber int32 = 10000 diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index b03d4963a..913587e52 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -9,13 +9,7 @@ import ( "github.com/docker/docker/api/types/registry" ) -type APIAdapter interface { - GetRegistryAuth() ImagePullConfig - SetState(string) -} - type DockerParameters interface { - DockerImageName() string DockerKeepContainer() bool DockerShellCommands() []string DockerMounts() ([]mount.Mount, error) @@ -40,21 +34,23 @@ type CLIArgs struct { } Docker struct { - SSHPort int - RegistryAuthRequired bool - ImageName string - KeepContainer bool - PublicSSHKey string + SSHPort int + KeepContainer bool + PublicSSHKey string } } -type ImagePullConfig struct { +type DockerImageConfig struct { Username string Password string ImageName string } -func (ra ImagePullConfig) EncodeRegistryAuth() (string, error) { +func (ra DockerImageConfig) EncodeRegistryAuth() (string, error) { + if ra.Username == "" && ra.Password == "" { + return "", nil + } + authConfig := registry.AuthConfig{ Username: ra.Username, Password: ra.Password, diff --git a/runner/internal/shim/states.go b/runner/internal/shim/states.go index eedc05c85..e12f66041 100644 --- a/runner/internal/shim/states.go +++ b/runner/internal/shim/states.go @@ -1,7 +1,10 @@ package shim +type RunnerStatus string + const ( - WaitRegistryAuth = "waiting_for_registry_auth" - Pulling = "pulling" - Running = "running" + Pending RunnerStatus = "pending" + Pulling RunnerStatus = "pulling" + Creating RunnerStatus = "creating" + Running RunnerStatus = "running" ) diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 88af9a5a5..d36b3f426 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -134,7 +134,6 @@ def create_instance( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ), tags=tags, security_group_id=aws_resources.create_security_group(ec2_client, project_id), @@ -199,7 +198,6 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ), tags=tags, security_group_id=aws_resources.create_security_group(ec2_client, project_id), diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index c8e8c1b19..32ddb3e1c 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -94,17 +94,11 @@ def get_instance_name(run: Run, job: Job) -> str: def get_user_data( - backend: BackendType, - image_name: str, authorized_keys: List[str], - registry_auth_required: bool, cloud_config_kwargs: Optional[dict] = None, ) -> str: commands = get_shim_commands( - backend=backend, - image_name=image_name, authorized_keys=authorized_keys, - registry_auth_required=registry_auth_required, ) return get_cloud_config( runcmd=[["sh", "-c", " && ".join(commands)]], @@ -114,24 +108,19 @@ def get_user_data( def get_shim_commands( - backend: BackendType, - image_name: str, authorized_keys: List[str], - registry_auth_required: bool, ) -> List[str]: build = get_dstack_runner_version() env = { - "DSTACK_BACKEND": backend.value, "DSTACK_RUNNER_LOG_LEVEL": "6", "DSTACK_RUNNER_VERSION": build, - "DSTACK_IMAGE_NAME": image_name, "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), "DSTACK_HOME": "/root/.dstack", } commands = get_dstack_shim(build) for k, v in env.items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script(registry_auth_required) + commands += get_run_shim_script() return commands @@ -156,12 +145,9 @@ def get_dstack_shim(build: str) -> List[str]: ] -def get_run_shim_script(registry_auth_required: bool) -> List[str]: +def get_run_shim_script() -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" - with_auth_flag = "--with-auth" if registry_auth_required else "" - return [ - f"nohup dstack-shim {dev_flag} docker {with_auth_flag} --keep-container >/root/shim.log 2>&1 &" - ] + return [f"nohup dstack-shim {dev_flag} docker --keep-container >/root/shim.log 2>&1 &"] def get_gateway_user_data(authorized_key: str) -> str: diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 0e0c3acc8..4de57863c 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -78,12 +78,8 @@ def create_instance( ) ) - registry_auth_required = instance_config.job_docker_config.registry_auth is not None commands = get_shim_commands( - backend=BackendType.DATACRUNCH, - image_name=instance_config.job_docker_config.image.image, authorized_keys=public_keys, - registry_auth_required=registry_auth_required, ) startup_script = " ".join([" && ".join(commands)]) @@ -147,13 +143,10 @@ def run_job( ) commands = get_shim_commands( - backend=BackendType.DATACRUNCH, - image_name=job.job_spec.image_name, authorized_keys=[ run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ) startup_script = " ".join([" && ".join(commands)]) diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 15b0e6765..3b17ec3a9 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -98,7 +98,6 @@ def create_instance( ) disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) - registry_auth_required = instance_config.job_docker_config.registry_auth is not None for zone in _get_instance_zones(instance_offer): request = compute_v1.InsertInstanceRequest() request.zone = zone @@ -119,7 +118,6 @@ def create_instance( backend=BackendType.GCP, image_name=instance_config.job_docker_config.image.image, authorized_keys=instance_config.get_public_keys(), - registry_auth_required=registry_auth_required, ), labels={ "owner": "dstack", @@ -190,7 +188,6 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ), labels={ "owner": "dstack", diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index 88847543e..5b7dd56b8 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -63,7 +63,6 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ) # shim is asssumed to be run under root launch_command = "sudo sh -c '" + "&& ".join(commands) + "'" diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 9360065f9..243a673cf 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -83,7 +83,6 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ), }, disk_size_gb=disk_size, diff --git a/src/dstack/_internal/core/backends/tensordock/compute.py b/src/dstack/_internal/core/backends/tensordock/compute.py index 2d296dca0..b22dc6dee 100644 --- a/src/dstack/_internal/core/backends/tensordock/compute.py +++ b/src/dstack/_internal/core/backends/tensordock/compute.py @@ -56,7 +56,6 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], - registry_auth_required=job.job_spec.registry_auth is not None, ) try: resp = self.api_client.deploy_single( diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 1c457e1ea..e47d653fb 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -65,7 +65,7 @@ class HealthcheckResponse(BaseModel): service: str -class RegistryAuthBody(BaseModel): +class DockerImageBody(BaseModel): username: str password: str image_name: str diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 992f65a3f..8000ffac8 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -9,9 +9,9 @@ from dstack._internal.core.models.repos.remote import RemoteRepoCreds from dstack._internal.core.models.runs import JobSpec, RunSpec from dstack._internal.server.schemas.runner import ( + DockerImageBody, HealthcheckResponse, PullResponse, - RegistryAuthBody, SubmitBody, ) @@ -99,10 +99,10 @@ def healthcheck(self) -> Optional[HealthcheckResponse]: except requests.exceptions.RequestException: return None - def registry_auth(self, username: str, password: str, image_name: str): + def submit(self, username: str, password: str, image_name: str): resp = requests.post( - self._url("/api/registry_auth"), - json=RegistryAuthBody( + self._url("/api/submit"), + json=DockerImageBody( username=username, password=password, image_name=image_name ).dict(), ) From 7f3316cc457dfac5643b2701f4975feaa5ebec9e Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Tue, 16 Jan 2024 07:51:02 +0300 Subject: [PATCH 06/47] update --- runner/cmd/runner/cmd.go | 5 +- runner/cmd/runner/main.go | 29 ++- runner/cmd/shim/main.go | 12 +- runner/cmd/shim/version.go | 2 +- runner/go.mod | 41 +--- runner/go.sum | 190 +----------------- runner/internal/runner/api/http_test.go | 2 +- runner/internal/shim/api/http.go | 6 +- runner/internal/shim/api/schemas.go | 4 +- runner/internal/shim/docker.go | 25 +-- runner/internal/shim/docker_test.go | 4 +- src/dstack/_internal/cli/commands/pool.py | 4 +- .../_internal/core/backends/azure/compute.py | 1 - .../_internal/core/backends/base/compute.py | 15 +- .../core/backends/datacrunch/compute.py | 20 ++ src/dstack/_internal/core/models/runs.py | 1 + .../background/tasks/process_finished_jobs.py | 49 ++++- .../background/tasks/process_running_jobs.py | 22 +- .../tasks/process_submitted_jobs.py | 14 +- ...add_pools.py => 73a959f64596_add_pools.py} | 24 +-- src/dstack/_internal/server/models.py | 4 + src/dstack/_internal/server/routers/pools.py | 4 +- src/dstack/_internal/server/routers/runs.py | 2 +- src/dstack/_internal/server/schemas/runs.py | 2 +- src/dstack/_internal/server/services/pools.py | 14 ++ src/dstack/_internal/server/services/runs.py | 13 +- src/dstack/api/server/_pools.py | 4 +- src/dstack/api/server/_runs.py | 2 +- 28 files changed, 207 insertions(+), 308 deletions(-) rename src/dstack/_internal/server/migrations/versions/{ec4dbadbab3c_add_pools.py => 73a959f64596_add_pools.py} (89%) diff --git a/runner/cmd/runner/cmd.go b/runner/cmd/runner/cmd.go index ca99ee258..a22f36360 100644 --- a/runner/cmd/runner/cmd.go +++ b/runner/cmd/runner/cmd.go @@ -56,7 +56,10 @@ func App() { }, }, Action: func(c *cli.Context) error { - start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel) + err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel) + if err != nil { + return cli.Exit(err, 1) + } return nil }, }, diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 61142c079..974062484 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -3,30 +3,37 @@ package main import ( "context" "fmt" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/runner/api" - "github.com/sirupsen/logrus" "io" _ "net/http/pprof" "os" "path/filepath" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/runner/api" + "github.com/sirupsen/logrus" + "github.com/ztrue/tracerr" ) func main() { App() } -func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int) { +func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int) error { if err := os.MkdirAll(tempDir, 0755); err != nil { - log.Error(context.TODO(), "Failed to create temp directory", "err", err) - os.Exit(1) + return tracerr.Errorf("Failed to create temp directory^ %w", err) } + defaultLogFile, err := log.CreateAppendFile(filepath.Join(tempDir, "default.log")) if err != nil { - log.Error(context.TODO(), "Failed to create default log file", "err", err) - os.Exit(1) + return tracerr.Errorf("Failed to create default log file: %w", err) } - defer func() { _ = defaultLogFile.Close() }() + defer func() { + err = defaultLogFile.Close() + if err != nil { + tracerr.Print(err) + } + }() + log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) @@ -34,6 +41,8 @@ func start(tempDir string, homeDir string, workingDir string, httpPort int, logL log.Trace(context.TODO(), "Starting API server", "port", httpPort) if err := server.Run(); err != nil { - log.Error(context.TODO(), "Server failed", "err", err) + return tracerr.Errorf("Server failed: %w", err) } + + return nil } diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 6a5e331fd..7ab960775 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -11,7 +11,6 @@ import ( "time" "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/api" "github.com/urfave/cli/v2" @@ -96,7 +95,7 @@ func main() { Action: func(c *cli.Context) error { if args.Runner.BinaryPath == "" { if err := args.Download("linux"); err != nil { - return gerrors.Wrap(err) + return cli.Exit(err, 1) } defer func() { _ = os.Remove(args.Runner.BinaryPath) }() } @@ -110,13 +109,13 @@ func main() { // set dstack home path args.Shim.HomeDir, err = getDstackHome(args.Shim.HomeDir) if err != nil { - return err + return cli.Exit(err, 1) } log.Printf("Docker: %+v\n", args) dockerRunner, err := shim.NewDockerRunner(args) if err != nil { - return err + return cli.Exit(err, 1) } address := fmt.Sprintf(":%d", args.Shim.HTTPPort) @@ -129,11 +128,10 @@ func main() { }() if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - panic(err) + return cli.Exit(err, 1) } return nil - }, }, }, @@ -151,7 +149,7 @@ func getDstackHome(flag string) (string, error) { home, err := os.UserHomeDir() if err != nil { - return "", gerrors.Wrap(err) + return "", err } return filepath.Join(home, consts.DSTACK_DIR_PATH), nil } diff --git a/runner/cmd/shim/version.go b/runner/cmd/shim/version.go index c2dfda93c..7aa1d0aae 100644 --- a/runner/cmd/shim/version.go +++ b/runner/cmd/shim/version.go @@ -1,3 +1,3 @@ package main -var Version = "0.0.0dev1" +var Version = "0.0.0dev2" diff --git a/runner/go.mod b/runner/go.mod index 2771f9714..83f89288f 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -3,12 +3,6 @@ module github.com/dstackai/dstack/runner go 1.19 require ( - cloud.google.com/go/compute v1.23.0 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1 - github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.2.1 - github.com/aws/aws-sdk-go-v2/config v1.18.39 - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 - github.com/aws/aws-sdk-go-v2/service/ec2 v1.118.0 github.com/bluekeyes/go-gitdiff v0.6.0 github.com/creack/pty v1.1.18 github.com/docker/docker v24.0.6+incompatible @@ -18,28 +12,15 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 github.com/urfave/cli/v2 v2.25.7 + github.com/ztrue/tracerr v0.4.0 golang.org/x/crypto v0.14.0 ) require ( - cloud.google.com/go/compute/metadata v0.2.3 // indirect dario.cat/mergo v1.0.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect - github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/ProtonMail/go-crypto v0.0.0-20230717121422-5aa5874ade95 // indirect github.com/acomagu/bufpipe v1.0.4 // indirect - github.com/aws/aws-sdk-go-v2 v1.21.0 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.13.37 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.13.6 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.6 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.21.5 // indirect - github.com/aws/smithy-go v1.14.2 // indirect github.com/cloudflare/circl v1.3.3 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -49,27 +30,18 @@ require ( github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.4.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.0.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/google/s2a-go v0.1.4 // indirect - github.com/google/uuid v1.3.0 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect - github.com/googleapis/gax-go/v2 v2.11.0 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5 // indirect github.com/juju/loggo v1.0.0 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/klauspost/compress v1.15.13 // indirect - github.com/kylelemons/godebug v1.1.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect - github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -78,22 +50,13 @@ require ( github.com/ulikunitz/xz v0.5.11 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - go.opencensus.io v0.24.0 // indirect golang.org/x/mod v0.13.0 // indirect golang.org/x/net v0.16.0 // indirect - golang.org/x/oauth2 v0.8.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.14.0 // indirect - google.golang.org/api v0.126.0 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/grpc v1.55.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gotest.tools/v3 v3.5.0 // indirect ) diff --git a/runner/go.sum b/runner/go.sum index 97afdac9f..6ac28bc9c 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -1,27 +1,7 @@ cloud.google.com/go v0.16.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.110.2 h1:sdFPBr6xG9/wkBbfhmUz/JmZC7X6LavQgcrVINrKiVA= -cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY= -cloud.google.com/go/compute v1.23.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= -cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= -cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1 h1:LNHhpdK7hzUcx/k1LIcuh5k7k1LGIWLQfCjaneSj7Fc= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.2.1 h1:UPeCRD+XY7QlaGQte2EVI2iOcWvUYA2XY8w5T/8v0NQ= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.2.1/go.mod h1:oGV6NlB0cvi1ZbYRR2UN44QHxWFyGk+iylgD0qaMXjA= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.1.2 h1:mLY+pNLjCUeKhgnAJWAKhEUQM+RJQo2H1fuGSw1Ky1E= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.0.0 h1:nBy98uKOIfun5z6wx6jwWLrULcM0+cjBalBFZlEZ7CA= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.0.0 h1:ECsQtyERDVz3NP3kvDOTLvbQhqWp/x9EsGKtb4ogUr8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 h1:WpB/QDNLpMw72xHJc34BNNykqSOeEJDAWkhf0u12/Jk= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= @@ -31,50 +11,14 @@ github.com/ProtonMail/go-crypto v0.0.0-20230717121422-5aa5874ade95/go.mod h1:EjA github.com/acomagu/bufpipe v1.0.4 h1:e3H4WUzM3npvo5uv95QuJM3cQspFNtFBzvJ2oNjKIDQ= github.com/acomagu/bufpipe v1.0.4/go.mod h1:mxdxdup/WdsKVreO5GpW4+M/1CE2sMG4jeGJ2sYmHc4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= -github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/arduino/go-paths-helper v1.2.0 h1:qDW93PR5IZUN/jzO4rCtexiwF8P4OIcOmcSgAYLZfY4= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= -github.com/aws/aws-sdk-go-v2 v1.21.0 h1:gMT0IW+03wtYJhRqTVYn0wLzwdnK9sRMcxmtfGzRdJc= -github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= -github.com/aws/aws-sdk-go-v2/config v1.18.39 h1:oPVyh6fuu/u4OiW4qcuQyEtk7U7uuNBmHmJSLg1AJsQ= -github.com/aws/aws-sdk-go-v2/config v1.18.39/go.mod h1:+NH/ZigdPckFpgB1TRcRuWCB/Kbbvkxc/iNAKTq5RhE= -github.com/aws/aws-sdk-go-v2/credentials v1.13.37 h1:BvEdm09+ZEh2XtN+PVHPcYwKY3wIeB6pw7vPRM4M9/U= -github.com/aws/aws-sdk-go-v2/credentials v1.13.37/go.mod h1:ACLrdkd4CLZyXOghZ8IYumQbcooAcp2jo/s2xsFH8IM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 h1:uDZJF1hu0EVT/4bogChk8DyjSF6fof6uL/0Y26Ma7Fg= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11/go.mod h1:TEPP4tENqBGO99KwVpV9MlOX4NSrSLP8u3KRy2CDwA8= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 h1:22dGT7PneFMx4+b3pz7lMTRyN8ZKH7M2cW4GP9yUS2g= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41/go.mod h1:CrObHAuPneJBlfEJ5T3szXOUkLEThaGfvnhTf33buas= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 h1:SijA0mgjV8E+8G45ltVHs0fvKpTj8xmZJ3VwhGKtUSI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35/go.mod h1:SJC1nEVVva1g3pHAIdCp7QsRIkMmLAgoDquQ9Rr8kYw= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42 h1:GPUcE/Yq7Ur8YSUk6lVkoIMWnJNO0HT18GUzCWCgCI0= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42/go.mod h1:rzfdUlfA+jdgLDmPKjd3Chq9V7LVLYo1Nz++Wb91aRo= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.118.0 h1:ueSJS07XpOwCFhYTHh/Jjw856+U+u0Dv5LIIPOB1/Ns= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.118.0/go.mod h1:0FhI2Rzcv5BNM3dNnbcCx2qa2naFZoAidJi11cQgzL0= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 h1:CdzPW9kKitgIiLV1+MHobfR5Xg25iYnyzWZhyQuSlDI= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35/go.mod h1:QGF2Rs33W5MaN9gYdEQOBBFPLwTZkEhRwI33f7KIG0o= -github.com/aws/aws-sdk-go-v2/service/sso v1.13.6 h1:2PylFCfKCEDv6PeSN09pC/VUiRd10wi1VfHG5FrW0/g= -github.com/aws/aws-sdk-go-v2/service/sso v1.13.6/go.mod h1:fIAwKQKBFu90pBxx07BFOMJLpRUGu8VOzLJakeY+0K4= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.6 h1:pSB560BbVj9ZlJZF4WYj5zsytWHWKxg+NgyGV4B2L58= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.6/go.mod h1:yygr8ACQRY2PrEcy3xsUI357stq2AxnFM6DIsR9lij4= -github.com/aws/aws-sdk-go-v2/service/sts v1.21.5 h1:CQBFElb0LS8RojMJlxRSo/HXipvTZW2S44Lt9Mk2aYQ= -github.com/aws/aws-sdk-go-v2/service/sts v1.21.5/go.mod h1:VC7JDqsqiwXukYEDjoHh9U0fOJtNWh04FPQz4ct4GGU= -github.com/aws/smithy-go v1.14.2 h1:MJU9hqBGbvWZdApzpvoF2WAIJDbtjK2NDJSiJP7HblQ= -github.com/aws/smithy-go v1.14.2/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/bluekeyes/go-gitdiff v0.6.0 h1:zyDBSR/o1axUl4lD08EWkXO3I834tBimmGUB0mhrvhQ= github.com/bluekeyes/go-gitdiff v0.6.0/go.mod h1:QpfYYO1E0fTVHVZAZKiRjtSGY9823iCdvGXBcEzHGbM= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= -github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/codeclysm/extract/v3 v3.1.0 h1:z14FpkRizce3HNHsqJoZWwj0ovzZ2hiIkmT96FQS3j8= github.com/codeclysm/extract/v3 v3.1.0/go.mod h1:ZJi80UG2JtfHqJI+lgJSCACttZi++dHxfWuPaMhlOfQ= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= @@ -85,7 +29,6 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE= @@ -97,15 +40,8 @@ github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD github.com/elazarl/goproxy v0.0.0-20221015165544-a0805db90819 h1:RIB4cRk+lBqKK3Oy0r2gRX4ui7tuhiZq2SuTtTCi0/0= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.3-0.20170329110642-4da3e2cfbabc/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/garyburd/redigo v1.1.1-0.20170914051019-70e1b1943d4f/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= @@ -117,67 +53,25 @@ github.com/go-git/go-git/v5 v5.8.1/go.mod h1:FHFuoD6yGz5OSKEBK+aWN9Oah0q54Jxl0ab github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= -github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f h1:16RtHeWGkJMc80Etb8RPCcKevXGldr57+LOyZt8zOlg= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f/go.mod h1:ijRvpgDJDI262hYq/IQVYgf8hd8IHUs93Ol0kvMBAx4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/lint v0.0.0-20170918230701-e5d664eb928e/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.1.1-0.20171103154506-982329095285/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc= -github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k= -github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.11.0 h1:9V9PWXEsWnPpQhu/PeQIkS4eGzMlTLGgt80cUUI8Ki4= -github.com/googleapis/gax-go/v2 v2.11.0/go.mod h1:DxmR61SGKkGLa2xigwuZIQpkCI2S5iydzRfb3peWZJI= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20170920190843-316c5e0ff04e/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= -github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/hashicorp/hcl v0.0.0-20170914154624-68e816d1c783/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w= github.com/inconshreveable/log15 v0.0.0-20170622235902-74a0988b5f80/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5 h1:rhqTjzJlm7EbkELJDKMTU7udov+Se0xZkWmugr6zGok= github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= @@ -198,8 +92,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lunixbochs/vtclean v0.0.0-20160125035106-4fbf7632a2c6/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.7.4-0.20170902060319-8d7837e64d3c/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A= @@ -221,14 +113,10 @@ github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zM github.com/pelletier/go-toml v1.0.1-0.20170904195809-1d6b12b7cb29/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -249,7 +137,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -266,41 +153,28 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +github.com/ztrue/tracerr v0.4.0 h1:vT5PFxwIGs7rCg9ZgJ/y0NmOpJkPCPFK8x0vVIYzd04= +github.com/ztrue/tracerr v0.4.0/go.mod h1:PaFfYlas0DfmXNpo7Eay4MFhZUONqvXM+T2HyGPpngk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -310,31 +184,21 @@ golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20170912212905-13449ad91cb2/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.8.0 h1:6dkIjl3j3LtZ/O3sTgZTMsLKSftL/B8Zgq4huOIIUu8= -golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/sync v0.0.0-20170517211232-f52d1811a629/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -355,19 +219,13 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910 h1:bCMaBn7ph495H+x72gEvgcv+mDRd9dElbzo/mVCMxX4= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -380,48 +238,9 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20170921000349-586095a6e407/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.126.0 h1:q4GJq+cAdMAC7XP7njvQ4tvohGLiSlytuL4BQxbIZ+o= -google.golang.org/api v0.126.0/go.mod h1:mBwVAtz+87bEN6CbA1GtZPDOqY2R5ONPqJeIlvyo4Aw= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20170918111702-1e559d0a00ee/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc h1:8DyZCyvI8mE1IdLy/60bS+52xfymkE72wv1asokgtao= -google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= -google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag= -google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20160105164936-4f90aeace3a2/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -432,14 +251,11 @@ gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3M gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index dfbd64b32..dbb8ddf42 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -19,7 +19,7 @@ func (ds DummyRunner) GetState() shim.RunnerStatus { return ds.State } -func (ds DummyRunner) Run(context.Context, shim.DockerTaskConfig) error { +func (ds DummyRunner) Run(context.Context, shim.DockerImageConfig) error { return nil } diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index f1915f223..1936d14ed 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "log" "net/http" @@ -32,7 +33,10 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) ( } go func(taskParams shim.DockerImageConfig) { - s.runner.Run(context.TODO(), taskParams) + err := s.runner.Run(context.TODO(), taskParams) + if err != nil { + fmt.Printf("failed Run %v", err) + } }(body.TaskParams()) return nil, nil diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 274109461..a6fefe441 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -16,8 +16,8 @@ type PullResponse struct { State string `json:"state"` } -func (ra DockerTaskBody) TaskParams() shim.DockerTaskConfig { - res := shim.DockerTaskConfig{ +func (ra DockerTaskBody) TaskParams() shim.DockerImageConfig { + res := shim.DockerImageConfig{ ImageName: ra.ImageName, Username: ra.Username, Password: ra.Password, diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 3fe325ef5..50531c434 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -19,7 +19,8 @@ import ( docker "github.com/docker/docker/client" "github.com/docker/go-connections/nat" "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/gerrors" + + "github.com/ztrue/tracerr" ) type DockerRunner struct { @@ -31,7 +32,7 @@ type DockerRunner struct { func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { client, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) if err != nil { - return nil, err + return nil, tracerr.Wrap(err) } runner := &DockerRunner{ @@ -98,7 +99,7 @@ func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerIm Filters: filters.NewArgs(filters.Arg("reference", taskParams.ImageName)), }) if err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } // TODO: force pull latset @@ -114,13 +115,13 @@ func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerIm reader, err := client.ImagePull(ctx, taskParams.ImageName, opts) // todo test registry auth if err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } defer func() { _ = reader.Close() }() _, err = io.Copy(io.Discard, reader) if err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } // {"status":"Pulling from clickhouse/clickhouse-server","id":"latest"} @@ -133,12 +134,12 @@ func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerIm func createContainer(ctx context.Context, client docker.APIClient, dockerParams DockerParameters, taskParams DockerImageConfig) (string, error) { runtime, err := getRuntime(ctx, client) if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } mounts, err := dockerParams.DockerMounts() if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } containerConfig := &container.Config{ @@ -157,20 +158,20 @@ func createContainer(ctx context.Context, client docker.APIClient, dockerParams } resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, "") if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } return resp.ID, nil } func runContainer(ctx context.Context, client docker.APIClient, containerID string) error { if err := client.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } waitCh, errorCh := client.ContainerWait(ctx, containerID, "") select { case <-waitCh: case err := <-errorCh: - return gerrors.Wrap(err) + return tracerr.Wrap(err) } return nil } @@ -230,7 +231,7 @@ func getNetworkMode() container.NetworkMode { func getRuntime(ctx context.Context, client docker.APIClient) (string, error) { info, err := client.Info(ctx) if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } for name := range info.Runtimes { if name == consts.NVIDIA_RUNTIME { @@ -255,7 +256,7 @@ func (c CLIArgs) DockerShellCommands() []string { func (c CLIArgs) DockerMounts() ([]mount.Mount, error) { runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", time.Now().Format("20060102-150405")) if err := os.MkdirAll(runnerTemp, 0755); err != nil { - return nil, gerrors.Wrap(err) + return nil, tracerr.Wrap(err) } return []mount.Mount{ diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 374e57455..60f29a3c9 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -32,7 +32,7 @@ func TestDocker_SSHServer(t *testing.T) { defer cancel() dockerRunner, _ := NewDockerRunner(params) - assert.NoError(t, dockerRunner.Run(ctx, DockerTaskConfig{ImageName: "ubuntu"})) + assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) } // TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH @@ -63,7 +63,7 @@ func TestDocker_SSHServerConnect(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - assert.NoError(t, dockerRunner.Run(ctx, DockerTaskConfig{ImageName: "ubuntu"})) + assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) }() for i := 0; i < timeout; i++ { diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 677f6f05b..26b7a23b6 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -259,7 +259,7 @@ def _add(self, args: argparse.Namespace): profile.pool_name = pool_name - with console.status("Getting run plan..."): + with console.status("Getting instances..."): requirements, offers = self.api.runs.get_offers(profile) print(pool_name, profile, requirements, offers) @@ -269,7 +269,7 @@ def _add(self, args: argparse.Namespace): return try: - with console.status("Submitting run..."): + with console.status("Submitting instance..."): self.api.runs.create_instance(pool_name, profile) except ServerClientError as e: raise CLIError(e.msg) diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index fdd6edda3..2e5c8c9f4 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -136,7 +136,6 @@ def run_job( backend=BackendType.AZURE, image_name=job.job_spec.image_name, authorized_keys=ssh_pub_keys, - registry_auth_required=job.job_spec.registry_auth is not None, ), ssh_pub_keys=ssh_pub_keys, spot=instance_offer.instance.resources.spot, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 32ddb3e1c..478f7f558 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -139,8 +139,13 @@ def get_dstack_shim(build: str) -> List[str]: if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" + + if os.getenv("DEV_DSTACK_RUNNER", None) is not None: + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" + return [ - f'sudo curl --output /usr/local/bin/dstack-shim "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"', + f'sudo curl --output /usr/local/bin/dstack-shim "{url}"', "sudo chmod +x /usr/local/bin/dstack-shim", ] @@ -205,8 +210,14 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: bucket = "dstack-runner-downloads-stgn" if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" + + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" + + if os.getenv("DEV_DSTACK_RUNNER", None) is not None: + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" + commands += [ - f'curl --output {runner} "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"', + f"curl --output {runner} {url}", f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 4de57863c..f7960f86a 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -15,6 +15,9 @@ ) from dstack._internal.core.models.runs import Job, Requirements, Run from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.utils.logging import get_logger + +logger = get_logger("datacrunch.compute") class DataCrunchCompute(Compute): @@ -84,6 +87,9 @@ def create_instance( startup_script = " ".join([" && ".join(commands)]) script_name = f"dstack-{instance_config.instance_name}.sh" + + logger.debug("startup script:", startup_script) + startup_script_ids = self.api_client.get_or_create_startup_scrpit( name=script_name, script=startup_script ) @@ -104,6 +110,20 @@ def create_instance( location=instance_offer.region, ) + logger.debug( + "deploy_instance", + { + "instance_type": instance_offer.instance.name, + "ssh_key_ids": ssh_ids, + "startup_script_id": startup_script_ids, + "hostname": instance_config.instance_name, + "description": instance_config.instance_name, + "image": image_name, + "disk_size": disk_size, + "location": instance_offer.region, + }, + ) + running_instance = self.api_client.wait_for_instance(instance.id) if running_instance is None: raise BackendError(f"Wait instance {instance.id!r} timeout") diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index d5edc3ef4..9c22de472 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -127,6 +127,7 @@ class JobProvisioningData(BaseModel): backend: BackendType instance_type: InstanceType instance_id: str + pool_id: str hostname: str region: str price: float diff --git a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py index 2eacbe8cb..15621e6c0 100644 --- a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py @@ -1,9 +1,11 @@ from sqlalchemy import or_, select from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import JobSpec, JobStatus +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import InstanceStatus, JobSpec, JobStatus from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import GatewayModel, JobModel +from dstack._internal.server.models import GatewayModel, InstanceModel, JobModel from dstack._internal.server.services.gateways import gateway_connections_pool from dstack._internal.server.services.jobs import ( TERMINATING_PROCESSING_JOBS_IDS, @@ -12,6 +14,7 @@ terminate_job_submission_instance, ) from dstack._internal.server.services.logging import job_log +from dstack._internal.server.services.pools import get_instances_by_pool_id from dstack._internal.server.utils.common import run_async from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger @@ -31,7 +34,7 @@ async def process_finished_jobs(): or_(JobModel.remove_at.is_(None), JobModel.remove_at < get_current_datetime()), ) .order_by(JobModel.last_processed_at.asc()) - .limit(1) # TODO(egor-s) process multiple at once + .limit(1) ) job_model = res.scalar() if job_model is None: @@ -39,6 +42,7 @@ async def process_finished_jobs(): TERMINATING_PROCESSING_JOBS_IDS.add(job_model.id) try: await _process_job(job_id=job_model.id) + await _terminate_old_instance() finally: TERMINATING_PROCESSING_JOBS_IDS.remove(job_model.id) @@ -78,11 +82,19 @@ async def _process_job(job_id): except Exception as e: logger.warning("failed to unregister service: %s", e) try: - if job_submission.job_provisioning_data is not None: - await terminate_job_submission_instance( - project=job_model.project, - job_submission=job_submission, - ) + jpd = job_submission.job_provisioning_data + if jpd is not None: + if jpd.backend == BackendType.LOCAL: + instances = await get_instances_by_pool_id(session, jpd.pool_id) + for instance in instances: + if instance.name == jpd.instance_id: + instance.finished_at = get_current_datetime() + instance.status = InstanceStatus.READY + else: + await terminate_job_submission_instance( + project=job_model.project, + job_submission=job_submission, + ) job_model.removed = True logger.info(*job_log("marked as removed", job_model)) except Exception as e: @@ -90,3 +102,24 @@ async def _process_job(job_id): logger.error(*job_log("failed to terminate job instance: %s", job_model, e)) job_model.last_processed_at = get_current_datetime() await session.commit() + + +async def _terminate_old_instance(): + async with get_session_ctx() as session: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE, + InstanceModel.deleted == False, + ) + .options() + ) + instances = res.scalars().all() + + for instance in instances: + if instance.finished_at + instance.termination_idle_time > get_current_datetime(): + await terminate_job_submission_instance( + project=instance.project, + job_submission=job_submission, + ) + await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 97472c763..9858501a6 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -151,13 +151,13 @@ async def _process_job(job_id: UUID): repo_creds, ) - if success: - instance_name: str = job_provisioning_data.instance_id - pool_name = str(job.job_spec.pool_name) - instances = await get_pool_instances(session, project, pool_name) - for inst in instances: - if inst.name == instance_name: - inst.status = InstanceStatus.BUSY + if success: + instance_name: str = job_provisioning_data.instance_id + pool_name = str(job.job_spec.pool_name) + instances = await get_pool_instances(session, project, pool_name) + for inst in instances: + if inst.name == instance_name: + inst.status = InstanceStatus.BUSY if not success: # check timeout if job_submission.age > _get_runner_timeout_interval( @@ -333,25 +333,29 @@ def _process_provisioning_with_shim( is successful """ job_spec = parse_raw_as(JobSpec, job_model.job_spec_data) + shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT]) + resp = shim_client.healthcheck() if resp is None: logger.debug(*job_log("shim is not available yet", job_model)) return False # shim is not available yet + if registry_auth is not None: logger.debug(*job_log("authenticating to the registry...", job_model)) interpolate = VariablesInterpolator({"secrets": secrets}).interpolate - shim_client.registry_auth( + shim_client.submit( username=interpolate(registry_auth.username), password=interpolate(registry_auth.password), image_name=job_spec.image_name, ) else: - shim_client.registry_auth( + shim_client.submit( username="", password="", image_name=job_spec.image_name, ) + job_model.status = JobStatus.PULLING logger.info(*job_log("now is pulling", job_model)) return True diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index d76c0ff25..0d8850ca8 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -13,7 +13,12 @@ InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile +from dstack._internal.core.models.profiles import ( + DEFAULT_POOL_NAME, + CreationPolicy, + Profile, + TerminationPolicy, +) from dstack._internal.core.models.runs import ( InstanceStatus, Job, @@ -136,7 +141,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # pool capacity pool_instances = await get_pool_instances(session, project_model, run_pool) - available_instanses = (p for p in pool_instances if p.status == InstanceStatus.PENDING) + available_instanses = (p for p in pool_instances if p.status == InstanceStatus.READY) relevant_instances: List[InstanceModel] = [] for instance in available_instanses: if check_relevance(profile, instance): @@ -181,6 +186,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): ) if job_provisioning_data is not None and offer is not None: logger.info(*job_log("now is provisioning", job_model)) + job_provisioning_data.pool_id = str(pool.id) job_model.job_provisioning_data = job_provisioning_data.json() job_model.status = JobStatus.PROVISIONING @@ -188,9 +194,11 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): name=job.job_spec.job_name, project=project_model, pool=pool, - status=InstanceStatus.PENDING, + status=InstanceStatus.CREATING, job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), + termination_policy=profile.termination_policy, + termination_idle_time=profile.termination_idle_time, ) session.add(im) diff --git a/src/dstack/_internal/server/migrations/versions/ec4dbadbab3c_add_pools.py b/src/dstack/_internal/server/migrations/versions/73a959f64596_add_pools.py similarity index 89% rename from src/dstack/_internal/server/migrations/versions/ec4dbadbab3c_add_pools.py rename to src/dstack/_internal/server/migrations/versions/73a959f64596_add_pools.py index ab6bc30b3..ce1cb6119 100644 --- a/src/dstack/_internal/server/migrations/versions/ec4dbadbab3c_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/73a959f64596_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: ec4dbadbab3c -Revises: 48ad3ecbaea2 -Create Date: 2024-01-10 07:56:08.754541 +Revision ID: 73a959f64596 +Revises: d3e8af4786fa +Create Date: 2024-01-16 09:57:28.183650 """ import sqlalchemy as sa @@ -10,8 +10,8 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "ec4dbadbab3c" -down_revision = "48ad3ecbaea2" +revision = "73a959f64596" +down_revision = "d3e8af4786fa" branch_labels = None depends_on = None @@ -52,22 +52,22 @@ def upgrade() -> None: "status", sa.Enum( "PENDING", - "SUBMITTED", - "PROVISIONING", - "PULLING", - "RUNNING", + "CREATING", + "STARTING", + "READY", + "BUSY", "TERMINATING", "TERMINATED", - "ABORTED", "FAILED", - "DONE", - name="jobstatus", + name="instancestatus", ), nullable=False, ), sa.Column("status_message", sa.String(length=50), nullable=True), sa.Column("started_at", sa.DateTime(), nullable=True), sa.Column("finished_at", sa.DateTime(), nullable=True), + sa.Column("termination_policy", sa.String(length=50), nullable=False), + sa.Column("termination_idle_time", sa.String(length=50), nullable=True), sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), sa.Column("offer", sa.String(length=4000), nullable=False), sa.ForeignKeyConstraint( diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index be5d3d253..bb1222ff3 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -19,6 +19,7 @@ from sqlalchemy_utils import UUIDType from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.profiles import TerminationPolicy from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.runs import InstanceStatus, JobErrorCode, JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole @@ -278,5 +279,8 @@ class InstanceModel(BaseModel): started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) + termination_policy: Mapped[TerminationPolicy] = mapped_column(String(50)) + termination_idle_time: Mapped[Optional[str]] = mapped_column(String(50)) + job_provisioning_data: Mapped[str] = mapped_column(String(4000)) offer: Mapped[str] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 5f8ad1d94..5508e7b9d 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -8,7 +8,7 @@ import dstack._internal.server.services.pools as pools from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel -from dstack._internal.server.schemas.runs import AddInstanceRequest +from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember from dstack._internal.server.services.runs import ( abort_runs_of_pool, @@ -82,7 +82,7 @@ async def how_pool( @router.post("/add") async def add_instance( - body: AddInstanceRequest, + body: AddRemoteInstanceRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ): diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 3106e0767..395cdaa93 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -8,7 +8,7 @@ from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( - AddInstanceRequest, + AddRemoteInstanceRequest, CreateInstanceRequest, DeleteRunsRequest, GetOffersRequest, diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 4483f62ae..bfa095e73 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -28,7 +28,7 @@ class CreateInstanceRequest(BaseModel): profile: Profile -class AddInstanceRequest(BaseModel): +class AddRemoteInstanceRequest(BaseModel): pool_name: str instance_name: Optional[str] host: str diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index eef3d6d86..a5e1dc217 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -154,6 +154,20 @@ async def get_pool_instances( return result.instances +async def get_instances_by_pool_id(session, pool_id: str) -> List[InstanceModel]: + res = await session.execute( + select(PoolModel) + .where( + PoolModel.id == pool_id, + ) + .options(joinedload(PoolModel.instances)) + ) + result = res.unique().scalars().one_or_none() + if result is None: + return [] + return result.instances + + _GENERATE_POOL_NAME_LOCK = {} diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index fde727146..6241e46a4 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -224,6 +224,7 @@ async def create_instance( backend=backend.TYPE, instance_type=instance_offer.instance, instance_id=launched_instance_info.instance_id, + pool_id=pool_name, hostname=launched_instance_info.ip_address, region=launched_instance_info.region, price=instance_offer.price, @@ -291,6 +292,7 @@ async def submit_run( backends = await backends_services.get_project_backends(project) if len(backends) == 0: raise ServerClientError("No backends configured") + if run_spec.run_name is None: run_spec.run_name = await _generate_run_name( session=session, @@ -304,7 +306,11 @@ async def submit_run( ) # create pool - pools = (await session.scalars(select(PoolModel).where(PoolModel.name == pool_name))).all() + pools = ( + await session.scalars( + select(PoolModel).where(PoolModel.name == pool_name, PoolModel.deleted == False) + ) + ).all() if not pools: await create_pool_model(session, project, pool_name) @@ -319,6 +325,7 @@ async def submit_run( run_spec=run_spec.json(), ) session.add(run_model) + jobs = get_jobs_from_run_spec(run_spec) if run_spec.configuration.type == "service": await gateways.register_service_jobs(session, project, run_spec.run_name, jobs) @@ -332,6 +339,7 @@ async def submit_run( session.add(job_model) await session.commit() await session.refresh(run_model) + run = run_model_to_run(run_model) return run @@ -428,10 +436,13 @@ def run_model_to_run(run_model: RunModel, include_job_submissions: bool = True) submissions.append(job_model_to_job_submission(job_model)) if job_spec is not None: jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + run_spec = RunSpec.parse_raw(run_model.run_spec) + latest_job_submission = None if include_job_submissions: latest_job_submission = jobs[0].job_submissions[-1] + run = Run( id=run_model.id, project_name=run_model.project.name, diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 751379f46..9035a5011 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -4,7 +4,7 @@ import dstack._internal.server.schemas.pools as schemas_pools from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.server.schemas.runs import AddInstanceRequest +from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest from dstack.api.server._group import APIClientGroup @@ -29,7 +29,7 @@ def show(self, project_name: str, pool_name: str) -> List[Instance]: def add( self, project_name: str, pool_name: str, instance_name: Optional[str], host: str, port: str ): - body = AddInstanceRequest( + body = AddRemoteInstanceRequest( pool_name=pool_name, instance_name=instance_name, host=host, port=port ) self._request(f"/api/project/{project_name}/pool/add", body=body.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index d5d7260b6..a3aecb353 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -6,7 +6,7 @@ from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( - AddInstanceRequest, + AddRemoteInstanceRequest, CreateInstanceRequest, DeleteRunsRequest, GetOffersRequest, From 7e1fc3ad40a418bb46182c81e11a93c7cd150886 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Tue, 23 Jan 2024 10:58:21 +0400 Subject: [PATCH 07/47] update --- .pre-commit-config.yaml | 12 ++ ruff.toml | 15 +++ src/dstack/_internal/cli/commands/pool.py | 72 ++++++++--- .../core/backends/datacrunch/api_client.py | 27 ++-- .../core/backends/datacrunch/compute.py | 3 +- src/dstack/_internal/core/models/pools.py | 4 +- src/dstack/_internal/core/models/resources.py | 4 +- src/dstack/_internal/core/models/runs.py | 4 +- .../server/background/tasks/process_pools.py | 77 +++++++++++- .../background/tasks/process_running_jobs.py | 9 ++ .../tasks/process_submitted_jobs.py | 59 ++++++--- ...add_pools.py => 6a084acc1211_add_pools.py} | 9 +- src/dstack/_internal/server/models.py | 4 +- src/dstack/_internal/server/routers/pools.py | 32 +++-- src/dstack/_internal/server/routers/runs.py | 12 +- src/dstack/_internal/server/schemas/pools.py | 11 +- src/dstack/_internal/server/schemas/runs.py | 4 +- .../services/backends/configurators/aws.py | 2 +- src/dstack/_internal/server/services/pools.py | 69 ++++++++-- .../server/services/runner/client.py | 7 +- src/dstack/_internal/server/services/runs.py | 118 +++++++++++------- src/dstack/_internal/server/testing/common.py | 53 +++++++- src/dstack/api/_public/pools.py | 21 ++-- src/dstack/api/_public/runs.py | 13 +- src/dstack/api/server/_pools.py | 18 ++- src/dstack/api/server/_runs.py | 16 ++- .../tasks/test_process_running_jobs.py | 21 ++-- .../tasks/test_process_submitted_jobs.py | 77 +++++++++++- .../tasks/test_process_terminating_jobs.py | 4 +- .../_internal/server/services/test_pools.py | 99 ++++++++++++++- 30 files changed, 689 insertions(+), 187 deletions(-) create mode 100644 ruff.toml rename src/dstack/_internal/server/migrations/versions/{73a959f64596_add_pools.py => 6a084acc1211_add_pools.py} (95%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0fb5afc5..4966e262e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,8 @@ repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.14 + hooks: + - id: ruff - repo: https://github.com/psf/black rev: 22.12.0 hooks: @@ -10,3 +14,11 @@ repos: - id: isort name: isort (python) args: ['--settings-file', 'pyconfig.toml'] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + args: ['--strict', '--follow-imports=skip', '--ignore-missing-imports', '--python-version=3.8'] + files: '.*pools?\.py' + exclude: 'versions|src/tests' + additional_dependencies: [types-PyYAML] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..22bb5d5e4 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,15 @@ +target-version = "py38" + +[lint] +select = ['E', 'F'] +ignore =[ + 'E402', + 'E501', + 'E711', + 'E712', + 'E741', + 'F401', + 'F541', + 'F841', + 'F901', +] diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 26b7a23b6..be69e2c13 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -2,6 +2,8 @@ from collections.abc import Sequence from pathlib import Path +import yaml +from pydantic import parse_obj_as from rich.table import Table from dstack._internal.cli.commands import APIBaseCommand @@ -9,25 +11,27 @@ apply_profile_args, register_profile_args, ) +from dstack._internal.cli.services.configurators.run import BaseRunConfigurator from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.core.errors import CLIError, ServerClientError +from dstack._internal.core.models.configurations import parse as parse_configuration from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, ) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy +from dstack._internal.core.models.resources import GPU, Resources from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import pretty_date from dstack._internal.utils.logging import get_logger -from dstack.api.utils import load_profile +from dstack.api.utils import load_configuration, load_profile logger = get_logger(__name__) -NOTSET = object() -def print_pool_table(pools: Sequence[Pool], verbose): +def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None: table = Table(box=None) table.add_column("NAME") table.add_column("DEFAULT") @@ -45,11 +49,12 @@ def print_pool_table(pools: Sequence[Pool], verbose): console.print() -def print_instance_table(instances: Sequence[Instance]): +def print_instance_table(instances: Sequence[Instance]) -> None: table = Table(box=None) table.add_column("INSTANCE ID") table.add_column("BACKEND") table.add_column("INSTANCE TYPE") + table.add_column("STATUS") table.add_column("PRICE") for instance in instances: @@ -57,6 +62,7 @@ def print_instance_table(instances: Sequence[Instance]): instance.instance_id, instance.backend, instance.instance_type.resources.pretty_format(), + instance.status, f"{instance.price:.02f}", ] table.add_row(*row) @@ -71,7 +77,7 @@ def print_offers_table( requirements: Requirements, instance_offers: Sequence[InstanceOfferWithAvailability], offers_limit: int = 3, -): +) -> None: pretty_req = requirements.pretty_format(resources_only=True) max_price = f"${requirements.max_price:g}" if requirements.max_price else "-" @@ -151,13 +157,14 @@ def th(s: str) -> str: console.print() -class PoolCommand(APIBaseCommand): +class PoolCommand(APIBaseCommand): # type: ignore[misc] NAME = "pool" DESCRIPTION = "Pool management" - def _register(self): + def _register(self) -> None: super()._register() self._parser.set_defaults(subfunc=self._list) + subparsers = self._parser.add_subparsers(dest="action") # list @@ -223,28 +230,48 @@ def _register(self): "--remote-port", help="Remote runner port", dest="remote_port", default=10999 ) add_parser.add_argument("--name", dest="instance_name", help="The name of the instance") - add_parser.set_defaults(subfunc=self._add) register_profile_args(add_parser) + BaseRunConfigurator.register(add_parser) + add_parser.set_defaults(subfunc=self._add) + + # remove + remove_parser = subparsers.add_parser( + "remove", + help="Remove instance from the pool", + formatter_class=self._parser.formatter_class, + ) + remove_parser.add_argument( + "--pool", dest="pool_name", help="The name of the pool", required=True + ) + remove_parser.add_argument( + "--name", dest="instance_name", help="The name of the instance", required=True + ) + remove_parser.set_defaults(subfunc=self._remove) - def _list(self, args: argparse.Namespace): + def _list(self, args: argparse.Namespace) -> None: pools = self.api.client.pool.list(self.api.project) print_pool_table(pools, verbose=getattr(args, "verbose", False)) - def _create(self, args: argparse.Namespace): + def _create(self, args: argparse.Namespace) -> None: self.api.client.pool.create(self.api.project, args.pool_name) - def _delete(self, args: argparse.Namespace): + def _delete(self, args: argparse.Namespace) -> None: self.api.client.pool.delete(self.api.project, args.pool_name, args.force) - def _show(self, args: argparse.Namespace): + def _remove(self, args: argparse.Namespace) -> None: + self.api.client.pool.remove(self.api.project, args.pool_name, args.instance_name) + + def _show(self, args: argparse.Namespace) -> None: instances = self.api.client.pool.show(self.api.project, args.pool_name) print_instance_table(instances) - def _add(self, args: argparse.Namespace): + def _add(self, args: argparse.Namespace) -> None: + super()._command(args) pool_name: str = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name + # Add remote instance if args.remote: self.api.client.pool.add( self.api.project, pool_name, args.instance_name, args.remote_host, args.remote_port @@ -254,15 +281,24 @@ def _add(self, args: argparse.Namespace): repo = self.api.repos.load(Path.cwd()) self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path + # TODO: read requirements from repo configuration + # conf = parse_configuration(yaml.safe_load(conf_path.open())) + # resources = conf.resources + profile = load_profile(Path.cwd(), args.profile) apply_profile_args(args, profile) profile.pool_name = pool_name + requirements = Requirements( + resources=Resources(gpu=parse_obj_as(GPU, args.gpu_spec)), + max_price=args.max_price, + spot=(args.spot_policy == SpotPolicy.SPOT), + ) + with console.status("Getting instances..."): - requirements, offers = self.api.runs.get_offers(profile) + offers = self.api.runs.get_offers(profile, requirements) - print(pool_name, profile, requirements, offers) print_offers_table(pool_name, profile, requirements, offers) if not args.yes and not confirm_ask("Continue?"): console.print("\nExiting...") @@ -270,11 +306,11 @@ def _add(self, args: argparse.Namespace): try: with console.status("Submitting instance..."): - self.api.runs.create_instance(pool_name, profile) + self.api.runs.create_instance(pool_name, profile, requirements) except ServerClientError as e: raise CLIError(e.msg) - def _command(self, args: argparse.Namespace): + def _command(self, args: argparse.Namespace) -> None: super()._command(args) # TODO handle 404 and other errors args.subfunc(args) diff --git a/src/dstack/_internal/core/backends/datacrunch/api_client.py b/src/dstack/_internal/core/backends/datacrunch/api_client.py index 759da3907..a3b931c08 100644 --- a/src/dstack/_internal/core/backends/datacrunch/api_client.py +++ b/src/dstack/_internal/core/backends/datacrunch/api_client.py @@ -5,6 +5,7 @@ from datacrunch.exceptions import APIException from datacrunch.instances.instances import Instance +from dstack._internal.core.errors import NoCapacityError from dstack._internal.utils.ssh import get_public_key_fingerprint @@ -68,15 +69,19 @@ def deploy_instance( is_spot=True, location="FIN-01", ): - instance = self.client.instances.create( - instance_type=instance_type, - image=image, - ssh_key_ids=ssh_key_ids, - hostname=hostname, - description=description, - startup_script_id=startup_script_id, - is_spot=is_spot, - location=location, - os_volume={"name": "OS volume", "size": disk_size}, - ) + try: + instance = self.client.instances.create( + instance_type=instance_type, + image=image, + ssh_key_ids=ssh_key_ids, + hostname=hostname, + description=description, + startup_script_id=startup_script_id, + is_spot=is_spot, + location=location, + os_volume={"name": "OS volume", "size": disk_size}, + ) + except APIException as e: + raise NoCapacityError() + return instance diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index f7960f86a..21e6806ca 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -5,7 +5,7 @@ from dstack._internal.core.backends.base.offers import get_catalog_offers from dstack._internal.core.backends.datacrunch.api_client import DataCrunchAPIClient from dstack._internal.core.backends.datacrunch.config import DataCrunchConfig -from dstack._internal.core.errors import BackendError +from dstack._internal.core.errors import BackendError, NoCapacityError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -99,6 +99,7 @@ def create_instance( image_name = "2088da25-bb0d-41cc-a191-dccae45d96fd" disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + instance = self.api_client.deploy_instance( instance_type=instance_offer.instance.name, ssh_key_ids=ssh_ids, diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 80ecd3961..c62bd0f9c 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -8,13 +8,13 @@ from dstack._internal.core.models.runs import InstanceStatus -class Pool(BaseModel): +class Pool(BaseModel): # type: ignore[misc] name: str default: bool created_at: datetime.datetime -class Instance(BaseModel): +class Instance(BaseModel): # type: ignore[misc] backend: BackendType instance_type: InstanceType instance_id: str diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index bfd305c77..788425442 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -42,11 +42,11 @@ def _post_validate(cls, values): raise ValueError(f"Invalid range order: {min}..{max}") return values - def __str__(self): + def __str__(self) -> str: min = self.min if self.min is not None else "" max = self.max if self.max is not None else "" if min == max: - return f"{min}" + return str(min) return f"{min}..{max}" diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 9c22de472..2970acb1b 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -23,8 +23,8 @@ class AppSpec(BaseModel): port: int map_to_port: Optional[int] app_name: str - url_path: Optional[str] - url_query_params: Optional[Dict[str, str]] + url_path: Optional[str] = None + url_query_params: Optional[Dict[str, str]] = None class JobStatus(str, Enum): diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 2b8536c28..a7d7c8d5b 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -1,14 +1,21 @@ from datetime import timedelta +from typing import Dict +from uuid import UUID from pydantic import parse_raw_as from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import InstanceStatus, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import InstanceModel, JobModel +from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.jobs import PROCESSING_POOL_IDS, PROCESSING_POOL_LOCK +from dstack._internal.server.services.logging import job_log +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) @@ -16,13 +23,19 @@ logger = get_logger(__name__) -async def process_pools(): +async def process_pools() -> None: async with get_session_ctx() as session: async with PROCESSING_POOL_LOCK: res = await session.scalars( select(InstanceModel).where( - InstanceModel.status.in_([InstanceStatus.READY, InstanceStatus.FAILED]), + InstanceModel.status.in_( + [ + InstanceStatus.CREATING, + InstanceStatus.STARTING, + InstanceStatus.TERMINATING, + ] + ), InstanceModel.id.not_in(PROCESSING_POOL_IDS), ) ) @@ -34,10 +47,62 @@ async def process_pools(): try: for inst in instances: - await _terminate_instance(inst) + if inst.status in (InstanceStatus.CREATING, InstanceStatus.STARTING): + await check_shim(inst.id) + if inst.status == InstanceStatus.TERMINATING: + await terminate(inst.id) finally: PROCESSING_POOL_IDS.difference_update(i.id for i in instances) -async def _terminate_instance(instance: InstanceModel): - pass +async def check_shim(instance_id: UUID) -> None: + async with get_session_ctx() as session: + instance = ( + await session.scalars( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project)) + ) + ).one() + ssh_private_key = instance.project.ssh_private_key + job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) + + instance_health = instance_healthcheck(ssh_private_key, job_provisioning_data) + + logger.info("check instance %s status: %s", instance.name, instance_health) + + if instance_health: + instance.status = InstanceStatus.READY + await session.commit() + return + + +@runner_ssh_tunnel(ports=[client.REMOTE_SHIM_PORT], retries=1) # type: ignore[misc] +def instance_healthcheck(*, ports: Dict[int, int]) -> bool: + shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT]) + resp = shim_client.healthcheck() + if resp is None: + return False # shim is not available yet + return bool(resp.service == "dstack-shim") + + +async def terminate(instance_id: UUID) -> None: + + async with get_session_ctx() as session: + instance = ( + await session.scalars( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project)) + ) + ).one() + jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) + BACKEND_TYPE = jpd.backend + backends = await backends_services.get_project_backends(project=instance.project) + backend = next((b for b in backends if b.TYPE in BACKEND_TYPE), None) + if backend is None: + raise ValueError(f"there is no backned {BACKEND_TYPE}") + + await run_async( + backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data + ) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 9858501a6..8f50e6f2e 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -207,6 +207,15 @@ async def _process_job(job_id: UUID): run_model, job_model, ) + + if success: + instance_name: str = job_provisioning_data.instance_id + pool_name = str(job.job_spec.pool_name) + instances = await get_pool_instances(session, project, pool_name) + for inst in instances: + if inst.name == instance_name: + inst.status = InstanceStatus.READY + if not success: # kill the job logger.warning( *job_log( diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 0d8850ca8..9d039a972 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -19,12 +19,14 @@ Profile, TerminationPolicy, ) +from dstack._internal.core.models.resources import Range, Resources from dstack._internal.core.models.runs import ( InstanceStatus, Job, JobErrorCode, JobProvisioningData, JobStatus, + Requirements, Run, RunSpec, ) @@ -39,6 +41,8 @@ from dstack._internal.server.services.pools import ( get_pool_instances, instance_model_to_instance, + list_project_pool, + list_project_pool_models, show_pool, ) from dstack._internal.server.services.runs import run_model_to_run @@ -83,7 +87,7 @@ async def _process_job(job_id: UUID): ) -def check_relevance(profile: Profile, instance_model: InstanceModel) -> bool: +def check_relevance(profile: Profile, resources: Resources, instance_model: InstanceModel) -> bool: jpd: JobProvisioningData = parse_raw_as( JobProvisioningData, instance_model.job_provisioning_data @@ -95,17 +99,22 @@ def check_relevance(profile: Profile, instance_model: InstanceModel) -> bool: instance = instance_model_to_instance(instance_model) if profile.backends is not None and instance.backend not in profile.backends: + logger.warning(f"no backnd select ") return False - instance_resources = jpd.instance_type.resources + # instance_resources = jpd.instance_type.resources - if profile.resources.cpu is not None and profile.resources.cpu < instance_resources.cpus: - return False + # TODO: full check requirements + # if isinstance(requirements.resources.cpu, Range): + # if requirements.resources.cpu.min < int(instance_resources.cpus): + # return False - # TODO: full check - if isinstance(profile.resources.gpu, int): - if profile.resources.gpu < len(instance_resources.gpus): - return False + # if isinstance(requirements.resources.gpu, Range): + # if requirements.resources.gpu.min < int(instance_resources.cpus): + # return False + # if isinstance(int(requirements.resources.gpu), int): + # if requirements.resources.gpu < len(instance_resources.gpus): + # return False return True # TODO: memory, shm_size, disk @@ -125,28 +134,40 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # check default pool pool = project_model.default_pool if pool is None: - pool = PoolModel( - name=DEFAULT_POOL_NAME, - project=project_model, - ) - session.add(pool) - await session.commit() + pools = await list_project_pool_models(session, job_model.project) + for pool_item in pools: + if pool_item.id == job_model.project.default_pool_id: + pool = pool_item + if pool_item.name == DEFAULT_POOL_NAME: + pool = pool_item + if pool is None: + pool = PoolModel( + name=DEFAULT_POOL_NAME, + project=project_model, + ) + session.add(pool) + await session.commit() + await session.refresh(pool) + if pool.id is not None: project_model.default_pool_id = pool.id - profile = parse_raw_as(RunSpec, run_model.run_spec).profile + run_spec = parse_raw_as(RunSpec, run_model.run_spec) + profile = run_spec.profile run_pool = profile.pool_name if run_pool is None: run_pool = pool.name # pool capacity + pool_instances = await get_pool_instances(session, project_model, run_pool) available_instanses = (p for p in pool_instances if p.status == InstanceStatus.READY) relevant_instances: List[InstanceModel] = [] for instance in available_instanses: - if check_relevance(profile, instance): + if check_relevance(profile, run_spec.configuration.resources, instance): relevant_instances.append(instance) + logger.info(*job_log(f"num relevance {len(relevant_instances)}", job_model)) if relevant_instances: sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name) @@ -183,6 +204,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): backends=backends, project_ssh_public_key=project_model.ssh_public_key, project_ssh_private_key=project_model.ssh_private_key, + pool_id=pool.id, ) if job_provisioning_data is not None and offer is not None: logger.info(*job_log("now is provisioning", job_model)) @@ -221,9 +243,11 @@ async def _run_job( backends: List[Backend], project_ssh_public_key: str, project_ssh_private_key: str, + pool_id: UUID, ) -> Tuple[Optional[JobProvisioningData], Optional[InstanceOfferWithAvailability]]: if run.run_spec.profile.backends is not None: backends = [b for b in backends if b.TYPE in run.run_spec.profile.backends] + try: requirements = job.job_spec.requirements offers = await backends_services.get_instance_offers( @@ -232,6 +256,7 @@ async def _run_job( except BackendError as e: logger.warning(*job_log("failed to get instance offers: %s", job_model, repr(e))) return (None, None) + for backend, offer in offers: logger.debug( *job_log( @@ -265,6 +290,7 @@ async def _run_job( ) continue else: + job_provisioning_data = JobProvisioningData( backend=backend.TYPE, instance_type=offer.instance, @@ -277,6 +303,7 @@ async def _run_job( dockerized=launched_instance_info.dockerized, ssh_proxy=launched_instance_info.ssh_proxy, backend_data=launched_instance_info.backend_data, + pool_id=str(pool_id), ) return (job_provisioning_data, offer) diff --git a/src/dstack/_internal/server/migrations/versions/73a959f64596_add_pools.py b/src/dstack/_internal/server/migrations/versions/6a084acc1211_add_pools.py similarity index 95% rename from src/dstack/_internal/server/migrations/versions/73a959f64596_add_pools.py rename to src/dstack/_internal/server/migrations/versions/6a084acc1211_add_pools.py index ce1cb6119..e8ec39874 100644 --- a/src/dstack/_internal/server/migrations/versions/73a959f64596_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/6a084acc1211_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: 73a959f64596 +Revision ID: 6a084acc1211 Revises: d3e8af4786fa -Create Date: 2024-01-16 09:57:28.183650 +Create Date: 2024-01-25 10:23:18.983726 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "73a959f64596" +revision = "6a084acc1211" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None @@ -35,7 +35,6 @@ def upgrade() -> None: ondelete="CASCADE", ), sa.PrimaryKeyConstraint("id", name=op.f("pk_pools")), - sa.UniqueConstraint("name", name=op.f("uq_pools_name")), ) op.create_table( "instances", @@ -66,7 +65,7 @@ def upgrade() -> None: sa.Column("status_message", sa.String(length=50), nullable=True), sa.Column("started_at", sa.DateTime(), nullable=True), sa.Column("finished_at", sa.DateTime(), nullable=True), - sa.Column("termination_policy", sa.String(length=50), nullable=False), + sa.Column("termination_policy", sa.String(length=50), nullable=True), sa.Column("termination_idle_time", sa.String(length=50), nullable=True), sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), sa.Column("offer", sa.String(length=4000), nullable=False), diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index bb1222ff3..627a2c368 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -246,7 +246,7 @@ class PoolModel(BaseModel): id: Mapped[uuid.UUID] = mapped_column( UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) - name: Mapped[str] = mapped_column(String(50), unique=True) + name: Mapped[str] = mapped_column(String(50)) created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime) @@ -279,7 +279,7 @@ class InstanceModel(BaseModel): started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) - termination_policy: Mapped[TerminationPolicy] = mapped_column(String(50)) + termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50)) termination_idle_time: Mapped[Optional[str]] = mapped_column(String(50)) job_provisioning_data: Mapped[str] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 5508e7b9d..a7aa5dad9 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Sequence, Tuple from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -7,7 +7,7 @@ import dstack._internal.server.schemas.pools as schemas import dstack._internal.server.services.pools as pools from dstack._internal.server.db import get_session -from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel, UserModel from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember from dstack._internal.server.services.runs import ( @@ -19,7 +19,7 @@ router = APIRouter(prefix="/api/project/{project_name}/pool", tags=["pool"]) -@router.post("/list") +@router.post("/list") # type: ignore[misc] async def list_pool( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), @@ -28,12 +28,22 @@ async def list_pool( return await pools.list_project_pool(session=session, project=project) -@router.post("/delete") +@router.post("/remove") # type: ignore[misc] +async def remove_instance( + body: schemas.RemoveInstanceRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> None: + _, project_model = user_project + await pools.remove_instance(session, project_model, body.pool_name, body.instance_name) + + +@router.post("/delete") # type: ignore[misc] async def delete_pool( body: schemas.DeletePoolRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -): +) -> None: pool_name = body.name _, project_model = user_project @@ -60,32 +70,32 @@ async def delete_pool( await pools.delete_pool(session, project_model, pool_name) -@router.post("/create") +@router.post("/create") # type: ignore[misc] async def create_pool( body: schemas.CreatePoolRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -): +) -> None: _, project = user_project await pools.create_pool_model(session=session, project=project, name=body.name) -@router.post("/show") +@router.post("/show") # type: ignore[misc] async def how_pool( body: schemas.CreatePoolRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -): +) -> Sequence[models.Instance]: _, project = user_project return await pools.show_pool(session, project, pool_name=body.name) -@router.post("/add") +@router.post("/add") # type: ignore[misc] async def add_instance( body: AddRemoteInstanceRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -): +) -> None: _, project = user_project await pools.add( session, diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 395cdaa93..de2d5ba12 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -69,13 +69,11 @@ async def get_offers( body: GetOffersRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Tuple[Requirements, List[InstanceOfferWithAvailability]]: +) -> List[InstanceOfferWithAvailability]: _, project = user_project - reqs, offers = await runs.get_run_plan_by_requirements( - project=project, - profile=body.profile, - ) - return (reqs, [instance for _, instance in offers]) + offers = await runs.get_run_plan_by_requirements(project, body.profile, body.requirements) + instances = [instance for _, instance in offers] + return instances @project_router.post("/create_instance") @@ -89,11 +87,13 @@ async def create_instance( session=session, project=project, pool_name=body.pool_name ) await runs.create_instance( + session=session, project=project, user=user, pool_name=body.pool_name, instance_name=instance_name, profile=body.profile, + requirements=body.requirements, ) diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index 902750c54..318874eec 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -1,14 +1,19 @@ from pydantic import BaseModel -class DeletePoolRequest(BaseModel): +class DeletePoolRequest(BaseModel): # type: ignore[misc] name: str force: bool -class CreatePoolRequest(BaseModel): +class CreatePoolRequest(BaseModel): # type: ignore[misc] name: str -class ShowPoolRequest(BaseModel): +class ShowPoolRequest(BaseModel): # type: ignore[misc] name: str + + +class RemoveInstanceRequest(BaseModel): # type: ignore[misc] + pool_name: str + instance_name: str diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index bfa095e73..39e2e91e7 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from dstack._internal.core.models.profiles import Profile -from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.runs import Requirements, RunSpec class ListRunsRequest(BaseModel): @@ -21,11 +21,13 @@ class GetRunPlanRequest(BaseModel): class GetOffersRequest(BaseModel): profile: Profile + requirements: Requirements class CreateInstanceRequest(BaseModel): pool_name: str profile: Profile + requirements: Requirements class AddRemoteInstanceRequest(BaseModel): diff --git a/src/dstack/_internal/server/services/backends/configurators/aws.py b/src/dstack/_internal/server/services/backends/configurators/aws.py index 3b2555e6a..9fec28c77 100644 --- a/src/dstack/_internal/server/services/backends/configurators/aws.py +++ b/src/dstack/_internal/server/services/backends/configurators/aws.py @@ -73,7 +73,7 @@ def get_config_values(self, config: AWSConfigInfoWithCredsPartial) -> AWSConfigV raise_invalid_credentials_error(fields=[["creds"]]) try: auth.authenticate(creds=config.creds, region=MAIN_REGION) - except: + except: # noqa: E722 if isinstance(config.creds, AWSAccessKeyCreds): raise_invalid_credentials_error( fields=[ diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index a5e1dc217..529d3dd22 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -1,6 +1,8 @@ import asyncio +from contextlib import asynccontextmanager, contextmanager from datetime import timezone -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Union +from uuid import UUID from pydantic import parse_raw_as from sqlalchemy import select @@ -11,6 +13,7 @@ from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, + InstanceState, InstanceType, Resources, ) @@ -33,6 +36,21 @@ async def list_project_pool(session: AsyncSession, project: ProjectModel) -> Lis return [pool_model_to_pool(p) for p in pools] +async def get_pool( + session: AsyncSession, project: ProjectModel, pool_name: str +) -> Optional[PoolModel]: + pool = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project_id == project.id, + PoolModel.deleted == False, + ) + ) + ).one_or_none() + return pool + + def pool_model_to_pool(pool_model: PoolModel) -> Pool: return Pool( name=pool_model.name, @@ -42,6 +60,14 @@ def pool_model_to_pool(pool_model: PoolModel) -> Pool: async def create_pool_model(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel: + pools = await session.scalars( + select(PoolModel).where( + PoolModel.name == name, PoolModel.project == project, PoolModel.deleted == False + ) + ) + if pools.all(): + raise ValueError("duplicate pool name") + pool = PoolModel( name=name, project_id=project.id, @@ -59,10 +85,32 @@ async def list_project_pool_models( pools = await session.scalars( select(PoolModel).where(PoolModel.project_id == project.id, PoolModel.deleted == False) ) - return pools.all() + return pools.all() # type: ignore[no-any-return] + + +async def remove_instance( + session: AsyncSession, project: ProjectModel, pool_name: str, instance_name: str +) -> None: + pool = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project == project, + PoolModel.deleted == False, + ) + ) + ).one() + terminated = False + for instance in pool.instances: + if instance.name == instance_name: + instance.status = InstanceStatus.TERMINATING + terminated = True + if not terminated: + logger.warning("Couldn't fined instance to terminate") + await session.commit() -async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str): +async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str) -> None: """delete the pool and set the default pool to project""" default_pool: Optional[PoolModel] = None @@ -95,7 +143,7 @@ async def list_deleted_pools( PoolModel.project_id == project_model.id, PoolModel.deleted == True ) ) - return pools.all() + return pools.all() # type: ignore[no-any-return] def instance_model_to_instance(instance_model: InstanceModel) -> Instance: @@ -151,10 +199,11 @@ async def get_pool_instances( result = res.unique().scalars().one_or_none() if result is None: return [] - return result.instances + instances: List[InstanceModel] = result.instances + return instances -async def get_instances_by_pool_id(session, pool_id: str) -> List[InstanceModel]: +async def get_instances_by_pool_id(session: AsyncSession, pool_id: str) -> List[InstanceModel]: res = await session.execute( select(PoolModel) .where( @@ -165,10 +214,11 @@ async def get_instances_by_pool_id(session, pool_id: str) -> List[InstanceModel] result = res.unique().scalars().one_or_none() if result is None: return [] - return result.instances + instances: List[InstanceModel] = result.instances + return instances -_GENERATE_POOL_NAME_LOCK = {} +_GENERATE_POOL_NAME_LOCK: Dict[str, asyncio.Lock] = {} async def generate_instance_name( @@ -193,7 +243,7 @@ async def add( instance_name: Optional[str], host: str, port: str, -): +) -> None: instance_name = instance_name if instance_name is None: @@ -225,6 +275,7 @@ async def add( ssh_port=22, dockerized=False, backend_data="", + pool_id=str(pool.id), ) offer = InstanceOfferWithAvailability( backend=BackendType.LOCAL, diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 8000ffac8..bd4045ae5 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -100,11 +100,12 @@ def healthcheck(self) -> Optional[HealthcheckResponse]: return None def submit(self, username: str, password: str, image_name: str): + post_body = DockerImageBody( + username=username, password=password, image_name=image_name + ).dict() resp = requests.post( self._url("/api/submit"), - json=DockerImageBody( - username=username, password=password, image_name=image_name - ).dict(), + json=post_body, ) resp.raise_for_status() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 6241e46a4..10ed3f61f 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -3,7 +3,7 @@ import math import uuid from datetime import timezone -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, cast import pydantic from sqlalchemy import select, update @@ -24,8 +24,10 @@ LaunchedInstanceInfo, ) from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy +from dstack._internal.core.models.resources import GPU, ComputeCapability, Memory, Range, Resources from dstack._internal.core.models.runs import ( GpusRequirements, + InstanceStatus, Job, JobPlan, JobProvisioningData, @@ -40,8 +42,16 @@ ServiceModelInfo, ) from dstack._internal.core.models.users import GlobalRole -from dstack._internal.server.models import JobModel, PoolModel, ProjectModel, RunModel, UserModel +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + PoolModel, + ProjectModel, + RunModel, + UserModel, +) from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import pools as pools_services from dstack._internal.server.services import repos as repos_services from dstack._internal.server.services.docker import parse_image_name from dstack._internal.server.services.jobs import ( @@ -53,7 +63,7 @@ get_default_image, get_default_python_verison, ) -from dstack._internal.server.services.pools import create_pool_model, list_project_pool, show_pool +from dstack._internal.server.services.pools import create_pool_model from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -81,7 +91,7 @@ async def list_user_runs( project=project, repo_id=repo_id, ) - runs.extend(map(run_model_to_run, project_runs)) + runs.extend(project_runs) return sorted(runs, key=lambda r: r.submitted_at, reverse=True) @@ -89,7 +99,7 @@ async def list_project_runs( session: AsyncSession, project: ProjectModel, repo_id: Optional[str], -) -> List[RunModel]: +) -> List[Run]: filters = [ RunModel.project_id == project.id, RunModel.deleted == False, @@ -141,58 +151,60 @@ async def get_run( async def get_run_plan_by_requirements( - project: ProjectModel, profile: Profile -) -> Tuple[Requirements, List[Tuple[Backend, InstanceOfferWithAvailability]]]: + project: ProjectModel, + profile: Profile, + requirements: Requirements, + exclude_not_available=False, +) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: backends = await backends_services.get_project_backends(project=project) if profile.backends is not None: backends = [b for b in backends if b.TYPE in profile.backends] - spot_policy = profile.spot_policy or SpotPolicy.AUTO # TODO: improve - requirements = Requirements( - cpus=profile.resources.cpu, - memory_mib=profile.resources.memory, - gpus=None, - shm_size_mib=profile.resources.shm_size, - max_price=profile.max_price, - spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), - ) - if profile.resources.gpu: - requirements.gpus = GpusRequirements( - count=profile.resources.gpu.count, - memory_mib=profile.resources.gpu.memory, - name=profile.resources.gpu.name, - total_memory_mib=profile.resources.gpu.total_memory, - compute_capability=profile.resources.gpu.compute_capability, - ) - offers = await backends_services.get_instance_offers( backends=backends, requirements=requirements, - exclude_not_available=False, + exclude_not_available=exclude_not_available, ) - return requirements, offers + return offers async def create_instance( - project: ProjectModel, user: UserModel, pool_name: str, instance_name: str, profile: Profile -): - _, offers = await get_run_plan_by_requirements(project, profile) + session: AsyncSession, + project: ProjectModel, + user: UserModel, + pool_name: str, + instance_name: str, + profile: Profile, + requirements: Requirements, +) -> Optional[InstanceModel]: + offers = await get_run_plan_by_requirements( + project, profile, requirements, exclude_not_available=True + ) + + if not offers: + return ssh_key = SSHKeys( public=project.ssh_public_key.strip(), private=project.ssh_private_key.strip(), ) + image = parse_image_name(get_default_image(get_default_python_verison())) instance_config = InstanceConfiguration( instance_name=instance_name, pool_name=pool_name, ssh_keys=[ssh_key], job_docker_config=DockerConfig( - image=parse_image_name(get_default_image(get_default_python_verison())), + image=image, registry_auth=None, ), ) + pool = await pools_services.get_pool(session, project, pool_name) + + if pool is None: + pool = await create_pool_model(session, project, pool_name) + for backend, instance_offer in offers: logger.debug( @@ -219,23 +231,33 @@ async def create_instance( repr(e), ) continue - else: - job_provisioning_data = JobProvisioningData( - backend=backend.TYPE, - instance_type=instance_offer.instance, - instance_id=launched_instance_info.instance_id, - pool_id=pool_name, - hostname=launched_instance_info.ip_address, - region=launched_instance_info.region, - price=instance_offer.price, - username=launched_instance_info.username, - ssh_port=launched_instance_info.ssh_port, - dockerized=launched_instance_info.dockerized, - backend_data=launched_instance_info.backend_data, - ) - return (job_provisioning_data, instance_offer) - return (None, None) + job_provisioning_data = JobProvisioningData( + backend=backend.TYPE, + instance_type=instance_offer.instance, + instance_id=launched_instance_info.instance_id, + pool_id=str(pool.id), + hostname=launched_instance_info.ip_address, + region=launched_instance_info.region, + price=instance_offer.price, + username=launched_instance_info.username, + ssh_port=launched_instance_info.ssh_port, + dockerized=launched_instance_info.dockerized, + backend_data=launched_instance_info.backend_data, + ) + + im = InstanceModel( + name=instance_name, + project=project, + pool=pool, + status=InstanceStatus.STARTING, + job_provisioning_data=job_provisioning_data.json(), + offer=cast(InstanceOfferWithAvailability, instance_offer).json(), + ) + session.add(im) + await session.commit() + + return im async def get_run_plan( @@ -289,6 +311,7 @@ async def submit_run( ) if repo is None: raise RepoDoesNotExistError.with_id(run_spec.repo_id) + backends = await backends_services.get_project_backends(project) if len(backends) == 0: raise ServerClientError("No backends configured") @@ -329,6 +352,7 @@ async def submit_run( jobs = get_jobs_from_run_spec(run_spec) if run_spec.configuration.type == "service": await gateways.register_service_jobs(session, project, run_spec.run_name, jobs) + for job in jobs: job.job_spec.pool_name = pool_name job_model = create_job_model_for_new_submission( diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1db28b8ac..10880b16d 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -9,16 +9,24 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import DevEnvironmentConfiguration from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.repos.local import LocalRunRepoData -from dstack._internal.core.models.runs import JobErrorCode, JobProvisioningData, JobStatus, RunSpec +from dstack._internal.core.models.runs import ( + InstanceStatus, + JobErrorCode, + JobProvisioningData, + JobStatus, + RunSpec, +) from dstack._internal.core.models.users import GlobalRole from dstack._internal.server.models import ( BackendModel, GatewayComputeModel, GatewayModel, + InstanceModel, JobModel, + PoolModel, ProjectModel, RepoModel, RunModel, @@ -226,6 +234,7 @@ def get_job_provisioning_data() -> JobProvisioningData: ssh_port=22, dockerized=False, backend_data=None, + pool_id="", ) @@ -271,3 +280,43 @@ async def create_gateway_compute( session.add(gateway_compute) await session.commit() return gateway_compute + + +async def create_pool( + session: AsyncSession, + project: ProjectModel, + pool_name: Optional[str] = None, +) -> PoolModel: + + pool_name = pool_name if pool_name is not None else DEFAULT_POOL_NAME + pool = PoolModel( + name=pool_name, + project=project, + project_id=project.id, + ) + session.add(pool) + await session.commit() + return pool + + +async def create_instance( + session: AsyncSession, + project: ProjectModel, + pool: PoolModel, + status: InstanceStatus, +) -> InstanceModel: + im = InstanceModel( + name="test_instance", + pool=pool, + project=project, + status=status, + job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', + offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', + ) + session.add(im) + await session.commit() + + # pool.instances.append(im) + # await session.commit() + + return im diff --git a/src/dstack/api/_public/pools.py b/src/dstack/api/_public/pools.py index f1dea3cbf..a496c2f4b 100644 --- a/src/dstack/api/_public/pools.py +++ b/src/dstack/api/_public/pools.py @@ -1,22 +1,23 @@ from typing import List +from dstack._internal.core.models.pools import Pool from dstack.api.server import APIClient -class Instance: - def __init__(self, api_client: APIClient, instance): +class PoolInstance: + def __init__(self, api_client: APIClient, pool: Pool): self._api_client = api_client - self._instance = instance + self._pool = pool @property def name(self) -> str: - return self._instance.name + return self._pool.name def __str__(self) -> str: - return f"" + return f"" def __repr__(self) -> str: - return f"" + return f"" class PoolCollection: @@ -28,13 +29,13 @@ def __init__(self, api_client: APIClient, project: str): self._api_client = api_client self._project = project - def list(self) -> List[Instance]: + def list(self) -> List[PoolInstance]: """ List available pool in the project Returns: pools """ - list_raw_instances = self._api_client.pool.list(project_name=self._project) - list_instances = [Instance(self._api_client, instance) for instance in list_raw_instances] - return list_instances + list_raw_pool = self._api_client.pool.list(project_name=self._project) + list_pool = [PoolInstance(self._api_client, instance) for instance in list_raw_pool] + return list_pool diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 322966087..449efe55a 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -360,12 +360,12 @@ def submit( return self.exec_plan(run_plan, repo, reserve_ports=reserve_ports) def get_offers( - self, profile: Profile - ) -> Tuple[Requirements, List[InstanceOfferWithAvailability]]: - return self._api_client.runs.get_offers(self._project, profile) + self, profile: Profile, requirements: Requirements + ) -> List[InstanceOfferWithAvailability]: + return self._api_client.runs.get_offers(self._project, profile, requirements) - def create_instance(self, pool_name: str, profile: Profile): - self._api_client.runs.create_instance(self._project, pool_name, profile) + def create_instance(self, pool_name: str, profile: Profile, requirements: Requirements): + self._api_client.runs.create_instance(self._project, pool_name, profile, requirements) def get_plan( self, @@ -411,6 +411,9 @@ def get_plan( max_duration=max_duration, max_price=max_price, pool_name=pool_name, + creation_policy=None, + termination_idle_time=None, + termination_policy=None, ) run_spec = RunSpec( run_name=run_name, diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 9035a5011..b46342c89 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -8,10 +8,11 @@ from dstack.api.server._group import APIClientGroup -class PoolAPIClient(APIClientGroup): +class PoolAPIClient(APIClientGroup): # type: ignore[misc] def list(self, project_name: str) -> List[Pool]: resp = self._request(f"/api/project/{project_name}/pool/list") - return parse_obj_as(List[Pool], resp.json()) + result: List[Pool] = parse_obj_as(List[Pool], resp.json()) + return result def delete(self, project_name: str, pool_name: str, force: bool) -> None: body = schemas_pools.DeletePoolRequest(name=pool_name, force=force) @@ -19,16 +20,23 @@ def delete(self, project_name: str, pool_name: str, force: bool) -> None: def create(self, project_name: str, pool_name: str) -> None: body = schemas_pools.CreatePoolRequest(name=pool_name) - self._request(f"/api/project/{project_name}/pool/create", body=body.json()) + self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) def show(self, project_name: str, pool_name: str) -> List[Instance]: body = schemas_pools.ShowPoolRequest(name=pool_name) resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) - return parse_obj_as(List[Instance], resp.json()) + result: List[Instance] = parse_obj_as(List[Instance], resp.json()) + return result + + def remove(self, project_name: str, pool_name: str, instance_name: str) -> None: + body = schemas_pools.RemoveInstanceRequest( + pool_name=pool_name, instance_name=instance_name + ) + self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) def add( self, project_name: str, pool_name: str, instance_name: Optional[str], host: str, port: str - ): + ) -> None: body = AddRemoteInstanceRequest( pool_name=pool_name, instance_name=instance_name, host=host, port=port ) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index a3aecb353..d500835a3 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -31,14 +31,18 @@ def get(self, project_name: str, run_name: str) -> Run: return parse_obj_as(Run, resp.json()) def get_offers( - self, project_name: str, profile: Profile - ) -> Tuple[Requirements, List[InstanceOfferWithAvailability]]: - body = GetOffersRequest(profile=profile) + self, project_name: str, profile: Profile, requirements: Requirements + ) -> List[InstanceOfferWithAvailability]: + body = GetOffersRequest(profile=profile, requirements=requirements) resp = self._request(f"/api/project/{project_name}/runs/get_offers", body=body.json()) - return parse_obj_as(Tuple[Requirements, List[InstanceOfferWithAvailability]], resp.json()) + return parse_obj_as(List[InstanceOfferWithAvailability], resp.json()) - def create_instance(self, project_name: str, pool_name: str, profile: Profile): - body = CreateInstanceRequest(pool_name=pool_name, profile=profile) + def create_instance( + self, project_name: str, pool_name: str, profile: Profile, requirements: Requirements + ): + body = CreateInstanceRequest( + pool_name=pool_name, profile=profile, requirements=requirements + ) self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) def get_plan(self, project_name: str, run_spec: RunSpec) -> RunPlan: diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index af825e7b7..19c7b46aa 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -12,6 +12,7 @@ from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.schemas.runner import HealthcheckResponse, JobStateEvent, PullResponse +from dstack._internal.server.services.jobs.configurators.base import get_default_python_verison from dstack._internal.server.testing.common import ( create_job, create_project, @@ -36,6 +37,7 @@ def get_job_provisioning_data(dockerized: bool) -> JobProvisioningData: ssh_port=22, dockerized=dockerized, backend_data=None, + pool_id="", ) @@ -197,12 +199,17 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): user=user, ) job_provisioning_data = get_job_provisioning_data(dockerized=True) - job = await create_job( - session=session, - run=run, - status=JobStatus.PROVISIONING, - job_provisioning_data=job_provisioning_data, - ) + + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as PyVersion: + PyVersion.return_value = "3.11" + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + job_provisioning_data=job_provisioning_data, + ) with patch( "dstack._internal.server.services.runner.ssh.RunnerTunnel" ) as RunnerTunnelMock, patch( @@ -214,7 +221,7 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): await process_running_jobs() RunnerTunnelMock.assert_called_once() ShimClientMock.return_value.healthcheck.assert_called_once() - ShimClientMock.return_value.registry_auth.assert_called_once_with( + ShimClientMock.return_value.submit.assert_called_once_with( username="", password="", image_name="dstackai/base:py3.11-0.4rc4-cuda-12.1" ) await session.refresh(job) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 87d37c92e..971d9f52d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.backends.base import BackendType @@ -13,10 +14,14 @@ Resources, ) from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, ProfileRetryPolicy -from dstack._internal.core.models.runs import JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.pools import list_project_pool, list_project_pool_models from dstack._internal.server.testing.common import ( + create_instance, create_job, + create_pool, create_project, create_repo, create_run, @@ -191,3 +196,73 @@ async def test_transitions_job_with_outdated_retry_to_failed_on_no_capacity( await session.refresh(project) assert not project.default_pool.instances + + @pytest.mark.asyncio + async def test_job_whith_instance(self, test_db, session: AsyncSession): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo( + session, + project_id=project.id, + ) + pools = await list_project_pool_models(session, project) + pool = None + for pool_item in pools: + if pool_item == DEFAULT_POOL_NAME: + pool = pool_item + if pool is None: + pool = await create_pool(session, project) + im = await create_instance(session, project, pool, InstanceStatus.READY) + await session.refresh(pool) + run = await create_run( + session, + project=project, + repo=repo, + user=user, + ) + job_provisioning_data = JobProvisioningData( + backend=BackendType.LOCAL, + instance_type=InstanceType( + name="local", resources=Resources(cpus=1, memory_mib=1024, gpus=[], spot=False) + ), + instance_id="0000-0000", + hostname="localhost", + region="", + price=0.0, + username="root", + ssh_port=22, + dockerized=False, + pool_id="", + backend_data=None, + ) + job = await create_job( + session, + run=run, + job_provisioning_data=job_provisioning_data, + ) + await process_submitted_jobs() + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PROVISIONING + + res = await session.execute(select(JobModel).where()) + jm = res.all()[0][0] + assert jm.job_num == 0 + assert jm.run_name == "test-run" + assert jm.job_name == "test-run-0" + assert jm.submission_num == 0 + assert jm.status == JobStatus.PROVISIONING + assert jm.error_code == None + assert ( + jm.job_spec_data + == r"""{"job_num": 0, "job_name": "test-run-0", "app_specs": [], "commands": ["/bin/bash", "-i", "-c", "(echo pip install ipykernel... && pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo \"no pip, ipykernel was not installed\" && echo '' && echo To open in VS Code Desktop, use link below: && echo '' && echo ' vscode://vscode-remote/ssh-remote+test-run/workflow' && echo '' && echo 'To connect via SSH, use: `ssh test-run`' && echo '' && echo -n 'To exit, press Ctrl+C.' && tail -f /dev/null"], "env": {}, "gateway": null, "home_dir": "/root", "image_name": "dstackai/base:py3.10-0.4rc4-cuda-12.1", "max_duration": 21600, "registry_auth": null, "requirements": {"resources": {"cpu": {"min": 2, "max": null}, "memory": {"min": 8.0, "max": null}, "shm_size": null, "gpu": null, "disk": null}, "max_price": null, "spot": false}, "retry_policy": {"retry": false, "limit": null}, "working_dir": ".", "pool_name": "default-pool"}""" + ) + assert jm.job_provisioning_data == ( + '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": ' + '{"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": ' + '{"size_mib": 102400}, "description": ""}}, "instance_id": ' + '"running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", ' + '"hostname": "running_instance.ip", "region": "running_instance.location", ' + '"price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, ' + '"backend_data": null}' + ) diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 08688261e..9e1b93b70 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -39,7 +39,7 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A run=run, status=JobStatus.DONE, job_provisioning_data=JobProvisioningData( - backend=BackendType.LOCAL, + backend=BackendType.AWS, instance_type=InstanceType( name="local", resources=Resources(cpus=1, memory_mib=1024, gpus=[], spot=False) ), @@ -50,6 +50,8 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A username="root", ssh_port=22, dockerized=False, + pool_id="", + backend_data=None, ), ) with patch(f"{MODULE}.terminate_job_submission_instance") as terminate: diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index f720f2e20..4d30518f7 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -1,16 +1,26 @@ import datetime as dt import uuid +from unittest.mock import patch import pytest from sqlalchemy.ext.asyncio import AsyncSession import dstack._internal.server.services.pools as services_pools import dstack._internal.server.services.projects as services_projects +import dstack._internal.server.services.runs as runs import dstack._internal.server.services.users as services_users +from dstack._internal.core.models import resources from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.instances import InstanceType, Resources +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + InstanceType, + LaunchedInstanceInfo, + Resources, +) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.runs import InstanceStatus +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.runs import InstanceStatus, Requirements from dstack._internal.core.models.users import GlobalRole from dstack._internal.server.models import InstanceModel from dstack._internal.server.testing.common import create_project, create_user @@ -68,7 +78,7 @@ def test_convert_instance(): status=InstanceStatus.PENDING, project_id=str(uuid.uuid4()), pool=None, - job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + job_provisioning_data='{"pool_id":"123", "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', ) @@ -104,7 +114,7 @@ async def test_show_pool(session: AsyncSession, test_db): project=project, pool=pool, status=InstanceStatus.PENDING, - job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + job_provisioning_data='{"pool_id":"123", "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', ) session.add(im) @@ -162,3 +172,84 @@ async def test_generate_instance_name(session: AsyncSession, test_db): car, _, cdr = name.partition("-") assert len(car) > 0 assert len(cdr) > 0 + + +@pytest.mark.asyncio +async def test_pool_double_name(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool1 = await services_pools.create_pool_model( + session=session, project=project, name="test_pool" + ) + with pytest.raises(ValueError): + + pool2 = await services_pools.create_pool_model( + session=session, project=project, name="test_pool" + ) + + +@pytest.mark.asyncio +async def test_create_cloud_instance(session: AsyncSession, test_db): + user = await create_user(session) + project = await create_project(session, user) + + profile = Profile(name="test_profile") + + requirements = Requirements(resources=resources.Resources(cpu=1), spot=True) + + offer = InstanceOfferWithAvailability( + backend=BackendType.DATACRUNCH, + instance=InstanceType( + name="instance", resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]) + ), + region="en", + price=0.1, + availability=InstanceAvailability.AVAILABLE, + ) + + launched_instance = LaunchedInstanceInfo( + instance_id="running_instance.id", + ip_address="running_instance.ip", + region="running_instance.location", + ssh_port=22, + username="root", + dockerized=True, + backend_data=None, + ) + + class DummyBackend: + TYPE = BackendType.DATACRUNCH + + def compute(self): + return self + + def create_instance(self, *args, **kwargs): + return launched_instance + + offers = [(DummyBackend(), offer)] + + with patch("dstack._internal.server.services.runs.get_run_plan_by_requirements") as reqs: + reqs.return_value = offers + await runs.create_instance( + session, + project, + user, + profile=profile, + pool_name="test_pool", + instance_name="test_instance", + requirements=requirements, + ) + + pool = await services_pools.get_pool(session, project, "test_pool") + assert pool is not None + instance = pool.instances[0] + + assert instance.name == "test_instance" + assert instance.deleted == False + assert instance.deleted_at == None + + # assert instance.job_provisioning_data == '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}' + assert ( + instance.offer + == '{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}' + ) From 6229341c5371ac268bc89b9353d1eb9df3440bf5 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 31 Jan 2024 14:32:49 +0300 Subject: [PATCH 08/47] update --- .pre-commit-config.yaml | 8 ++ .../src/dstack/gateway/registry/schemas.py | 2 - src/dstack/_internal/cli/commands/pool.py | 111 ++++++++++++++---- src/dstack/_internal/core/models/pools.py | 1 - src/dstack/_internal/core/models/resources.py | 2 +- .../server/background/tasks/process_pools.py | 8 +- .../background/tasks/process_running_jobs.py | 18 ++- .../tasks/process_submitted_jobs.py | 66 +++++++---- ...add_pools.py => 309e4be6671b_add_pools.py} | 6 +- src/dstack/_internal/server/models.py | 3 + src/dstack/_internal/server/routers/pools.py | 5 +- src/dstack/_internal/server/routers/runs.py | 3 +- src/dstack/_internal/server/schemas/runs.py | 3 +- .../server/services/gateways/pool.py | 7 +- src/dstack/_internal/server/services/pools.py | 26 ++-- src/dstack/_internal/server/services/runs.py | 5 +- src/dstack/_internal/server/testing/common.py | 3 + src/dstack/api/_public/runs.py | 2 +- src/dstack/api/server/_pools.py | 15 ++- src/dstack/api/server/_runs.py | 3 +- .../tasks/test_process_submitted_jobs.py | 22 ++-- .../_internal/server/services/test_pools.py | 2 +- 22 files changed, 226 insertions(+), 95 deletions(-) rename src/dstack/_internal/server/migrations/versions/{6a084acc1211_add_pools.py => 309e4be6671b_add_pools.py} (97%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4966e262e..e9ce6cb8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,11 @@ repos: files: '.*pools?\.py' exclude: 'versions|src/tests' additional_dependencies: [types-PyYAML] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.14 + hooks: + - id: ruff + name: ruff autofix + args: ['--fix', '--select', 'F401'] + files: 'runs?\.py|pools?\.py' + exclude: 'versions|src/tests' diff --git a/gateway/src/dstack/gateway/registry/schemas.py b/gateway/src/dstack/gateway/registry/schemas.py index 337356243..9ff3a11aa 100644 --- a/gateway/src/dstack/gateway/registry/schemas.py +++ b/gateway/src/dstack/gateway/registry/schemas.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel import dstack.gateway.schemas diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index be69e2c13..0551eb1c9 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -2,8 +2,6 @@ from collections.abc import Sequence from pathlib import Path -import yaml -from pydantic import parse_obj_as from rich.table import Table from dstack._internal.cli.commands import APIBaseCommand @@ -11,22 +9,29 @@ apply_profile_args, register_profile_args, ) -from dstack._internal.cli.services.configurators.run import BaseRunConfigurator from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.core.errors import CLIError, ServerClientError -from dstack._internal.core.models.configurations import parse as parse_configuration from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, ) from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy -from dstack._internal.core.models.resources import GPU, Resources +from dstack._internal.core.models.resources import ( + DEFAULT_CPU_COUNT, + DEFAULT_MEMORY_SIZE, + DiskLike, + GPULike, + IntRangeLike, + MemoryLike, + MemoryRangeLike, +) from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import pretty_date from dstack._internal.utils.logging import get_logger -from dstack.api.utils import load_configuration, load_profile +from dstack.api._public.resources import Resources +from dstack.api.utils import load_profile logger = get_logger(__name__) @@ -157,6 +162,56 @@ def th(s: str) -> str: console.print() +def register_resource_args(parser: argparse.ArgumentParser) -> None: + resources_group = parser.add_argument_group("Resources") + resources_group.add_argument( + "--cpu", + type=IntRangeLike, + help=f"Request the CPU count. Default: '{DEFAULT_CPU_COUNT.min}..'", + dest="cpu", + metavar="SPEC", + default=DEFAULT_CPU_COUNT, + ) + + resources_group.add_argument( + "--memory", + type=MemoryRangeLike, + help="Request the size of RAM. " + f"The format is [code]SIZE[/]:[code]MB|GB|TB[/]. Default: {DEFAULT_MEMORY_SIZE.min}", + dest="memory", + metavar="SIZE", + default=DEFAULT_MEMORY_SIZE, + ) + + resources_group.add_argument( + "--shared-memory", + type=MemoryLike, + help="Request the size of Shared Memory. The format is [code]SIZE[/]:[code]MB|GB|TB[/].", + dest="shared_memory", + default=None, + metavar="SIZE", + ) + + resources_group.add_argument( + "--gpu", + type=GPULike, + help="Request GPU for the run. " + "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)", + dest="gpu", + default=None, + metavar="SPEC", + ) + + resources_group.add_argument( + "--disk", + type=DiskLike, + help="Request the size of disk for the run. Example [code]--disk 100GB[/].", + dest="disk", + metavar="SIZE", + default=None, + ) + + class PoolCommand(APIBaseCommand): # type: ignore[misc] NAME = "pool" DESCRIPTION = "Pool management" @@ -167,7 +222,7 @@ def _register(self) -> None: subparsers = self._parser.add_subparsers(dest="action") - # list + # list pools list_parser = subparsers.add_parser( "list", help="List pools", @@ -177,14 +232,14 @@ def _register(self) -> None: list_parser.add_argument("-v", "--verbose", help="Show more information") list_parser.set_defaults(subfunc=self._list) - # create + # create pool create_parser = subparsers.add_parser( "create", help="Create pool", formatter_class=self._parser.formatter_class ) create_parser.add_argument("-n", "--name", dest="pool_name", help="The name of the pool") create_parser.set_defaults(subfunc=self._create) - # delete + # delete pool delete_parser = subparsers.add_parser( "delete", help="Delete pool", formatter_class=self._parser.formatter_class ) @@ -196,7 +251,7 @@ def _register(self) -> None: ) delete_parser.set_defaults(subfunc=self._delete) - # show + # show pool instances show_parser = subparsers.add_parser( "show", help="Show pool instances", @@ -208,7 +263,7 @@ def _register(self) -> None: ) show_parser.set_defaults(subfunc=self._show) - # add + # add instance add_parser = subparsers.add_parser( "add", help="Add instance to pool", formatter_class=self._parser.formatter_class ) @@ -231,10 +286,10 @@ def _register(self) -> None: ) add_parser.add_argument("--name", dest="instance_name", help="The name of the instance") register_profile_args(add_parser) - BaseRunConfigurator.register(add_parser) + register_resource_args(add_parser) add_parser.set_defaults(subfunc=self._add) - # remove + # remove instance remove_parser = subparsers.add_parser( "remove", help="Remove instance from the pool", @@ -271,31 +326,39 @@ def _add(self, args: argparse.Namespace) -> None: pool_name: str = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name + resources = Resources( + cpu=args.cpu, + memory=args.memory, + gpu=args.gpu, + shm_size=args.shared_memory, + disk=args.disk, + ) + requirements = Requirements( + resources=resources, + max_price=args.max_price, + spot=(args.spot_policy == SpotPolicy.SPOT), + ) + # Add remote instance if args.remote: self.api.client.pool.add( - self.api.project, pool_name, args.instance_name, args.remote_host, args.remote_port + self.api.project, + resources, + pool_name, + args.instance_name, + args.remote_host, + args.remote_port, ) return repo = self.api.repos.load(Path.cwd()) self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path - # TODO: read requirements from repo configuration - # conf = parse_configuration(yaml.safe_load(conf_path.open())) - # resources = conf.resources - profile = load_profile(Path.cwd(), args.profile) apply_profile_args(args, profile) profile.pool_name = pool_name - requirements = Requirements( - resources=Resources(gpu=parse_obj_as(GPU, args.gpu_spec)), - max_price=args.max_price, - spot=(args.spot_policy == SpotPolicy.SPOT), - ) - with console.status("Getting instances..."): offers = self.api.runs.get_offers(profile, requirements) diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index c62bd0f9c..204b1cca1 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -1,5 +1,4 @@ import datetime -from typing import List, Optional from pydantic import BaseModel diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index 788425442..c3ac0d047 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -191,7 +191,7 @@ class ResourcesSpec(ForbidExtra): cpu (Optional[Range[int]]): The number of CPUs memory (Optional[Range[Memory]]): The size of RAM memory (e.g., `"16GB"`) gpu (Optional[GPUSpec]): The GPU spec - shm_size (Optional[Range[Memory]]): The of shared memory (e.g., `"8GB"`). If you are using parallel communicating processes (e.g., dataloaders in PyTorch), you may need to configure this. + shm_size (Optional[Range[Memory]]): The size of shared memory (e.g., `"8GB"`). If you are using parallel communicating processes (e.g., dataloaders in PyTorch), you may need to configure this. disk (Optional[DiskSpec]): The disk spec """ diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index a7d7c8d5b..bf71aa63d 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -3,16 +3,14 @@ from uuid import UUID from pydantic import parse_raw_as -from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import InstanceModel, JobModel +from dstack._internal.server.models import InstanceModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.jobs import PROCESSING_POOL_IDS, PROCESSING_POOL_LOCK -from dstack._internal.server.services.logging import job_log from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.utils.common import run_async diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 8f50e6f2e..c5ec1e2f4 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -172,6 +172,14 @@ async def _process_job(job_id: UUID): ) job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.WAITING_RUNNER_LIMIT_EXCEEDED + + instance_name: str = job_provisioning_data.instance_id + pool_name = str(job.job_spec.pool_name) + instances = await get_pool_instances(session, project, pool_name) + for inst in instances: + if inst.name == instance_name: + inst.status = InstanceStatus.READY # TODO: or fail? + else: # fails are not acceptable if initial_status == JobStatus.PULLING: logger.debug( @@ -208,7 +216,7 @@ async def _process_job(job_id: UUID): job_model, ) - if success: + if success and job_model.status == JobStatus.DONE: instance_name: str = job_provisioning_data.instance_id pool_name = str(job.job_spec.pool_name) instances = await get_pool_instances(session, project, pool_name) @@ -234,6 +242,14 @@ async def _process_job(job_id: UUID): status=JobStatus.PENDING, ) session.add(new_job_model) + + instance_name: str = job_provisioning_data.instance_id + pool_name = str(job.job_spec.pool_name) + instances = await get_pool_instances(session, project, pool_name) + for inst in instances: + if inst.name == instance_name: + inst.status = InstanceStatus.READY + # job will be terminated by process_finished_jobs if ( diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 9d039a972..1d095bdc7 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -13,20 +13,14 @@ InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import ( - DEFAULT_POOL_NAME, - CreationPolicy, - Profile, - TerminationPolicy, -) -from dstack._internal.core.models.resources import Range, Resources +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile +from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( InstanceStatus, Job, JobErrorCode, JobProvisioningData, JobStatus, - Requirements, Run, RunSpec, ) @@ -41,9 +35,7 @@ from dstack._internal.server.services.pools import ( get_pool_instances, instance_model_to_instance, - list_project_pool, list_project_pool_models, - show_pool, ) from dstack._internal.server.services.runs import run_model_to_run from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED @@ -87,37 +79,61 @@ async def _process_job(job_id: UUID): ) -def check_relevance(profile: Profile, resources: Resources, instance_model: InstanceModel) -> bool: +def check_relevance( + profile: Profile, resources: ResourcesSpec, instance_model: InstanceModel +) -> bool: jpd: JobProvisioningData = parse_raw_as( JobProvisioningData, instance_model.job_provisioning_data ) + # TODO: remove on prod if LOCAL_BACKEND_ENABLED and jpd.backend == BackendType.LOCAL: return True instance = instance_model_to_instance(instance_model) if profile.backends is not None and instance.backend not in profile.backends: - logger.warning(f"no backnd select ") + logger.warning(f"no backend select ") + return False + + instance_resources: ResourcesSpec = parse_raw_as( + ResourcesSpec, instance_model.resource_spec_data + ) + + if resources.cpu.min > instance_resources.cpu.min: + return False + + if resources.gpu is not None: + + if instance_resources.gpu is None: + return False + + if resources.gpu.count.min > instance_resources.gpu.count.min: + return False + + if resources.gpu.memory.min > instance_resources.gpu.memory.min: + return False + + # TODO: compare GPU names + + if resources.memory.min > instance_resources.memory.min: return False - # instance_resources = jpd.instance_type.resources + if resources.shm_size is not None: + if instance_resources.shm_size is None: + return False - # TODO: full check requirements - # if isinstance(requirements.resources.cpu, Range): - # if requirements.resources.cpu.min < int(instance_resources.cpus): - # return False + if resources.shm_size > instance_resources.shm_size: + return False - # if isinstance(requirements.resources.gpu, Range): - # if requirements.resources.gpu.min < int(instance_resources.cpus): - # return False - # if isinstance(int(requirements.resources.gpu), int): - # if requirements.resources.gpu < len(instance_resources.gpus): - # return False + if resources.disk is not None: + if instance_resources.disk is None: + return False + if resources.disk.size.min > instance_resources.disk.size.min: + return False return True - # TODO: memory, shm_size, disk async def _process_submitted_job(session: AsyncSession, job_model: JobModel): @@ -216,7 +232,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): name=job.job_spec.job_name, project=project_model, pool=pool, - status=InstanceStatus.CREATING, + status=InstanceStatus.BUSY, job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), termination_policy=profile.termination_policy, diff --git a/src/dstack/_internal/server/migrations/versions/6a084acc1211_add_pools.py b/src/dstack/_internal/server/migrations/versions/309e4be6671b_add_pools.py similarity index 97% rename from src/dstack/_internal/server/migrations/versions/6a084acc1211_add_pools.py rename to src/dstack/_internal/server/migrations/versions/309e4be6671b_add_pools.py index e8ec39874..eba3d08cd 100644 --- a/src/dstack/_internal/server/migrations/versions/6a084acc1211_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/309e4be6671b_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: 6a084acc1211 +Revision ID: 309e4be6671b Revises: d3e8af4786fa -Create Date: 2024-01-25 10:23:18.983726 +Create Date: 2024-01-31 10:35:34.977788 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "6a084acc1211" +revision = "309e4be6671b" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 627a2c368..e1c68af6e 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -283,4 +283,7 @@ class InstanceModel(BaseModel): termination_idle_time: Mapped[Optional[str]] = mapped_column(String(50)) job_provisioning_data: Mapped[str] = mapped_column(String(4000)) + offer: Mapped[str] = mapped_column(String(4000)) + + resource_spec_data: Mapped[Optional[str]] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index a7aa5dad9..eb08667d5 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -7,7 +7,7 @@ import dstack._internal.server.schemas.pools as schemas import dstack._internal.server.services.pools as pools from dstack._internal.server.db import get_session -from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel, UserModel +from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember from dstack._internal.server.services.runs import ( @@ -99,7 +99,8 @@ async def add_instance( _, project = user_project await pools.add( session, - project, + project=project, + resources=body.resources, pool_name=body.pool_name, instance_name=body.instance_name, host=body.host, diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index de2d5ba12..416f9c111 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -4,11 +4,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.instances import InstanceOfferWithAvailability -from dstack._internal.core.models.runs import Requirements, Run, RunPlan +from dstack._internal.core.models.runs import Run, RunPlan from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( - AddRemoteInstanceRequest, CreateInstanceRequest, DeleteRunsRequest, GetOffersRequest, diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 39e2e91e7..1b6bda262 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import Requirements, RunSpec @@ -35,7 +36,7 @@ class AddRemoteInstanceRequest(BaseModel): instance_name: Optional[str] host: str port: str - # TODO: define runner spec (gpu, cpu, etc) + resources: ResourcesSpec class SubmitRunRequest(BaseModel): diff --git a/src/dstack/_internal/server/services/gateways/pool.py b/src/dstack/_internal/server/services/gateways/pool.py index 86213d210..fd3a96d55 100644 --- a/src/dstack/_internal/server/services/gateways/pool.py +++ b/src/dstack/_internal/server/services/gateways/pool.py @@ -1,7 +1,6 @@ import asyncio from typing import Dict, List, Optional -from dstack._internal.core.errors import SSHError from dstack._internal.server.services.gateways.connection import GatewayConnection from dstack._internal.utils.logging import get_logger @@ -9,7 +8,7 @@ class GatewayConnectionsPool: - def __init__(self): + def __init__(self) -> None: self._connections: Dict[str, GatewayConnection] = {} self._lock = asyncio.Lock() self.server_port: Optional[int] = None @@ -40,7 +39,7 @@ async def remove(self, hostname: str) -> bool: await stop_task return True - async def remove_all(self): + async def remove_all(self) -> None: async with self._lock: await asyncio.gather( *(conn.tunnel.stop() for conn in self._connections.values()), @@ -55,4 +54,4 @@ async def all(self) -> List[GatewayConnection]: return list(self._connections.values()) -gateway_connections_pool = GatewayConnectionsPool() +gateway_connections_pool: GatewayConnectionsPool = GatewayConnectionsPool() diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 529d3dd22..63e2ae4bb 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -1,8 +1,6 @@ import asyncio -from contextlib import asynccontextmanager, contextmanager from datetime import timezone -from typing import Dict, List, Optional, Sequence, Union -from uuid import UUID +from typing import Dict, List, Optional, Sequence from pydantic import parse_raw_as from sqlalchemy import select @@ -11,14 +9,15 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( + Gpu, InstanceAvailability, InstanceOfferWithAvailability, - InstanceState, InstanceType, Resources, ) from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME +from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel from dstack._internal.utils import random_names @@ -238,6 +237,7 @@ async def generate_instance_name( async def add( session: AsyncSession, + resources: ResourcesSpec, project: ProjectModel, pool_name: str, instance_name: Optional[str], @@ -262,11 +262,19 @@ async def add( if pool is None: pool = await create_pool_model(session, project, pool_name) + gpus = [] + if resources.gpu is not None: + gpus = [ + Gpu(name=resources.gpu.name, memory_mib=resources.gpu.memory) + ] * resources.gpu.count.min + + instance_resource = Resources( + cpus=resources.cpu.min, memory_mib=resources.memory.min, gpus=gpus, spot=False + ) + local = JobProvisioningData( backend=BackendType.LOCAL, - instance_type=InstanceType( - name="local", resources=Resources(cpus=0, memory_mib=0, gpus=[], spot=False) - ), + instance_type=InstanceType(name="local", resources=instance_resource), instance_id=instance_name, hostname=host, region="", @@ -276,12 +284,13 @@ async def add( dockerized=False, backend_data="", pool_id=str(pool.id), + ssh_proxy=None, ) offer = InstanceOfferWithAvailability( backend=BackendType.LOCAL, instance=InstanceType( name="instance", - resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + resources=instance_resource, ), region="", price=0.0, @@ -295,6 +304,7 @@ async def add( status=InstanceStatus.PENDING, job_provisioning_data=local.json(), offer=offer.json(), + resource_spec_data=resources.json(), ) session.add(im) await session.commit() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 10ed3f61f..951ca72fe 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -23,10 +23,8 @@ InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy -from dstack._internal.core.models.resources import GPU, ComputeCapability, Memory, Range, Resources +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile from dstack._internal.core.models.runs import ( - GpusRequirements, InstanceStatus, Job, JobPlan, @@ -244,6 +242,7 @@ async def create_instance( ssh_port=launched_instance_info.ssh_port, dockerized=launched_instance_info.dockerized, backend_data=launched_instance_info.backend_data, + ssh_proxy=None, ) im = InstanceModel( diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 10880b16d..4fbee673b 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -12,6 +12,7 @@ from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.repos.local import LocalRunRepoData +from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( InstanceStatus, JobErrorCode, @@ -304,6 +305,7 @@ async def create_instance( project: ProjectModel, pool: PoolModel, status: InstanceStatus, + resources: ResourcesSpec, ) -> InstanceModel: im = InstanceModel( name="test_instance", @@ -312,6 +314,7 @@ async def create_instance( status=status, job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', + resource_spec_data=resources.json(), ) session.add(im) await session.commit() diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 449efe55a..33ae57ca8 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -7,7 +7,7 @@ from copy import copy from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Union import requests from websocket import WebSocketApp diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index b46342c89..cbfef5916 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -4,6 +4,7 @@ import dstack._internal.server.schemas.pools as schemas_pools from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest from dstack.api.server._group import APIClientGroup @@ -35,9 +36,19 @@ def remove(self, project_name: str, pool_name: str, instance_name: str) -> None: self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) def add( - self, project_name: str, pool_name: str, instance_name: Optional[str], host: str, port: str + self, + project_name: str, + resources: ResourcesSpec, + pool_name: str, + instance_name: Optional[str], + host: str, + port: str, ) -> None: body = AddRemoteInstanceRequest( - pool_name=pool_name, instance_name=instance_name, host=host, port=port + pool_name=pool_name, + instance_name=instance_name, + host=host, + port=port, + resources=resources, ) self._request(f"/api/project/{project_name}/pool/add", body=body.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index d500835a3..da24b8d0f 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional from pydantic import parse_obj_as @@ -6,7 +6,6 @@ from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( - AddRemoteInstanceRequest, CreateInstanceRequest, DeleteRunsRequest, GetOffersRequest, diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 971d9f52d..e5e148ef0 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -28,6 +28,7 @@ create_user, get_run_spec, ) +from dstack.api._public.resources import Resources as MakeResources class TestProcessSubmittedJobs: @@ -212,7 +213,8 @@ async def test_job_whith_instance(self, test_db, session: AsyncSession): pool = pool_item if pool is None: pool = await create_pool(session, project) - im = await create_instance(session, project, pool, InstanceStatus.READY) + resources = MakeResources(cpu=2, memory="12GB") + await create_instance(session, project, pool, InstanceStatus.READY, resources) await session.refresh(pool) run = await create_run( session, @@ -223,7 +225,8 @@ async def test_job_whith_instance(self, test_db, session: AsyncSession): job_provisioning_data = JobProvisioningData( backend=BackendType.LOCAL, instance_type=InstanceType( - name="local", resources=Resources(cpus=1, memory_mib=1024, gpus=[], spot=False) + name="local", + resources=Resources(cpus=2, memory_mib=12 * 1024, gpus=[], spot=False), ), instance_id="0000-0000", hostname="localhost", @@ -234,12 +237,17 @@ async def test_job_whith_instance(self, test_db, session: AsyncSession): dockerized=False, pool_id="", backend_data=None, + ssh_proxy=None, ) - job = await create_job( - session, - run=run, - job_provisioning_data=job_provisioning_data, - ) + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as PyVersion: + PyVersion.return_value = "3.10" + job = await create_job( + session, + run=run, + job_provisioning_data=job_provisioning_data, + ) await process_submitted_jobs() await session.refresh(job) assert job is not None diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index 4d30518f7..280fbb53f 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -195,7 +195,7 @@ async def test_create_cloud_instance(session: AsyncSession, test_db): profile = Profile(name="test_profile") - requirements = Requirements(resources=resources.Resources(cpu=1), spot=True) + requirements = Requirements(resources=resources.ResourcesSpec(cpu=1), spot=True) offer = InstanceOfferWithAvailability( backend=BackendType.DATACRUNCH, From e2461484c6bc51801e871624958ae8bbe40706a2 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Thu, 1 Feb 2024 20:54:09 +0300 Subject: [PATCH 09/47] debug --- .pre-commit-config.yaml | 4 +- src/dstack/_internal/cli/commands/pool.py | 52 ++++--- .../_internal/core/backends/aws/compute.py | 19 +-- .../_internal/core/backends/base/compute.py | 139 +++++++++++++++--- .../_internal/core/backends/base/offers.py | 1 + .../core/backends/datacrunch/api_client.py | 3 +- .../core/backends/datacrunch/compute.py | 17 ++- .../_internal/core/backends/gcp/compute.py | 15 +- .../_internal/core/models/backends/base.py | 2 +- .../_internal/core/services/ssh/attach.py | 2 +- .../background/tasks/process_finished_jobs.py | 47 ++---- .../server/background/tasks/process_pools.py | 32 +++- .../background/tasks/process_running_jobs.py | 55 ++----- .../tasks/process_submitted_jobs.py | 5 + ...add_pools.py => 718bf16e84c5_add_pools.py} | 23 ++- src/dstack/_internal/server/models.py | 21 ++- src/dstack/_internal/server/routers/pools.py | 19 ++- src/dstack/_internal/server/schemas/pools.py | 4 + src/dstack/_internal/server/schemas/runs.py | 2 +- .../server/services/backends/__init__.py | 1 + .../server/services/jobs/__init__.py | 18 ++- src/dstack/_internal/server/services/pools.py | 46 ++++-- src/dstack/_internal/server/services/runs.py | 25 +++- src/dstack/_internal/server/testing/common.py | 2 + src/dstack/api/server/_pools.py | 17 ++- .../tasks/test_process_running_jobs.py | 12 +- .../tasks/test_process_terminating_jobs.py | 15 +- 27 files changed, 413 insertions(+), 185 deletions(-) rename src/dstack/_internal/server/migrations/versions/{309e4be6671b_add_pools.py => 718bf16e84c5_add_pools.py} (83%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9ce6cb8b..8aba21e1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,12 +21,12 @@ repos: args: ['--strict', '--follow-imports=skip', '--ignore-missing-imports', '--python-version=3.8'] files: '.*pools?\.py' exclude: 'versions|src/tests' - additional_dependencies: [types-PyYAML] + additional_dependencies: [types-PyYAML, types-requests, pydantic,sqlalchemy] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.14 hooks: - id: ruff name: ruff autofix args: ['--fix', '--select', 'F401'] - files: 'runs?\.py|pools?\.py' + files: 'process_.*\.py|runs?\.py|pools?\.py' exclude: 'versions|src/tests' diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 0551eb1c9..dca628645 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -9,7 +9,7 @@ apply_profile_args, register_profile_args, ) -from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.cli.utils.common import colors, confirm_ask, console from dstack._internal.core.errors import CLIError, ServerClientError from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -17,15 +17,7 @@ ) from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy -from dstack._internal.core.models.resources import ( - DEFAULT_CPU_COUNT, - DEFAULT_MEMORY_SIZE, - DiskLike, - GPULike, - IntRangeLike, - MemoryLike, - MemoryRangeLike, -) +from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import pretty_date @@ -166,7 +158,6 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: resources_group = parser.add_argument_group("Resources") resources_group.add_argument( "--cpu", - type=IntRangeLike, help=f"Request the CPU count. Default: '{DEFAULT_CPU_COUNT.min}..'", dest="cpu", metavar="SPEC", @@ -175,7 +166,6 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: resources_group.add_argument( "--memory", - type=MemoryRangeLike, help="Request the size of RAM. " f"The format is [code]SIZE[/]:[code]MB|GB|TB[/]. Default: {DEFAULT_MEMORY_SIZE.min}", dest="memory", @@ -185,7 +175,6 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: resources_group.add_argument( "--shared-memory", - type=MemoryLike, help="Request the size of Shared Memory. The format is [code]SIZE[/]:[code]MB|GB|TB[/].", dest="shared_memory", default=None, @@ -194,7 +183,6 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: resources_group.add_argument( "--gpu", - type=GPULike, help="Request GPU for the run. " "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)", dest="gpu", @@ -204,7 +192,6 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: resources_group.add_argument( "--disk", - type=DiskLike, help="Request the size of disk for the run. Example [code]--disk 100GB[/].", dest="disk", metavar="SIZE", @@ -303,6 +290,17 @@ def _register(self) -> None: ) remove_parser.set_defaults(subfunc=self._remove) + # pool set-default + set_default_parser = subparsers.add_parser( + "set-default", + help="Set the project's default pool", + formatter_class=self._parser.formatter_class, + ) + set_default_parser.add_argument( + "--pool", dest="pool_name", help="The name of the pool", required=True + ) + set_default_parser.set_defaults(subfunc=self._set_default) + def _list(self, args: argparse.Namespace) -> None: pools = self.api.client.pool.list(self.api.project) print_pool_table(pools, verbose=getattr(args, "verbose", False)) @@ -316,6 +314,13 @@ def _delete(self, args: argparse.Namespace) -> None: def _remove(self, args: argparse.Namespace) -> None: self.api.client.pool.remove(self.api.project, args.pool_name, args.instance_name) + def _set_default(self, args: argparse.Namespace) -> None: + result = self.api.client.pool.set_default(self.api.project, args.pool_name) + if not result: + console.print( + f"[{colors['error']}]Failed to set default pool {args.pool_name!r}[/{colors['code']}]" + ) + def _show(self, args: argparse.Namespace) -> None: instances = self.api.client.pool.show(self.api.project, args.pool_name) print_instance_table(instances) @@ -339,26 +344,29 @@ def _add(self, args: argparse.Namespace) -> None: spot=(args.spot_policy == SpotPolicy.SPOT), ) + profile = load_profile(Path.cwd(), args.profile) + apply_profile_args(args, profile) + profile.pool_name = pool_name + # Add remote instance if args.remote: - self.api.client.pool.add( + result = self.api.client.pool.add_remote( self.api.project, resources, - pool_name, + profile, args.instance_name, args.remote_host, args.remote_port, ) + if not result: + console.print( + f"[{colors['error']}]Failed to add remote instance {args.instance_name!r}[/{colors['code']}]" + ) return repo = self.api.repos.load(Path.cwd()) self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path - profile = load_profile(Path.cwd(), args.profile) - apply_profile_args(args, profile) - - profile.pool_name = pool_name - with console.status("Getting instances..."): offers = self.api.runs.get_offers(profile, requirements) diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index d36b3f426..a99067905 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -14,6 +14,7 @@ InstanceConfiguration, get_gateway_user_data, get_instance_name, + get_instance_user_data, get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -87,7 +88,7 @@ def get_quotas(client: botocore.client.BaseClient) -> Dict[str, int]: def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: client = self.session.client("ec2", region_name=region) try: client.terminate_instances(InstanceIds=[instance_id]) @@ -110,10 +111,10 @@ def create_instance( iam_client = self.session.client("iam", region_name=instance_offer.region) tags = [ - {"Key": "Name", "Value": run.run_spec.run_name}, + {"Key": "Name", "Value": instance_config.instance_name}, {"Key": "owner", "Value": "dstack"}, {"Key": "dstack_project", "Value": project_id}, - {"Key": "dstack_user", "Value": run.user}, + {"Key": "dstack_user", "Value": instance_config.user}, ] try: disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) @@ -127,13 +128,8 @@ def create_instance( iam_instance_profile_arn=aws_resources.create_iam_instance_profile( iam_client, project_id ), - user_data=get_user_data( - backend=BackendType.AWS, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], + user_data=get_instance_user_data( + authorized_keys=instance_config.get_public_keys(), ), tags=tags, security_group_id=aws_resources.create_security_group(ec2_client, project_id), @@ -198,6 +194,7 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], + registry_auth_required=job.job_spec.registry_auth is not None, ), tags=tags, security_group_id=aws_resources.create_security_group(ec2_client, project_id), @@ -266,7 +263,7 @@ def create_gateway( ) -def _has_quota(quotas: Dict[str, float], instance_name: str) -> bool: +def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool: if instance_name.startswith("p"): return quotas.get("P/OnDemand", 0) > 0 if instance_name.startswith("g"): diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 478f7f558..deefbf1f5 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1,7 +1,7 @@ import os import re from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Any, Dict, List, Optional import git import requests @@ -40,8 +40,9 @@ class InstanceConfiguration(BaseModel): instance_name: str # unique in pool ssh_keys: List[SSHKeys] job_docker_config: Optional[DockerConfig] + user: Optional[str] - def get_public_keys(self): + def get_public_keys(self) -> List[str]: return [ssh_key.public.strip() for ssh_key in self.ssh_keys] @@ -76,7 +77,7 @@ def create_instance( @abstractmethod def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: pass def create_gateway( @@ -94,11 +95,17 @@ def get_instance_name(run: Run, job: Job) -> str: def get_user_data( + backend: BackendType, + image_name: str, authorized_keys: List[str], + registry_auth_required: bool, cloud_config_kwargs: Optional[dict] = None, ) -> str: commands = get_shim_commands( + backend=backend, + image_name=image_name, authorized_keys=authorized_keys, + registry_auth_required=registry_auth_required, ) return get_cloud_config( runcmd=[["sh", "-c", " && ".join(commands)]], @@ -108,19 +115,24 @@ def get_user_data( def get_shim_commands( + backend: BackendType, + image_name: str, authorized_keys: List[str], + registry_auth_required: bool, ) -> List[str]: build = get_dstack_runner_version() env = { + "DSTACK_BACKEND": backend.value, "DSTACK_RUNNER_LOG_LEVEL": "6", "DSTACK_RUNNER_VERSION": build, + "DSTACK_IMAGE_NAME": image_name, "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), "DSTACK_HOME": "/root/.dstack", } commands = get_dstack_shim(build) for k, v in env.items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script() + commands += get_run_shim_script(registry_auth_required) return commands @@ -139,20 +151,18 @@ def get_dstack_shim(build: str) -> List[str]: if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" - url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" - - if os.getenv("DEV_DSTACK_RUNNER", None) is not None: - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" - return [ - f'sudo curl --output /usr/local/bin/dstack-shim "{url}"', + f'sudo curl --output /usr/local/bin/dstack-shim "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"', "sudo chmod +x /usr/local/bin/dstack-shim", ] -def get_run_shim_script() -> List[str]: +def get_run_shim_script(registry_auth_required: bool) -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" - return [f"nohup dstack-shim {dev_flag} docker --keep-container >/root/shim.log 2>&1 &"] + with_auth_flag = "--with-auth" if registry_auth_required else "" + return [ + f"nohup dstack-shim {dev_flag} docker {with_auth_flag} --keep-container >/root/shim.log 2>&1 &" + ] def get_gateway_user_data(authorized_key: str) -> str: @@ -210,14 +220,8 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: bucket = "dstack-runner-downloads-stgn" if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" - - url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" - - if os.getenv("DEV_DSTACK_RUNNER", None) is not None: - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" - commands += [ - f"curl --output {runner} {url}", + f'curl --output {runner} "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"', f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] @@ -277,3 +281,100 @@ def get_dstack_gateway_commands() -> List[str]: f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}", "sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run", ] + + +def get_instance_user_data( + authorized_keys: List[str], + cloud_config_kwargs: Optional[Dict[Any, Any]] = None, +) -> str: + commands = get_instance_shim_commands( + authorized_keys=authorized_keys, + ) + return get_cloud_config( + runcmd=[["sh", "-c", " && ".join(commands)]], + ssh_authorized_keys=authorized_keys, + **(cloud_config_kwargs or {}), + ) + + +def get_instance_shim_commands( + authorized_keys: List[str], +) -> List[str]: + build = get_dstack_runner_version() + env = { + "DSTACK_RUNNER_LOG_LEVEL": "6", + "DSTACK_RUNNER_VERSION": build, + "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), + "DSTACK_HOME": "/root/.dstack", + } + commands = get_instance_dstack_shim(build) + for k, v in env.items(): + commands += [f'export "{k}={v}"'] + commands += get_instance_run_shim_script() + return commands + + +def get_instance_dstack_shim(build: str) -> List[str]: + bucket = "dstack-runner-downloads-stgn" + if settings.DSTACK_VERSION is not None: + bucket = "dstack-runner-downloads" + + # TODO: use official build + # url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" + + return [ + f'sudo curl --output /usr/local/bin/dstack-shim "{url}"', + "sudo chmod +x /usr/local/bin/dstack-shim", + ] + + +def get_instance_docker_commands(authorized_keys: List[str]) -> List[str]: + authorized_keys_body = "\n".join(authorized_keys).strip() + commands = [ + # note: &> redirection doesn't work in /bin/sh + # check in sshd is here, install if not + ( + "if ! command -v sshd >/dev/null 2>&1; then { " + "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y openssh-server; " + "} || { " + "yum -y install openssh-server; " + "}; fi" + ), + # prohibit password authentication + 'sed -i "s/.*PasswordAuthentication.*/PasswordAuthentication no/g" /etc/ssh/sshd_config', + # create ssh dirs and add public key + "mkdir -p /run/sshd ~/.ssh", + "chmod 700 ~/.ssh", + f"echo '{authorized_keys_body}' > ~/.ssh/authorized_keys", + "chmod 600 ~/.ssh/authorized_keys", + # preserve environment variables for SSH clients + "env >> ~/.ssh/environment", + 'echo "export PATH=$PATH" >> ~/.profile', + # regenerate host keys + "rm -rf /etc/ssh/ssh_host_*", + "ssh-keygen -A > /dev/null", + # start sshd + "/usr/sbin/sshd -p 10022 -o PermitUserEnvironment=yes", + ] + build = get_dstack_runner_version() + runner = "/usr/local/bin/dstack-runner" + bucket = "dstack-runner-downloads-stgn" + if settings.DSTACK_VERSION is not None: + bucket = "dstack-runner-downloads" + + # TODO: use official build + # url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" + + commands += [ + f"curl --output {runner} {url}", + f"chmod +x {runner}", + f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", + ] + return commands + + +def get_instance_run_shim_script() -> List[str]: + dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" + return [f"nohup dstack-shim {dev_flag} docker --keep-container >/root/shim.log 2>&1 &"] diff --git a/src/dstack/_internal/core/backends/base/offers.py b/src/dstack/_internal/core/backends/base/offers.py index f484e6102..1dbb33d79 100644 --- a/src/dstack/_internal/core/backends/base/offers.py +++ b/src/dstack/_internal/core/backends/base/offers.py @@ -27,6 +27,7 @@ def get_catalog_offers( q = requirements_to_query_filter(requirements) q.provider = [provider] offers = [] + catalog = catalog if catalog is not None else gpuhunt.default_catalog() for item in catalog.query(**asdict(q)): if locations is not None and item.location not in locations: diff --git a/src/dstack/_internal/core/backends/datacrunch/api_client.py b/src/dstack/_internal/core/backends/datacrunch/api_client.py index a3b931c08..ed27f79e1 100644 --- a/src/dstack/_internal/core/backends/datacrunch/api_client.py +++ b/src/dstack/_internal/core/backends/datacrunch/api_client.py @@ -56,6 +56,7 @@ def wait_for_instance(self, instance_id: str) -> Optional[Instance]: if instance is not None and instance.status == "running": return instance time.sleep(WAIT_FOR_INSTANCE_INTERVAL) + return def deploy_instance( self, @@ -68,7 +69,7 @@ def deploy_instance( disk_size, is_spot=True, location="FIN-01", - ): + ) -> Instance: try: instance = self.client.instances.create( instance_type=instance_type, diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 21e6806ca..866d6f39e 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -1,7 +1,11 @@ from typing import Dict, List, Optional from dstack._internal.core.backends.base import Compute -from dstack._internal.core.backends.base.compute import InstanceConfiguration, get_shim_commands +from dstack._internal.core.backends.base.compute import ( + InstanceConfiguration, + get_instance_shim_commands, + get_shim_commands, +) from dstack._internal.core.backends.base.offers import get_catalog_offers from dstack._internal.core.backends.datacrunch.api_client import DataCrunchAPIClient from dstack._internal.core.backends.datacrunch.config import DataCrunchConfig @@ -39,9 +43,7 @@ def get_offers( def _get_offers_with_availability( self, offers: List[InstanceOffer] ) -> List[InstanceOfferWithAvailability]: - raw_availabilities: List[ - Dict - ] = self.api_client.client.instances.get_availabilities() # type: ignore + raw_availabilities: List[Dict] = self.api_client.client.instances.get_availabilities() # type: ignore region_availabilities = {} for location in raw_availabilities: @@ -81,7 +83,7 @@ def create_instance( ) ) - commands = get_shim_commands( + commands = get_instance_shim_commands( authorized_keys=public_keys, ) @@ -164,10 +166,13 @@ def run_job( ) commands = get_shim_commands( + backend=BackendType.DATACRUNCH, + image_name=job.job_spec.image_name, authorized_keys=[ run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], + registry_auth_required=job.job_spec.registry_auth is not None, ) startup_script = " ".join([" && ".join(commands)]) @@ -212,5 +217,5 @@ def run_job( def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: self.api_client.delete_instance(instance_id) diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 3b17ec3a9..364ab9216 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -11,6 +11,7 @@ InstanceConfiguration, get_gateway_user_data, get_instance_name, + get_instance_user_data, get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -42,12 +43,13 @@ def __init__(self, config: GCPConfig): def get_offers( self, requirements: Optional[Requirements] = None ) -> List[InstanceOfferWithAvailability]: + offers = get_catalog_offers( backend=BackendType.GCP, requirements=requirements, extra_filter=_supported_instances_and_zones(self.config.regions), ) - quotas = defaultdict(dict) + quotas: Dict[str, Dict[str, float]] = defaultdict(dict) for region in self.regions_client.list(project=self.config.project_id): for quota in region.quotas: quotas[region.name][quota.metric] = quota.limit - quota.usage @@ -73,7 +75,7 @@ def get_offers( def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: try: self.instances_client.delete( project=self.config.project_id, zone=region, instance=instance_id @@ -114,9 +116,7 @@ def create_instance( gpus=instance_offer.instance.resources.gpus, ), spot=instance_offer.instance.resources.spot, - user_data=get_user_data( - backend=BackendType.GCP, - image_name=instance_config.job_docker_config.image.image, + user_data=get_instance_user_data( authorized_keys=instance_config.get_public_keys(), ), labels={ @@ -188,6 +188,7 @@ def run_job( run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), ], + registry_auth_required=job.job_spec.registry_auth is not None, ), labels={ "owner": "dstack", @@ -274,8 +275,6 @@ def create_gateway( def _supported_instances_and_zones( regions: List[str], ) -> Optional[Callable[[InstanceOffer], bool]]: - regions = set(regions) - def _filter(offer: InstanceOffer) -> bool: # strip zone if offer.region[:-2] not in regions: @@ -299,7 +298,7 @@ def _filter(offer: InstanceOffer) -> bool: return _filter -def _has_gpu_quota(quotas: Dict[str, int], resources: Resources) -> bool: +def _has_gpu_quota(quotas: Dict[str, float], resources: Resources) -> bool: if not resources.gpus: return True gpu = resources.gpus[0] diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py index b3c9394cd..47286e6ef 100644 --- a/src/dstack/_internal/core/models/backends/base.py +++ b/src/dstack/_internal/core/models/backends/base.py @@ -26,7 +26,7 @@ class BackendType(str, enum.Enum): KUBERNETES = "kubernetes" LAMBDA = "lambda" LOCAL = "local" - # REMOTE= "remote" # replace for LOCAL + REMOTE = "remote" # TODO: replace for LOCAL NEBIUS = "nebius" TENSORDOCK = "tensordock" VASTAI = "vastai" diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index 252d343aa..4b30355f6 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -119,7 +119,7 @@ def attach(self): self.tunnel.open() atexit.register(self.detach) break - except SSHError: + except SSHError as e: if i < max_retries - 1: time.sleep(1) else: diff --git a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py index 15621e6c0..a693bf629 100644 --- a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py @@ -2,16 +2,14 @@ from sqlalchemy.orm import joinedload from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.profiles import TerminationPolicy from dstack._internal.core.models.runs import InstanceStatus, JobSpec, JobStatus from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import GatewayModel, InstanceModel, JobModel +from dstack._internal.server.models import GatewayModel, JobModel from dstack._internal.server.services.gateways import gateway_connections_pool from dstack._internal.server.services.jobs import ( TERMINATING_PROCESSING_JOBS_IDS, TERMINATING_PROCESSING_JOBS_LOCK, job_model_to_job_submission, - terminate_job_submission_instance, ) from dstack._internal.server.services.logging import job_log from dstack._internal.server.services.pools import get_instances_by_pool_id @@ -42,7 +40,6 @@ async def process_finished_jobs(): TERMINATING_PROCESSING_JOBS_IDS.add(job_model.id) try: await _process_job(job_id=job_model.id) - await _terminate_old_instance() finally: TERMINATING_PROCESSING_JOBS_IDS.remove(job_model.id) @@ -50,7 +47,11 @@ async def process_finished_jobs(): async def _process_job(job_id): async with get_session_ctx() as session: res = await session.execute( - select(JobModel).where(JobModel.id == job_id).options(joinedload(JobModel.project)) + select(JobModel) + .where(JobModel.id == job_id) + .options(joinedload(JobModel.project)) + .options(joinedload(JobModel.instance)) + .options(joinedload(JobModel.run)) ) job_model = res.scalar_one() job_submission = job_model_to_job_submission(job_model) @@ -90,36 +91,20 @@ async def _process_job(job_id): if instance.name == jpd.instance_id: instance.finished_at = get_current_datetime() instance.status = InstanceStatus.READY - else: - await terminate_job_submission_instance( - project=job_model.project, - job_submission=job_submission, - ) + # else: + # if job_model.instance is not None and job_model.instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE: + # await terminate_job_provisioning_data_instance( + # project=job_model.project, + # job_provisioning_data=job_submission.job_provisioning_data, + # ) job_model.removed = True + if job_model.instance is not None: + job_model.used_instance_id = job_model.instance.id + job_model.instance.status = InstanceStatus.READY + job_model.instance = None logger.info(*job_log("marked as removed", job_model)) except Exception as e: job_model.removed = False logger.error(*job_log("failed to terminate job instance: %s", job_model, e)) job_model.last_processed_at = get_current_datetime() await session.commit() - - -async def _terminate_old_instance(): - async with get_session_ctx() as session: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE, - InstanceModel.deleted == False, - ) - .options() - ) - instances = res.scalars().all() - - for instance in instances: - if instance.finished_at + instance.termination_idle_time > get_current_datetime(): - await terminate_job_submission_instance( - project=instance.project, - job_submission=job_submission, - ) - await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index bf71aa63d..06d7e01c0 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -6,14 +6,20 @@ from sqlalchemy import select from sqlalchemy.orm import joinedload +from dstack._internal.core.models.profiles import TerminationPolicy from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import InstanceModel from dstack._internal.server.services import backends as backends_services -from dstack._internal.server.services.jobs import PROCESSING_POOL_IDS, PROCESSING_POOL_LOCK +from dstack._internal.server.services.jobs import ( + PROCESSING_POOL_IDS, + PROCESSING_POOL_LOCK, + terminate_job_provisioning_data_instance, +) from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.utils.common import run_async +from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) @@ -104,3 +110,27 @@ async def terminate(instance_id: UUID) -> None: await run_async( backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data ) + + +async def _terminate_old_instance() -> None: + async with get_session_ctx() as session: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE, + InstanceModel.deleted == False, + InstanceModel.job == None, + ) + .options() + ) + instances = res.scalars().all() + + for instance in instances: + if instance.finished_at + instance.termination_idle_time > get_current_datetime(): + instance_type = parse_raw_as( + JobProvisioningData, instance.job_provisioning_data + ).backend + await terminate_job_provisioning_data_instance( + project=instance.project, job_provisioning_data=instance.job_provisioning_data + ) + await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index c5ec1e2f4..7ab84c362 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -1,10 +1,8 @@ -from asyncio.proactor_events import _ProactorBasePipeTransport from datetime import timedelta from typing import Dict, Optional from uuid import UUID import httpx -import requests from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -13,17 +11,8 @@ from dstack._internal.core.errors import GatewayError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RegistryAuth -from dstack._internal.core.models.instances import InstanceState from dstack._internal.core.models.repos import RemoteRepoCreds -from dstack._internal.core.models.runs import ( - InstanceStatus, - Job, - JobErrorCode, - JobProvisioningData, - JobSpec, - JobStatus, - Run, -) +from dstack._internal.core.models.runs import Job, JobErrorCode, JobSpec, JobStatus, Run from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import ( GatewayModel, @@ -41,7 +30,6 @@ job_model_to_job_submission, ) from dstack._internal.server.services.logging import job_log -from dstack._internal.server.services.pools import get_pool_instances from dstack._internal.server.services.repos import get_code_model, repo_model_to_repo_head from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel @@ -86,7 +74,9 @@ async def process_running_jobs(): async def _process_job(job_id: UUID): async with get_session_ctx() as session: - res = await session.execute(select(JobModel).where(JobModel.id == job_id)) + res = await session.execute( + select(JobModel).where(JobModel.id == job_id).options(joinedload(JobModel.instance)) + ) job_model = res.scalar_one() res = await session.execute( select(RunModel) @@ -151,15 +141,8 @@ async def _process_job(job_id: UUID): repo_creds, ) - if success: - instance_name: str = job_provisioning_data.instance_id - pool_name = str(job.job_spec.pool_name) - instances = await get_pool_instances(session, project, pool_name) - for inst in instances: - if inst.name == instance_name: - inst.status = InstanceStatus.BUSY - - if not success: # check timeout + if not success: + # check timeout if job_submission.age > _get_runner_timeout_interval( job_provisioning_data.backend ): @@ -172,13 +155,8 @@ async def _process_job(job_id: UUID): ) job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.WAITING_RUNNER_LIMIT_EXCEEDED - - instance_name: str = job_provisioning_data.instance_id - pool_name = str(job.job_spec.pool_name) - instances = await get_pool_instances(session, project, pool_name) - for inst in instances: - if inst.name == instance_name: - inst.status = InstanceStatus.READY # TODO: or fail? + job_model.used_instance_id = job_model.instance.id + job_model.instance = None else: # fails are not acceptable if initial_status == JobStatus.PULLING: @@ -216,14 +194,6 @@ async def _process_job(job_id: UUID): job_model, ) - if success and job_model.status == JobStatus.DONE: - instance_name: str = job_provisioning_data.instance_id - pool_name = str(job.job_spec.pool_name) - instances = await get_pool_instances(session, project, pool_name) - for inst in instances: - if inst.name == instance_name: - inst.status = InstanceStatus.READY - if not success: # kill the job logger.warning( *job_log( @@ -234,6 +204,8 @@ async def _process_job(job_id: UUID): ) job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.INTERRUPTED_BY_NO_CAPACITY + job_model.used_instance_id = job_model.instance.id + job_model.instance = None if job.is_retry_active(): if job_submission.job_provisioning_data.instance_type.resources.spot: new_job_model = create_job_model_for_new_submission( @@ -243,13 +215,6 @@ async def _process_job(job_id: UUID): ) session.add(new_job_model) - instance_name: str = job_provisioning_data.instance_id - pool_name = str(job.job_spec.pool_name) - instances = await get_pool_instances(session, project, pool_name) - for inst in instances: - if inst.name == instance_name: - inst.status = InstanceStatus.READY - # job will be terminated by process_finished_jobs if ( diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 1d095bdc7..c69b94e69 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -97,6 +97,8 @@ def check_relevance( logger.warning(f"no backend select ") return False + # use gpuhunt + instance_resources: ResourcesSpec = parse_raw_as( ResourcesSpec, instance_model.resource_spec_data ) @@ -189,7 +191,9 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name) instance = sorted_instances[0] + # need lock instance.status = InstanceStatus.BUSY + instance.job = job_model logger.info(*job_log("now is provisioning", job_model)) job_model.job_provisioning_data = instance.job_provisioning_data @@ -237,6 +241,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): offer=offer.json(), termination_policy=profile.termination_policy, termination_idle_time=profile.termination_idle_time, + job=job_model, ) session.add(im) diff --git a/src/dstack/_internal/server/migrations/versions/309e4be6671b_add_pools.py b/src/dstack/_internal/server/migrations/versions/718bf16e84c5_add_pools.py similarity index 83% rename from src/dstack/_internal/server/migrations/versions/309e4be6671b_add_pools.py rename to src/dstack/_internal/server/migrations/versions/718bf16e84c5_add_pools.py index eba3d08cd..26a018cdb 100644 --- a/src/dstack/_internal/server/migrations/versions/309e4be6671b_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/718bf16e84c5_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: 309e4be6671b +Revision ID: 718bf16e84c5 Revises: d3e8af4786fa -Create Date: 2024-01-31 10:35:34.977788 +Create Date: 2024-02-01 18:01:47.612769 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "309e4be6671b" +revision = "718bf16e84c5" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None @@ -46,7 +46,7 @@ def upgrade() -> None: sa.Column( "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False ), - sa.Column("pool_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.Column("pool_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), sa.Column( "status", sa.Enum( @@ -69,6 +69,9 @@ def upgrade() -> None: sa.Column("termination_idle_time", sa.String(length=50), nullable=True), sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), sa.Column("offer", sa.String(length=4000), nullable=False), + sa.Column("resource_spec_data", sa.String(length=4000), nullable=True), + sa.Column("job_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], name=op.f("fk_instances_job_id_jobs")), sa.ForeignKeyConstraint( ["pool_id"], ["pools.id"], name=op.f("fk_instances_pool_id_pools") ), @@ -80,6 +83,15 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id", name=op.f("pk_instances")), ) + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "used_instance_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + with op.batch_alter_table("projects", schema=None) as batch_op: batch_op.add_column( sa.Column( @@ -108,6 +120,9 @@ def downgrade() -> None: ) batch_op.drop_column("default_pool_id") + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("used_instance_id") + op.drop_table("instances") op.drop_table("pools") # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index e1c68af6e..3ee2f3438 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -191,6 +191,8 @@ class JobModel(BaseModel): # `removed` is used to ensure that the instance is killed after the job is finished removed: Mapped[bool] = mapped_column(Boolean, default=False) remove_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="job") + used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) class GatewayModel(BaseModel): @@ -271,11 +273,12 @@ class InstanceModel(BaseModel): project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id], single_parent=True) - pool_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("pools.id")) + pool_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("pools.id")) pool: Mapped["PoolModel"] = relationship(back_populates="instances", single_parent=True) status: Mapped[InstanceStatus] = mapped_column(Enum(InstanceStatus)) status_message: Mapped[Optional[str]] = mapped_column(String(50)) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) @@ -287,3 +290,19 @@ class InstanceModel(BaseModel): offer: Mapped[str] = mapped_column(String(4000)) resource_spec_data: Mapped[Optional[str]] = mapped_column(String(4000)) + + # current job + job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id")) + job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance") + + # + # job_id: Optional[FK] (current job) + # ip address + # ssh creds: user, port, dockerized + # real resources + spot (exact) / instance offer + # backend + backend data + # region + # price (for querying) + # + # termination policy + # creation policy + # job_provisioning_data=job_provisioning_data.json(), + # TODO: instance provisioning diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index eb08667d5..d5422a6cb 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -38,6 +38,16 @@ async def remove_instance( await pools.remove_instance(session, project_model, body.pool_name, body.instance_name) +@router.post("/set-default") # type: ignore[misc] +async def set_default_pool( + body: schemas.SetDefaultPoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> bool: + _, project_model = user_project + return await pools.set_default_pool(session, project_model, body.pool_name) + + @router.post("/delete") # type: ignore[misc] async def delete_pool( body: schemas.DeletePoolRequest, @@ -90,19 +100,20 @@ async def how_pool( return await pools.show_pool(session, project, pool_name=body.name) -@router.post("/add") # type: ignore[misc] +@router.post("/add_remote") # type: ignore[misc] async def add_instance( body: AddRemoteInstanceRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> None: +) -> bool: _, project = user_project - await pools.add( + result = await pools.add_remote( session, project=project, resources=body.resources, - pool_name=body.pool_name, + profile=body.profile, instance_name=body.instance_name, host=body.host, port=body.port, ) + return result diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index 318874eec..bbf006032 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -17,3 +17,7 @@ class ShowPoolRequest(BaseModel): # type: ignore[misc] class RemoveInstanceRequest(BaseModel): # type: ignore[misc] pool_name: str instance_name: str + + +class SetDefaultPoolRequest(BaseModel): # type: ignore[misc] + pool_name: str diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 1b6bda262..c3745fe39 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -32,11 +32,11 @@ class CreateInstanceRequest(BaseModel): class AddRemoteInstanceRequest(BaseModel): - pool_name: str instance_name: Optional[str] host: str port: str resources: ResourcesSpec + profile: Profile class SubmitRunRequest(BaseModel): diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 8d6ebde39..945521300 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -311,6 +311,7 @@ async def get_instance_offers( ] for backend, backend_offers in zip(backends, await asyncio.gather(*tasks)) ] + # Merge preserving order for every backend offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price) # Put NOT_AVAILABLE and NO_QUOTA instances at the end, do not sort by price diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 9bc10068b..72744250c 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -9,6 +9,7 @@ from dstack._internal.core.errors import SSHError from dstack._internal.core.models.configurations import ConfigurationType +from dstack._internal.core.models.profiles import TerminationPolicy from dstack._internal.core.models.runs import ( Job, JobErrorCode, @@ -112,8 +113,9 @@ async def stop_job( job_submission, project.ssh_private_key, ) + # delay termination for 15 seconds to allow the runner to stop gracefully - delay_job_instance_termination(job_model) + # delay_job_instance_termination(job_model) except SSHError: logger.debug(*job_log("failed to stop runner", job_model)) # process_finished_jobs will terminate the instance in the background @@ -124,20 +126,20 @@ async def stop_job( logger.info(*job_log("%s by user", job_model, new_status.value)) -async def terminate_job_submission_instance( +async def terminate_job_provisioning_data_instance( project: ProjectModel, - job_submission: JobSubmission, + job_provisioning_data: JobProvisioningData, ): backend = await get_project_backend_by_type( project=project, - backend_type=job_submission.job_provisioning_data.backend, + backend_type=job_provisioning_data.backend, ) - logger.debug("Terminating runner instance %s", job_submission.job_provisioning_data.hostname) + logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) await run_async( backend.compute().terminate_instance, - job_submission.job_provisioning_data.instance_id, - job_submission.job_provisioning_data.region, - job_submission.job_provisioning_data.backend_data, + job_provisioning_data.instance_id, + job_provisioning_data.region, + job_provisioning_data.backend_data, ) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 63e2ae4bb..e44968332 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -16,7 +16,7 @@ Resources, ) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel @@ -65,7 +65,7 @@ async def create_pool_model(session: AsyncSession, project: ProjectModel, name: ) ) if pools.all(): - raise ValueError("duplicate pool name") + raise ValueError("duplicate pool name") # TODO: return error with description pool = PoolModel( name=name, @@ -73,8 +73,11 @@ async def create_pool_model(session: AsyncSession, project: ProjectModel, name: ) session.add(pool) await session.commit() - project.default_pool = pool # TODO: add CLI flag --set-default - await session.commit() + + if project.default_pool is None: + project.default_pool = pool + await session.commit() + return pool @@ -87,6 +90,25 @@ async def list_project_pool_models( return pools.all() # type: ignore[no-any-return] +async def set_default_pool(session: AsyncSession, project: ProjectModel, pool_name: str) -> bool: + pool = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project == project, + PoolModel.deleted == False, + ) + ) + ).one_or_none() + + if pool is None: + return False + project.default_pool = pool + + await session.commit() + return True + + async def remove_instance( session: AsyncSession, project: ProjectModel, pool_name: str, instance_name: str ) -> None: @@ -235,17 +257,17 @@ async def generate_instance_name( return name -async def add( +async def add_remote( session: AsyncSession, resources: ResourcesSpec, project: ProjectModel, - pool_name: str, + profile: Profile, instance_name: Optional[str], host: str, port: str, -) -> None: +) -> bool: - instance_name = instance_name + pool_name = profile.pool_name if instance_name is None: instance_name = await generate_instance_name(session, project, pool_name) @@ -273,7 +295,7 @@ async def add( ) local = JobProvisioningData( - backend=BackendType.LOCAL, + backend=BackendType.REMOTE, instance_type=InstanceType(name="local", resources=instance_resource), instance_id=instance_name, hostname=host, @@ -287,7 +309,7 @@ async def add( ssh_proxy=None, ) offer = InstanceOfferWithAvailability( - backend=BackendType.LOCAL, + backend=BackendType.REMOTE, instance=InstanceType( name="instance", resources=instance_resource, @@ -305,6 +327,10 @@ async def add( job_provisioning_data=local.json(), offer=offer.json(), resource_spec_data=resources.json(), + termination_policy=profile.termination_policy, + termination_idle_time=str(profile.termination_idle_time), ) session.add(im) await session.commit() + + return True diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 951ca72fe..a4f8f4b56 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -154,7 +154,9 @@ async def get_run_plan_by_requirements( requirements: Requirements, exclude_not_available=False, ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: + backends = await backends_services.get_project_backends(project=project) + if profile.backends is not None: backends = [b for b in backends if b.TYPE in profile.backends] @@ -196,6 +198,7 @@ async def create_instance( image=image, registry_auth=None, ), + user=user.name, ) pool = await pools_services.get_pool(session, project, pool_name) @@ -245,13 +248,31 @@ async def create_instance( ssh_proxy=None, ) + # types of queries + # 1. Get all available instance + # 2. Get job's instance (process job) + # 3. Get instance's jobs history + im = InstanceModel( name=instance_name, project=project, pool=pool, status=InstanceStatus.STARTING, + # job_id: Optional[FK] (current job) + # ip address + # ssh creds: user, port, dockerized + # real resources + spot (exact) / instance offer + # backend + backend data + # region + # price (for querying) + # termination policy + # creation policy job_provisioning_data=job_provisioning_data.json(), + # TODO: instance provisioning offer=cast(InstanceOfferWithAvailability, instance_offer).json(), + resource_spec_data=requirements.resources.json(), + termination_policy=profile.termination_policy, + termination_idle_time=str(profile.termination_idle_time), ) session.add(im) await session.commit() @@ -401,11 +422,13 @@ async def stop_runs( new_status = JobStatus.ABORTED res = await session.execute( - select(JobModel).where( + select(JobModel) + .where( JobModel.project_id == project.id, JobModel.run_name.in_(runs_names), JobModel.status.not_in(JobStatus.finished_statuses()), ) + .options(joinedload(JobModel.instance)) ) job_models = res.scalars().all() for job_model in job_models: diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 4fbee673b..758a0e242 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -198,6 +198,7 @@ async def create_job( last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), error_code: Optional[JobErrorCode] = None, job_provisioning_data: Optional[JobProvisioningData] = None, + instance: Optional[InstanceModel] = None, ) -> JobModel: run_spec = RunSpec.parse_raw(run.run_spec) job_spec = get_job_specs_from_run_spec(run_spec)[0] @@ -214,6 +215,7 @@ async def create_job( error_code=error_code, job_spec_data=job_spec.json(), job_provisioning_data=job_provisioning_data.json() if job_provisioning_data else None, + instance=instance if instance is not None else None, ) session.add(job) await session.commit() diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index cbfef5916..1e41498d3 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -4,6 +4,7 @@ import dstack._internal.server.schemas.pools as schemas_pools from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest from dstack.api.server._group import APIClientGroup @@ -35,20 +36,26 @@ def remove(self, project_name: str, pool_name: str, instance_name: str) -> None: ) self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) - def add( + def set_default(self, project_name: str, pool_name: str) -> bool: + body = schemas_pools.SetDefaultPoolRequest(pool_name=pool_name) + result = self._request(f"/api/project/{project_name}/pool/set-default", body=body.json()) + return bool(result.json()) + + def add_remote( self, project_name: str, resources: ResourcesSpec, - pool_name: str, + profile: Profile, instance_name: Optional[str], host: str, port: str, - ) -> None: + ) -> bool: body = AddRemoteInstanceRequest( - pool_name=pool_name, + profile=profile, instance_name=instance_name, host=host, port=port, resources=resources, ) - self._request(f"/api/project/{project_name}/pool/add", body=body.json()) + result = self._request(f"/api/project/{project_name}/pool/add_remote", body=body.json()) + return bool(result.json()) diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 19c7b46aa..e305bd28a 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -8,13 +8,15 @@ from dstack._internal.core.errors import SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.runs import JobProvisioningData, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.schemas.runner import HealthcheckResponse, JobStateEvent, PullResponse from dstack._internal.server.services.jobs.configurators.base import get_default_python_verison from dstack._internal.server.testing.common import ( + create_instance, create_job, + create_pool, create_project, create_repo, create_run, @@ -285,12 +287,20 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): repo=repo, user=user, ) + instance = await create_instance( + session, + project, + await create_pool(session, project), + InstanceStatus.READY, + Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ) job_provisioning_data = get_job_provisioning_data(dockerized=True) job = await create_job( session=session, run=run, status=JobStatus.PULLING, job_provisioning_data=job_provisioning_data, + instance=instance, ) with patch( "dstack._internal.server.services.runner.ssh.RunnerTunnel" diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 9e1b93b70..1c65fd656 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -6,10 +6,12 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.runs import JobProvisioningData, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server.background.tasks.process_finished_jobs import process_finished_jobs from dstack._internal.server.testing.common import ( + create_instance, create_job, + create_pool, create_project, create_repo, create_run, @@ -34,10 +36,19 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A repo=repo, user=user, ) + instance = await create_instance( + session, + project, + await create_pool(session, project), + InstanceStatus.READY, + Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ) + job = await create_job( session=session, run=run, status=JobStatus.DONE, + instance=instance, job_provisioning_data=JobProvisioningData( backend=BackendType.AWS, instance_type=InstanceType( @@ -54,7 +65,7 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A backend_data=None, ), ) - with patch(f"{MODULE}.terminate_job_submission_instance") as terminate: + with patch(f"{MODULE}.terminate_job_provisioning_data_instance") as terminate: await process_finished_jobs() terminate.assert_called_once() await session.refresh(job) From 28dfa3723bcd0c80ab11251dfdcdcffbda164002 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Thu, 1 Feb 2024 20:54:09 +0300 Subject: [PATCH 10/47] debug --- .pre-commit-config.yaml | 2 +- src/dstack/_internal/core/backends/base/offers.py | 2 ++ src/dstack/_internal/core/services/ssh/attach.py | 3 +++ .../server/background/tasks/process_submitted_jobs.py | 7 +++++++ ...718bf16e84c5_add_pools.py => 98b9e40f03b0_add_pools.py} | 6 +++--- src/dstack/_internal/server/services/backends/__init__.py | 1 - src/dstack/_internal/server/services/jobs/__init__.py | 1 - src/dstack/_internal/server/services/runs.py | 1 - 8 files changed, 16 insertions(+), 7 deletions(-) rename src/dstack/_internal/server/migrations/versions/{718bf16e84c5_add_pools.py => 98b9e40f03b0_add_pools.py} (98%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8aba21e1a..1ae35e988 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: args: ['--strict', '--follow-imports=skip', '--ignore-missing-imports', '--python-version=3.8'] files: '.*pools?\.py' exclude: 'versions|src/tests' - additional_dependencies: [types-PyYAML, types-requests, pydantic,sqlalchemy] + additional_dependencies: [types-PyYAML, types-requests, pydantic, sqlalchemy] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.14 hooks: diff --git a/src/dstack/_internal/core/backends/base/offers.py b/src/dstack/_internal/core/backends/base/offers.py index 1dbb33d79..89581e619 100644 --- a/src/dstack/_internal/core/backends/base/offers.py +++ b/src/dstack/_internal/core/backends/base/offers.py @@ -29,7 +29,9 @@ def get_catalog_offers( offers = [] catalog = catalog if catalog is not None else gpuhunt.default_catalog() + locs = [] for item in catalog.query(**asdict(q)): + locs.append(item.location) if locations is not None and item.location not in locations: continue offer = catalog_item_to_offer(backend, item, requirements) diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index 4b30355f6..eff893de9 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -4,6 +4,8 @@ import time from typing import Optional, Tuple +from icecream import ic + from dstack._internal.core.errors import SSHError from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.services.configs import ConfigManager @@ -99,6 +101,7 @@ def __init__( } else: self.container_config = None + ic(self.container_config) self.ssh_config_path = str(ConfigManager().dstack_ssh_config_path) def attach(self): diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index c69b94e69..1554b8d74 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -79,6 +79,9 @@ async def _process_job(job_id: UUID): ) +from icecream import ic + + def check_relevance( profile: Profile, resources: ResourcesSpec, instance_model: InstanceModel ) -> bool: @@ -103,9 +106,11 @@ def check_relevance( ResourcesSpec, instance_model.resource_spec_data ) + ic(resources.cpu.min, instance_resources.cpu.min) if resources.cpu.min > instance_resources.cpu.min: return False + ic(resources.gpu) if resources.gpu is not None: if instance_resources.gpu is None: @@ -119,6 +124,7 @@ def check_relevance( # TODO: compare GPU names + ic(resources.memory, instance_resources.memory.min) if resources.memory.min > instance_resources.memory.min: return False @@ -180,6 +186,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): pool_instances = await get_pool_instances(session, project_model, run_pool) available_instanses = (p for p in pool_instances if p.status == InstanceStatus.READY) + ic(available_instanses) relevant_instances: List[InstanceModel] = [] for instance in available_instanses: if check_relevance(profile, run_spec.configuration.resources, instance): diff --git a/src/dstack/_internal/server/migrations/versions/718bf16e84c5_add_pools.py b/src/dstack/_internal/server/migrations/versions/98b9e40f03b0_add_pools.py similarity index 98% rename from src/dstack/_internal/server/migrations/versions/718bf16e84c5_add_pools.py rename to src/dstack/_internal/server/migrations/versions/98b9e40f03b0_add_pools.py index 26a018cdb..f15a24ba0 100644 --- a/src/dstack/_internal/server/migrations/versions/718bf16e84c5_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/98b9e40f03b0_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: 718bf16e84c5 +Revision ID: 98b9e40f03b0 Revises: d3e8af4786fa -Create Date: 2024-02-01 18:01:47.612769 +Create Date: 2024-02-04 17:25:03.945051 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "718bf16e84c5" +revision = "98b9e40f03b0" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 945521300..8d6ebde39 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -311,7 +311,6 @@ async def get_instance_offers( ] for backend, backend_offers in zip(backends, await asyncio.gather(*tasks)) ] - # Merge preserving order for every backend offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price) # Put NOT_AVAILABLE and NO_QUOTA instances at the end, do not sort by price diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 72744250c..05b9746bd 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -113,7 +113,6 @@ async def stop_job( job_submission, project.ssh_private_key, ) - # delay termination for 15 seconds to allow the runner to stop gracefully # delay_job_instance_termination(job_model) except SSHError: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index a4f8f4b56..febf365e8 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -156,7 +156,6 @@ async def get_run_plan_by_requirements( ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: backends = await backends_services.get_project_backends(project=project) - if profile.backends is not None: backends = [b for b in backends if b.TYPE in profile.backends] From 718a7d4768e2c3a2366d70341a28c2843f2ab5dc Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Sun, 4 Feb 2024 17:40:38 +0300 Subject: [PATCH 11/47] fix mypy --- src/dstack/_internal/core/models/pools.py | 4 ++-- .../_internal/server/background/tasks/process_pools.py | 6 +++--- src/dstack/_internal/server/schemas/pools.py | 10 +++++----- src/dstack/_internal/server/services/pools.py | 8 ++------ src/dstack/api/server/_pools.py | 4 ++-- 5 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 204b1cca1..9cb599d7e 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -7,13 +7,13 @@ from dstack._internal.core.models.runs import InstanceStatus -class Pool(BaseModel): # type: ignore[misc] +class Pool(BaseModel): # type: ignore[misc,valid-type] name: str default: bool created_at: datetime.datetime -class Instance(BaseModel): # type: ignore[misc] +class Instance(BaseModel): # type: ignore[misc,valid-type] backend: BackendType instance_type: InstanceType instance_id: str diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 06d7e01c0..526685ee3 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -69,7 +69,7 @@ async def check_shim(instance_id: UUID) -> None: ) ).one() ssh_private_key = instance.project.ssh_private_key - job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) + job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) # type: ignore[operator] instance_health = instance_healthcheck(ssh_private_key, job_provisioning_data) @@ -100,7 +100,7 @@ async def terminate(instance_id: UUID) -> None: .options(joinedload(InstanceModel.project)) ) ).one() - jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) + jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) # type: ignore[operator] BACKEND_TYPE = jpd.backend backends = await backends_services.get_project_backends(project=instance.project) backend = next((b for b in backends if b.TYPE in BACKEND_TYPE), None) @@ -127,7 +127,7 @@ async def _terminate_old_instance() -> None: for instance in instances: if instance.finished_at + instance.termination_idle_time > get_current_datetime(): - instance_type = parse_raw_as( + instance_type = parse_raw_as( # type: ignore[operator] JobProvisioningData, instance.job_provisioning_data ).backend await terminate_job_provisioning_data_instance( diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index bbf006032..67e04f21c 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -1,23 +1,23 @@ from pydantic import BaseModel -class DeletePoolRequest(BaseModel): # type: ignore[misc] +class DeletePoolRequest(BaseModel): # type: ignore[misc,valid-type] name: str force: bool -class CreatePoolRequest(BaseModel): # type: ignore[misc] +class CreatePoolRequest(BaseModel): # type: ignore[misc,valid-type] name: str -class ShowPoolRequest(BaseModel): # type: ignore[misc] +class ShowPoolRequest(BaseModel): # type: ignore[misc,valid-type] name: str -class RemoveInstanceRequest(BaseModel): # type: ignore[misc] +class RemoveInstanceRequest(BaseModel): # type: ignore[misc,valid-type] pool_name: str instance_name: str -class SetDefaultPoolRequest(BaseModel): # type: ignore[misc] +class SetDefaultPoolRequest(BaseModel): # type: ignore[misc,valid-type] pool_name: str diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index e44968332..c86c11f07 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -168,12 +168,8 @@ async def list_deleted_pools( def instance_model_to_instance(instance_model: InstanceModel) -> Instance: - offer: InstanceOfferWithAvailability = parse_raw_as( - InstanceOfferWithAvailability, instance_model.offer - ) - jpd: JobProvisioningData = parse_raw_as( - JobProvisioningData, instance_model.job_provisioning_data - ) + offer: InstanceOfferWithAvailability = parse_raw_as(InstanceOfferWithAvailability, instance_model.offer) # type: ignore[operator] + jpd: JobProvisioningData = parse_raw_as(JobProvisioningData, instance_model.job_provisioning_data) # type: ignore[operator] instance = Instance( backend=offer.backend, diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 1e41498d3..bb10263d0 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -13,7 +13,7 @@ class PoolAPIClient(APIClientGroup): # type: ignore[misc] def list(self, project_name: str) -> List[Pool]: resp = self._request(f"/api/project/{project_name}/pool/list") - result: List[Pool] = parse_obj_as(List[Pool], resp.json()) + result: List[Pool] = parse_obj_as(List[Pool], resp.json()) # type: ignore[operator] return result def delete(self, project_name: str, pool_name: str, force: bool) -> None: @@ -27,7 +27,7 @@ def create(self, project_name: str, pool_name: str) -> None: def show(self, project_name: str, pool_name: str) -> List[Instance]: body = schemas_pools.ShowPoolRequest(name=pool_name) resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) - result: List[Instance] = parse_obj_as(List[Instance], resp.json()) + result: List[Instance] = parse_obj_as(List[Instance], resp.json()) # type: ignore[operator] return result def remove(self, project_name: str, pool_name: str, instance_name: str) -> None: From 97d1c1b9ef86b62d3a046d6c4c3d0f77b088b0f1 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 5 Feb 2024 11:10:29 +0300 Subject: [PATCH 12/47] improve dstack run --- .pre-commit-config.yaml | 4 +- src/dstack/_internal/cli/commands/pool.py | 18 ++- src/dstack/_internal/cli/commands/run.py | 61 ++++++-- src/dstack/_internal/core/models/pools.py | 2 + src/dstack/_internal/core/models/profiles.py | 3 +- src/dstack/_internal/core/models/runs.py | 7 +- .../_internal/core/services/ssh/attach.py | 3 - .../server/background/tasks/process_pools.py | 18 ++- .../tasks/process_submitted_jobs.py | 86 ++---------- ...add_pools.py => dad000707a2c_add_pools.py} | 28 +++- src/dstack/_internal/server/models.py | 12 +- src/dstack/_internal/server/services/pools.py | 24 ++-- src/dstack/_internal/server/services/runs.py | 130 ++++++++++++++++-- src/dstack/_internal/server/testing/common.py | 3 + src/dstack/api/_public/runs.py | 19 ++- src/dstack/api/server/_pools.py | 2 +- .../tasks/test_process_submitted_jobs.py | 4 +- .../tasks/test_process_terminating_jobs.py | 4 +- .../_internal/server/routers/test_runs.py | 5 +- .../_internal/server/services/test_pools.py | 21 ++- 20 files changed, 310 insertions(+), 144 deletions(-) rename src/dstack/_internal/server/migrations/versions/{98b9e40f03b0_add_pools.py => dad000707a2c_add_pools.py} (85%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ae35e988..a3a98309e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.14 + rev: v0.2.0 hooks: - id: ruff - repo: https://github.com/psf/black @@ -23,7 +23,7 @@ repos: exclude: 'versions|src/tests' additional_dependencies: [types-PyYAML, types-requests, pydantic, sqlalchemy] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.14 + rev: v0.2.0 hooks: - id: ruff name: ruff autofix diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index dca628645..95e70bd27 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -32,12 +32,20 @@ def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None: table = Table(box=None) table.add_column("NAME") table.add_column("DEFAULT") + table.add_column("INSTANCES") if verbose: table.add_column("CREATED") sorted_pools = sorted(pools, key=lambda r: r.name) for pool in sorted_pools: - row = [pool.name, "default" if pool.default else ""] + default_mark = "default" if pool.default else "" + color = ( + colors["success"] + if pool.total_instances == pool.available_instances + else colors["error"] + ) + health = f"[{color}]{pool.available_instances}/{pool.total_instances}[/{color}]" + row = [pool.name, default_mark, health] if verbose: row.append(pretty_date(pool.created_at)) table.add_row(*row) @@ -48,18 +56,20 @@ def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None: def print_instance_table(instances: Sequence[Instance]) -> None: table = Table(box=None) - table.add_column("INSTANCE ID") + table.add_column("INSTANCE NAME") table.add_column("BACKEND") table.add_column("INSTANCE TYPE") table.add_column("STATUS") table.add_column("PRICE") for instance in instances: + status_mark = "success" if instance.status.is_available() else "warning" + color = colors[status_mark] row = [ instance.instance_id, instance.backend, instance.instance_type.resources.pretty_format(), - instance.status, + f"[{color}]{instance.status}[/{color}]", f"{instance.price:.02f}", ] table.add_row(*row) @@ -246,7 +256,7 @@ def _register(self) -> None: formatter_class=self._parser.formatter_class, ) show_parser.add_argument( - "--name", "-n", dest="pool_name", help="The name of the pool", required=True + "--pool", dest="pool_name", help="The name of the pool", required=True ) show_parser.set_defaults(subfunc=self._show) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index 7aa22fb12..62b69b977 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -13,13 +13,18 @@ BaseRunConfigurator, run_configurators_mapping, ) -from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.cli.utils.common import colors, confirm_ask, console from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy +from dstack._internal.core.models.profiles import ( + DEFAULT_POOL_NAME, + CreationPolicy, + TerminationPolicy, +) from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager +from dstack._internal.utils.common import parse_pretty_duration from dstack._internal.utils.logging import get_logger from dstack.api import RunStatus from dstack.api._public.runs import Run @@ -86,10 +91,21 @@ def _register(self): ) self._parser.add_argument( "--reuse", - dest="creation_policy", - action="store_const", - const=CreationPolicy.REUSE, - help="Reuse instance", + dest="creation_policy_reuse", + action="store_true", + help="Reuse instance from pool", + ) + self._parser.add_argument( + "--idle-duration", + dest="idle_duration", + type=str, + help="Idle time before instance termination", + ) + self._parser.add_argument( + "--instance", + dest="instance_name", + metavar="NAME", + help="Reuse instance from pool with name [code]NAME[/]", ) register_profile_args(self._parser) @@ -102,6 +118,33 @@ def _command(self, args: argparse.Namespace): self._parser.print_help() return + termination_policy_idle = 5 * 60 + termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + + if args.idle_duration is not None: + try: + termination_policy_idle = int(args.idle_duration) + except ValueError: + termination_policy_idle = 60 * parse_pretty_duration(args.idle_duration) + + creation_policy = ( + CreationPolicy.REUSE if args.creation_policy_reuse else CreationPolicy.REUSE_OR_CREATE + ) + + if creation_policy == CreationPolicy.REUSE and termination_policy_idle is not None: + console.print( + f'[{colors["warning"]}]If the flag --reuse is set, the argument --idle-duration will be skipped[/]' + ) + termination_policy_idle = None + termination_policy = TerminationPolicy.DONT_DESTROY + + if args.instance_name is not None and termination_policy_idle is not None: + console.print( + f'[{colors["warning"]}]--idle-duration won\'t be applied to the instance {args.instance_name!r}[/]' + ) + termination_policy_idle = None + termination_policy = TerminationPolicy.DONT_DESTROY + super()._command(args) try: repo = self.api.repos.load(Path.cwd()) @@ -137,6 +180,10 @@ def _command(self, args: argparse.Namespace): working_dir=args.working_dir, run_name=args.run_name, pool_name=pool_name, + instance_name=args.instance_name, + creation_policy=creation_policy, + termination_policy=termination_policy, + termination_policy_idle=f"{termination_policy_idle}s", ) except ConfigurationError as e: raise CLIError(str(e)) @@ -155,8 +202,6 @@ def _command(self, args: argparse.Namespace): console.print("\nExiting...") return - run_plan.run_spec.profile.creation_policy = args.creation_policy - try: with console.status("Submitting run..."): run = self.api.runs.exec_plan(run_plan, repo, reserve_ports=not args.detach) diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 9cb599d7e..5fc4c0b65 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -11,6 +11,8 @@ class Pool(BaseModel): # type: ignore[misc,valid-type] name: str default: bool created_at: datetime.datetime + total_instances: int + available_instances: int class Instance(BaseModel): # type: ignore[misc,valid-type] diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 6957d8998..fdc70e3fc 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -108,7 +108,8 @@ class Profile(ForbidExtra): pool_name: Annotated[ Optional[str], Field(description="The name of the pool. If not set, dstack will use the default name."), - ] = DEFAULT_POOL_NAME + ] = None + instance_name: Annotated[Optional[str], Field(description="The name of the instance")] creation_policy: Annotated[ Optional[CreationPolicy], Field(description="The policy for using instances from the pool") ] diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 2970acb1b..d745303c9 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,9 +1,9 @@ from datetime import datetime, timedelta from enum import Enum -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence from pydantic import UUID4, BaseModel, Field -from typing_extensions import Annotated, Literal +from typing_extensions import Annotated from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration, RegistryAuth @@ -239,3 +239,6 @@ def is_finished(self): def is_started(self): return not self.is_finished() + + def is_available(self) -> bool: + return self in (self.READY, self.BUSY) diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index eff893de9..4b30355f6 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -4,8 +4,6 @@ import time from typing import Optional, Tuple -from icecream import ic - from dstack._internal.core.errors import SSHError from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.services.configs import ConfigManager @@ -101,7 +99,6 @@ def __init__( } else: self.container_config = None - ic(self.container_config) self.ssh_config_path = str(ConfigManager().dstack_ssh_config_path) def attach(self): diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 526685ee3..71af69a4c 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -1,3 +1,4 @@ +import datetime from datetime import timedelta from typing import Dict from uuid import UUID @@ -19,7 +20,7 @@ from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.utils.common import run_async -from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.common import get_current_datetime, parse_pretty_duration from dstack._internal.utils.logging import get_logger PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) @@ -126,11 +127,20 @@ async def _terminate_old_instance() -> None: instances = res.scalars().all() for instance in instances: - if instance.finished_at + instance.termination_idle_time > get_current_datetime(): - instance_type = parse_raw_as( # type: ignore[operator] + if instance.finished_at is None: + continue + + delta = datetime.timedelta( + seconds=parse_pretty_duration(instance.termination_idle_time) + ) + if instance.finished_at + delta > get_current_datetime(): + jpd: JobProvisioningData = parse_raw_as( # type: ignore[operator] JobProvisioningData, instance.job_provisioning_data ).backend await terminate_job_provisioning_data_instance( - project=instance.project, job_provisioning_data=instance.job_provisioning_data + project=instance.project, job_provisioning_data=jpd ) + instance.deleted = True + instance.deleted_at = get_current_datetime() + await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 1554b8d74..ba959dff9 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -8,13 +8,11 @@ from dstack._internal.core.backends.base import Backend from dstack._internal.core.errors import BackendError -from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile -from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy from dstack._internal.core.models.runs import ( InstanceStatus, Job, @@ -32,13 +30,8 @@ SUBMITTED_PROCESSING_JOBS_LOCK, ) from dstack._internal.server.services.logging import job_log -from dstack._internal.server.services.pools import ( - get_pool_instances, - instance_model_to_instance, - list_project_pool_models, -) -from dstack._internal.server.services.runs import run_model_to_run -from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED +from dstack._internal.server.services.pools import get_pool_instances, list_project_pool_models +from dstack._internal.server.services.runs import check_relevance, run_model_to_run from dstack._internal.server.utils.common import run_async from dstack._internal.utils import common as common_utils from dstack._internal.utils.logging import get_logger @@ -79,71 +72,6 @@ async def _process_job(job_id: UUID): ) -from icecream import ic - - -def check_relevance( - profile: Profile, resources: ResourcesSpec, instance_model: InstanceModel -) -> bool: - - jpd: JobProvisioningData = parse_raw_as( - JobProvisioningData, instance_model.job_provisioning_data - ) - - # TODO: remove on prod - if LOCAL_BACKEND_ENABLED and jpd.backend == BackendType.LOCAL: - return True - - instance = instance_model_to_instance(instance_model) - - if profile.backends is not None and instance.backend not in profile.backends: - logger.warning(f"no backend select ") - return False - - # use gpuhunt - - instance_resources: ResourcesSpec = parse_raw_as( - ResourcesSpec, instance_model.resource_spec_data - ) - - ic(resources.cpu.min, instance_resources.cpu.min) - if resources.cpu.min > instance_resources.cpu.min: - return False - - ic(resources.gpu) - if resources.gpu is not None: - - if instance_resources.gpu is None: - return False - - if resources.gpu.count.min > instance_resources.gpu.count.min: - return False - - if resources.gpu.memory.min > instance_resources.gpu.memory.min: - return False - - # TODO: compare GPU names - - ic(resources.memory, instance_resources.memory.min) - if resources.memory.min > instance_resources.memory.min: - return False - - if resources.shm_size is not None: - if instance_resources.shm_size is None: - return False - - if resources.shm_size > instance_resources.shm_size: - return False - - if resources.disk is not None: - if instance_resources.disk is None: - return False - if resources.disk.size.min > instance_resources.disk.size.min: - return False - - return True - - async def _process_submitted_job(session: AsyncSession, job_model: JobModel): logger.debug(*job_log("provisioning", job_model)) res = await session.execute( @@ -185,8 +113,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # pool capacity pool_instances = await get_pool_instances(session, project_model, run_pool) - available_instanses = (p for p in pool_instances if p.status == InstanceStatus.READY) - ic(available_instanses) + available_instanses = [p for p in pool_instances if p.status == InstanceStatus.READY] relevant_instances: List[InstanceModel] = [] for instance in available_instanses: if check_relevance(profile, run_spec.configuration.resources, instance): @@ -247,8 +174,11 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), termination_policy=profile.termination_policy, - termination_idle_time=profile.termination_idle_time, + termination_idle_time="300", # TODO: fix deserailize job=job_model, + backend=offer.backend, + price=offer.price, + region=offer.region, ) session.add(im) diff --git a/src/dstack/_internal/server/migrations/versions/98b9e40f03b0_add_pools.py b/src/dstack/_internal/server/migrations/versions/dad000707a2c_add_pools.py similarity index 85% rename from src/dstack/_internal/server/migrations/versions/98b9e40f03b0_add_pools.py rename to src/dstack/_internal/server/migrations/versions/dad000707a2c_add_pools.py index f15a24ba0..405fc2d37 100644 --- a/src/dstack/_internal/server/migrations/versions/98b9e40f03b0_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/dad000707a2c_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: 98b9e40f03b0 +Revision ID: dad000707a2c Revises: d3e8af4786fa -Create Date: 2024-02-04 17:25:03.945051 +Create Date: 2024-02-05 07:42:58.102664 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "98b9e40f03b0" +revision = "dad000707a2c" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None @@ -67,6 +67,28 @@ def upgrade() -> None: sa.Column("finished_at", sa.DateTime(), nullable=True), sa.Column("termination_policy", sa.String(length=50), nullable=True), sa.Column("termination_idle_time", sa.String(length=50), nullable=True), + sa.Column( + "backend", + sa.Enum( + "AWS", + "AZURE", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "TENSORDOCK", + "VASTAI", + name="backendtype", + ), + nullable=False, + ), + sa.Column("backend_data", sa.String(length=4000), nullable=True), + sa.Column("region", sa.String(length=2000), nullable=False), + sa.Column("price", sa.Float(), nullable=False), sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), sa.Column("offer", sa.String(length=4000), nullable=False), sa.Column("resource_spec_data", sa.String(length=4000), nullable=True), diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 3ee2f3438..a603ae05c 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -7,6 +7,7 @@ Boolean, DateTime, Enum, + Float, ForeignKey, Integer, MetaData, @@ -285,6 +286,11 @@ class InstanceModel(BaseModel): termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50)) termination_idle_time: Mapped[Optional[str]] = mapped_column(String(50)) + backend: Mapped[BackendType] = mapped_column(Enum(BackendType)) + backend_data: Mapped[Optional[str]] = mapped_column(String(4000)) + region: Mapped[str] = mapped_column(String(2000)) + price: Mapped[float] = mapped_column(Float) + job_provisioning_data: Mapped[str] = mapped_column(String(4000)) offer: Mapped[str] = mapped_column(String(4000)) @@ -299,9 +305,9 @@ class InstanceModel(BaseModel): # ip address # ssh creds: user, port, dockerized # real resources + spot (exact) / instance offer - # backend + backend data - # region - # price (for querying) + # + backend + backend data + # + region + # + price (for querying) # + # termination policy # creation policy # job_provisioning_data=job_provisioning_data.json(), diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index c86c11f07..b75dda4b1 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -51,32 +51,36 @@ async def get_pool( def pool_model_to_pool(pool_model: PoolModel) -> Pool: + total = len(pool_model.instances) + available = sum(instance.status.is_available() for instance in pool_model.instances) return Pool( name=pool_model.name, default=pool_model.project.default_pool_id == pool_model.id, created_at=pool_model.created_at.replace(tzinfo=timezone.utc), + total_instances=total, + available_instances=available, ) async def create_pool_model(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel: pools = await session.scalars( - select(PoolModel).where( - PoolModel.name == name, PoolModel.project == project, PoolModel.deleted == False - ) + select(PoolModel) + .where(PoolModel.name == name, PoolModel.project == project, PoolModel.deleted == False) + .options(joinedload(PoolModel.instances)) ) - if pools.all(): + if pools.unique().all(): raise ValueError("duplicate pool name") # TODO: return error with description pool = PoolModel( name=name, project_id=project.id, ) - session.add(pool) - await session.commit() if project.default_pool is None: project.default_pool = pool - await session.commit() + + session.add(pool) + await session.commit() return pool @@ -85,9 +89,11 @@ async def list_project_pool_models( session: AsyncSession, project: ProjectModel ) -> Sequence[PoolModel]: pools = await session.scalars( - select(PoolModel).where(PoolModel.project_id == project.id, PoolModel.deleted == False) + select(PoolModel) + .where(PoolModel.project_id == project.id, PoolModel.deleted == False) + .options(joinedload(PoolModel.instances)) ) - return pools.all() # type: ignore[no-any-return] + return pools.unique().all() # type: ignore[no-any-return] async def set_default_pool(session: AsyncSession, project: ProjectModel, pool_name: str) -> bool: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index febf365e8..2c13fbd3b 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -19,11 +19,13 @@ SSHKeys, ) from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile +from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( InstanceStatus, Job, @@ -61,8 +63,13 @@ get_default_image, get_default_python_verison, ) -from dstack._internal.server.services.pools import create_pool_model +from dstack._internal.server.services.pools import ( + create_pool_model, + get_pool_instances, + instance_model_to_instance, +) from dstack._internal.server.services.projects import list_project_models, list_user_project_models +from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger from dstack._internal.utils.random_names import generate_name @@ -257,6 +264,9 @@ async def create_instance( project=project, pool=pool, status=InstanceStatus.STARTING, + backend=backend.TYPE, + region=instance_offer.region, + price=instance_offer.price, # job_id: Optional[FK] (current job) # ip address # ssh creds: user, port, dockerized @@ -285,31 +295,62 @@ async def get_run_plan( user: UserModel, run_spec: RunSpec, ) -> RunPlan: + pool_instances = await get_pool_instances(session, project, run_spec.profile.pool_name) + pool_offers = [] + + if run_spec.profile.creation_policy == CreationPolicy.REUSE: + requirements = Requirements( + resources=run_spec.configuration.resources, + max_price=run_spec.profile.max_price, + spot=None, + ) + if run_spec.profile.instance_name is not None: + for instance in pool_instances: + if instance.name == run_spec.profile.instance_name and check_relevance( + run_spec.profile, requirements.resources, instance + ): + offer = pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) + pool_offers.append(offer) + else: + for instance in pool_instances: + if check_relevance(run_spec.profile, requirements.resources, instance): + offer = pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) + pool_offers.append(offer) + backends = await backends_services.get_project_backends(project=project) if run_spec.profile.backends is not None: backends = [b for b in backends if b.TYPE in run_spec.profile.backends] + run_name = run_spec.run_name # preserve run_name run_spec.run_name = "dry-run" # will regenerate jobs on submission jobs = get_jobs_from_run_spec(run_spec) job_plans = [] + + creation_policy = run_spec.profile.creation_policy + for job in jobs: - # TODO: use the job.pool_name to select an offer - requirements = job.job_spec.requirements - offers = await backends_services.get_instance_offers( - backends=backends, - requirements=requirements, - exclude_not_available=False, - ) - for backend, offer in offers: - offer.backend = backend.TYPE - offers = [offer for _, offer in offers] + job_offers = [] + job_offers.extend(pool_offers) + + if creation_policy is None or creation_policy == CreationPolicy.REUSE_OR_CREATE: + requirements = job.job_spec.requirements + offers = await backends_services.get_instance_offers( + backends=backends, + requirements=requirements, + exclude_not_available=False, + ) + for backend, offer in offers: + offer.backend = backend.TYPE + job_offers.extend(offer for _, offer in offers) + job_plan = JobPlan( job_spec=job.job_spec, - offers=offers[:50], - total_offers=len(offers), - max_price=max((offer.price for offer in offers), default=None), + offers=job_offers[:50], + total_offers=len(job_offers), + max_price=max((offer.price for offer in job_offers), default=None), ) job_plans.append(job_plan) + run_spec.run_name = run_name # restore run_name run_plan = RunPlan( project_name=project.name, user=user.name, run_spec=run_spec, job_plans=job_plans @@ -590,3 +631,62 @@ async def abort_runs_of_pool(session: AsyncSession, project_model: ProjectModel, active_run_names.append(run.run_spec.run_name) await stop_runs(session, project_model, active_run_names, abort=True) + + +def check_relevance( + profile: Profile, resources: ResourcesSpec, instance_model: InstanceModel +) -> bool: + + jpd: JobProvisioningData = pydantic.parse_raw_as( + JobProvisioningData, instance_model.job_provisioning_data + ) + + # TODO: remove on prod + if LOCAL_BACKEND_ENABLED and jpd.backend == BackendType.LOCAL: + return True + + instance = instance_model_to_instance(instance_model) + + if profile.backends is not None and instance.backend not in profile.backends: + logger.warning(f"no backend select ") + return False + + # use gpuhunt + + instance_resources: ResourcesSpec = pydantic.parse_raw_as( + ResourcesSpec, instance_model.resource_spec_data + ) + + if resources.cpu.min > instance_resources.cpu.min: + return False + + if resources.gpu is not None: + + if instance_resources.gpu is None: + return False + + if resources.gpu.count.min > instance_resources.gpu.count.min: + return False + + if resources.gpu.memory.min > instance_resources.gpu.memory.min: + return False + + # TODO: compare GPU names + + if resources.memory.min > instance_resources.memory.min: + return False + + if resources.shm_size is not None: + if instance_resources.shm_size is None: + return False + + if resources.shm_size > instance_resources.shm_size: + return False + + if resources.disk is not None: + if instance_resources.disk is None: + return False + if resources.disk.size.min > instance_resources.disk.size.min: + return False + + return True diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 758a0e242..c734166c2 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -317,6 +317,9 @@ async def create_instance( job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', resource_spec_data=resources.json(), + price=1, + region="eu-west", + backend=BackendType.DATACRUNCH, ) session.add(im) await session.commit() diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 33ae57ca8..18232487b 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -17,7 +17,13 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration from dstack._internal.core.models.instances import InstanceOfferWithAvailability -from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy, SpotPolicy +from dstack._internal.core.models.profiles import ( + CreationPolicy, + Profile, + ProfileRetryPolicy, + SpotPolicy, + TerminationPolicy, +) from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import JobSpec @@ -381,6 +387,10 @@ def get_plan( working_dir: Optional[str] = None, run_name: Optional[str] = None, pool_name: Optional[str] = None, + instance_name: Optional[str] = None, + creation_policy: Optional[CreationPolicy] = None, + termination_policy: Optional[TerminationPolicy] = None, + termination_policy_idle: Optional[Union[int, str]] = None, ) -> RunPlan: # """ # Get run plan. Same arguments as `submit` @@ -411,9 +421,10 @@ def get_plan( max_duration=max_duration, max_price=max_price, pool_name=pool_name, - creation_policy=None, - termination_idle_time=None, - termination_policy=None, + instance_name=instance_name, + creation_policy=creation_policy, + termination_policy=termination_policy, + termination_idle_time=None, # TODO: fix deserialize ) run_spec = RunSpec( run_name=run_name, diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index bb10263d0..a1d87012c 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -22,7 +22,7 @@ def delete(self, project_name: str, pool_name: str, force: bool) -> None: def create(self, project_name: str, pool_name: str) -> None: body = schemas_pools.CreatePoolRequest(name=pool_name) - self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) + self._request(f"/api/project/{project_name}/pool/create", body=body.json()) def show(self, project_name: str, pool_name: str) -> List[Instance]: body = schemas_pools.ShowPoolRequest(name=pool_name) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index e5e148ef0..7014b7884 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -235,7 +235,7 @@ async def test_job_whith_instance(self, test_db, session: AsyncSession): username="root", ssh_port=22, dockerized=False, - pool_id="", + pool_id=str(pool.id), backend_data=None, ssh_proxy=None, ) @@ -263,7 +263,7 @@ async def test_job_whith_instance(self, test_db, session: AsyncSession): assert jm.error_code == None assert ( jm.job_spec_data - == r"""{"job_num": 0, "job_name": "test-run-0", "app_specs": [], "commands": ["/bin/bash", "-i", "-c", "(echo pip install ipykernel... && pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo \"no pip, ipykernel was not installed\" && echo '' && echo To open in VS Code Desktop, use link below: && echo '' && echo ' vscode://vscode-remote/ssh-remote+test-run/workflow' && echo '' && echo 'To connect via SSH, use: `ssh test-run`' && echo '' && echo -n 'To exit, press Ctrl+C.' && tail -f /dev/null"], "env": {}, "gateway": null, "home_dir": "/root", "image_name": "dstackai/base:py3.10-0.4rc4-cuda-12.1", "max_duration": 21600, "registry_auth": null, "requirements": {"resources": {"cpu": {"min": 2, "max": null}, "memory": {"min": 8.0, "max": null}, "shm_size": null, "gpu": null, "disk": null}, "max_price": null, "spot": false}, "retry_policy": {"retry": false, "limit": null}, "working_dir": ".", "pool_name": "default-pool"}""" + == r"""{"job_num": 0, "job_name": "test-run-0", "app_specs": [], "commands": ["/bin/bash", "-i", "-c", "(echo pip install ipykernel... && pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo \"no pip, ipykernel was not installed\" && echo '' && echo To open in VS Code Desktop, use link below: && echo '' && echo ' vscode://vscode-remote/ssh-remote+test-run/workflow' && echo '' && echo 'To connect via SSH, use: `ssh test-run`' && echo '' && echo -n 'To exit, press Ctrl+C.' && tail -f /dev/null"], "env": {}, "gateway": null, "home_dir": "/root", "image_name": "dstackai/base:py3.10-0.4rc4-cuda-12.1", "max_duration": 21600, "registry_auth": null, "requirements": {"resources": {"cpu": {"min": 2, "max": null}, "memory": {"min": 8.0, "max": null}, "shm_size": null, "gpu": null, "disk": null}, "max_price": null, "spot": false}, "retry_policy": {"retry": false, "limit": null}, "working_dir": ".", "pool_name": null}""" ) assert jm.job_provisioning_data == ( '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": ' diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 1c65fd656..795e5aa17 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -18,7 +18,7 @@ create_user, ) -MODULE = "dstack._internal.server.background.tasks.process_finished_jobs" +MODULE = "dstack._internal.server.services.jobs" class TestProcessFinishedJobs: @@ -63,11 +63,11 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A dockerized=False, pool_id="", backend_data=None, + ssh_proxy=None, ), ) with patch(f"{MODULE}.terminate_job_provisioning_data_instance") as terminate: await process_finished_jobs() - terminate.assert_called_once() await session.refresh(job) assert job is not None assert job.removed diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 57f826c27..d3bcc0b36 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -73,6 +73,7 @@ def get_dev_env_run_plan_dict( "backends": ["local", "aws", "azure", "gcp", "lambda"], "creation_policy": None, "default": False, + "instance_name": None, "max_duration": "off", "max_price": None, "name": "string", @@ -184,6 +185,7 @@ def get_dev_env_run_dict( "backends": ["local", "aws", "azure", "gcp", "lambda"], "creation_policy": None, "default": False, + "instance_name": None, "max_duration": "off", "max_price": None, "name": "string", @@ -266,7 +268,7 @@ def get_dev_env_run_dict( "error_code": None, "job_provisioning_data": None, }, - "cost": 0, + "cost": 0.0, "service": None, } @@ -587,7 +589,6 @@ async def test_terminates_running_run(self, test_db, session: AsyncSession): await session.refresh(job) assert job.status == JobStatus.TERMINATED assert not job.removed - assert job.remove_at is not None @pytest.mark.asyncio async def test_leaves_finished_runs_unchanged(self, test_db, session: AsyncSession): diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index 280fbb53f..4c65fc91e 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -40,12 +40,22 @@ async def test_pool(session: AsyncSession, test_db): status=InstanceStatus.PENDING, job_provisioning_data="", offer="", + region="", + price=1, + backend=BackendType.LOCAL, ) session.add(im) await session.commit() + await session.refresh(pool) core_model_pool = services_pools.pool_model_to_pool(pool) - assert core_model_pool == Pool(name="test_pool", default=True, created_at=pool.created_at) + assert core_model_pool == Pool( + name="test_pool", + default=True, + created_at=pool.created_at.replace(tzinfo=dt.timezone.utc), # ??? + total_instances=1, + available_instances=0, + ) list_pools = await services_pools.list_project_pool(session=session, project=project) assert list_pools == [services_pools.pool_model_to_pool(pool)] @@ -116,6 +126,9 @@ async def test_show_pool(session: AsyncSession, test_db): status=InstanceStatus.PENDING, job_provisioning_data='{"pool_id":"123", "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + region="eu-west", + price=1, + backend=BackendType.LOCAL, ) session.add(im) await session.commit() @@ -137,6 +150,9 @@ async def test_get_pool_instances(session: AsyncSession, test_db): status=InstanceStatus.PENDING, job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + region="eu-west", + price=1, + backend=BackendType.LOCAL, ) session.add(im) await session.commit() @@ -162,6 +178,9 @@ async def test_generate_instance_name(session: AsyncSession, test_db): status=InstanceStatus.PENDING, job_provisioning_data="", offer="", + backend=BackendType.REMOTE, + region="", + price=0, ) session.add(im) await session.commit() From d575d4a71ca0969ebc15e669eb5155d51ad5013b Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 5 Feb 2024 13:26:43 +0300 Subject: [PATCH 13/47] fixup! improve dstack run --- src/dstack/_internal/cli/commands/pool.py | 4 ++-- src/dstack/_internal/cli/utils/run.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 95e70bd27..3cfa0020f 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -152,10 +152,10 @@ def th(s: str) -> str: "yes" if r.spot else "no", f"${offer.price:g}", availability, - style=None if i == 1 else "grey58", + style=None if i == 1 else colors["secondary"], ) if len(print_offers) > offers_limit: - offers_table.add_row("", "...", style="grey58") + offers_table.add_row("", "...", style=colors["secondary"]) console.print(props) console.print() diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index c2509be70..8369394e1 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -2,7 +2,7 @@ from rich.table import Table -from dstack._internal.cli.utils.common import console +from dstack._internal.cli.utils.common import colors, console from dstack._internal.core.models.instances import InstanceAvailability, InstanceType from dstack._internal.core.models.runs import RunPlan from dstack._internal.utils.common import pretty_date @@ -78,10 +78,10 @@ def th(s: str) -> str: "yes" if r.spot else "no", f"${offer.price:g}", availability, - style=None if i == 1 else "grey58", + style=None if i == 1 else colors["secondary"], ) if job_plan.total_offers > len(job_plan.offers): - offers.add_row("", "...", style="secondary") + offers.add_row("", "...", style=colors["secondary"]) console.print(props) console.print() From 89160da751bfa4f5ddb28058d8ca267361abdc51 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 5 Feb 2024 13:48:53 +0100 Subject: [PATCH 14/47] Refactor check_relevance to use gpuhunt --- .../tasks/process_submitted_jobs.py | 16 ++-- src/dstack/_internal/server/services/pools.py | 63 +++++++++++++- src/dstack/_internal/server/services/runs.py | 87 ++----------------- 3 files changed, 75 insertions(+), 91 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index ba959dff9..7f8e68f09 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -30,8 +30,12 @@ SUBMITTED_PROCESSING_JOBS_LOCK, ) from dstack._internal.server.services.logging import job_log -from dstack._internal.server.services.pools import get_pool_instances, list_project_pool_models -from dstack._internal.server.services.runs import check_relevance, run_model_to_run +from dstack._internal.server.services.pools import ( + filter_pool_instances, + get_pool_instances, + list_project_pool_models, +) +from dstack._internal.server.services.runs import run_model_to_run from dstack._internal.server.utils.common import run_async from dstack._internal.utils import common as common_utils from dstack._internal.utils.logging import get_logger @@ -113,11 +117,9 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # pool capacity pool_instances = await get_pool_instances(session, project_model, run_pool) - available_instanses = [p for p in pool_instances if p.status == InstanceStatus.READY] - relevant_instances: List[InstanceModel] = [] - for instance in available_instanses: - if check_relevance(profile, run_spec.configuration.resources, instance): - relevant_instances.append(instance) + relevant_instances = filter_pool_instances( + pool_instances, profile, run_spec.configuration.resources, status=InstanceStatus.READY + ) logger.info(*job_log(f"num relevance {len(relevant_instances)}", job_model)) if relevant_instances: diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index b75dda4b1..bb0095327 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -2,23 +2,30 @@ from datetime import timezone from typing import Dict, List, Optional, Sequence +import gpuhunt from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from dstack._internal.core.backends.base.offers import ( + offer_to_catalog_item, + requirements_to_query_filter, +) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( Gpu, InstanceAvailability, + InstanceOffer, InstanceOfferWithAvailability, InstanceType, Resources, ) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy from dstack._internal.core.models.resources import ResourcesSpec -from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, Requirements +from dstack._internal.server import settings from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel from dstack._internal.utils import random_names from dstack._internal.utils.common import get_current_datetime @@ -174,8 +181,12 @@ async def list_deleted_pools( def instance_model_to_instance(instance_model: InstanceModel) -> Instance: - offer: InstanceOfferWithAvailability = parse_raw_as(InstanceOfferWithAvailability, instance_model.offer) # type: ignore[operator] - jpd: JobProvisioningData = parse_raw_as(JobProvisioningData, instance_model.job_provisioning_data) # type: ignore[operator] + offer: InstanceOfferWithAvailability = parse_raw_as( + InstanceOfferWithAvailability, instance_model.offer + ) + jpd: JobProvisioningData = parse_raw_as( + JobProvisioningData, instance_model.job_provisioning_data + ) instance = Instance( backend=offer.backend, @@ -336,3 +347,47 @@ async def add_remote( await session.commit() return True + + +def filter_pool_instances( + pool_instances: List[InstanceModel], + profile: Profile, + resources: ResourcesSpec, + *, + status: Optional[InstanceStatus] = None, +) -> List[InstanceModel]: + """ + Filter instances by `instance_name`, `backends`, `resources`, `spot_policy`, `max_price`, `status` + """ + instances: List[InstanceModel] = [] + candidates: List[InstanceModel] = [] + for instance in pool_instances: + if profile.instance_name is not None and instance.name != profile.instance_name: + continue + if status is not None and instance.status != status: + continue + + # TODO: remove on prod + if settings.LOCAL_BACKEND_ENABLED and instance.backend == BackendType.LOCAL: + instances.append(instance) + continue + + if profile.backends is not None and instance.backend not in profile.backends: + continue + candidates.append(instance) + + requirements = Requirements( + resources=resources, + max_price=profile.max_price, + spot={ + SpotPolicy.AUTO: None, + SpotPolicy.SPOT: True, + SpotPolicy.ONDEMAND: False, + }[profile.spot_policy], + ) + query_filter = requirements_to_query_filter(requirements) + for instance in candidates: + catalog_item = offer_to_catalog_item(parse_raw_as(InstanceOffer, instance.offer)) + if gpuhunt.matches(catalog_item, query_filter): + instances.append(instance) + return instances diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 2c13fbd3b..d31c66111 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -19,13 +19,11 @@ SSHKeys, ) from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError -from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedInstanceInfo, ) from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile -from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( InstanceStatus, Job, @@ -65,11 +63,10 @@ ) from dstack._internal.server.services.pools import ( create_pool_model, + filter_pool_instances, get_pool_instances, - instance_model_to_instance, ) from dstack._internal.server.services.projects import list_project_models, list_user_project_models -from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger from dstack._internal.utils.random_names import generate_name @@ -299,23 +296,12 @@ async def get_run_plan( pool_offers = [] if run_spec.profile.creation_policy == CreationPolicy.REUSE: - requirements = Requirements( - resources=run_spec.configuration.resources, - max_price=run_spec.profile.max_price, - spot=None, - ) - if run_spec.profile.instance_name is not None: - for instance in pool_instances: - if instance.name == run_spec.profile.instance_name and check_relevance( - run_spec.profile, requirements.resources, instance - ): - offer = pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) - pool_offers.append(offer) - else: - for instance in pool_instances: - if check_relevance(run_spec.profile, requirements.resources, instance): - offer = pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) - pool_offers.append(offer) + for instance in filter_pool_instances( + pool_instances, run_spec.profile, run_spec.configuration.resources + ): + pool_offers.append( + pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) + ) backends = await backends_services.get_project_backends(project=project) if run_spec.profile.backends is not None: @@ -631,62 +617,3 @@ async def abort_runs_of_pool(session: AsyncSession, project_model: ProjectModel, active_run_names.append(run.run_spec.run_name) await stop_runs(session, project_model, active_run_names, abort=True) - - -def check_relevance( - profile: Profile, resources: ResourcesSpec, instance_model: InstanceModel -) -> bool: - - jpd: JobProvisioningData = pydantic.parse_raw_as( - JobProvisioningData, instance_model.job_provisioning_data - ) - - # TODO: remove on prod - if LOCAL_BACKEND_ENABLED and jpd.backend == BackendType.LOCAL: - return True - - instance = instance_model_to_instance(instance_model) - - if profile.backends is not None and instance.backend not in profile.backends: - logger.warning(f"no backend select ") - return False - - # use gpuhunt - - instance_resources: ResourcesSpec = pydantic.parse_raw_as( - ResourcesSpec, instance_model.resource_spec_data - ) - - if resources.cpu.min > instance_resources.cpu.min: - return False - - if resources.gpu is not None: - - if instance_resources.gpu is None: - return False - - if resources.gpu.count.min > instance_resources.gpu.count.min: - return False - - if resources.gpu.memory.min > instance_resources.gpu.memory.min: - return False - - # TODO: compare GPU names - - if resources.memory.min > instance_resources.memory.min: - return False - - if resources.shm_size is not None: - if instance_resources.shm_size is None: - return False - - if resources.shm_size > instance_resources.shm_size: - return False - - if resources.disk is not None: - if instance_resources.disk is None: - return False - if resources.disk.size.min > instance_resources.disk.size.min: - return False - - return True From d7a32b115b313848f58ed0f853a96926305fe411 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 5 Feb 2024 13:51:14 +0100 Subject: [PATCH 15/47] Fix tests discovery for Python 3.8 --- src/dstack/_internal/cli/commands/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 95e70bd27..af026e511 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -1,6 +1,6 @@ import argparse -from collections.abc import Sequence from pathlib import Path +from typing import Sequence from rich.table import Table From 2b6e24ffad43d8c04efd355584797823068ebdfa Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 5 Feb 2024 16:11:33 +0300 Subject: [PATCH 16/47] Improve profile.pool_name handling --- src/dstack/_internal/cli/commands/pool.py | 16 ++++++------ src/dstack/_internal/cli/commands/run.py | 2 +- src/dstack/_internal/server/routers/runs.py | 15 ++++++++--- src/dstack/_internal/server/services/pools.py | 25 +++++++++++++++++++ src/dstack/_internal/server/services/runs.py | 19 +++----------- src/dstack/api/_public/runs.py | 4 +-- src/dstack/api/server/_runs.py | 6 ++--- 7 files changed, 56 insertions(+), 31 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 3cfa0020f..81b242e3a 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -16,7 +16,7 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy +from dstack._internal.core.models.profiles import Profile, SpotPolicy from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager @@ -265,7 +265,9 @@ def _register(self) -> None: "add", help="Add instance to pool", formatter_class=self._parser.formatter_class ) add_parser.add_argument( - "--pool", dest="pool_name", help="The name of the pool", required=True + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", ) add_parser.add_argument( "-y", "--yes", help="Don't ask for confirmation", action="store_true" @@ -293,7 +295,9 @@ def _register(self) -> None: formatter_class=self._parser.formatter_class, ) remove_parser.add_argument( - "--pool", dest="pool_name", help="The name of the pool", required=True + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", ) remove_parser.add_argument( "--name", dest="instance_name", help="The name of the instance", required=True @@ -339,8 +343,6 @@ def _add(self, args: argparse.Namespace) -> None: super()._command(args) - pool_name: str = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name - resources = Resources( cpu=args.cpu, memory=args.memory, @@ -356,7 +358,7 @@ def _add(self, args: argparse.Namespace) -> None: profile = load_profile(Path.cwd(), args.profile) apply_profile_args(args, profile) - profile.pool_name = pool_name + profile.pool_name = args.pool_name # Add remote instance if args.remote: @@ -378,7 +380,7 @@ def _add(self, args: argparse.Namespace) -> None: self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path with console.status("Getting instances..."): - offers = self.api.runs.get_offers(profile, requirements) + pool_name, offers = self.api.runs.get_offers(profile, requirements) print_offers_table(pool_name, profile, requirements, offers) if not args.yes and not confirm_ask("Continue?"): diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index 62b69b977..54ad6c424 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -87,7 +87,7 @@ def _register(self): self._parser.add_argument( "--pool", dest="pool_name", - help="The name of the pool", + help="The name of the pool. If not set, the default pool will be used", ) self._parser.add_argument( "--reuse", diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 67492f8fe..556696ec6 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -20,7 +20,10 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.services import runs -from dstack._internal.server.services.pools import generate_instance_name +from dstack._internal.server.services.pools import ( + generate_instance_name, + get_or_create_default_pool_by_name, +) root_router = APIRouter( prefix="/api/runs", @@ -68,11 +71,17 @@ async def get_offers( body: GetOffersRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[InstanceOfferWithAvailability]: +) -> Tuple[str, List[InstanceOfferWithAvailability]]: _, project = user_project + + active_pool = await get_or_create_default_pool_by_name( + session, project, body.profile.pool_name + ) + offers = await runs.get_run_plan_by_requirements(project, body.profile, body.requirements) instances = [instance for _, instance in offers] - return instances + + return active_pool.name, instances @project_router.post("/create_instance") diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index b75dda4b1..27f4d1149 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -50,6 +50,29 @@ async def get_pool( return pool +async def get_or_create_default_pool_by_name( + session: AsyncSession, project: ProjectModel, pool_name: Optional[str] +) -> PoolModel: + active_pool = None + if pool_name is None: + default_pool = None + pools = [ + pool + for pool in (await list_project_pool_models(session, project)) + if project.default_pool == pool + ] + if pools: + default_pool = pools[0] + if not default_pool: + default_pool = await create_pool_model(session, project, DEFAULT_POOL_NAME) + active_pool = default_pool + else: + active_pool = await get_pool(session, project, pool_name) + if active_pool is None: + active_pool = await create_pool_model(session, project, DEFAULT_POOL_NAME) + return active_pool + + def pool_model_to_pool(pool_model: PoolModel) -> Pool: total = len(pool_model.instances) available = sum(instance.status.is_available() for instance in pool_model.instances) @@ -269,6 +292,8 @@ async def add_remote( port: str, ) -> bool: + pool_model = await get_or_create_default_pool_by_name(session, project, profile.pool_name) + pool_name = profile.pool_name if instance_name is None: instance_name = await generate_instance_name(session, project, pool_name) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 2c13fbd3b..4823e3bcd 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -24,7 +24,7 @@ InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile +from dstack._internal.core.models.profiles import CreationPolicy, Profile from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( InstanceStatus, @@ -45,7 +45,6 @@ from dstack._internal.server.models import ( InstanceModel, JobModel, - PoolModel, ProjectModel, RunModel, UserModel, @@ -65,6 +64,7 @@ ) from dstack._internal.server.services.pools import ( create_pool_model, + get_or_create_default_pool_by_name, get_pool_instances, instance_model_to_instance, ) @@ -384,18 +384,7 @@ async def submit_run( else: await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) - pool_name = ( - DEFAULT_POOL_NAME if run_spec.profile.pool_name is None else run_spec.profile.pool_name - ) - - # create pool - pools = ( - await session.scalars( - select(PoolModel).where(PoolModel.name == pool_name, PoolModel.deleted == False) - ) - ).all() - if not pools: - await create_pool_model(session, project, pool_name) + pool = await get_or_create_default_pool_by_name(session, project, run_spec.profile.pool_name) run_model = RunModel( id=uuid.uuid4(), @@ -414,7 +403,7 @@ async def submit_run( await gateways.register_service_jobs(session, project, run_spec.run_name, jobs) for job in jobs: - job.job_spec.pool_name = pool_name + job.job_spec.pool_name = pool.name job_model = create_job_model_for_new_submission( run_model=run_model, job=job, diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 18232487b..0a0a3c537 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -7,7 +7,7 @@ from copy import copy from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import requests from websocket import WebSocketApp @@ -367,7 +367,7 @@ def submit( def get_offers( self, profile: Profile, requirements: Requirements - ) -> List[InstanceOfferWithAvailability]: + ) -> Tuple[str, List[InstanceOfferWithAvailability]]: return self._api_client.runs.get_offers(self._project, profile, requirements) def create_instance(self, pool_name: str, profile: Profile, requirements: Requirements): diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index da24b8d0f..dc0224db1 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from pydantic import parse_obj_as @@ -31,10 +31,10 @@ def get(self, project_name: str, run_name: str) -> Run: def get_offers( self, project_name: str, profile: Profile, requirements: Requirements - ) -> List[InstanceOfferWithAvailability]: + ) -> Tuple[str, List[InstanceOfferWithAvailability]]: body = GetOffersRequest(profile=profile, requirements=requirements) resp = self._request(f"/api/project/{project_name}/runs/get_offers", body=body.json()) - return parse_obj_as(List[InstanceOfferWithAvailability], resp.json()) + return parse_obj_as(Tuple[str, List[InstanceOfferWithAvailability]], resp.json()) def create_instance( self, project_name: str, pool_name: str, profile: Profile, requirements: Requirements From cae383a78f8affb43fe8958a8275a476487634d1 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 5 Feb 2024 14:29:26 +0100 Subject: [PATCH 17/47] Fix spot policy mapping --- src/dstack/_internal/server/services/pools.py | 1 + .../server/background/tasks/test_process_submitted_jobs.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index bb0095327..f081e76c1 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -380,6 +380,7 @@ def filter_pool_instances( resources=resources, max_price=profile.max_price, spot={ + None: None, SpotPolicy.AUTO: None, SpotPolicy.SPOT: True, SpotPolicy.ONDEMAND: False, diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 7014b7884..16310057a 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -199,7 +199,7 @@ async def test_transitions_job_with_outdated_retry_to_failed_on_no_capacity( assert not project.default_pool.instances @pytest.mark.asyncio - async def test_job_whith_instance(self, test_db, session: AsyncSession): + async def test_job_with_instance(self, test_db, session: AsyncSession): project = await create_project(session) user = await create_user(session) repo = await create_repo( From 58dab52b7d227b7410433337b6c2f4105320913b Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 5 Feb 2024 16:34:58 +0300 Subject: [PATCH 18/47] fixup! Improve profile.pool_name handling --- src/dstack/_internal/server/services/pools.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index d35fd9340..6d2681525 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -305,22 +305,9 @@ async def add_remote( pool_model = await get_or_create_default_pool_by_name(session, project, profile.pool_name) - pool_name = profile.pool_name + profile.pool_name = pool_model.name if instance_name is None: - instance_name = await generate_instance_name(session, project, pool_name) - - pool = ( - await session.scalars( - select(PoolModel).where( - PoolModel.name == pool_name, - PoolModel.project_id == project.id, - PoolModel.deleted == False, - ) - ) - ).one_or_none() - - if pool is None: - pool = await create_pool_model(session, project, pool_name) + instance_name = await generate_instance_name(session, project, profile.pool_name) gpus = [] if resources.gpu is not None: @@ -343,7 +330,7 @@ async def add_remote( ssh_port=22, dockerized=False, backend_data="", - pool_id=str(pool.id), + pool_id=str(pool_model.id), ssh_proxy=None, ) offer = InstanceOfferWithAvailability( @@ -360,7 +347,7 @@ async def add_remote( im = InstanceModel( name=instance_name, project=project, - pool=pool, + pool=pool_model, status=InstanceStatus.PENDING, job_provisioning_data=local.json(), offer=offer.json(), From e047d4e6c28888b86843a7164147d59f10bf9459 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Tue, 6 Feb 2024 13:16:46 +0300 Subject: [PATCH 19/47] small fixies --- .pre-commit-config.yaml | 7 +-- src/dstack/_internal/cli/commands/pool.py | 10 +++- src/dstack/_internal/cli/commands/run.py | 16 ++---- .../core/backends/kubernetes/compute.py | 8 +-- src/dstack/_internal/core/models/profiles.py | 8 +-- .../_internal/server/background/__init__.py | 6 ++- .../background/tasks/process_finished_jobs.py | 1 + .../server/background/tasks/process_pools.py | 39 +++++++++++--- .../background/tasks/process_running_jobs.py | 2 + .../tasks/process_submitted_jobs.py | 7 ++- ...add_pools.py => b55bd09bf186_add_pools.py} | 7 +-- src/dstack/_internal/server/models.py | 2 + src/dstack/_internal/server/services/pools.py | 4 ++ .../_internal/server/services/runner/ssh.py | 1 + src/dstack/_internal/server/services/runs.py | 54 +++++++++++++------ src/dstack/api/_public/runs.py | 4 +- .../tasks/test_process_submitted_jobs.py | 13 ++--- 17 files changed, 121 insertions(+), 68 deletions(-) rename src/dstack/_internal/server/migrations/versions/{dad000707a2c_add_pools.py => b55bd09bf186_add_pools.py} (96%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f155d038e..addc646d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.2.1 hooks: - - id: ruff - name: ruff autofix F401 - args: ['--fix', '--select', 'F401'] - files: 'process_.*\.py|runs?\.py|pools?\.py' - exclude: 'versions|src/tests' - id: ruff name: ruff common args: ['--fix'] diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 0c7ad55b0..bd90c525e 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -16,7 +16,7 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.profiles import Profile, SpotPolicy +from dstack._internal.core.models.profiles import Profile, SpotPolicy, TerminationPolicy from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager @@ -358,6 +358,12 @@ def _add(self, args: argparse.Namespace) -> None: apply_profile_args(args, profile) profile.pool_name = args.pool_name + # TODO: add full support + termination_policy_idle = 5 * 60 # 5 minutes by default + termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + profile.termination_idle_time = str(termination_policy_idle) + profile.termination_policy = termination_policy + # Add remote instance if args.remote: result = self.api.client.pool.add_remote( @@ -386,7 +392,7 @@ def _add(self, args: argparse.Namespace) -> None: return try: - with console.status("Submitting instance..."): + with console.status("Creating instance..."): self.api.runs.create_instance(pool_name, profile, requirements) except ServerClientError as e: raise CLIError(e.msg) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index 54ad6c424..f33d9453a 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -17,11 +17,7 @@ from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType -from dstack._internal.core.models.profiles import ( - DEFAULT_POOL_NAME, - CreationPolicy, - TerminationPolicy, -) +from dstack._internal.core.models.profiles import CreationPolicy, TerminationPolicy from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import parse_pretty_duration @@ -118,7 +114,7 @@ def _command(self, args: argparse.Namespace): self._parser.print_help() return - termination_policy_idle = 5 * 60 + termination_policy_idle = 5 * 60 # 5 minutes by default termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE if args.idle_duration is not None: @@ -135,14 +131,12 @@ def _command(self, args: argparse.Namespace): console.print( f'[{colors["warning"]}]If the flag --reuse is set, the argument --idle-duration will be skipped[/]' ) - termination_policy_idle = None termination_policy = TerminationPolicy.DONT_DESTROY if args.instance_name is not None and termination_policy_idle is not None: console.print( f'[{colors["warning"]}]--idle-duration won\'t be applied to the instance {args.instance_name!r}[/]' ) - termination_policy_idle = None termination_policy = TerminationPolicy.DONT_DESTROY super()._command(args) @@ -165,8 +159,6 @@ def _command(self, args: argparse.Namespace): known, unknown = parser.parse_known_args(args.unknown) configurator.apply(known, unknown, conf) - pool_name = DEFAULT_POOL_NAME if args.pool_name is None else args.pool_name - with console.status("Getting run plan..."): run_plan = self.api.runs.get_plan( configuration=conf, @@ -179,11 +171,11 @@ def _command(self, args: argparse.Namespace): max_price=profile.max_price, working_dir=args.working_dir, run_name=args.run_name, - pool_name=pool_name, + pool_name=args.pool_name, instance_name=args.instance_name, creation_policy=creation_policy, termination_policy=termination_policy, - termination_policy_idle=f"{termination_policy_idle}s", + termination_policy_idle=termination_policy_idle, ) except ConfigurationError as e: raise CLIError(str(e)) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index b21d9b4b7..2007a7d6b 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -184,13 +184,7 @@ def terminate_instance( if e.status != 404: raise - def create_gateway( - self, - instance_name: str, - ssh_key_pub: str, - region: str, - project_id: str, - ) -> LaunchedGatewayInfo: + def create_gateway(self, instance_name: str, ssh_key_pub: str, region: str, project_id: str): # Gateway creation is currently limited to Kubernetes with Load Balancer support. # If the cluster does not support Load Balancer, the service will be provisioned but # the external IP/hostname will never be allocated. diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 3a29d5c31..e36bafeab 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -124,9 +124,11 @@ class Profile(ForbidExtra): _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration ) - _validate_termination_idle_time = validator( - "termination_idle_time", pre=True, allow_reuse=True - )(parse_max_duration) + + # TODO: fix deserialization + # _validate_termination_idle_time = validator( + # "termination_idle_time", pre=True, allow_reuse=True + # )(parse_max_duration) class ProfilesConfig(ForbidExtra): diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 540f8df20..770dd53fa 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -3,7 +3,10 @@ from dstack._internal.server.background.tasks.process_finished_jobs import process_finished_jobs from dstack._internal.server.background.tasks.process_pending_jobs import process_pending_jobs -from dstack._internal.server.background.tasks.process_pools import process_pools +from dstack._internal.server.background.tasks.process_pools import ( + process_pools, + terminate_idle_instance, +) from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs @@ -20,5 +23,6 @@ def start_background_tasks() -> AsyncIOScheduler: _scheduler.add_job(process_finished_jobs, IntervalTrigger(seconds=2)) _scheduler.add_job(process_pending_jobs, IntervalTrigger(seconds=10)) _scheduler.add_job(process_pools, IntervalTrigger(seconds=10)) + _scheduler.add_job(terminate_idle_instance, IntervalTrigger(seconds=10)) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py index a693bf629..878dea75c 100644 --- a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py @@ -101,6 +101,7 @@ async def _process_job(job_id): if job_model.instance is not None: job_model.used_instance_id = job_model.instance.id job_model.instance.status = InstanceStatus.READY + job_model.instance.last_job_processed_at = get_current_datetime() job_model.instance = None logger.info(*job_log("marked as removed", job_model)) except Exception as e: diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 06d062316..ef12a3ded 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -99,6 +99,9 @@ async def terminate(instance_id: UUID) -> None: .options(joinedload(InstanceModel.project)) ) ).one() + + # TODO: need lock + jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) # type: ignore[operator] BACKEND_TYPE = jpd.backend backends = await backends_services.get_project_backends(project=instance.project) @@ -110,8 +113,17 @@ async def terminate(instance_id: UUID) -> None: backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data ) + instance.deleted = True + instance.deleted_at = get_current_datetime() + instance.finished_at = get_current_datetime() + instance.status = InstanceStatus.TERMINATED + + logger.info("instance %s terminated", instance.name) + + await session.commit() -async def _terminate_old_instance() -> None: + +async def terminate_idle_instance() -> None: async with get_session_ctx() as session: res = await session.execute( select(InstanceModel) @@ -120,25 +132,40 @@ async def _terminate_old_instance() -> None: InstanceModel.deleted == False, InstanceModel.job == None, # noqa: E711 ) - .options() + .options(joinedload(InstanceModel.project)) ) instances = res.scalars().all() + # TODO: need lock + for instance in instances: - if instance.finished_at is None: + if instance.last_job_processed_at is None: continue delta = datetime.timedelta( seconds=parse_pretty_duration(instance.termination_idle_time) ) - if instance.finished_at + delta > get_current_datetime(): - jpd: JobProvisioningData = parse_raw_as( # type: ignore[operator] + if instance.last_job_processed_at.replace( + tzinfo=datetime.timezone.utc + ) + delta < get_current_datetime().replace(tzinfo=datetime.timezone.utc): + jpd: JobProvisioningData = parse_raw_as( JobProvisioningData, instance.job_provisioning_data - ).backend + ) await terminate_job_provisioning_data_instance( project=instance.project, job_provisioning_data=jpd ) instance.deleted = True instance.deleted_at = get_current_datetime() + instance.finished_at = get_current_datetime() + instance.status = InstanceStatus.TERMINATED + + idle_time = get_current_datetime().replace( + tzinfo=datetime.timezone.utc + ) - instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc) + logger.info( + "instance %s terminated by termination policy: idle time %ss", + instance.name, + str(idle_time.seconds), + ) await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 6b25f7e59..da4e36cb6 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -156,6 +156,7 @@ async def _process_job(job_id: UUID): job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.WAITING_RUNNER_LIMIT_EXCEEDED job_model.used_instance_id = job_model.instance.id + job_model.instance.last_job_processed_at = common_utils.get_current_datetime() job_model.instance = None else: # fails are not acceptable @@ -205,6 +206,7 @@ async def _process_job(job_id: UUID): job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.INTERRUPTED_BY_NO_CAPACITY job_model.used_instance_id = job_model.instance.id + job_model.instance.last_job_processed_at = common_utils.get_current_datetime() job_model.instance = None if job.is_retry_active(): if job_submission.job_provisioning_data.instance_type.resources.spot: diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 3f46e2115..2196cb654 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -90,6 +90,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # check default pool pool = project_model.default_pool if pool is None: + # TODO: get_or_create_default_pool... pools = await list_project_pool_models(session, job_model.project) for pool_item in pools: if pool_item.id == job_model.project.default_pool_id: @@ -168,14 +169,16 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): job_model.status = JobStatus.PROVISIONING im = InstanceModel( - name=job.job_spec.job_name, + name=job.job_spec.job_name, # TODO: make new name project=project_model, pool=pool, + created_at=common_utils.get_current_datetime(), + started_at=common_utils.get_current_datetime(), status=InstanceStatus.BUSY, job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), termination_policy=profile.termination_policy, - termination_idle_time="300", # TODO: fix deserailize + termination_idle_time=str(profile.termination_idle_time), job=job_model, backend=offer.backend, price=offer.price, diff --git a/src/dstack/_internal/server/migrations/versions/dad000707a2c_add_pools.py b/src/dstack/_internal/server/migrations/versions/b55bd09bf186_add_pools.py similarity index 96% rename from src/dstack/_internal/server/migrations/versions/dad000707a2c_add_pools.py rename to src/dstack/_internal/server/migrations/versions/b55bd09bf186_add_pools.py index 405fc2d37..73c107793 100644 --- a/src/dstack/_internal/server/migrations/versions/dad000707a2c_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/b55bd09bf186_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: dad000707a2c +Revision ID: b55bd09bf186 Revises: d3e8af4786fa -Create Date: 2024-02-05 07:42:58.102664 +Create Date: 2024-02-06 08:44:44.235928 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "dad000707a2c" +revision = "b55bd09bf186" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None @@ -93,6 +93,7 @@ def upgrade() -> None: sa.Column("offer", sa.String(length=4000), nullable=False), sa.Column("resource_spec_data", sa.String(length=4000), nullable=True), sa.Column("job_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.Column("last_job_processed_at", sa.DateTime(), nullable=True), sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], name=op.f("fk_instances_job_id_jobs")), sa.ForeignKeyConstraint( ["pool_id"], ["pools.id"], name=op.f("fk_instances_pool_id_pools") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index a603ae05c..d0affc2db 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -280,6 +280,7 @@ class InstanceModel(BaseModel): status: Mapped[InstanceStatus] = mapped_column(Enum(InstanceStatus)) status_message: Mapped[Optional[str]] = mapped_column(String(50)) + # VM started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) @@ -300,6 +301,7 @@ class InstanceModel(BaseModel): # current job job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id")) job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance") + last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) # + # job_id: Optional[FK] (current job) # ip address diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index fc4c33de6..8c75f7469 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -27,6 +27,7 @@ from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, Requirements from dstack._internal.server import settings from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel +from dstack._internal.utils import common as common_utils from dstack._internal.utils import random_names from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger @@ -111,6 +112,7 @@ async def create_pool_model(session: AsyncSession, project: ProjectModel, name: session.add(pool) await session.commit() + await session.refresh(pool) return pool @@ -347,6 +349,8 @@ async def add_remote( name=instance_name, project=project, pool=pool_model, + created_at=common_utils.get_current_datetime(), + started_at=common_utils.get_current_datetime(), status=InstanceStatus.PENDING, job_provisioning_data=local.json(), offer=offer.json(), diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 9ef754243..42bbd53d9 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -36,6 +36,7 @@ def wrapper( """ if LOCAL_BACKEND_ENABLED: + # without SSH port_map = {p: p for p in ports} return func(*args, ports=port_map, **kwargs) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index b97ac714a..811f19668 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -23,7 +23,7 @@ InstanceOfferWithAvailability, LaunchedInstanceInfo, ) -from dstack._internal.core.models.profiles import CreationPolicy, Profile +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile from dstack._internal.core.models.runs import ( InstanceStatus, Job, @@ -158,7 +158,8 @@ async def get_run_plan_by_requirements( requirements: Requirements, exclude_not_available=False, ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: - backends = await backends_services.get_project_backends(project=project) + backends: List[Backend] = await backends_services.get_project_backends(project=project) + if profile.backends is not None: backends = [b for b in backends if b.TYPE in profile.backends] @@ -258,6 +259,8 @@ async def create_instance( name=instance_name, project=project, pool=pool, + created_at=common_utils.get_current_datetime(), + started_at=common_utils.get_current_datetime(), status=InstanceStatus.STARTING, backend=backend.TYPE, region=instance_offer.region, @@ -290,30 +293,43 @@ async def get_run_plan( user: UserModel, run_spec: RunSpec, ) -> RunPlan: - pool_instances = await get_pool_instances(session, project, run_spec.profile.pool_name) - pool_offers = [] + profile = run_spec.profile - if run_spec.profile.creation_policy == CreationPolicy.REUSE: - for instance in filter_pool_instances( - pool_instances, run_spec.profile, run_spec.configuration.resources - ): - pool_offers.append( - pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) - ) + # TODO: get_or_create_default_pool + + pool_name = profile.pool_name + if profile.pool_name is None: + try: + pool_name = project.default_pool.name + except Exception as e: + pool_name = DEFAULT_POOL_NAME # TODO: get pool from project + + pool_instances = [ + instance + for instance in (await get_pool_instances(session, project, pool_name)) + if not instance.deleted + ] + + pool_offers: List[InstanceOfferWithAvailability] = [] + + for instance in filter_pool_instances( + pool_instances, profile, run_spec.configuration.resources + ): + pool_offers.append(pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer)) backends = await backends_services.get_project_backends(project=project) - if run_spec.profile.backends is not None: - backends = [b for b in backends if b.TYPE in run_spec.profile.backends] + if profile.backends is not None: + backends = [b for b in backends if b.TYPE in profile.backends] run_name = run_spec.run_name # preserve run_name run_spec.run_name = "dry-run" # will regenerate jobs on submission jobs = get_jobs_from_run_spec(run_spec) job_plans = [] - creation_policy = run_spec.profile.creation_policy + creation_policy = profile.creation_policy for job in jobs: - job_offers = [] + job_offers: List[InstanceOfferWithAvailability] = [] job_offers.extend(pool_offers) if creation_policy is None or creation_policy == CreationPolicy.REUSE_OR_CREATE: @@ -335,6 +351,8 @@ async def get_run_plan( ) job_plans.append(job_plan) + run_spec.profile.termination_idle_time = None + run_spec.run_name = run_name # restore run_name run_plan = RunPlan( project_name=project.name, user=user.name, run_spec=run_spec, job_plans=job_plans @@ -368,6 +386,9 @@ async def submit_run( else: await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) + # TODO: fix deserialize + run_spec.profile.termination_idle_time = "300s" + pool = await get_or_create_default_pool_by_name(session, project, run_spec.profile.pool_name) run_model = RunModel( @@ -498,6 +519,9 @@ def run_model_to_run(run_model: RunModel, include_job_submissions: bool = True) run_spec = RunSpec.parse_raw(run_model.run_spec) + # TODO: fix deserialization + run_spec.profile.termination_idle_time = None + latest_job_submission = None if include_job_submissions: latest_job_submission = jobs[0].job_submissions[-1] diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 04d3e2566..3a2297ee5 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -388,7 +388,7 @@ def get_plan( instance_name: Optional[str] = None, creation_policy: Optional[CreationPolicy] = None, termination_policy: Optional[TerminationPolicy] = None, - termination_policy_idle: Optional[Union[int, str]] = None, + termination_policy_idle: Union[int, str] = 5 * 60, ) -> RunPlan: # """ # Get run plan. Same arguments as `submit` @@ -422,7 +422,7 @@ def get_plan( instance_name=instance_name, creation_policy=creation_policy, termination_policy=termination_policy, - termination_idle_time=None, # TODO: fix deserialize + termination_idle_time=None, # TODO: fix deserialization ) run_spec = RunSpec( run_name=run_name, diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index b1dba7f69..fe630eb29 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -17,11 +17,12 @@ from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs from dstack._internal.server.models import JobModel -from dstack._internal.server.services.pools import list_project_pool_models +from dstack._internal.server.services.pools import ( + get_or_create_default_pool_by_name, +) from dstack._internal.server.testing.common import ( create_instance, create_job, - create_pool, create_project, create_repo, create_run, @@ -206,13 +207,7 @@ async def test_job_with_instance(self, test_db, session: AsyncSession): session, project_id=project.id, ) - pools = await list_project_pool_models(session, project) - pool = None - for pool_item in pools: - if pool_item == DEFAULT_POOL_NAME: - pool = pool_item - if pool is None: - pool = await create_pool(session, project) + pool = await get_or_create_default_pool_by_name(session, project, pool_name=None) resources = MakeResources(cpu=2, memory="12GB") await create_instance(session, project, pool, InstanceStatus.READY, resources) await session.refresh(pool) From c618305a4ffc60e26f397f1b61e6d8855b579fed Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 6 Feb 2024 16:44:16 +0100 Subject: [PATCH 20/47] Fix rich formatting, require --name --- src/dstack/_internal/cli/commands/config.py | 6 +- src/dstack/_internal/cli/commands/pool.py | 369 +++++++++--------- src/dstack/_internal/cli/commands/run.py | 6 +- src/dstack/_internal/cli/main.py | 6 +- src/dstack/_internal/cli/utils/common.py | 4 +- src/dstack/_internal/cli/utils/run.py | 6 +- .../core/services/configs/__init__.py | 8 +- 7 files changed, 197 insertions(+), 208 deletions(-) diff --git a/src/dstack/_internal/cli/commands/config.py b/src/dstack/_internal/cli/commands/config.py index c251baa54..6c98b2107 100644 --- a/src/dstack/_internal/cli/commands/config.py +++ b/src/dstack/_internal/cli/commands/config.py @@ -4,7 +4,7 @@ import dstack.api.server from dstack._internal.cli.commands import BaseCommand -from dstack._internal.cli.utils.common import colors, confirm_ask, console +from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.core.errors import CLIError from dstack._internal.core.services.configs import ConfigManager @@ -79,6 +79,4 @@ def _command(self, args: argparse.Namespace): name=args.project, url=args.url, token=args.token, default=set_it_as_default ) config_manager.save() - console.print( - f"Configuration updated at [{colors['code']}]{config_manager.config_filepath}[/{colors['code']}]" - ) + console.print(f"Configuration updated at [code]{config_manager.config_filepath}[/]") diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index bd90c525e..b88ad2c5d 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -9,7 +9,7 @@ apply_profile_args, register_profile_args, ) -from dstack._internal.cli.utils.common import colors, confirm_ask, console +from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.core.errors import CLIError, ServerClientError from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -28,186 +28,6 @@ logger = get_logger(__name__) -def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None: - table = Table(box=None) - table.add_column("NAME") - table.add_column("DEFAULT") - table.add_column("INSTANCES") - if verbose: - table.add_column("CREATED") - - sorted_pools = sorted(pools, key=lambda r: r.name) - for pool in sorted_pools: - default_mark = "default" if pool.default else "" - color = ( - colors["success"] - if pool.total_instances == pool.available_instances - else colors["error"] - ) - health = f"[{color}]{pool.available_instances}/{pool.total_instances}[/{color}]" - row = [pool.name, default_mark, health] - if verbose: - row.append(pretty_date(pool.created_at)) - table.add_row(*row) - - console.print(table) - console.print() - - -def print_instance_table(instances: Sequence[Instance]) -> None: - table = Table(box=None) - table.add_column("INSTANCE NAME") - table.add_column("BACKEND") - table.add_column("INSTANCE TYPE") - table.add_column("STATUS") - table.add_column("PRICE") - - for instance in instances: - status_mark = "success" if instance.status.is_available() else "warning" - color = colors[status_mark] - row = [ - instance.instance_id, - instance.backend, - instance.instance_type.resources.pretty_format(), - f"[{color}]{instance.status}[/{color}]", - f"{instance.price:.02f}", - ] - table.add_row(*row) - - console.print(table) - console.print() - - -def print_offers_table( - pool_name: str, - profile: Profile, - requirements: Requirements, - instance_offers: Sequence[InstanceOfferWithAvailability], - offers_limit: int = 3, -) -> None: - pretty_req = requirements.pretty_format(resources_only=True) - max_price = f"${requirements.max_price:g}" if requirements.max_price else "-" - max_duration = ( - f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-" - ) - - # TODO: improve retry policy - # retry_policy = profile.retry_policy - # retry_policy = ( - # (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes") - # if retry_policy.retry - # else "no" - # ) - - # TODO: improve spot policy - if requirements.spot is None: - spot_policy = "auto" - elif requirements.spot: - spot_policy = "spot" - else: - spot_policy = "on-demand" - - def th(s: str) -> str: - return f"[bold]{s}[/bold]" - - props = Table(box=None, show_header=False) - props.add_column(no_wrap=True) # key - props.add_column() # value - - props.add_row(th("Pool name"), pool_name) - props.add_row(th("Min resources"), pretty_req) - props.add_row(th("Max price"), max_price) - props.add_row(th("Max duration"), max_duration) - props.add_row(th("Spot policy"), spot_policy) - # props.add_row(th("Retry policy"), retry_policy) - - offers_table = Table(box=None) - offers_table.add_column("#") - offers_table.add_column("BACKEND") - offers_table.add_column("REGION") - offers_table.add_column("INSTANCE") - offers_table.add_column("RESOURCES") - offers_table.add_column("SPOT") - offers_table.add_column("PRICE") - offers_table.add_column() - - print_offers = instance_offers[:offers_limit] - - for i, offer in enumerate(print_offers, start=1): - r = offer.instance.resources - - availability = "" - if offer.availability in { - InstanceAvailability.NOT_AVAILABLE, - InstanceAvailability.NO_QUOTA, - }: - availability = offer.availability.value.replace("_", " ").title() - offers_table.add_row( - f"{i}", - offer.backend, - offer.region, - offer.instance.name, - r.pretty_format(), - "yes" if r.spot else "no", - f"${offer.price:g}", - availability, - style=None if i == 1 else colors["secondary"], - ) - if len(print_offers) > offers_limit: - offers_table.add_row("", "...", style=colors["secondary"]) - - console.print(props) - console.print() - if len(print_offers) > 0: - console.print(offers_table) - console.print() - - -def register_resource_args(parser: argparse.ArgumentParser) -> None: - resources_group = parser.add_argument_group("Resources") - resources_group.add_argument( - "--cpu", - help=f"Request the CPU count. Default: '{DEFAULT_CPU_COUNT.min}..'", - dest="cpu", - metavar="SPEC", - default=DEFAULT_CPU_COUNT, - ) - - resources_group.add_argument( - "--memory", - help="Request the size of RAM. " - f"The format is [code]SIZE[/]:[code]MB|GB|TB[/]. Default: {DEFAULT_MEMORY_SIZE.min}", - dest="memory", - metavar="SIZE", - default=DEFAULT_MEMORY_SIZE, - ) - - resources_group.add_argument( - "--shared-memory", - help="Request the size of Shared Memory. The format is [code]SIZE[/]:[code]MB|GB|TB[/].", - dest="shared_memory", - default=None, - metavar="SIZE", - ) - - resources_group.add_argument( - "--gpu", - help="Request GPU for the run. " - "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)", - dest="gpu", - default=None, - metavar="SPEC", - ) - - resources_group.add_argument( - "--disk", - help="Request the size of disk for the run. Example [code]--disk 100GB[/].", - dest="disk", - metavar="SIZE", - default=None, - ) - - class PoolCommand(APIBaseCommand): # type: ignore[misc] NAME = "pool" DESCRIPTION = "Pool management" @@ -232,7 +52,9 @@ def _register(self) -> None: create_parser = subparsers.add_parser( "create", help="Create pool", formatter_class=self._parser.formatter_class ) - create_parser.add_argument("-n", "--name", dest="pool_name", help="The name of the pool") + create_parser.add_argument( + "-n", "--name", dest="pool_name", help="The name of the pool", required=True + ) create_parser.set_defaults(subfunc=self._create) # delete pool @@ -330,9 +152,7 @@ def _remove(self, args: argparse.Namespace) -> None: def _set_default(self, args: argparse.Namespace) -> None: result = self.api.client.pool.set_default(self.api.project, args.pool_name) if not result: - console.print( - f"[{colors['error']}]Failed to set default pool {args.pool_name!r}[/{colors['code']}]" - ) + console.print(f"[error]Failed to set default pool {args.pool_name!r}[/]") def _show(self, args: argparse.Namespace) -> None: instances = self.api.client.pool.show(self.api.project, args.pool_name) @@ -375,9 +195,7 @@ def _add(self, args: argparse.Namespace) -> None: args.remote_port, ) if not result: - console.print( - f"[{colors['error']}]Failed to add remote instance {args.instance_name!r}[/{colors['code']}]" - ) + console.print(f"[error]Failed to add remote instance {args.instance_name!r}[/]") return repo = self.api.repos.load(Path.cwd()) @@ -401,3 +219,178 @@ def _command(self, args: argparse.Namespace) -> None: super()._command(args) # TODO handle 404 and other errors args.subfunc(args) + + +def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None: + table = Table(box=None) + table.add_column("NAME") + table.add_column("DEFAULT") + table.add_column("INSTANCES") + if verbose: + table.add_column("CREATED") + + sorted_pools = sorted(pools, key=lambda r: r.name) + for pool in sorted_pools: + default_mark = "default" if pool.default else "" + style = "success" if pool.total_instances == pool.available_instances else "error" + health = f"[{style}]{pool.available_instances}/{pool.total_instances}[/]" + row = [pool.name, default_mark, health] + if verbose: + row.append(pretty_date(pool.created_at)) + table.add_row(*row) + + console.print(table) + console.print() + + +def print_instance_table(instances: Sequence[Instance]) -> None: + table = Table(box=None) + table.add_column("INSTANCE NAME") + table.add_column("BACKEND") + table.add_column("INSTANCE TYPE") + table.add_column("STATUS") + table.add_column("PRICE") + + for instance in instances: + style = "success" if instance.status.is_available() else "warning" + row = [ + instance.instance_id, + instance.backend, + instance.instance_type.resources.pretty_format(), + f"[{style}]{instance.status}[/]", + f"{instance.price:.02f}", + ] + table.add_row(*row) + + console.print(table) + console.print() + + +def print_offers_table( + pool_name: str, + profile: Profile, + requirements: Requirements, + instance_offers: Sequence[InstanceOfferWithAvailability], + offers_limit: int = 3, +) -> None: + pretty_req = requirements.pretty_format(resources_only=True) + max_price = f"${requirements.max_price:g}" if requirements.max_price else "-" + max_duration = ( + f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-" + ) + + # TODO: improve retry policy + # retry_policy = profile.retry_policy + # retry_policy = ( + # (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes") + # if retry_policy.retry + # else "no" + # ) + + # TODO: improve spot policy + if requirements.spot is None: + spot_policy = "auto" + elif requirements.spot: + spot_policy = "spot" + else: + spot_policy = "on-demand" + + def th(s: str) -> str: + return f"[bold]{s}[/bold]" + + props = Table(box=None, show_header=False) + props.add_column(no_wrap=True) # key + props.add_column() # value + + props.add_row(th("Pool name"), pool_name) + props.add_row(th("Min resources"), pretty_req) + props.add_row(th("Max price"), max_price) + props.add_row(th("Max duration"), max_duration) + props.add_row(th("Spot policy"), spot_policy) + # props.add_row(th("Retry policy"), retry_policy) + + offers_table = Table(box=None) + offers_table.add_column("#") + offers_table.add_column("BACKEND") + offers_table.add_column("REGION") + offers_table.add_column("INSTANCE") + offers_table.add_column("RESOURCES") + offers_table.add_column("SPOT") + offers_table.add_column("PRICE") + offers_table.add_column() + + print_offers = instance_offers[:offers_limit] + + for i, offer in enumerate(print_offers, start=1): + r = offer.instance.resources + + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + }: + availability = offer.availability.value.replace("_", " ").title() + offers_table.add_row( + f"{i}", + offer.backend, + offer.region, + offer.instance.name, + r.pretty_format(), + "yes" if r.spot else "no", + f"${offer.price:g}", + availability, + style=None if i == 1 else "secondary", + ) + if len(print_offers) > offers_limit: + offers_table.add_row("", "...", style="secondary") + + console.print(props) + console.print() + if len(print_offers) > 0: + console.print(offers_table) + console.print() + + +def register_resource_args(parser: argparse.ArgumentParser) -> None: + resources_group = parser.add_argument_group("Resources") + resources_group.add_argument( + "--cpu", + help=f"Request the CPU count. Default: {DEFAULT_CPU_COUNT}", + dest="cpu", + metavar="SPEC", + default=DEFAULT_CPU_COUNT, + ) + + resources_group.add_argument( + "--memory", + help="Request the size of RAM. " + f"The format is [code]SIZE[/]:[code]MB|GB|TB[/]. Default: {DEFAULT_MEMORY_SIZE}", + dest="memory", + metavar="SIZE", + default=DEFAULT_MEMORY_SIZE, + ) + + resources_group.add_argument( + "--shared-memory", + help="Request the size of Shared Memory. The format is [code]SIZE[/]:[code]MB|GB|TB[/].", + dest="shared_memory", + default=None, + metavar="SIZE", + ) + + resources_group.add_argument( + "--gpu", + help="Request GPU for the run. " + "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)", + dest="gpu", + default=None, + metavar="SPEC", + ) + + resources_group.add_argument( + "--disk", + help="Request the size of disk for the run. Example [code]--disk 100GB[/].", + dest="disk", + metavar="SIZE", + default=None, + ) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index f33d9453a..2608d8638 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -13,7 +13,7 @@ BaseRunConfigurator, run_configurators_mapping, ) -from dstack._internal.cli.utils.common import colors, confirm_ask, console +from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType @@ -129,13 +129,13 @@ def _command(self, args: argparse.Namespace): if creation_policy == CreationPolicy.REUSE and termination_policy_idle is not None: console.print( - f'[{colors["warning"]}]If the flag --reuse is set, the argument --idle-duration will be skipped[/]' + "[warning]If the flag --reuse is set, the argument --idle-duration will be skipped[/]" ) termination_policy = TerminationPolicy.DONT_DESTROY if args.instance_name is not None and termination_policy_idle is not None: console.print( - f'[{colors["warning"]}]--idle-duration won\'t be applied to the instance {args.instance_name!r}[/]' + f"[warning]--idle-duration won't be applied to the instance {args.instance_name!r}[/]" ) termination_policy = TerminationPolicy.DONT_DESTROY diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 295bc2136..a6afa33a7 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -11,7 +11,7 @@ from dstack._internal.cli.commands.run import RunCommand from dstack._internal.cli.commands.server import ServerCommand from dstack._internal.cli.commands.stop import StopCommand -from dstack._internal.cli.utils.common import colors, console +from dstack._internal.cli.utils.common import _colors, console from dstack._internal.cli.utils.updates import check_for_updates from dstack._internal.core.errors import ClientError, CLIError from dstack._internal.utils.logging import get_logger @@ -22,8 +22,8 @@ def main(): RichHelpFormatter.usage_markup = True - RichHelpFormatter.styles["code"] = colors["code"] - RichHelpFormatter.styles["argparse.args"] = colors["code"] + RichHelpFormatter.styles["code"] = _colors["code"] + RichHelpFormatter.styles["argparse.args"] = _colors["code"] RichHelpFormatter.styles["argparse.groups"] = "bold grey74" RichHelpFormatter.styles["argparse.text"] = "grey74" diff --git a/src/dstack/_internal/cli/utils/common.py b/src/dstack/_internal/cli/utils/common.py index 5bec4dd17..9eb1dba98 100644 --- a/src/dstack/_internal/cli/utils/common.py +++ b/src/dstack/_internal/cli/utils/common.py @@ -11,7 +11,7 @@ from dstack._internal.core.errors import CLIError, DstackError -colors = { +_colors = { "secondary": "grey58", "success": "green", "warning": "yellow", @@ -19,7 +19,7 @@ "code": "bold sea_green3", } -console = Console(theme=Theme(colors)) +console = Console(theme=Theme(_colors)) def cli_error(e: DstackError) -> CLIError: diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 8369394e1..cac447fdb 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -2,7 +2,7 @@ from rich.table import Table -from dstack._internal.cli.utils.common import colors, console +from dstack._internal.cli.utils.common import console from dstack._internal.core.models.instances import InstanceAvailability, InstanceType from dstack._internal.core.models.runs import RunPlan from dstack._internal.utils.common import pretty_date @@ -78,10 +78,10 @@ def th(s: str) -> str: "yes" if r.spot else "no", f"${offer.price:g}", availability, - style=None if i == 1 else colors["secondary"], + style=None if i == 1 else "secondary", ) if job_plan.total_offers > len(job_plan.offers): - offers.add_row("", "...", style=colors["secondary"]) + offers.add_row("", "...", style="secondary") console.print(props) console.print() diff --git a/src/dstack/_internal/core/services/configs/__init__.py b/src/dstack/_internal/core/services/configs/__init__.py index f4cd82164..c7005066f 100644 --- a/src/dstack/_internal/core/services/configs/__init__.py +++ b/src/dstack/_internal/core/services/configs/__init__.py @@ -8,7 +8,7 @@ from pydantic import ValidationError from rich import print -from dstack._internal.cli.utils.common import colors, confirm_ask +from dstack._internal.cli.utils.common import confirm_ask from dstack._internal.core.models.config import GlobalConfig, ProjectConfig, RepoConfig from dstack._internal.core.models.repos.base import RepoType from dstack._internal.utils.common import get_dstack_dir @@ -127,9 +127,7 @@ def update_default_project( ( default_project is None or default - or confirm_ask( - f"Update the default project in [{colors['code']}]{config_dir}[/{colors['code']}]?" - ) + or confirm_ask(f"Update the default project in [code]{config_dir}[/]?") ) if not no_default else False @@ -139,4 +137,4 @@ def update_default_project( name=project_name, url=url, token=token, default=set_it_as_default ) config_manager.save() - print(f"Configuration updated at [{colors['code']}]{config_dir}[/{colors['code']}]") + print(f"Configuration updated at [code]{config_dir}[/]") From 9393ff6090fc1dc8ec8e650f15a65c7c0ab49828 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 6 Feb 2024 17:35:13 +0100 Subject: [PATCH 21/47] Filter out deleted instances in pools --- src/dstack/_internal/cli/commands/pool.py | 2 +- src/dstack/_internal/server/routers/pools.py | 2 +- src/dstack/_internal/server/services/pools.py | 32 +++++++------------ src/dstack/_internal/server/services/runs.py | 2 +- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index b88ad2c5d..7e268cab9 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -257,7 +257,7 @@ def print_instance_table(instances: Sequence[Instance]) -> None: instance.instance_id, instance.backend, instance.instance_type.resources.pretty_format(), - f"[{style}]{instance.status}[/]", + f"[{style}]{instance.status.value}[/]", f"{instance.price:.02f}", ] table.add_row(*row) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index d5422a6cb..82186c9fb 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -91,7 +91,7 @@ async def create_pool( @router.post("/show") # type: ignore[misc] -async def how_pool( +async def show_pool( body: schemas.CreatePoolRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 8c75f7469..f0d7b6687 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -82,8 +82,13 @@ async def get_or_create_default_pool_by_name( def pool_model_to_pool(pool_model: PoolModel) -> Pool: - total = len(pool_model.instances) - available = sum(instance.status.is_available() for instance in pool_model.instances) + total = 0 + available = 0 + for instance in pool_model.instances: + if not instance.deleted: + total += 1 + if instance.status.is_available(): + available += 1 return Pool( name=pool_model.name, default=pool_model.project.default_pool_id == pool_model.id, @@ -165,7 +170,7 @@ async def remove_instance( instance.status = InstanceStatus.TERMINATING terminated = True if not terminated: - logger.warning("Couldn't fined instance to terminate") + logger.warning("Couldn't find instance to terminate") await session.commit() @@ -227,33 +232,20 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: async def show_pool( session: AsyncSession, project: ProjectModel, pool_name: str ) -> Sequence[Instance]: - pool = ( - await session.scalars( - select(PoolModel).where( - PoolModel.name == pool_name, - PoolModel.project_id == project.id, - PoolModel.deleted == False, - ) - ) - ).one_or_none() - if pool is not None: - instances = [instance_model_to_instance(i) for i in pool.instances] - return instances - else: - return [] + """Show active instances in the pool. If the pool doesn't exist, return an empty list.""" + pool_instances = await get_pool_instances(session, project, pool_name) + return [instance_model_to_instance(i) for i in pool_instances if not i.deleted] async def get_pool_instances( session: AsyncSession, project: ProjectModel, pool_name: str ) -> List[InstanceModel]: res = await session.execute( - select(PoolModel) - .where( + select(PoolModel).where( PoolModel.name == pool_name, PoolModel.project_id == project.id, PoolModel.deleted == False, ) - .options(joinedload(PoolModel.instances)) ) result = res.unique().scalars().one_or_none() if result is None: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 811f19668..43c70e1b9 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -301,7 +301,7 @@ async def get_run_plan( if profile.pool_name is None: try: pool_name = project.default_pool.name - except Exception as e: + except Exception: pool_name = DEFAULT_POOL_NAME # TODO: get pool from project pool_instances = [ From 2cfc268900742d058e016dfb74207ca82e61185b Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 7 Feb 2024 08:36:09 +0300 Subject: [PATCH 22/47] Update runner, fix status --- .github/workflows/build.yml | 1 + .pre-commit-config.yaml | 2 +- runner/cmd/runner/cmd.go | 2 +- runner/cmd/runner/main.go | 6 +- runner/cmd/runner/version.go | 2 +- runner/cmd/shim/main.go | 7 +- runner/internal/runner/api/http.go | 1 + runner/internal/runner/api/server.go | 6 +- runner/internal/schemas/schemas.go | 1 + runner/internal/shim/api/http.go | 1 + runner/internal/shim/api/schemas.go | 1 + runner/internal/shim/api/server.go | 6 +- runner/internal/shim/runner.go | 73 +++++++++++-------- runner/internal/shim/subprocess.go | 28 ------- .../_internal/core/backends/base/compute.py | 33 ++++++--- src/dstack/_internal/core/models/pools.py | 6 +- src/dstack/_internal/core/models/runs.py | 1 - .../background/tasks/process_finished_jobs.py | 38 +++------- .../server/background/tasks/process_pools.py | 6 +- .../background/tasks/process_running_jobs.py | 1 + .../tasks/process_submitted_jobs.py | 19 +++-- src/dstack/_internal/server/schemas/pools.py | 12 +-- src/dstack/_internal/server/schemas/runner.py | 1 + .../server/services/jobs/__init__.py | 11 +-- src/dstack/_internal/server/services/pools.py | 9 +-- src/dstack/_internal/server/services/runs.py | 8 +- src/dstack/_internal/server/testing/common.py | 6 +- src/dstack/api/server/_pools.py | 6 +- .../tasks/test_process_running_jobs.py | 8 +- .../tasks/test_process_submitted_jobs.py | 3 +- .../tasks/test_process_terminating_jobs.py | 1 - .../_internal/server/services/test_pools.py | 6 +- 32 files changed, 145 insertions(+), 167 deletions(-) delete mode 100644 runner/internal/shim/subprocess.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5cc4837a0..824e790b8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -112,6 +112,7 @@ jobs: VERSION=$((${{ github.run_number }} + ${{ env.BUILD_INCREMENT }})) go build -ldflags "-X '$REPO_NAME/runner/cmd/runner/version.Version=$VERSION' -extldflags '-static'" -o dstack-runner-$GOOS-$GOARCH $REPO_NAME/runner/cmd/runner go build -ldflags "-X '$REPO_NAME/runner/cmd/shim/version.Version=$VERSION' -extldflags '-static'" -o dstack-shim-$GOOS-$GOARCH $REPO_NAME/runner/cmd/shim + echo $VERSION - uses: actions/upload-artifact@v3 with: name: dstack-runner diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index addc646d2..d49ce9a07 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,4 +13,4 @@ repos: args: ['--strict', '--follow-imports=skip', '--ignore-missing-imports', '--python-version=3.8'] files: '.*pools?\.py' exclude: 'versions|src/tests' - additional_dependencies: [types-PyYAML, types-requests, pydantic, sqlalchemy] + additional_dependencies: [types-PyYAML, types-requests, pydantic<2, sqlalchemy] diff --git a/runner/cmd/runner/cmd.go b/runner/cmd/runner/cmd.go index a22f36360..233862835 100644 --- a/runner/cmd/runner/cmd.go +++ b/runner/cmd/runner/cmd.go @@ -56,7 +56,7 @@ func App() { }, }, Action: func(c *cli.Context) error { - err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel) + err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel, Version) if err != nil { return cli.Exit(err, 1) } diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 974062484..9c83f8f17 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -18,9 +18,9 @@ func main() { App() } -func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int) error { +func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int, version string) error { if err := os.MkdirAll(tempDir, 0755); err != nil { - return tracerr.Errorf("Failed to create temp directory^ %w", err) + return tracerr.Errorf("Failed to create temp directory: %w", err) } defaultLogFile, err := log.CreateAppendFile(filepath.Join(tempDir, "default.log")) @@ -37,7 +37,7 @@ func start(tempDir string, homeDir string, workingDir string, httpPort int, logL log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) - server := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort)) + server := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort), version) log.Trace(context.TODO(), "Starting API server", "port", httpPort) if err := server.Run(); err != nil { diff --git a/runner/cmd/runner/version.go b/runner/cmd/runner/version.go index 7b2b1de54..788aadab0 100644 --- a/runner/cmd/runner/version.go +++ b/runner/cmd/runner/version.go @@ -1,4 +1,4 @@ package main // Version A default build-time variable. The value is overridden via ldflags. -var Version = "0.0.1.dev1" +var Version = "0.0.1.dev2" diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 7ab960775..fa7fd6c19 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -22,7 +22,7 @@ func main() { app := &cli.App{ Name: "dstack-shim", - Usage: "Starts dstack-runner or docker container. Kills the VM on exit.", + Usage: "Starts dstack-runner or docker container.", Version: Version, Flags: []cli.Flag{ /* Shim Parameters */ @@ -94,10 +94,9 @@ func main() { }, Action: func(c *cli.Context) error { if args.Runner.BinaryPath == "" { - if err := args.Download("linux"); err != nil { + if err := args.DownloadRunner(); err != nil { return cli.Exit(err, 1) } - defer func() { _ = os.Remove(args.Runner.BinaryPath) }() } args.Runner.TempDir = "/tmp/runner" @@ -119,7 +118,7 @@ func main() { } address := fmt.Sprintf(":%d", args.Shim.HTTPPort) - shimServer := api.NewShimServer(address, dockerRunner) + shimServer := api.NewShimServer(address, dockerRunner, Version) defer func() { shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 936085983..74da7225c 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -20,6 +20,7 @@ func (s *Server) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) ( defer s.executor.RUnlock() return &schemas.HealthcheckResponse{ Service: "dstack-runner", + Version: s.version, }, nil } diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 5d4952246..9c469e070 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -30,9 +30,11 @@ type Server struct { executor executor.Executor cancelRun context.CancelFunc + + version string } -func NewServer(tempDir string, homeDir string, workingDir string, address string) *Server { +func NewServer(tempDir string, homeDir string, workingDir string, address string, version string) *Server { mux := http.NewServeMux() s := &Server{ srv: &http.Server{ @@ -51,6 +53,8 @@ func NewServer(tempDir string, homeDir string, workingDir string, address string logsWaitDuration: 30 * time.Second, executor: executor.NewRunExecutor(tempDir, homeDir, workingDir), + + version: version, } mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.healthcheckGetHandler)) mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.submitPostHandler)) diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index bd3150849..3253d840c 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -81,6 +81,7 @@ type Gateway struct { type HealthcheckResponse struct { Service string `json:"service"` + Version string `json:"version"` } func (d *RepoData) FormatURL(format string) string { diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index 1936d14ed..5508dfebf 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -16,6 +16,7 @@ func (s *ShimServer) HealthcheckGetHandler(w http.ResponseWriter, r *http.Reques return &HealthcheckResponse{ Service: "dstack-shim", + Version: s.version, }, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index a6fefe441..8f86b8a21 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -10,6 +10,7 @@ type DockerTaskBody struct { type HealthcheckResponse struct { Service string `json:"service"` + Version string `json:"version"` } type PullResponse struct { diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index f6fb77036..ce1bcfd59 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -19,9 +19,11 @@ type ShimServer struct { mu sync.RWMutex runner TaskRunner + + version string } -func NewShimServer(address string, runner TaskRunner) *ShimServer { +func NewShimServer(address string, runner TaskRunner, version string) *ShimServer { mux := http.NewServeMux() s := &ShimServer{ HttpServer: &http.Server{ @@ -30,6 +32,8 @@ func NewShimServer(address string, runner TaskRunner) *ShimServer { }, runner: runner, + + version: version, } mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.SubmitPostHandler)) mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.HealthcheckGetHandler)) diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go index 3cbf1ce64..bf8ae6edb 100644 --- a/runner/internal/shim/runner.go +++ b/runner/internal/shim/runner.go @@ -6,7 +6,6 @@ import ( "log" "net/http" "os" - rt "runtime" "strconv" "strings" @@ -27,16 +26,17 @@ func (c *CLIArgs) GetDockerCommands() []string { } } -func (c *CLIArgs) Download(osName string) error { - tempFile, err := os.CreateTemp("", "dstack-runner") +func (c *CLIArgs) DownloadRunner() error { + url := makeDownloadRunnerUrl(c.Runner.Version, c.Runner.DevChannel) + + runnerBinaryPath, err := downloadRunner(url) if err != nil { return gerrors.Wrap(err) } - if err = tempFile.Close(); err != nil { - return gerrors.Wrap(err) - } - c.Runner.BinaryPath = tempFile.Name() - return gerrors.Wrap(downloadRunner(c.Runner.Version, c.Runner.DevChannel, osName, c.Runner.BinaryPath)) + + c.Runner.BinaryPath = runnerBinaryPath + + return nil } func (c *CLIArgs) getRunnerArgs() []string { @@ -50,42 +50,53 @@ func (c *CLIArgs) getRunnerArgs() []string { } } -func downloadRunner(runnerVersion string, useDev bool, osName string, path string) error { - // darwin-amd64 - // darwin-arm64 - // linux-386 - // linux-amd64 - archName := rt.GOARCH - if osName == "linux" && archName == "arm64" { - archName = "amd64" - } - var url string - if useDev { - url = fmt.Sprintf(DstackRunnerURL, DstackStagingBucket, runnerVersion, osName, archName) - } else { - url = fmt.Sprintf(DstackRunnerURL, DstackReleaseBucket, runnerVersion, osName, archName) +func makeDownloadRunnerUrl(version string, staging bool) string { + bucket := DstackReleaseBucket + if staging { + bucket = DstackStagingBucket } - file, err := os.Create(path) + osName := "linux" + archName := "amd64" + + url := fmt.Sprintf(DstackRunnerURL, bucket, version, osName, archName) + return url +} + +func downloadRunner(url string) (string, error) { + tempFile, err := os.CreateTemp("", "dstack-runner") if err != nil { - return gerrors.Wrap(err) + return "", gerrors.Wrap(err) } - defer func() { _ = file.Close() }() + defer func() { + err := tempFile.Close() + if err != nil { + log.Printf("close file error: %s\n", err) + } + }() log.Printf("Downloading runner from %s\n", url) resp, err := http.Get(url) if err != nil { - return gerrors.Wrap(err) + return "", gerrors.Wrap(err) } - defer func() { _ = resp.Body.Close() }() + defer func() { + err := resp.Body.Close() + log.Printf("close body error: %s\n", err) + }() + if resp.StatusCode != http.StatusOK { - return gerrors.Newf("unexpected status code: %s", resp.Status) + return "", gerrors.Newf("unexpected status code: %s", resp.Status) } - _, err = io.Copy(file, resp.Body) + _, err = io.Copy(tempFile, resp.Body) if err != nil { - return gerrors.Wrap(err) + return "", gerrors.Wrap(err) + } + + if err := tempFile.Chmod(0755); err != nil { + return "", gerrors.Wrap(err) } - return gerrors.Wrap(file.Chmod(0755)) + return tempFile.Name(), nil } diff --git a/runner/internal/shim/subprocess.go b/runner/internal/shim/subprocess.go deleted file mode 100644 index 38268d371..000000000 --- a/runner/internal/shim/subprocess.go +++ /dev/null @@ -1,28 +0,0 @@ -package shim - -import ( - "github.com/dstackai/dstack/runner/internal/gerrors" - "os" - "path/filepath" - rt "runtime" -) - -func RunSubprocess(httpPort int, logLevel int, runnerVersion string, useDev bool) error { - userHomeDir, err := os.UserHomeDir() - if err != nil { - return gerrors.Wrap(err) - } - runnerPath := filepath.Join(userHomeDir, ".dstack/dstack-runner") - if err = os.MkdirAll(filepath.Dir(runnerPath), 0755); err != nil { - return gerrors.Wrap(err) - } - - err = downloadRunner(runnerVersion, useDev, rt.GOOS, runnerPath) - if err != nil { - return gerrors.Wrap(err) - } - // todo create temporary, home and working dirs - // todo start runner - // todo wait till runner completes - return nil -} diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index fb5dd92d6..86a0568c8 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -150,12 +150,16 @@ def get_cloud_config(**config) -> str: def get_dstack_shim(build: str) -> List[str]: - bucket = "dstack-runner-downloads-stgn" - if settings.DSTACK_VERSION is not None: - bucket = "dstack-runner-downloads" + # TODO: use official url + # bucket = "dstack-runner-downloads-stgn" + # if settings.DSTACK_VERSION is not None: + # bucket = "dstack-runner-downloads" + # url =f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" + + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" return [ - f'sudo curl --output /usr/local/bin/dstack-shim "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"', + f'sudo --connect-timeout 120 --output /usr/local/bin/dstack-shim "{url}"', "sudo chmod +x /usr/local/bin/dstack-shim", ] @@ -218,13 +222,20 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: # start sshd "/usr/sbin/sshd -p 10022 -o PermitUserEnvironment=yes", ] - build = get_dstack_runner_version() + runner = "/usr/local/bin/dstack-runner" - bucket = "dstack-runner-downloads-stgn" - if settings.DSTACK_VERSION is not None: - bucket = "dstack-runner-downloads" + + # TODO: use official url + # build = get_dstack_runner_version() + # bucket = "dstack-runner-downloads-stgn" + # if settings.DSTACK_VERSION is not None: + # bucket = "dstack-runner-downloads" + # url = f'https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64' + + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" + commands += [ - f'curl --output {runner} "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"', + f'curl --connect-timeout 120 --output {runner} "{url}"', f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] @@ -335,7 +346,7 @@ def get_instance_dstack_shim(build: str) -> List[str]: url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" return [ - f'sudo curl --output /usr/local/bin/dstack-shim "{url}"', + f'sudo curl --connect-timeout 120 --output /usr/local/bin/dstack-shim "{url}"', "sudo chmod +x /usr/local/bin/dstack-shim", ] @@ -381,7 +392,7 @@ def get_instance_docker_commands(authorized_keys: List[str]) -> List[str]: url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" commands += [ - f"curl --output {runner} {url}", + f"curl --connect-timeout 120 --output {runner} {url}", f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 5fc4c0b65..a859f9dbc 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -1,13 +1,13 @@ import datetime -from pydantic import BaseModel +from pydantic import BaseModel # type: ignore[attr-defined] from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType from dstack._internal.core.models.runs import InstanceStatus -class Pool(BaseModel): # type: ignore[misc,valid-type] +class Pool(BaseModel): # type: ignore[misc] name: str default: bool created_at: datetime.datetime @@ -15,7 +15,7 @@ class Pool(BaseModel): # type: ignore[misc,valid-type] available_instances: int -class Instance(BaseModel): # type: ignore[misc,valid-type] +class Instance(BaseModel): # type: ignore[misc] backend: BackendType instance_type: InstanceType instance_id: str diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index d745303c9..d14ebe9c4 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -127,7 +127,6 @@ class JobProvisioningData(BaseModel): backend: BackendType instance_type: InstanceType instance_id: str - pool_id: str hostname: str region: str price: float diff --git a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py index 878dea75c..8519dfdb9 100644 --- a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py @@ -1,7 +1,6 @@ from sqlalchemy import or_, select from sqlalchemy.orm import joinedload -from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.runs import InstanceStatus, JobSpec, JobStatus from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import GatewayModel, JobModel @@ -9,10 +8,8 @@ from dstack._internal.server.services.jobs import ( TERMINATING_PROCESSING_JOBS_IDS, TERMINATING_PROCESSING_JOBS_LOCK, - job_model_to_job_submission, ) from dstack._internal.server.services.logging import job_log -from dstack._internal.server.services.pools import get_instances_by_pool_id from dstack._internal.server.utils.common import run_async from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger @@ -54,7 +51,6 @@ async def _process_job(job_id): .options(joinedload(JobModel.run)) ) job_model = res.scalar_one() - job_submission = job_model_to_job_submission(job_model) job_spec = JobSpec.parse_raw(job_model.job_spec_data) if job_spec.gateway is not None: res = await session.execute( @@ -82,30 +78,14 @@ async def _process_job(job_id): logger.debug(*job_log("service is unregistered", job_model)) except Exception as e: logger.warning("failed to unregister service: %s", e) - try: - jpd = job_submission.job_provisioning_data - if jpd is not None: - if jpd.backend == BackendType.LOCAL: - instances = await get_instances_by_pool_id(session, jpd.pool_id) - for instance in instances: - if instance.name == jpd.instance_id: - instance.finished_at = get_current_datetime() - instance.status = InstanceStatus.READY - # else: - # if job_model.instance is not None and job_model.instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE: - # await terminate_job_provisioning_data_instance( - # project=job_model.project, - # job_provisioning_data=job_submission.job_provisioning_data, - # ) - job_model.removed = True - if job_model.instance is not None: - job_model.used_instance_id = job_model.instance.id - job_model.instance.status = InstanceStatus.READY - job_model.instance.last_job_processed_at = get_current_datetime() - job_model.instance = None - logger.info(*job_log("marked as removed", job_model)) - except Exception as e: - job_model.removed = False - logger.error(*job_log("failed to terminate job instance: %s", job_model, e)) + + if job_model.instance is not None: + job_model.used_instance_id = job_model.instance.id + job_model.instance.status = InstanceStatus.READY + job_model.instance.last_job_processed_at = get_current_datetime() + job_model.instance = None + + job_model.removed = True job_model.last_processed_at = get_current_datetime() await session.commit() + logger.info(*job_log("marked as removed", job_model)) diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index ef12a3ded..988b6fca5 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -3,7 +3,7 @@ from typing import Dict from uuid import UUID -from pydantic import parse_raw_as +from pydantic import parse_raw_as # type: ignore[attr-defined] from sqlalchemy import select from sqlalchemy.orm import joinedload @@ -69,7 +69,7 @@ async def check_shim(instance_id: UUID) -> None: ) ).one() ssh_private_key = instance.project.ssh_private_key - job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) # type: ignore[operator] + job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) instance_health = instance_healthcheck(ssh_private_key, job_provisioning_data) @@ -102,7 +102,7 @@ async def terminate(instance_id: UUID) -> None: # TODO: need lock - jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) # type: ignore[operator] + jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) BACKEND_TYPE = jpd.backend backends = await backends_services.get_project_backends(project=instance.project) backend = next((b for b in backends if b.TYPE in BACKEND_TYPE), None) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index da4e36cb6..ee7dc9e47 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -208,6 +208,7 @@ async def _process_job(job_id: UUID): job_model.used_instance_id = job_model.instance.id job_model.instance.last_job_processed_at = common_utils.get_current_datetime() job_model.instance = None + if job.is_retry_active(): if job_submission.job_provisioning_data.instance_type.resources.spot: new_job_model = create_job_model_for_new_submission( diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 2196cb654..d61411e40 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -140,16 +140,22 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): return + run = run_model_to_run(run_model) + job = run.jobs[job_model.job_num] + if profile.creation_policy == CreationPolicy.REUSE: - job_model.status = JobStatus.FAILED - job_model.error_code = JobErrorCode.FAILED_TO_START_DUE_TO_NO_CAPACITY + logger.debug(*job_log("reuse instance failed", job_model)) + if job.is_retry_active(): + logger.debug(*job_log("now is pending because retry is active", job_model)) + job_model.status = JobStatus.PENDING + else: + job_model.status = JobStatus.FAILED + job_model.error_code = JobErrorCode.FAILED_TO_START_DUE_TO_NO_CAPACITY job_model.last_processed_at = common_utils.get_current_datetime() await session.commit() return # create a new cloud instance - run = run_model_to_run(run_model) - job = run.jobs[job_model.job_num] backends = await backends_services.get_project_backends(project=run_model.project) # TODO: create VM (backend.compute().create_instance) @@ -160,11 +166,10 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): backends=backends, project_ssh_public_key=project_model.ssh_public_key, project_ssh_private_key=project_model.ssh_private_key, - pool_id=pool.id, ) if job_provisioning_data is not None and offer is not None: logger.info(*job_log("now is provisioning", job_model)) - job_provisioning_data.pool_id = str(pool.id) + job_model.job_provisioning_data = job_provisioning_data.json() job_model.status = JobStatus.PROVISIONING @@ -205,7 +210,6 @@ async def _run_job( backends: List[Backend], project_ssh_public_key: str, project_ssh_private_key: str, - pool_id: UUID, ) -> Tuple[Optional[JobProvisioningData], Optional[InstanceOfferWithAvailability]]: if run.run_spec.profile.backends is not None: backends = [b for b in backends if b.TYPE in run.run_spec.profile.backends] @@ -264,7 +268,6 @@ async def _run_job( dockerized=launched_instance_info.dockerized, ssh_proxy=launched_instance_info.ssh_proxy, backend_data=launched_instance_info.backend_data, - pool_id=str(pool_id), ) return (job_provisioning_data, offer) diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index 67e04f21c..8b40e710d 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -1,23 +1,23 @@ -from pydantic import BaseModel +from pydantic import BaseModel # type: ignore[attr-defined] -class DeletePoolRequest(BaseModel): # type: ignore[misc,valid-type] +class DeletePoolRequest(BaseModel): # type: ignore[misc] name: str force: bool -class CreatePoolRequest(BaseModel): # type: ignore[misc,valid-type] +class CreatePoolRequest(BaseModel): # type: ignore[misc] name: str -class ShowPoolRequest(BaseModel): # type: ignore[misc,valid-type] +class ShowPoolRequest(BaseModel): # type: ignore[misc] name: str -class RemoveInstanceRequest(BaseModel): # type: ignore[misc,valid-type] +class RemoveInstanceRequest(BaseModel): # type: ignore[misc] pool_name: str instance_name: str -class SetDefaultPoolRequest(BaseModel): # type: ignore[misc,valid-type] +class SetDefaultPoolRequest(BaseModel): # type: ignore[misc] pool_name: str diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index e47d653fb..97672ca40 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -63,6 +63,7 @@ class SubmitBody(BaseModel): class HealthcheckResponse(BaseModel): service: str + version: str class DockerImageBody(BaseModel): diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 62ba9e50a..c60d8be3b 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -106,13 +106,9 @@ async def stop_job( job_submission = job_model_to_job_submission(job_model) if new_status == JobStatus.TERMINATED and job_model.status == JobStatus.RUNNING: try: - await run_async( - _stop_runner, - job_submission, - project.ssh_private_key, - ) + await run_async(_stop_runner, job_submission, project.ssh_private_key) # delay termination for 15 seconds to allow the runner to stop gracefully - # delay_job_instance_termination(job_model) + delay_job_instance_termination(job_model) except SSHError: logger.debug(*job_log("failed to stop runner", job_model)) # process_finished_jobs will terminate the instance in the background @@ -124,8 +120,7 @@ async def stop_job( async def terminate_job_provisioning_data_instance( - project: ProjectModel, - job_provisioning_data: JobProvisioningData, + project: ProjectModel, job_provisioning_data: JobProvisioningData ): backend = await get_project_backend_by_type( project=project, diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 8c75f7469..1a08f2941 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence import gpuhunt -from pydantic import parse_raw_as +from pydantic import parse_raw_as # type: ignore[attr-defined] from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -206,10 +206,10 @@ async def list_deleted_pools( def instance_model_to_instance(instance_model: InstanceModel) -> Instance: - offer: InstanceOfferWithAvailability = parse_raw_as( # type: ignore[operator] + offer: InstanceOfferWithAvailability = parse_raw_as( InstanceOfferWithAvailability, instance_model.offer ) - jpd: JobProvisioningData = parse_raw_as( # type: ignore[operator] + jpd: JobProvisioningData = parse_raw_as( JobProvisioningData, instance_model.job_provisioning_data ) @@ -331,7 +331,6 @@ async def add_remote( ssh_port=22, dockerized=False, backend_data="", - pool_id=str(pool_model.id), ssh_proxy=None, ) offer = InstanceOfferWithAvailability( @@ -403,7 +402,7 @@ def filter_pool_instances( ) query_filter = requirements_to_query_filter(requirements) for instance in candidates: - catalog_item = offer_to_catalog_item(parse_raw_as(InstanceOffer, instance.offer)) # type: ignore[operator] + catalog_item = offer_to_catalog_item(parse_raw_as(InstanceOffer, instance.offer)) if gpuhunt.matches(catalog_item, query_filter): instances.append(instance) return instances diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 811f19668..6fdd524a4 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -239,7 +239,6 @@ async def create_instance( backend=backend.TYPE, instance_type=instance_offer.instance, instance_id=launched_instance_info.instance_id, - pool_id=str(pool.id), hostname=launched_instance_info.ip_address, region=launched_instance_info.region, price=instance_offer.price, @@ -288,10 +287,7 @@ async def create_instance( async def get_run_plan( - session: AsyncSession, - project: ProjectModel, - user: UserModel, - run_spec: RunSpec, + session: AsyncSession, project: ProjectModel, user: UserModel, run_spec: RunSpec ) -> RunPlan: profile = run_spec.profile @@ -301,7 +297,7 @@ async def get_run_plan( if profile.pool_name is None: try: pool_name = project.default_pool.name - except Exception as e: + except Exception: pool_name = DEFAULT_POOL_NAME # TODO: get pool from project pool_instances = [ diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index e57cc8cbc..d8cbaca8c 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -237,7 +237,7 @@ def get_job_provisioning_data() -> JobProvisioningData: ssh_port=22, dockerized=False, backend_data=None, - pool_id="", + ssh_proxy=None, ) @@ -313,8 +313,8 @@ async def create_instance( pool=pool, project=project, status=status, - job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', - offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', + job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "ssh_proxy": null, "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', + offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 2, "memory_mib": 12000, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', resource_spec_data=resources.json(), price=1, region="eu-west", diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index a1d87012c..0d9f1248d 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -1,6 +1,6 @@ from typing import List, Optional -from pydantic import parse_obj_as +from pydantic import parse_obj_as # type: ignore[attr-defined] import dstack._internal.server.schemas.pools as schemas_pools from dstack._internal.core.models.pools import Instance, Pool @@ -13,7 +13,7 @@ class PoolAPIClient(APIClientGroup): # type: ignore[misc] def list(self, project_name: str) -> List[Pool]: resp = self._request(f"/api/project/{project_name}/pool/list") - result: List[Pool] = parse_obj_as(List[Pool], resp.json()) # type: ignore[operator] + result: List[Pool] = parse_obj_as(List[Pool], resp.json()) return result def delete(self, project_name: str, pool_name: str, force: bool) -> None: @@ -27,7 +27,7 @@ def create(self, project_name: str, pool_name: str) -> None: def show(self, project_name: str, pool_name: str) -> List[Instance]: body = schemas_pools.ShowPoolRequest(name=pool_name) resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) - result: List[Instance] = parse_obj_as(List[Instance], resp.json()) # type: ignore[operator] + result: List[Instance] = parse_obj_as(List[Instance], resp.json()) return result def remove(self, project_name: str, pool_name: str, instance_name: str) -> None: diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index c5cacdaf0..d645519e7 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -38,7 +38,7 @@ def get_job_provisioning_data(dockerized: bool) -> JobProvisioningData: ssh_port=22, dockerized=dockerized, backend_data=None, - pool_id="", + ssh_proxy=None, ) @@ -113,7 +113,7 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): ) as RunnerClientMock: runner_client_mock = RunnerClientMock.return_value runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner" + service="dstack-runner", version="0.0.1.dev2" ) await process_running_jobs() RunnerTunnelMock.assert_called_once() @@ -215,7 +215,7 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): "dstack._internal.server.services.runner.client.ShimClient" ) as ShimClientMock: ShimClientMock.return_value.healthcheck.return_value = HealthcheckResponse( - service="dstack-shim" + service="dstack-shim", version="0.0.1.dev2" ) await process_running_jobs() RunnerTunnelMock.assert_called_once() @@ -256,7 +256,7 @@ async def test_pulling_shim(self, test_db, session: AsyncSession): "dstack._internal.server.services.runner.client.ShimClient" ) as ShimClientMock: RunnerTunnelMock.return_value.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner" + service="dstack-runner", version="0.0.1.dev2" ) await process_running_jobs() RunnerTunnelMock.assert_called_once() diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index fe630eb29..498de90c6 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -230,7 +230,6 @@ async def test_job_with_instance(self, test_db, session: AsyncSession): username="root", ssh_port=22, dockerized=False, - pool_id=str(pool.id), backend_data=None, ssh_proxy=None, ) @@ -264,7 +263,7 @@ async def test_job_with_instance(self, test_db, session: AsyncSession): '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": ' '{"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": ' '{"size_mib": 102400}, "description": ""}}, "instance_id": ' - '"running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", ' + '"running_instance.id", "ssh_proxy": null, ' '"hostname": "running_instance.ip", "region": "running_instance.location", ' '"price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, ' '"backend_data": null}' diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 76322fc92..89e96beed 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -60,7 +60,6 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A username="root", ssh_port=22, dockerized=False, - pool_id="", backend_data=None, ssh_proxy=None, ), diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index b14ac142a..77a1832ba 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -88,7 +88,7 @@ def test_convert_instance(): status=InstanceStatus.PENDING, project_id=str(uuid.uuid4()), pool=None, - job_provisioning_data='{"pool_id":"123", "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + job_provisioning_data='{"ssh_proxy":null, "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', ) @@ -124,7 +124,7 @@ async def test_show_pool(session: AsyncSession, test_db): project=project, pool=pool, status=InstanceStatus.PENDING, - job_provisioning_data='{"pool_id":"123", "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + job_provisioning_data='{"ssh_proxy":null, "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', region="eu-west", price=1, @@ -262,7 +262,7 @@ def create_instance(self, *args, **kwargs): assert instance.deleted == False assert instance.deleted_at is None - # assert instance.job_provisioning_data == '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "pool_id": "1b2b4c57-5851-487f-b92e-948f946dfa49", "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}' + # assert instance.job_provisioning_data == '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "ssh_proxy": null, "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}' assert ( instance.offer == '{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}' From 64f3ef88cadae881adc464dac58f5862f2d076d0 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 7 Feb 2024 10:51:38 +0300 Subject: [PATCH 23/47] Improve job statuses --- .../_internal/core/backends/base/compute.py | 8 ++--- .../server/background/tasks/process_pools.py | 34 ++++++++++++------- .../background/tasks/process_running_jobs.py | 1 + 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 86a0568c8..38e3e5e30 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -159,7 +159,7 @@ def get_dstack_shim(build: str) -> List[str]: url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" return [ - f'sudo --connect-timeout 120 --output /usr/local/bin/dstack-shim "{url}"', + f'sudo curl --connect-timeout 60 --max-time 240 --retry 1 --output /usr/local/bin/dstack-shim "{url}"', "sudo chmod +x /usr/local/bin/dstack-shim", ] @@ -235,7 +235,7 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" commands += [ - f'curl --connect-timeout 120 --output {runner} "{url}"', + f'curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} "{url}"', f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] @@ -346,7 +346,7 @@ def get_instance_dstack_shim(build: str) -> List[str]: url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" return [ - f'sudo curl --connect-timeout 120 --output /usr/local/bin/dstack-shim "{url}"', + f'sudo curl --connect-timeout 60 --max-time 240 --retry 1 --output /usr/local/bin/dstack-shim "{url}"', "sudo chmod +x /usr/local/bin/dstack-shim", ] @@ -392,7 +392,7 @@ def get_instance_docker_commands(authorized_keys: List[str]) -> List[str]: url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" commands += [ - f"curl --connect-timeout 120 --output {runner} {url}", + f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} {url}", f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 988b6fca5..39f6c2a51 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -38,6 +38,8 @@ async def process_pools() -> None: InstanceStatus.CREATING, InstanceStatus.STARTING, InstanceStatus.TERMINATING, + InstanceStatus.READY, + InstanceStatus.BUSY, ] ), InstanceModel.id.not_in(PROCESSING_POOL_IDS), @@ -51,7 +53,12 @@ async def process_pools() -> None: try: for inst in instances: - if inst.status in (InstanceStatus.CREATING, InstanceStatus.STARTING): + if inst.status in ( + InstanceStatus.CREATING, + InstanceStatus.STARTING, + InstanceStatus.READY, + InstanceStatus.BUSY, + ): await check_shim(inst.id) if inst.status == InstanceStatus.TERMINATING: await terminate(inst.id) @@ -76,9 +83,13 @@ async def check_shim(instance_id: UUID) -> None: logger.info("check instance %s status: %s", instance.name, instance_health) if instance_health: - instance.status = InstanceStatus.READY - await session.commit() - return + if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING): + instance.status = InstanceStatus.READY + await session.commit() + else: + if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY): + instance.status = InstanceStatus.FAILED + await session.commit() @runner_ssh_tunnel(ports=[client.REMOTE_SHIM_PORT], retries=1) # type: ignore[misc] @@ -139,15 +150,16 @@ async def terminate_idle_instance() -> None: # TODO: need lock for instance in instances: - if instance.last_job_processed_at is None: - continue + last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc) + if instance.last_job_processed_at is not None: + last_time = instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc) + + current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc) delta = datetime.timedelta( seconds=parse_pretty_duration(instance.termination_idle_time) ) - if instance.last_job_processed_at.replace( - tzinfo=datetime.timezone.utc - ) + delta < get_current_datetime().replace(tzinfo=datetime.timezone.utc): + if last_time + delta < current_time: jpd: JobProvisioningData = parse_raw_as( JobProvisioningData, instance.job_provisioning_data ) @@ -159,9 +171,7 @@ async def terminate_idle_instance() -> None: instance.finished_at = get_current_datetime() instance.status = InstanceStatus.TERMINATED - idle_time = get_current_datetime().replace( - tzinfo=datetime.timezone.utc - ) - instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc) + idle_time = current_time - last_time logger.info( "instance %s terminated by termination policy: idle time %ss", instance.name, diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index ee7dc9e47..34d987d88 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -425,6 +425,7 @@ def _process_running( last_job_state = resp.job_states[-1] job_model.status = last_job_state.state if job_model.status == JobStatus.DONE: + job_model.run.status = JobStatus.DONE delay_job_instance_termination(job_model) logger.info(*job_log("now is %s", job_model, job_model.status.value)) return True From f36ca26d3a2caec0bd9e6146612e99934b41aa9e Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 7 Feb 2024 13:17:21 +0300 Subject: [PATCH 24/47] fix hardcode --- src/dstack/_internal/core/backends/base/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 38e3e5e30..d899708e5 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -232,7 +232,7 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: # bucket = "dstack-runner-downloads" # url = f'https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64' - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" + url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" commands += [ f'curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} "{url}"', From 28b0e19925771d4d319c89e17e1c3d868a76d2e6 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Wed, 7 Feb 2024 13:57:40 +0300 Subject: [PATCH 25/47] Fix CI --- runner/internal/runner/api/http_test.go | 38 +-------------- runner/internal/runner/api/submit_test.go | 47 +++++++++++++++++++ .../_internal/core/backends/base/compute.py | 44 +++++++---------- src/dstack/_internal/core/models/pools.py | 2 +- .../server/background/tasks/process_pools.py | 2 +- src/dstack/_internal/server/schemas/pools.py | 2 +- src/dstack/_internal/server/services/pools.py | 2 +- src/dstack/api/server/_pools.py | 2 +- 8 files changed, 72 insertions(+), 67 deletions(-) create mode 100644 runner/internal/runner/api/submit_test.go diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index dbb8ddf42..df5ef8575 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -28,7 +28,7 @@ func TestHealthcheck(t *testing.T) { request := httptest.NewRequest("GET", "/api/healthcheck", nil) responseRecorder := httptest.NewRecorder() - server := api.NewShimServer(":12345", DummyRunner{}) + server := api.NewShimServer(":12345", DummyRunner{}, "0.0.1.dev2") f := common.JSONResponseHandler("GET", server.HealthcheckGetHandler) f(responseRecorder, request) @@ -37,43 +37,9 @@ func TestHealthcheck(t *testing.T) { t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) } - expected := "{\"service\":\"dstack-shim\"}" + expected := "{\"service\":\"dstack-shim\",\"version\":\"0.0.1.dev2\"}" if strings.TrimSpace(responseRecorder.Body.String()) != expected { t.Errorf("Want '%s', got '%s'", expected, responseRecorder.Body.String()) } } - -func TestSubmit(t *testing.T) { - - request := httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) - responseRecorder := httptest.NewRecorder() - - dummyRunner := DummyRunner{} - dummyRunner.State = shim.Pending - - server := api.NewShimServer(":12340", &dummyRunner) - - firstSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) - firstSubmitPost(responseRecorder, request) - - if responseRecorder.Code != 200 { - t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) - } - - t.Logf("%v", responseRecorder.Result()) - - dummyRunner.State = shim.Pulling - - request = httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) - responseRecorder = httptest.NewRecorder() - - secondSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) - secondSubmitPost(responseRecorder, request) - - t.Logf("%v", responseRecorder.Result()) - - if responseRecorder.Code != 409 { - t.Errorf("Want status '%d', got '%d'", 409, responseRecorder.Code) - } -} diff --git a/runner/internal/runner/api/submit_test.go b/runner/internal/runner/api/submit_test.go new file mode 100644 index 000000000..9ff24efd0 --- /dev/null +++ b/runner/internal/runner/api/submit_test.go @@ -0,0 +1,47 @@ +//go:build !race + +package api + +import ( + "net/http/httptest" + "strings" + "testing" + + common "github.com/dstackai/dstack/runner/internal/api" + "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/api" +) + +func TestSubmit(t *testing.T) { + + request := httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) + responseRecorder := httptest.NewRecorder() + + dummyRunner := DummyRunner{} + dummyRunner.State = shim.Pending + + server := api.NewShimServer(":12340", &dummyRunner, "0.0.1.dev2") + + firstSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + firstSubmitPost(responseRecorder, request) + + if responseRecorder.Code != 200 { + t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) + } + + t.Logf("%v", responseRecorder.Result()) + + dummyRunner.State = shim.Pulling + + request = httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) + responseRecorder = httptest.NewRecorder() + + secondSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + secondSubmitPost(responseRecorder, request) + + t.Logf("%v", responseRecorder.Result()) + + if responseRecorder.Code != 409 { + t.Errorf("Want status '%d', got '%d'", 409, responseRecorder.Code) + } +} diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index d899708e5..f6d40a864 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -150,13 +150,11 @@ def get_cloud_config(**config) -> str: def get_dstack_shim(build: str) -> List[str]: - # TODO: use official url - # bucket = "dstack-runner-downloads-stgn" - # if settings.DSTACK_VERSION is not None: - # bucket = "dstack-runner-downloads" - # url =f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" + bucket = "dstack-runner-downloads-stgn" + if settings.DSTACK_VERSION is not None: + bucket = "dstack-runner-downloads" - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" return [ f'sudo curl --connect-timeout 60 --max-time 240 --retry 1 --output /usr/local/bin/dstack-shim "{url}"', @@ -225,14 +223,12 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: runner = "/usr/local/bin/dstack-runner" - # TODO: use official url - # build = get_dstack_runner_version() - # bucket = "dstack-runner-downloads-stgn" - # if settings.DSTACK_VERSION is not None: - # bucket = "dstack-runner-downloads" - # url = f'https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64' + build = get_dstack_runner_version() + bucket = "dstack-runner-downloads-stgn" + if settings.DSTACK_VERSION is not None: + bucket = "dstack-runner-downloads" - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" commands += [ f'curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} "{url}"', @@ -337,13 +333,11 @@ def get_instance_shim_commands( def get_instance_dstack_shim(build: str) -> List[str]: - # TODO: use official build - # bucket = "dstack-runner-downloads-stgn" - # if settings.DSTACK_VERSION is not None: - # bucket = "dstack-runner-downloads" - # url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" + bucket = "dstack-runner-downloads-stgn" + if settings.DSTACK_VERSION is not None: + bucket = "dstack-runner-downloads" - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-shim" + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" return [ f'sudo curl --connect-timeout 60 --max-time 240 --retry 1 --output /usr/local/bin/dstack-shim "{url}"', @@ -382,14 +376,12 @@ def get_instance_docker_commands(authorized_keys: List[str]) -> List[str]: runner = "/usr/local/bin/dstack-runner" - # TODO: use official build - # build = get_dstack_runner_version() - # bucket = "dstack-runner-downloads-stgn" - # if settings.DSTACK_VERSION is not None: - # bucket = "dstack-runner-downloads" - # url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" + build = get_dstack_runner_version() + bucket = "dstack-runner-downloads-stgn" + if settings.DSTACK_VERSION is not None: + bucket = "dstack-runner-downloads" - url = "https://da344481-89d9-4f32-bd6a-8e0b47b1eb8c.selstorage.ru/dstack-runner" + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" commands += [ f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} {url}", diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index a859f9dbc..68bce22b0 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -1,6 +1,6 @@ import datetime -from pydantic import BaseModel # type: ignore[attr-defined] +from pydantic import BaseModel from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 39f6c2a51..d23dd2ead 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -3,7 +3,7 @@ from typing import Dict from uuid import UUID -from pydantic import parse_raw_as # type: ignore[attr-defined] +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.orm import joinedload diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index 8b40e710d..bbf006032 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel # type: ignore[attr-defined] +from pydantic import BaseModel class DeletePoolRequest(BaseModel): # type: ignore[misc] diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 4d8264948..9c94b6a78 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence import gpuhunt -from pydantic import parse_raw_as # type: ignore[attr-defined] +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 0d9f1248d..cde962a06 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -1,6 +1,6 @@ from typing import List, Optional -from pydantic import parse_obj_as # type: ignore[attr-defined] +from pydantic import parse_obj_as import dstack._internal.server.schemas.pools as schemas_pools from dstack._internal.core.models.pools import Instance, Pool From 03a80d0ed908f3dc6675aa2b51091a73d1a14f00 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 8 Feb 2024 12:43:23 +0500 Subject: [PATCH 26/47] Fix currency missing in dstack pool show PRICE --- src/dstack/_internal/cli/commands/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 7e268cab9..4485f9604 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -258,7 +258,7 @@ def print_instance_table(instances: Sequence[Instance]) -> None: instance.backend, instance.instance_type.resources.pretty_format(), f"[{style}]{instance.status.value}[/]", - f"{instance.price:.02f}", + f"${instance.price:.4}", ] table.add_row(*row) From 447ed533caaea66a8898a4a39a59053f1cd826ce Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Thu, 8 Feb 2024 11:11:30 +0300 Subject: [PATCH 27/47] Fix ssh keys. Fix review --- .pre-commit-config.yaml | 5 +++ src/dstack/_internal/cli/commands/pool.py | 16 ++++++-- .../_internal/core/backends/base/compute.py | 2 +- src/dstack/_internal/core/models/pools.py | 7 +++- .../_internal/core/services/ssh/ports.py | 3 ++ .../server/background/tasks/process_pools.py | 8 ++-- .../tasks/process_submitted_jobs.py | 2 +- .../5395b4ae6c3b_add_pools_fix_optional.py | 41 +++++++++++++++++++ src/dstack/_internal/server/models.py | 2 +- src/dstack/_internal/server/routers/pools.py | 4 +- src/dstack/_internal/server/routers/runs.py | 1 + src/dstack/_internal/server/schemas/pools.py | 5 ++- src/dstack/_internal/server/schemas/runs.py | 2 + src/dstack/_internal/server/services/pools.py | 27 +++++++++--- src/dstack/_internal/server/services/runs.py | 9 ++-- src/dstack/api/_public/runs.py | 9 +++- src/dstack/api/server/_pools.py | 6 ++- src/dstack/api/server/_runs.py | 10 ++++- .../_internal/server/routers/test_runs.py | 4 +- 19 files changed, 133 insertions(+), 30 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d49ce9a07..08bd08cce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,8 @@ repos: files: '.*pools?\.py' exclude: 'versions|src/tests' additional_dependencies: [types-PyYAML, types-requests, pydantic<2, sqlalchemy] + # - repo: https://github.com/golangci/golangci-lint + # rev: v1.56.0 + # hooks: + # - id: golangci-lint-full + # - id: golangci-lint diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 7e268cab9..c698cd77c 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -10,6 +10,7 @@ register_profile_args, ) from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.errors import CLIError, ServerClientError from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -123,6 +124,11 @@ def _register(self) -> None: remove_parser.add_argument( "--name", dest="instance_name", help="The name of the instance", required=True ) + remove_parser.add_argument( + "--force", + action="store_true", + help="The name of the instance", + ) remove_parser.set_defaults(subfunc=self._remove) # pool set-default @@ -147,7 +153,9 @@ def _delete(self, args: argparse.Namespace) -> None: self.api.client.pool.delete(self.api.project, args.pool_name, args.force) def _remove(self, args: argparse.Namespace) -> None: - self.api.client.pool.remove(self.api.project, args.pool_name, args.instance_name) + self.api.client.pool.remove( + self.api.project, args.pool_name, args.instance_name, args.force + ) def _set_default(self, args: argparse.Namespace) -> None: result = self.api.client.pool.set_default(self.api.project, args.pool_name) @@ -181,7 +189,7 @@ def _add(self, args: argparse.Namespace) -> None: # TODO: add full support termination_policy_idle = 5 * 60 # 5 minutes by default termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE - profile.termination_idle_time = str(termination_policy_idle) + profile.termination_idle_time = str(termination_policy_idle) # TODO: fix serialization profile.termination_policy = termination_policy # Add remote instance @@ -210,8 +218,10 @@ def _add(self, args: argparse.Namespace) -> None: return try: + user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() + pub_key = SSHKeys(public=user_pub_key) with console.status("Creating instance..."): - self.api.runs.create_instance(pool_name, profile, requirements) + self.api.runs.create_instance(pool_name, profile, requirements, pub_key) except ServerClientError as e: raise CLIError(e.msg) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index f6d40a864..bf74bd62b 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -27,7 +27,7 @@ class SSHKeys(BaseModel): public: str - private: Optional[str] + private: Optional[str] = None class DockerConfig(BaseModel): diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 68bce22b0..e53e5b848 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -1,10 +1,11 @@ import datetime +from typing import Optional from pydantic import BaseModel from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType -from dstack._internal.core.models.runs import InstanceStatus +from dstack._internal.core.models.runs import InstanceStatus, JobStatus class Pool(BaseModel): # type: ignore[misc] @@ -18,7 +19,9 @@ class Pool(BaseModel): # type: ignore[misc] class Instance(BaseModel): # type: ignore[misc] backend: BackendType instance_type: InstanceType - instance_id: str + instance_id: str # TODO: rename to name + job_name: Optional[str] = None + job_status: Optional[JobStatus] = None hostname: str status: InstanceStatus price: float diff --git a/src/dstack/_internal/core/services/ssh/ports.py b/src/dstack/_internal/core/services/ssh/ports.py index b65d6f163..3d81d0f11 100644 --- a/src/dstack/_internal/core/services/ssh/ports.py +++ b/src/dstack/_internal/core/services/ssh/ports.py @@ -66,6 +66,9 @@ def dict(self) -> Dict[int, int]: d[remote_port] = self.sockets[remote_port].getsockname()[1] return d + def __str__(self) -> str: + return f"" + @staticmethod def _listen(port: int) -> Optional[socket.socket]: try: diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index d23dd2ead..4e3faa7a8 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -20,7 +20,7 @@ from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.utils.common import run_async -from dstack._internal.utils.common import get_current_datetime, parse_pretty_duration +from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) @@ -154,11 +154,11 @@ async def terminate_idle_instance() -> None: if instance.last_job_processed_at is not None: last_time = instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc) + idle_seconds = instance.termination_idle_time + delta = datetime.timedelta(seconds=idle_seconds) + current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc) - delta = datetime.timedelta( - seconds=parse_pretty_duration(instance.termination_idle_time) - ) if last_time + delta < current_time: jpd: JobProvisioningData = parse_raw_as( JobProvisioningData, instance.job_provisioning_data diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index d61411e40..51c4f8b7b 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -183,7 +183,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), termination_policy=profile.termination_policy, - termination_idle_time=str(profile.termination_idle_time), + termination_idle_time=300, # TODO: fix deserialize job=job_model, backend=offer.backend, price=offer.price, diff --git a/src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py b/src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py new file mode 100644 index 000000000..87a48e367 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py @@ -0,0 +1,41 @@ +"""add pools fix optional + +Revision ID: 5395b4ae6c3b +Revises: b55bd09bf186 +Create Date: 2024-02-08 11:08:31.426042 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5395b4ae6c3b" +down_revision = "b55bd09bf186" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column( + "termination_idle_time", + existing_type=sa.VARCHAR(length=50), + type_=sa.Integer(), + nullable=False, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column( + "termination_idle_time", + existing_type=sa.Integer(), + type_=sa.VARCHAR(length=50), + nullable=True, + ) + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d0affc2db..c9e14cc2a 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -285,7 +285,7 @@ class InstanceModel(BaseModel): finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50)) - termination_idle_time: Mapped[Optional[str]] = mapped_column(String(50)) + termination_idle_time: Mapped[int] = mapped_column(Integer) backend: Mapped[BackendType] = mapped_column(Enum(BackendType)) backend_data: Mapped[Optional[str]] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 82186c9fb..273586a85 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -35,7 +35,9 @@ async def remove_instance( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ) -> None: _, project_model = user_project - await pools.remove_instance(session, project_model, body.pool_name, body.instance_name) + await pools.remove_instance( + session, project_model, body.pool_name, body.instance_name, body.force + ) @router.post("/set-default") # type: ignore[misc] diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 556696ec6..90902f515 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -98,6 +98,7 @@ async def create_instance( session=session, project=project, user=user, + ssh_key=body.ssh_key, pool_name=body.pool_name, instance_name=instance_name, profile=body.profile, diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index bbf006032..6ae3b2d2f 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel @@ -15,8 +17,9 @@ class ShowPoolRequest(BaseModel): # type: ignore[misc] class RemoveInstanceRequest(BaseModel): # type: ignore[misc] - pool_name: str + pool_name: Optional[str] instance_name: str + force: bool = False class SetDefaultPoolRequest(BaseModel): # type: ignore[misc] diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index c3745fe39..58c458cd0 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import Requirements, RunSpec @@ -29,6 +30,7 @@ class CreateInstanceRequest(BaseModel): pool_name: str profile: Profile requirements: Requirements + ssh_key: SSHKeys class AddRemoteInstanceRequest(BaseModel): diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 9c94b6a78..2b8f15940 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -153,7 +153,11 @@ async def set_default_pool(session: AsyncSession, project: ProjectModel, pool_na async def remove_instance( - session: AsyncSession, project: ProjectModel, pool_name: str, instance_name: str + session: AsyncSession, + project: ProjectModel, + pool_name: Optional[str], + instance_name: str, + force: bool, ) -> None: pool = ( await session.scalars( @@ -163,14 +167,23 @@ async def remove_instance( PoolModel.deleted == False, ) ) - ).one() + ).one_or_none() + + if pool is None: + logger.warning("Couldn't find pool") + return + + # TODO: need lock terminated = False for instance in pool.instances: if instance.name == instance_name: - instance.status = InstanceStatus.TERMINATING - terminated = True + if force or instance.job_id is None: + instance.status = InstanceStatus.TERMINATING + terminated = True + if not terminated: logger.warning("Couldn't find instance to terminate") + await session.commit() @@ -226,6 +239,10 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: status=instance_model.status, price=offer.price, ) + if instance_model.job is not None: + instance.job_name = instance_model.job.name + instance.job_status = instance_model.job.status + return instance @@ -347,7 +364,7 @@ async def add_remote( offer=offer.json(), resource_spec_data=resources.json(), termination_policy=profile.termination_policy, - termination_idle_time=str(profile.termination_idle_time), + termination_idle_time=300, # TODO: fix deserialize ) session.add(im) await session.commit() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 6fdd524a4..53e2b02e4 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -176,6 +176,7 @@ async def create_instance( session: AsyncSession, project: ProjectModel, user: UserModel, + ssh_key: SSHKeys, pool_name: str, instance_name: str, profile: Profile, @@ -188,15 +189,17 @@ async def create_instance( if not offers: return - ssh_key = SSHKeys( + user_ssh_key = ssh_key + project_ssh_key = SSHKeys( public=project.ssh_public_key.strip(), private=project.ssh_private_key.strip(), ) + image = parse_image_name(get_default_image(get_default_python_verison())) instance_config = InstanceConfiguration( instance_name=instance_name, pool_name=pool_name, - ssh_keys=[ssh_key], + ssh_keys=[user_ssh_key, project_ssh_key], job_docker_config=DockerConfig( image=image, registry_auth=None, @@ -278,7 +281,7 @@ async def create_instance( offer=cast(InstanceOfferWithAvailability, instance_offer).json(), resource_spec_data=requirements.resources.json(), termination_policy=profile.termination_policy, - termination_idle_time=str(profile.termination_idle_time), + termination_idle_time=300, # TODO: fix deserialize ) session.add(im) await session.commit() diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 6b716c8b7..fb525aafb 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -12,6 +12,7 @@ from websocket import WebSocketApp import dstack.api as api +from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration @@ -367,8 +368,12 @@ def get_offers( ) -> Tuple[str, List[InstanceOfferWithAvailability]]: return self._api_client.runs.get_offers(self._project, profile, requirements) - def create_instance(self, pool_name: str, profile: Profile, requirements: Requirements): - self._api_client.runs.create_instance(self._project, pool_name, profile, requirements) + def create_instance( + self, pool_name: str, profile: Profile, requirements: Requirements, ssh_key: SSHKeys + ): + self._api_client.runs.create_instance( + self._project, pool_name, profile, requirements, ssh_key + ) def get_plan( self, diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index cde962a06..4f9b9c455 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -30,9 +30,11 @@ def show(self, project_name: str, pool_name: str) -> List[Instance]: result: List[Instance] = parse_obj_as(List[Instance], resp.json()) return result - def remove(self, project_name: str, pool_name: str, instance_name: str) -> None: + def remove( + self, project_name: str, pool_name: Optional[str], instance_name: str, force: bool + ) -> None: body = schemas_pools.RemoveInstanceRequest( - pool_name=pool_name, instance_name=instance_name + pool_name=pool_name, instance_name=instance_name, force=force ) self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index bb771b38c..404f47a60 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -2,6 +2,7 @@ from pydantic import parse_obj_as +from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.models.instances import InstanceOfferWithAvailability from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec @@ -37,10 +38,15 @@ def get_offers( return parse_obj_as(Tuple[str, List[InstanceOfferWithAvailability]], resp.json()) def create_instance( - self, project_name: str, pool_name: str, profile: Profile, requirements: Requirements + self, + project_name: str, + pool_name: str, + profile: Profile, + requirements: Requirements, + ssh_key: SSHKeys, ): body = CreateInstanceRequest( - pool_name=pool_name, profile=profile, requirements=requirements + pool_name=pool_name, profile=profile, requirements=requirements, ssh_key=ssh_key ) self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 2b01570c6..30abbb495 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -80,7 +80,7 @@ def get_dev_env_run_plan_dict( "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", - "termination_idle_time": None, + "termination_idle_time": 300, "termination_policy": None, }, "repo_code_hash": None, @@ -192,7 +192,7 @@ def get_dev_env_run_dict( "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", - "termination_idle_time": None, + "termination_idle_time": 300, "termination_policy": None, }, "repo_code_hash": None, From 889555705c6cc073b64b7f7ab470a7f76e99d605 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 8 Feb 2024 13:12:43 +0100 Subject: [PATCH 28/47] Do not require repo in `dstack add` --- src/dstack/_internal/cli/commands/pool.py | 5 +---- src/dstack/_internal/core/backends/gcp/compute.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 8c1b54d4e..0a4c2eea1 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -20,7 +20,6 @@ from dstack._internal.core.models.profiles import Profile, SpotPolicy, TerminationPolicy from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE from dstack._internal.core.models.runs import Requirements -from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import pretty_date from dstack._internal.utils.logging import get_logger from dstack.api._public.resources import Resources @@ -206,9 +205,6 @@ def _add(self, args: argparse.Namespace) -> None: console.print(f"[error]Failed to add remote instance {args.instance_name!r}[/]") return - repo = self.api.repos.load(Path.cwd()) - self.api.ssh_identity_file = ConfigManager().get_repo_config(repo.repo_dir).ssh_key_path - with console.status("Getting instances..."): pool_name, offers = self.api.runs.get_offers(profile, requirements) @@ -218,6 +214,7 @@ def _add(self, args: argparse.Namespace) -> None: return try: + # TODO(egor-s): user pub key must be added during the `run`, not `pool add` user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() pub_key = SSHKeys(public=user_pub_key) with console.status("Creating instance..."): diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index bb19b4ada..3d421f160 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -116,6 +116,7 @@ def create_instance( user_data=get_instance_user_data( authorized_keys=instance_config.get_public_keys(), ), + authorized_keys=instance_config.get_public_keys(), labels={ "owner": "dstack", "dstack_project": project_id, From f9fb2cd2eea819913045027a5a57d925d4d9d7ff Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 8 Feb 2024 13:56:01 +0100 Subject: [PATCH 29/47] Show provisioned instance in `dstack pool add` --- src/dstack/_internal/cli/commands/pool.py | 9 +++++---- src/dstack/_internal/server/routers/runs.py | 10 +++++++--- src/dstack/_internal/server/services/runs.py | 6 ++++-- src/dstack/api/_public/runs.py | 5 +++-- src/dstack/api/server/_runs.py | 6 ++++-- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 0a4c2eea1..4a8a3bd6b 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -213,14 +213,15 @@ def _add(self, args: argparse.Namespace) -> None: console.print("\nExiting...") return + # TODO(egor-s): user pub key must be added during the `run`, not `pool add` + user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() + pub_key = SSHKeys(public=user_pub_key) try: - # TODO(egor-s): user pub key must be added during the `run`, not `pool add` - user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() - pub_key = SSHKeys(public=user_pub_key) with console.status("Creating instance..."): - self.api.runs.create_instance(pool_name, profile, requirements, pub_key) + instance = self.api.runs.create_instance(pool_name, profile, requirements, pub_key) except ServerClientError as e: raise CLIError(e.msg) + print_instance_table([instance]) def _command(self, args: argparse.Namespace) -> None: super()._command(args) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 90902f515..29d47920b 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -3,8 +3,9 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.runs import Run, RunPlan from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel @@ -89,12 +90,12 @@ async def create_instance( body: CreateInstanceRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -): +) -> Instance: user, project = user_project instance_name = await generate_instance_name( session=session, project=project, pool_name=body.pool_name ) - await runs.create_instance( + instance = await runs.create_instance( session=session, project=project, user=user, @@ -104,6 +105,9 @@ async def create_instance( profile=body.profile, requirements=body.requirements, ) + if instance is None: + raise ServerClientError(msg="Failed to create an instance") + return instance @project_router.post("/get_plan") diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 53e2b02e4..94d59babd 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -23,6 +23,7 @@ InstanceOfferWithAvailability, LaunchedInstanceInfo, ) +from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile from dstack._internal.core.models.runs import ( InstanceStatus, @@ -65,6 +66,7 @@ filter_pool_instances, get_or_create_default_pool_by_name, get_pool_instances, + instance_model_to_instance, ) from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.server.utils.common import run_async @@ -181,7 +183,7 @@ async def create_instance( instance_name: str, profile: Profile, requirements: Requirements, -) -> Optional[InstanceModel]: +) -> Optional[Instance]: offers = await get_run_plan_by_requirements( project, profile, requirements, exclude_not_available=True ) @@ -286,7 +288,7 @@ async def create_instance( session.add(im) await session.commit() - return im + return instance_model_to_instance(im) async def get_run_plan( diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index fb525aafb..beb673589 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -17,6 +17,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.profiles import ( CreationPolicy, Profile, @@ -370,8 +371,8 @@ def get_offers( def create_instance( self, pool_name: str, profile: Profile, requirements: Requirements, ssh_key: SSHKeys - ): - self._api_client.runs.create_instance( + ) -> Instance: + return self._api_client.runs.create_instance( self._project, pool_name, profile, requirements, ssh_key ) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 404f47a60..e4ca956ad 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( @@ -44,11 +45,12 @@ def create_instance( profile: Profile, requirements: Requirements, ssh_key: SSHKeys, - ): + ) -> Instance: body = CreateInstanceRequest( pool_name=pool_name, profile=profile, requirements=requirements, ssh_key=ssh_key ) - self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) + resp = self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) + return parse_obj_as(Instance, resp.json()) def get_plan(self, project_name: str, run_spec: RunSpec) -> RunPlan: body = GetRunPlanRequest(run_spec=run_spec) From 9f01bca30c3520e449155fc8232f7485297bfebd Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 8 Feb 2024 15:16:52 +0100 Subject: [PATCH 30/47] Validate `pool add` resources args --- src/dstack/_internal/cli/commands/pool.py | 25 +++++++++---- src/dstack/_internal/cli/services/args.py | 35 +++++++++++++++++++ .../cli/services/configurators/run.py | 26 ++------------ 3 files changed, 56 insertions(+), 30 deletions(-) create mode 100644 src/dstack/_internal/cli/services/args.py diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 4a8a3bd6b..2fcc15b2b 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -5,6 +5,7 @@ from rich.table import Table from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, memory_spec from dstack._internal.cli.services.configurators.profile import ( apply_profile_args, register_profile_args, @@ -147,19 +148,26 @@ def _list(self, args: argparse.Namespace) -> None: def _create(self, args: argparse.Namespace) -> None: self.api.client.pool.create(self.api.project, args.pool_name) + console.print(f"Pool {args.pool_name!r} created") def _delete(self, args: argparse.Namespace) -> None: - self.api.client.pool.delete(self.api.project, args.pool_name, args.force) + # TODO(egor-s): ask for confirmation + with console.status("Removing pool..."): + self.api.client.pool.delete(self.api.project, args.pool_name, args.force) + console.print(f"Pool {args.pool_name!r} removed") def _remove(self, args: argparse.Namespace) -> None: - self.api.client.pool.remove( - self.api.project, args.pool_name, args.instance_name, args.force - ) + # TODO(egor-s): ask for confirmation + with console.status("Removing instance..."): + self.api.client.pool.remove( + self.api.project, args.pool_name, args.instance_name, args.force + ) + console.print(f"Instance {args.instance_name!r} removed") def _set_default(self, args: argparse.Namespace) -> None: result = self.api.client.pool.set_default(self.api.project, args.pool_name) if not result: - console.print(f"[error]Failed to set default pool {args.pool_name!r}[/]") + console.print(f"Failed to set default pool {args.pool_name!r}", style="error") def _show(self, args: argparse.Namespace) -> None: instances = self.api.client.pool.show(self.api.project, args.pool_name) @@ -203,6 +211,7 @@ def _add(self, args: argparse.Namespace) -> None: ) if not result: console.print(f"[error]Failed to add remote instance {args.instance_name!r}[/]") + # TODO(egor-s): print on success return with console.status("Getting instances..."): @@ -367,6 +376,7 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: dest="cpu", metavar="SPEC", default=DEFAULT_CPU_COUNT, + type=cpu_spec, ) resources_group.add_argument( @@ -376,6 +386,7 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: dest="memory", metavar="SIZE", default=DEFAULT_MEMORY_SIZE, + type=memory_spec, ) resources_group.add_argument( @@ -393,12 +404,14 @@ def register_resource_args(parser: argparse.ArgumentParser) -> None: dest="gpu", default=None, metavar="SPEC", + type=gpu_spec, ) resources_group.add_argument( "--disk", - help="Request the size of disk for the run. Example [code]--disk 100GB[/].", + help="Request the size of disk for the run. Example [code]--disk 100GB..[/].", dest="disk", metavar="SIZE", default=None, + type=disk_spec, ) diff --git a/src/dstack/_internal/cli/services/args.py b/src/dstack/_internal/cli/services/args.py new file mode 100644 index 000000000..24a4663dc --- /dev/null +++ b/src/dstack/_internal/cli/services/args.py @@ -0,0 +1,35 @@ +import re +from typing import Dict, Tuple + +from pydantic import parse_obj_as + +from dstack._internal.core.models import resources as resources +from dstack._internal.core.models.configurations import PortMapping + + +def gpu_spec(v: str) -> Dict: + return resources.GPUSpec.parse(v) + + +def env_var(v: str) -> Tuple[str, str]: + r = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)=(.*)$", v) + if r is None: + raise ValueError(v) + key, value = r.groups() + return key, value + + +def port_mapping(v: str) -> PortMapping: + return PortMapping.parse(v) + + +def cpu_spec(v: str) -> resources.Range[int]: + return parse_obj_as(resources.Range[int], v) + + +def memory_spec(v: str) -> resources.Range[resources.Memory]: + return parse_obj_as(resources.Range[resources.Memory], v) + + +def disk_spec(v: str) -> resources.DiskSpec: + return parse_obj_as(resources.DiskSpec, v) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 155e85cce..75b26aa44 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -1,11 +1,9 @@ import argparse -import re import subprocess -from typing import Dict, List, Optional, Tuple, Type - -from pydantic import parse_obj_as +from typing import Dict, List, Optional, Type import dstack._internal.core.models.resources as resources +from dstack._internal.cli.services.args import disk_spec, env_var, gpu_spec, port_mapping from dstack._internal.cli.utils.common import console from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.configurations import ( @@ -131,26 +129,6 @@ def apply(cls, args: argparse.Namespace, unknown: List[str], conf: ServiceConfig cls.interpolate_run_args(conf.commands, unknown) -def env_var(v: str) -> Tuple[str, str]: - r = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)=(.*)$", v) - if r is None: - raise ValueError(v) - key, value = r.groups() - return key, value - - -def gpu_spec(v: str) -> Dict: - return resources.GPUSpec.parse(v) - - -def disk_spec(v: str) -> resources.DiskSpec: - return parse_obj_as(resources.DiskSpec, v) - - -def port_mapping(v: str) -> PortMapping: - return PortMapping.parse(v) - - def merge_ports(conf: List[PortMapping], args: List[PortMapping]) -> Dict[int, PortMapping]: unique_ports_constraint([pm.container_port for pm in conf]) unique_ports_constraint([pm.container_port for pm in args]) From e79bfb64d2becc471718bb7f77f763321feda0c2 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 8 Feb 2024 17:01:48 +0100 Subject: [PATCH 31/47] Print pool name in run plan --- src/dstack/_internal/cli/utils/run.py | 1 + src/dstack/_internal/server/services/runs.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index cac447fdb..ac7802f49 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -42,6 +42,7 @@ def th(s: str) -> str: props.add_row(th("Configuration"), run_plan.run_spec.configuration_path) props.add_row(th("Project"), run_plan.project_name) props.add_row(th("User"), run_plan.user) + props.add_row(th("Pool name"), run_plan.run_spec.profile.pool_name) props.add_row(th("Min resources"), pretty_req) props.add_row(th("Max price"), max_price) props.add_row(th("Max duration"), max_duration) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 94d59babd..4c59ef30c 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -352,8 +352,9 @@ async def get_run_plan( ) job_plans.append(job_plan) - run_spec.profile.termination_idle_time = None + run_spec.profile.termination_idle_time = None # TODO(egor-s) explain why? + run_spec.profile.pool_name = pool_name # write pool name back for the client run_spec.run_name = run_name # restore run_name run_plan = RunPlan( project_name=project.name, user=user.name, run_spec=run_spec, job_plans=job_plans From b831a6079c2f16ae4c6fae19d5211736df28f494 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 8 Feb 2024 17:15:32 +0100 Subject: [PATCH 32/47] Add TODOs --- src/dstack/_internal/server/services/runs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 4c59ef30c..7276f3c58 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -317,6 +317,7 @@ async def get_run_plan( pool_instances, profile, run_spec.configuration.resources ): pool_offers.append(pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer)) + # TODO(egor-s): assign availability based on the instance status backends = await backends_services.get_project_backends(project=project) if profile.backends is not None: @@ -344,6 +345,7 @@ async def get_run_plan( offer.backend = backend.TYPE job_offers.extend(offer for _, offer in offers) + # TODO(egor-s): merge job_offers and pool_offers based on (availability, use/create, price) job_plan = JobPlan( job_spec=job.job_spec, offers=job_offers[:50], From c35e63b08a498762ba20a0fd4e633bf1d6a14f4c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 8 Feb 2024 17:37:46 +0100 Subject: [PATCH 33/47] Add InstanceAvailability for pool instances --- src/dstack/_internal/cli/utils/run.py | 2 ++ src/dstack/_internal/core/models/instances.py | 9 +++++++++ .../_internal/server/services/backends/__init__.py | 10 +++------- src/dstack/_internal/server/services/runs.py | 9 +++++++-- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index ac7802f49..b3c27191e 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -68,6 +68,8 @@ def th(s: str) -> str: if offer.availability in { InstanceAvailability.NOT_AVAILABLE, InstanceAvailability.NO_QUOTA, + InstanceAvailability.READY, + InstanceAvailability.BUSY, }: availability = offer.availability.value.replace("_", " ").title() offers.add_row( diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index cd45ded11..99b2129dd 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -76,6 +76,15 @@ class InstanceAvailability(Enum): AVAILABLE = "available" NOT_AVAILABLE = "not_available" NO_QUOTA = "no_quota" + READY = "ready" + BUSY = "busy" + + def is_available(self) -> bool: + return self in { + InstanceAvailability.UNKNOWN, + InstanceAvailability.AVAILABLE, + InstanceAvailability.READY, + } class InstanceOffer(BaseModel): diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 8fc3593c3..3a05e262c 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -20,7 +20,6 @@ ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( - InstanceAvailability, InstanceOfferWithAvailability, ) from dstack._internal.core.models.runs import Requirements @@ -292,9 +291,6 @@ async def get_project_backend_model_by_type( return None -_NOT_AVAILABLE = {InstanceAvailability.NOT_AVAILABLE, InstanceAvailability.NO_QUOTA} - - async def get_instance_offers( backends: List[Backend], requirements: Requirements, exclude_not_available: bool = False ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: @@ -306,11 +302,11 @@ async def get_instance_offers( [ (backend, offer) for offer in backend_offers - if not exclude_not_available or offer.availability not in _NOT_AVAILABLE + if not exclude_not_available or offer.availability.is_available() ] for backend, backend_offers in zip(backends, await asyncio.gather(*tasks)) ] # Merge preserving order for every backend offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price) - # Put NOT_AVAILABLE and NO_QUOTA instances at the end, do not sort by price - return sorted(offers, key=lambda i: i[1].availability in _NOT_AVAILABLE) + # Put NOT_AVAILABLE, NO_QUOTA, and BUSY instances at the end, do not sort by price + return sorted(offers, key=lambda i: not i[1].availability.is_available()) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 7276f3c58..6408e7b97 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -20,6 +20,7 @@ ) from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError from dstack._internal.core.models.instances import ( + InstanceAvailability, InstanceOfferWithAvailability, LaunchedInstanceInfo, ) @@ -316,8 +317,12 @@ async def get_run_plan( for instance in filter_pool_instances( pool_instances, profile, run_spec.configuration.resources ): - pool_offers.append(pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer)) - # TODO(egor-s): assign availability based on the instance status + offer = pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) + if instance.status == InstanceStatus.READY: + offer.availability = InstanceAvailability.READY + else: + offer.availability = InstanceAvailability.BUSY + pool_offers.append(offer) backends = await backends_services.get_project_backends(project=project) if profile.backends is not None: From 9bc636bd1367e414b4f020aedb3607f35326a1e5 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 9 Feb 2024 05:10:29 +0300 Subject: [PATCH 34/47] Timout for runner download --- .pre-commit-config.yaml | 11 ++++++----- runner/internal/common/interpolator.go | 3 ++- runner/internal/common/interpolator_test.go | 3 ++- runner/internal/executor/base.go | 1 + runner/internal/executor/exec_test.go | 3 ++- runner/internal/executor/executor_test.go | 7 ++++--- runner/internal/executor/logs.go | 3 ++- runner/internal/executor/timestamp.go | 3 ++- runner/internal/log/log.go | 5 +++-- runner/internal/repo/manager.go | 2 +- runner/internal/runner/api/ws.go | 5 +++-- runner/internal/shim/docker.go | 1 - runner/internal/shim/runner.go | 13 ++++++++++++- 13 files changed, 40 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08bd08cce..46a2d6e25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,8 +14,9 @@ repos: files: '.*pools?\.py' exclude: 'versions|src/tests' additional_dependencies: [types-PyYAML, types-requests, pydantic<2, sqlalchemy] - # - repo: https://github.com/golangci/golangci-lint - # rev: v1.56.0 - # hooks: - # - id: golangci-lint-full - # - id: golangci-lint + - repo: https://github.com/golangci/golangci-lint + rev: v1.56.0 + hooks: + - id: golangci-lint-full + entry: bash -c 'cd runner && golangci-lint run -D depguard --presets import,module,unused "$@"' + stages: [manual] diff --git a/runner/internal/common/interpolator.go b/runner/internal/common/interpolator.go index 68b22e6b8..733114181 100644 --- a/runner/internal/common/interpolator.go +++ b/runner/internal/common/interpolator.go @@ -3,9 +3,10 @@ package common import ( "context" "fmt" + "strings" + "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" - "strings" ) const ( diff --git a/runner/internal/common/interpolator_test.go b/runner/internal/common/interpolator_test.go index fbb8d6667..e14a24874 100644 --- a/runner/internal/common/interpolator_test.go +++ b/runner/internal/common/interpolator_test.go @@ -2,8 +2,9 @@ package common import ( "context" - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestPlainText(t *testing.T) { diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index 36e441d4a..bf3eeb291 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -2,6 +2,7 @@ package executor import ( "context" + "github.com/dstackai/dstack/runner/internal/schemas" ) diff --git a/runner/internal/executor/exec_test.go b/runner/internal/executor/exec_test.go index fc075e3b7..841f4e6b1 100644 --- a/runner/internal/executor/exec_test.go +++ b/runner/internal/executor/exec_test.go @@ -1,8 +1,9 @@ package executor import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestJoinRelPath(t *testing.T) { diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index c6b524574..ff61910a6 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -5,14 +5,15 @@ import ( "bytes" "context" "fmt" - "github.com/dstackai/dstack/runner/internal/schemas" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "io" "os" "path/filepath" "testing" "time" + + "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // todo test get history diff --git a/runner/internal/executor/logs.go b/runner/internal/executor/logs.go index 71228b3db..807071eeb 100644 --- a/runner/internal/executor/logs.go +++ b/runner/internal/executor/logs.go @@ -1,8 +1,9 @@ package executor import ( - "github.com/dstackai/dstack/runner/internal/schemas" "sync" + + "github.com/dstackai/dstack/runner/internal/schemas" ) type appendWriter struct { diff --git a/runner/internal/executor/timestamp.go b/runner/internal/executor/timestamp.go index d93a649b3..a9463c04c 100644 --- a/runner/internal/executor/timestamp.go +++ b/runner/internal/executor/timestamp.go @@ -2,9 +2,10 @@ package executor import ( "context" - "github.com/dstackai/dstack/runner/internal/log" "sync" "time" + + "github.com/dstackai/dstack/runner/internal/log" ) type MonotonicTimestamp struct { diff --git a/runner/internal/log/log.go b/runner/internal/log/log.go index 5fd03fd51..99478a8f9 100644 --- a/runner/internal/log/log.go +++ b/runner/internal/log/log.go @@ -3,10 +3,11 @@ package log import ( "context" "fmt" - "github.com/dstackai/dstack/runner/internal/gerrors" - "github.com/sirupsen/logrus" "io" "os" + + "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/sirupsen/logrus" ) type loggerKey struct{} diff --git a/runner/internal/repo/manager.go b/runner/internal/repo/manager.go index 0f5a9d254..3d2ad42d6 100644 --- a/runner/internal/repo/manager.go +++ b/runner/internal/repo/manager.go @@ -3,9 +3,9 @@ package repo import ( "context" "fmt" - "github.com/dstackai/dstack/runner/internal/gerrors" "os" + "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go index 26a52cd32..cade1170a 100644 --- a/runner/internal/runner/api/ws.go +++ b/runner/internal/runner/api/ws.go @@ -2,10 +2,11 @@ package api import ( "context" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/gorilla/websocket" "net/http" "time" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 50531c434..b70170cb3 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -19,7 +19,6 @@ import ( docker "github.com/docker/docker/client" "github.com/docker/go-connections/nat" "github.com/dstackai/dstack/runner/consts" - "github.com/ztrue/tracerr" ) diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go index bf8ae6edb..1388c4b85 100644 --- a/runner/internal/shim/runner.go +++ b/runner/internal/shim/runner.go @@ -1,6 +1,7 @@ package shim import ( + "context" "fmt" "io" "log" @@ -8,6 +9,7 @@ import ( "os" "strconv" "strings" + "time" "github.com/dstackai/dstack/runner/internal/gerrors" ) @@ -76,7 +78,16 @@ func downloadRunner(url string) (string, error) { }() log.Printf("Downloading runner from %s\n", url) - resp, err := http.Get(url) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*600) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", gerrors.Wrap(err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { return "", gerrors.Wrap(err) } From 5011a939f397f450c3640f40f4985bed43a170b4 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 9 Feb 2024 06:32:17 +0300 Subject: [PATCH 35/47] Remove duplicate code. Fix termination_idle_time. Small review fixes --- src/dstack/_internal/cli/commands/pool.py | 16 +- src/dstack/_internal/cli/commands/run.py | 8 +- .../_internal/core/backends/aws/compute.py | 88 +++-------- .../_internal/core/backends/azure/compute.py | 6 +- .../_internal/core/backends/base/compute.py | 144 ++---------------- .../core/backends/datacrunch/compute.py | 80 ++-------- .../_internal/core/backends/gcp/compute.py | 92 +++-------- .../core/backends/lambdalabs/compute.py | 7 +- .../_internal/core/backends/local/compute.py | 15 +- .../_internal/core/backends/nebius/compute.py | 6 +- .../core/backends/tensordock/compute.py | 7 +- src/dstack/_internal/core/models/instances.py | 5 + src/dstack/_internal/core/models/profiles.py | 12 +- .../tasks/process_submitted_jobs.py | 3 +- src/dstack/_internal/server/models.py | 6 +- src/dstack/_internal/server/schemas/runs.py | 4 +- src/dstack/_internal/server/services/pools.py | 2 +- src/dstack/_internal/server/services/runs.py | 20 +-- src/dstack/_internal/server/testing/common.py | 11 +- src/dstack/api/_public/runs.py | 11 +- src/dstack/api/server/_runs.py | 5 +- .../_internal/server/services/test_pools.py | 2 + 22 files changed, 125 insertions(+), 425 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 8c1b54d4e..636ed8fff 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -10,14 +10,19 @@ register_profile_args, ) from dstack._internal.cli.utils.common import confirm_ask, console -from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.errors import CLIError, ServerClientError from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, + SSHKey, ) from dstack._internal.core.models.pools import Instance, Pool -from dstack._internal.core.models.profiles import Profile, SpotPolicy, TerminationPolicy +from dstack._internal.core.models.profiles import ( + DEFAULT_TERMINATION_IDLE_TIME, + Profile, + SpotPolicy, + TerminationPolicy, +) from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE from dstack._internal.core.models.runs import Requirements from dstack._internal.core.services.configs import ConfigManager @@ -186,10 +191,9 @@ def _add(self, args: argparse.Namespace) -> None: apply_profile_args(args, profile) profile.pool_name = args.pool_name - # TODO: add full support - termination_policy_idle = 5 * 60 # 5 minutes by default + termination_policy_idle = DEFAULT_TERMINATION_IDLE_TIME termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE - profile.termination_idle_time = str(termination_policy_idle) # TODO: fix serialization + profile.termination_idle_time = termination_policy_idle profile.termination_policy = termination_policy # Add remote instance @@ -219,7 +223,7 @@ def _add(self, args: argparse.Namespace) -> None: try: user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() - pub_key = SSHKeys(public=user_pub_key) + pub_key = SSHKey(public=user_pub_key) with console.status("Creating instance..."): self.api.runs.create_instance(pool_name, profile, requirements, pub_key) except ServerClientError as e: diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index 2608d8638..e2de69822 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -17,7 +17,11 @@ from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType -from dstack._internal.core.models.profiles import CreationPolicy, TerminationPolicy +from dstack._internal.core.models.profiles import ( + DEFAULT_TERMINATION_IDLE_TIME, + CreationPolicy, + TerminationPolicy, +) from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.common import parse_pretty_duration @@ -114,7 +118,7 @@ def _command(self, args: argparse.Namespace): self._parser.print_help() return - termination_policy_idle = 5 * 60 # 5 minutes by default + termination_policy_idle = DEFAULT_TERMINATION_IDLE_TIME termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE if args.idle_duration is not None: diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index c03d6da69..d019b08fd 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -13,7 +13,6 @@ InstanceConfiguration, get_gateway_user_data, get_instance_name, - get_instance_user_data, get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -26,9 +25,9 @@ InstanceOfferWithAvailability, LaunchedGatewayInfo, LaunchedInstanceInfo, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run -from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -99,12 +98,10 @@ def terminate_instance( def create_instance( self, - project: ProjectModel, - user: UserModel, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: - project_id = project.name + project_name = instance_config.project_name ec2 = self.session.resource("ec2", region_name=instance_offer.region) ec2_client = self.session.client("ec2", region_name=instance_offer.region) iam_client = self.session.client("iam", region_name=instance_offer.region) @@ -112,7 +109,7 @@ def create_instance( tags = [ {"Key": "Name", "Value": instance_config.instance_name}, {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_project", "Value": project_id}, + {"Key": "dstack_project", "Value": project_name}, {"Key": "dstack_user", "Value": instance_config.user}, ] try: @@ -125,13 +122,13 @@ def create_instance( ), instance_type=instance_offer.instance.name, iam_instance_profile_arn=aws_resources.create_iam_instance_profile( - iam_client, project_id - ), - user_data=get_instance_user_data( - authorized_keys=instance_config.get_public_keys(), + iam_client, project_name ), + user_data=get_user_data(authorized_keys=instance_config.get_public_keys()), tags=tags, - security_group_id=aws_resources.create_security_group(ec2_client, project_id), + security_group_id=aws_resources.create_security_group( + ec2_client, project_name + ), spot=instance_offer.instance.resources.spot, ) ) @@ -163,63 +160,18 @@ def run_job( project_ssh_public_key: str, project_ssh_private_key: str, ) -> LaunchedInstanceInfo: - project_id = run.project_name - ec2 = self.session.resource("ec2", region_name=instance_offer.region) - ec2_client = self.session.client("ec2", region_name=instance_offer.region) - iam_client = self.session.client("iam", region_name=instance_offer.region) - - tags = [ - {"Key": "Name", "Value": get_instance_name(run, job)}, - {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_project", "Value": project_id}, - {"Key": "dstack_user", "Value": run.user}, - ] - try: - disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) - response = ec2.create_instances( - **aws_resources.create_instances_struct( - disk_size=disk_size, - image_id=aws_resources.get_image_id( - ec2_client, len(instance_offer.instance.resources.gpus) > 0 - ), - instance_type=instance_offer.instance.name, - iam_instance_profile_arn=aws_resources.create_iam_instance_profile( - iam_client, project_id - ), - user_data=get_user_data( - backend=BackendType.AWS, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], - registry_auth_required=job.job_spec.registry_auth is not None, - ), - tags=tags, - security_group_id=aws_resources.create_security_group(ec2_client, project_id), - spot=instance_offer.instance.resources.spot, - ) - ) - instance = response[0] - instance.wait_until_running() - instance.reload() # populate instance.public_ip_address - - if instance_offer.instance.resources.spot: # it will not terminate the instance - ec2_client.cancel_spot_instance_requests( - SpotInstanceRequestIds=[instance.spot_instance_request_id] - ) - return LaunchedInstanceInfo( - instance_id=instance.instance_id, - ip_address=instance.public_ip_address, - region=instance_offer.region, - username="ubuntu", - ssh_port=22, - dockerized=True, # because `dstack-shim docker` is used - backend_data=None, - ) - except botocore.exceptions.ClientError as e: - logger.warning("Got botocore.exceptions.ClientError: %s", e) - raise NoCapacityError() + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=get_instance_name(run, job), # TODO: generate name + ssh_keys=[ + SSHKey(public=run.run_spec.ssh_key_pub.strip()), + SSHKey(public=project_ssh_public_key.strip()), + ], + job_docker_config=None, + user=run.user, + ) + launched_instance_info = self.create_instance(instance_offer, instance_config) + return launched_instance_info def create_gateway( self, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 2e5c8c9f4..57a90e3cf 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -132,11 +132,7 @@ def run_job( # instance_name includes region because Azure may create an instance resource # even when provisioning fails. instance_name=f"{get_instance_name(run, job)}-{instance_offer.region}", - user_data=get_user_data( - backend=BackendType.AZURE, - image_name=job.job_spec.image_name, - authorized_keys=ssh_pub_keys, - ), + user_data=get_user_data(authorized_keys=ssh_pub_keys), ssh_pub_keys=ssh_pub_keys, spot=instance_offer.instance.resources.spot, disk_size=disk_size, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index bf74bd62b..48c6a51fb 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -10,37 +10,31 @@ from pydantic import BaseModel from dstack._internal import settings -from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RegistryAuth from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedGatewayInfo, LaunchedInstanceInfo, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run -from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.services.docker import DockerImage from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) -class SSHKeys(BaseModel): - public: str - private: Optional[str] = None - - class DockerConfig(BaseModel): registry_auth: Optional[RegistryAuth] image: Optional[DockerImage] class InstanceConfiguration(BaseModel): - pool_name: str + project_name: str instance_name: str # unique in pool - ssh_keys: List[SSHKeys] + ssh_keys: List[SSHKey] job_docker_config: Optional[DockerConfig] - user: Optional[str] + user: str # dstack user name def get_public_keys(self) -> List[str]: return [ssh_key.public.strip() for ssh_key in self.ssh_keys] @@ -64,15 +58,12 @@ def run_job( ) -> LaunchedInstanceInfo: pass - @abstractmethod def create_instance( self, - project: ProjectModel, - user: UserModel, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: - pass + raise NotImplementedError() @abstractmethod def terminate_instance( @@ -95,18 +86,10 @@ def get_instance_name(run: Run, job: Job) -> str: def get_user_data( - backend: BackendType, - image_name: str, authorized_keys: List[str], - registry_auth_required: bool, - cloud_config_kwargs: Optional[dict] = None, + cloud_config_kwargs: Optional[Dict[Any, Any]] = None, ) -> str: - commands = get_shim_commands( - backend=backend, - image_name=image_name, - authorized_keys=authorized_keys, - registry_auth_required=registry_auth_required, - ) + commands = get_shim_commands(authorized_keys) return get_cloud_config( runcmd=[["sh", "-c", " && ".join(commands)]], ssh_authorized_keys=authorized_keys, @@ -114,25 +97,18 @@ def get_user_data( ) -def get_shim_commands( - backend: BackendType, - image_name: str, - authorized_keys: List[str], - registry_auth_required: bool, -) -> List[str]: +def get_shim_commands(authorized_keys: List[str]) -> List[str]: build = get_dstack_runner_version() env = { - "DSTACK_BACKEND": backend.value, "DSTACK_RUNNER_LOG_LEVEL": "6", "DSTACK_RUNNER_VERSION": build, - "DSTACK_IMAGE_NAME": image_name, "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), "DSTACK_HOME": "/root/.dstack", } commands = get_dstack_shim(build) for k, v in env.items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script(registry_auth_required) + commands += get_run_shim_script() return commands @@ -162,12 +138,9 @@ def get_dstack_shim(build: str) -> List[str]: ] -def get_run_shim_script(registry_auth_required: bool) -> List[str]: +def get_run_shim_script() -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" - with_auth_flag = "--with-auth" if registry_auth_required else "" - return [ - f"nohup dstack-shim {dev_flag} docker {with_auth_flag} --keep-container >/root/shim.log 2>&1 &" - ] + return [f"nohup dstack-shim {dev_flag} docker --keep-container >/root/shim.log 2>&1 &"] def get_gateway_user_data(authorized_key: str) -> str: @@ -299,98 +272,3 @@ def get_dstack_gateway_commands() -> List[str]: f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}", "sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run", ] - - -def get_instance_user_data( - authorized_keys: List[str], - cloud_config_kwargs: Optional[Dict[Any, Any]] = None, -) -> str: - commands = get_instance_shim_commands( - authorized_keys=authorized_keys, - ) - return get_cloud_config( - runcmd=[["sh", "-c", " && ".join(commands)]], - ssh_authorized_keys=authorized_keys, - **(cloud_config_kwargs or {}), - ) - - -def get_instance_shim_commands( - authorized_keys: List[str], -) -> List[str]: - build = get_dstack_runner_version() - env = { - "DSTACK_RUNNER_LOG_LEVEL": "6", - "DSTACK_RUNNER_VERSION": build, - "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), - "DSTACK_HOME": "/root/.dstack", - } - commands = get_instance_dstack_shim(build) - for k, v in env.items(): - commands += [f'export "{k}={v}"'] - commands += get_instance_run_shim_script() - return commands - - -def get_instance_dstack_shim(build: str) -> List[str]: - bucket = "dstack-runner-downloads-stgn" - if settings.DSTACK_VERSION is not None: - bucket = "dstack-runner-downloads" - - url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" - - return [ - f'sudo curl --connect-timeout 60 --max-time 240 --retry 1 --output /usr/local/bin/dstack-shim "{url}"', - "sudo chmod +x /usr/local/bin/dstack-shim", - ] - - -def get_instance_docker_commands(authorized_keys: List[str]) -> List[str]: - authorized_keys_body = "\n".join(authorized_keys).strip() - commands = [ - # note: &> redirection doesn't work in /bin/sh - # check in sshd is here, install if not - ( - "if ! command -v sshd >/dev/null 2>&1; then { " - "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y openssh-server; " - "} || { " - "yum -y install openssh-server; " - "}; fi" - ), - # prohibit password authentication - 'sed -i "s/.*PasswordAuthentication.*/PasswordAuthentication no/g" /etc/ssh/sshd_config', - # create ssh dirs and add public key - "mkdir -p /run/sshd ~/.ssh", - "chmod 700 ~/.ssh", - f"echo '{authorized_keys_body}' > ~/.ssh/authorized_keys", - "chmod 600 ~/.ssh/authorized_keys", - # preserve environment variables for SSH clients - "env >> ~/.ssh/environment", - 'echo "export PATH=$PATH" >> ~/.profile', - # regenerate host keys - "rm -rf /etc/ssh/ssh_host_*", - "ssh-keygen -A > /dev/null", - # start sshd - "/usr/sbin/sshd -p 10022 -o PermitUserEnvironment=yes", - ] - - runner = "/usr/local/bin/dstack-runner" - - build = get_dstack_runner_version() - bucket = "dstack-runner-downloads-stgn" - if settings.DSTACK_VERSION is not None: - bucket = "dstack-runner-downloads" - - url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" - - commands += [ - f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} {url}", - f"chmod +x {runner}", - f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", - ] - return commands - - -def get_instance_run_shim_script() -> List[str]: - dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" - return [f"nohup dstack-shim {dev_flag} docker --keep-container >/root/shim.log 2>&1 &"] diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 63fa1f64c..a26fe1fec 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -3,7 +3,6 @@ from dstack._internal.core.backends.base import Compute from dstack._internal.core.backends.base.compute import ( InstanceConfiguration, - get_instance_shim_commands, get_shim_commands, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -16,9 +15,9 @@ InstanceOffer, InstanceOfferWithAvailability, LaunchedInstanceInfo, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run -from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.utils.logging import get_logger logger = get_logger("datacrunch.compute") @@ -67,8 +66,6 @@ def _get_offers_with_availability( def create_instance( self, - project: ProjectModel, - user: UserModel, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: @@ -83,9 +80,7 @@ def create_instance( ) ) - commands = get_instance_shim_commands( - authorized_keys=public_keys, - ) + commands = get_shim_commands(authorized_keys=public_keys) startup_script = " ".join([" && ".join(commands)]) script_name = f"dstack-{instance_config.instance_name}.sh" @@ -151,69 +146,18 @@ def run_job( project_ssh_public_key: str, project_ssh_private_key: str, ) -> LaunchedInstanceInfo: - ssh_ids = [] - ssh_ids.append( - self.api_client.get_or_create_ssh_key( - name=f"dstack-{job.job_spec.job_name}.key", - public_key=run.run_spec.ssh_key_pub.strip(), - ) - ) - ssh_ids.append( - self.api_client.get_or_create_ssh_key( - name=f"dstack-{job.job_spec.job_name}.key", - public_key=project_ssh_public_key.strip(), - ) - ) - - commands = get_shim_commands( - backend=BackendType.DATACRUNCH, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=job.job_spec.job_name, # TODO: generate name + ssh_keys=[ + SSHKey(public=run.run_spec.ssh_key_pub.strip()), + SSHKey(public=project_ssh_public_key.strip()), ], - registry_auth_required=job.job_spec.registry_auth is not None, - ) - - startup_script = " ".join([" && ".join(commands)]) - script_name = f"dstack-{job.job_spec.job_name}.sh" - startup_script_ids = self.api_client.get_or_create_startup_scrpit( - name=script_name, script=startup_script - ) - - name = job.job_spec.job_name - - # Id of image "Ubuntu 22.04 + CUDA 12.0 + Docker" - # from API https://datacrunch.stoplight.io/docs/datacrunch-public/c46ab45dbc508-get-all-image-types - image_name = "2088da25-bb0d-41cc-a191-dccae45d96fd" - - disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) - instance = self.api_client.deploy_instance( - instance_type=instance_offer.instance.name, - ssh_key_ids=ssh_ids, - startup_script_id=startup_script_ids, - hostname=name, - description=name, - image=image_name, - disk_size=disk_size, - location=instance_offer.region, - ) - - running_instance = self.api_client.wait_for_instance(instance.id) - if running_instance is None: - raise BackendError(f"Wait instance {instance.id!r} timeout") - - launched_instance = LaunchedInstanceInfo( - instance_id=running_instance.id, - ip_address=running_instance.ip, - region=running_instance.location, - ssh_port=22, - username="root", - dockerized=True, - backend_data=None, + job_docker_config=None, + user=run.user, ) - - return launched_instance + launched_instance_info = self.create_instance(instance_offer, instance_config) + return launched_instance_info def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index bb19b4ada..65cb89c4f 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -11,7 +11,6 @@ InstanceConfiguration, get_gateway_user_data, get_instance_name, - get_instance_user_data, get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -26,9 +25,9 @@ LaunchedGatewayInfo, LaunchedInstanceInfo, Resources, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run -from dstack._internal.server.models import ProjectModel, UserModel class GCPCompute(Compute): @@ -83,14 +82,13 @@ def terminate_instance( def create_instance( self, - project: ProjectModel, - user: UserModel, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: - project_id = project.name instance_name = instance_config.instance_name + authorized_keys = instance_config.get_public_keys() + gcp_resources.create_runner_firewall_rules( firewalls_client=self.firewalls_client, project_id=self.config.project_id, @@ -113,13 +111,12 @@ def create_instance( gpus=instance_offer.instance.resources.gpus, ), spot=instance_offer.instance.resources.spot, - user_data=get_instance_user_data( - authorized_keys=instance_config.get_public_keys(), - ), + user_data=get_user_data(authorized_keys), + authorized_keys=authorized_keys, labels={ "owner": "dstack", - "dstack_project": project_id, - "dstack_user": user.name, + "dstack_project": instance_config.project_name, + "dstack_user": instance_config.user, }, tags=[gcp_resources.DSTACK_INSTANCE_TAG], instance_name=instance_name, @@ -155,71 +152,18 @@ def run_job( project_ssh_public_key: str, project_ssh_private_key: str, ) -> LaunchedInstanceInfo: - project_id = run.project_name - instance_name = get_instance_name(run, job) - gcp_resources.create_runner_firewall_rules( - firewalls_client=self.firewalls_client, - project_id=self.config.project_id, + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=get_instance_name(run, job), # TODO: generate name + ssh_keys=[ + SSHKey(public=run.run_spec.ssh_key_pub.strip()), + SSHKey(public=project_ssh_public_key.strip()), + ], + job_docker_config=None, + user=run.user, ) - disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) - authorized_keys = [ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ] - for zone in _get_instance_zones(instance_offer): - request = compute_v1.InsertInstanceRequest() - request.zone = zone - request.project = self.config.project_id - request.instance_resource = gcp_resources.create_instance_struct( - disk_size=disk_size, - image_id=gcp_resources.get_image_id( - len(instance_offer.instance.resources.gpus) > 0, - ), - machine_type=instance_offer.instance.name, - accelerators=gcp_resources.get_accelerators( - project_id=self.config.project_id, - zone=zone, - gpus=instance_offer.instance.resources.gpus, - ), - spot=instance_offer.instance.resources.spot, - user_data=get_user_data( - backend=BackendType.GCP, - image_name=job.job_spec.image_name, - authorized_keys=authorized_keys, - registry_auth_required=job.job_spec.registry_auth is not None, - ), - authorized_keys=authorized_keys, - labels={ - "owner": "dstack", - "dstack_project": project_id, - "dstack_user": run.user, - }, - tags=[gcp_resources.DSTACK_INSTANCE_TAG], - instance_name=instance_name, - zone=zone, - service_account=self.config.service_account_email, - ) - try: - operation = self.instances_client.insert(request=request) - gcp_resources.wait_for_extended_operation(operation, "instance creation") - except ( - google.api_core.exceptions.ServiceUnavailable, - google.api_core.exceptions.NotFound, - ): - continue - instance = self.instances_client.get( - project=self.config.project_id, zone=zone, instance=instance_name - ) - return LaunchedInstanceInfo( - instance_id=instance_name, - region=zone, - ip_address=instance.network_interfaces[0].access_configs[0].nat_i_p, - username="ubuntu", - ssh_port=22, - dockerized=True, - backend_data=None, - ) - raise NoCapacityError() + launched_instance_info = self.create_instance(instance_offer, instance_config) + return launched_instance_info def create_gateway( self, diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index 5b7dd56b8..5b19361f2 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -57,12 +57,7 @@ def run_job( project_ssh_private_key: str, ) -> LaunchedInstanceInfo: commands = get_shim_commands( - backend=BackendType.LAMBDA, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], + authorized_keys=[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) # shim is asssumed to be run under root launch_command = "sudo sh -c '" + "&& ".join(commands) + "'" diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 35779f989..596ce7e1d 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -1,6 +1,6 @@ from typing import List, Optional -from dstack._internal.core.backends.base.compute import Compute, get_dstack_runner_version +from dstack._internal.core.backends.base.compute import Compute from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -37,9 +37,7 @@ def terminate_instance( ): pass - def create_instance( - self, project, user, instance_offer, instance_config - ) -> LaunchedInstanceInfo: + def create_instance(self, instance_offer, instance_config) -> LaunchedInstanceInfo: launched_instance = LaunchedInstanceInfo( instance_id="local", ip_address="127.0.0.1", @@ -59,15 +57,6 @@ def run_job( project_ssh_public_key: str, project_ssh_private_key: str, ) -> LaunchedInstanceInfo: - authorized_keys = f"{run.run_spec.ssh_key_pub.strip()}\\n{project_ssh_public_key.strip()}" - logger.info( - "Running job in LocalBackend. To start processing, run: `" - f"DSTACK_BACKEND=local " - "DSTACK_RUNNER_LOG_LEVEL=6 " - f"DSTACK_RUNNER_VERSION={get_dstack_runner_version()} " - f"DSTACK_IMAGE_NAME={job.job_spec.image_name} " - f'DSTACK_PUBLIC_SSH_KEY="{authorized_keys}" ./shim --dev docker --keep-container`', - ) return LaunchedInstanceInfo( instance_id="local", ip_address="127.0.0.1", diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 243a673cf..f02d0dc50 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -77,13 +77,11 @@ def run_job( ), metadata={ "user-data": get_user_data( - backend=BackendType.NEBIUS, - image_name=job.job_spec.image_name, authorized_keys=[ run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), - ], - ), + ] + ) }, disk_size_gb=disk_size, image_id=image_id, diff --git a/src/dstack/_internal/core/backends/tensordock/compute.py b/src/dstack/_internal/core/backends/tensordock/compute.py index b22dc6dee..da09c6f83 100644 --- a/src/dstack/_internal/core/backends/tensordock/compute.py +++ b/src/dstack/_internal/core/backends/tensordock/compute.py @@ -50,12 +50,7 @@ def run_job( project_ssh_private_key: str, ) -> LaunchedInstanceInfo: commands = get_shim_commands( - backend=BackendType.TENSORDOCK, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], + authorized_keys=[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) try: resp = self.api_client.deploy_single( diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index cd45ded11..997cc5afb 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -60,6 +60,11 @@ class SSHConnectionParams(BaseModel): port: int +class SSHKey(BaseModel): + public: str + private: Optional[str] = None + + class LaunchedInstanceInfo(BaseModel): instance_id: str region: str diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index e36bafeab..ce031f814 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -10,6 +10,7 @@ DEFAULT_RETRY_LIMIT = 3600 DEFAULT_POOL_NAME = "default-pool" +DEFAULT_TERMINATION_IDLE_TIME = 5 * 60 # 5 minutes by default class SpotPolicy(str, Enum): @@ -117,19 +118,14 @@ class Profile(ForbidExtra): Optional[TerminationPolicy], Field(description="The policy for termination instances") ] termination_idle_time: Annotated[ - Optional[Union[Literal["off"], str, int]], - Field(description=""), - ] + int, + Field(description="Seconds to wait before destroying the instance"), + ] = DEFAULT_TERMINATION_IDLE_TIME _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration ) - # TODO: fix deserialization - # _validate_termination_idle_time = validator( - # "termination_idle_time", pre=True, allow_reuse=True - # )(parse_max_duration) - class ProfilesConfig(ForbidExtra): profiles: List[Profile] diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 51c4f8b7b..710cc3022 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -122,7 +122,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): pool_instances, profile, run_spec.configuration.resources, status=InstanceStatus.READY ) - logger.info(*job_log(f"num relevance {len(relevant_instances)}", job_model)) if relevant_instances: sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name) instance = sorted_instances[0] @@ -183,7 +182,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): job_provisioning_data=job_provisioning_data.json(), offer=offer.json(), termination_policy=profile.termination_policy, - termination_idle_time=300, # TODO: fix deserialize + termination_idle_time=profile.termination_idle_time, job=job_model, backend=offer.backend, price=offer.price, diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index c9e14cc2a..2196bf0fd 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -20,7 +20,7 @@ from sqlalchemy_utils import UUIDType from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.profiles import DEFAULT_TERMINATION_IDLE_TIME, TerminationPolicy from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.runs import InstanceStatus, JobErrorCode, JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole @@ -285,7 +285,9 @@ class InstanceModel(BaseModel): finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50)) - termination_idle_time: Mapped[int] = mapped_column(Integer) + termination_idle_time: Mapped[int] = mapped_column( + Integer, default=DEFAULT_TERMINATION_IDLE_TIME + ) backend: Mapped[BackendType] = mapped_column(Enum(BackendType)) backend_data: Mapped[Optional[str]] = mapped_column(String(4000)) diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 58c458cd0..a81371461 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from dstack._internal.core.backends.base.compute import SSHKeys +from dstack._internal.core.models.instances import SSHKey from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import Requirements, RunSpec @@ -30,7 +30,7 @@ class CreateInstanceRequest(BaseModel): pool_name: str profile: Profile requirements: Requirements - ssh_key: SSHKeys + ssh_key: SSHKey class AddRemoteInstanceRequest(BaseModel): diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 2b8f15940..0ec66f0ab 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -364,7 +364,7 @@ async def add_remote( offer=offer.json(), resource_spec_data=resources.json(), termination_policy=profile.termination_policy, - termination_idle_time=300, # TODO: fix deserialize + termination_idle_time=profile.termination_idle_time, ) session.add(im) await session.commit() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 53e2b02e4..918ffa654 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -16,12 +16,12 @@ from dstack._internal.core.backends.base.compute import ( DockerConfig, InstanceConfiguration, - SSHKeys, ) from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, LaunchedInstanceInfo, + SSHKey, ) from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile from dstack._internal.core.models.runs import ( @@ -176,7 +176,7 @@ async def create_instance( session: AsyncSession, project: ProjectModel, user: UserModel, - ssh_key: SSHKeys, + ssh_key: SSHKey, pool_name: str, instance_name: str, profile: Profile, @@ -190,15 +190,15 @@ async def create_instance( return user_ssh_key = ssh_key - project_ssh_key = SSHKeys( + project_ssh_key = SSHKey( public=project.ssh_public_key.strip(), private=project.ssh_private_key.strip(), ) image = parse_image_name(get_default_image(get_default_python_verison())) instance_config = InstanceConfiguration( + project_name=project.name, instance_name=instance_name, - pool_name=pool_name, ssh_keys=[user_ssh_key, project_ssh_key], job_docker_config=DockerConfig( image=image, @@ -223,8 +223,6 @@ async def create_instance( try: launched_instance_info: LaunchedInstanceInfo = await run_async( backend.compute().create_instance, - project, - user, instance_offer, instance_config, ) @@ -281,7 +279,7 @@ async def create_instance( offer=cast(InstanceOfferWithAvailability, instance_offer).json(), resource_spec_data=requirements.resources.json(), termination_policy=profile.termination_policy, - termination_idle_time=300, # TODO: fix deserialize + termination_idle_time=profile.termination_idle_time, ) session.add(im) await session.commit() @@ -350,8 +348,6 @@ async def get_run_plan( ) job_plans.append(job_plan) - run_spec.profile.termination_idle_time = None - run_spec.run_name = run_name # restore run_name run_plan = RunPlan( project_name=project.name, user=user.name, run_spec=run_spec, job_plans=job_plans @@ -385,9 +381,6 @@ async def submit_run( else: await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) - # TODO: fix deserialize - run_spec.profile.termination_idle_time = "300s" - pool = await get_or_create_default_pool_by_name(session, project, run_spec.profile.pool_name) run_model = RunModel( @@ -518,9 +511,6 @@ def run_model_to_run(run_model: RunModel, include_job_submissions: bool = True) run_spec = RunSpec.parse_raw(run_model.run_spec) - # TODO: fix deserialization - run_spec.profile.termination_idle_time = None - latest_job_submission = None if include_job_submissions: latest_job_submission = jobs[0].job_submissions[-1] diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index d8cbaca8c..ac46b7bea 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -9,7 +9,11 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import DevEnvironmentConfiguration from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile +from dstack._internal.core.models.profiles import ( + DEFAULT_POOL_NAME, + DEFAULT_TERMINATION_IDLE_TIME, + Profile, +) from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.repos.local import LocalRunRepoData from dstack._internal.core.models.resources import ResourcesSpec @@ -145,8 +149,10 @@ async def create_repo( def get_run_spec( run_name: str, repo_id: str, - profile: Optional[Profile] = Profile(name="default"), + profile: Optional[Profile] = None, ) -> RunSpec: + if profile is None: + profile = Profile(name="default") return RunSpec( run_name=run_name, repo_id=repo_id, @@ -319,6 +325,7 @@ async def create_instance( price=1, region="eu-west", backend=BackendType.DATACRUNCH, + termination_idle_time=DEFAULT_TERMINATION_IDLE_TIME, ) session.add(im) await session.commit() diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index fb525aafb..20eeb3196 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -12,12 +12,12 @@ from websocket import WebSocketApp import dstack.api as api -from dstack._internal.core.backends.base.compute import SSHKeys from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration -from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.instances import InstanceOfferWithAvailability, SSHKey from dstack._internal.core.models.profiles import ( + DEFAULT_TERMINATION_IDLE_TIME, CreationPolicy, Profile, ProfileRetryPolicy, @@ -273,6 +273,7 @@ def attach( if not control_sock_path_and_port_locks: self._ssh_attach.attach() self._ports_lock = None + return True def detach(self): @@ -369,7 +370,7 @@ def get_offers( return self._api_client.runs.get_offers(self._project, profile, requirements) def create_instance( - self, pool_name: str, profile: Profile, requirements: Requirements, ssh_key: SSHKeys + self, pool_name: str, profile: Profile, requirements: Requirements, ssh_key: SSHKey ): self._api_client.runs.create_instance( self._project, pool_name, profile, requirements, ssh_key @@ -392,7 +393,7 @@ def get_plan( instance_name: Optional[str] = None, creation_policy: Optional[CreationPolicy] = None, termination_policy: Optional[TerminationPolicy] = None, - termination_policy_idle: Union[int, str] = 5 * 60, + termination_policy_idle: int = DEFAULT_TERMINATION_IDLE_TIME, ) -> RunPlan: # """ # Get run plan. Same arguments as `submit` @@ -426,7 +427,7 @@ def get_plan( instance_name=instance_name, creation_policy=creation_policy, termination_policy=termination_policy, - termination_idle_time=None, # TODO: fix deserialization + termination_idle_time=termination_policy_idle, ) run_spec = RunSpec( run_name=run_name, diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 404f47a60..dd84f1892 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -2,8 +2,7 @@ from pydantic import parse_obj_as -from dstack._internal.core.backends.base.compute import SSHKeys -from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.instances import InstanceOfferWithAvailability, SSHKey from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( @@ -43,7 +42,7 @@ def create_instance( pool_name: str, profile: Profile, requirements: Requirements, - ssh_key: SSHKeys, + ssh_key: SSHKey, ): body = CreateInstanceRequest( pool_name=pool_name, profile=profile, requirements=requirements, ssh_key=ssh_key diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index 77a1832ba..0b8e54ee8 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -17,6 +17,7 @@ InstanceType, LaunchedInstanceInfo, Resources, + SSHKey, ) from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import Profile @@ -252,6 +253,7 @@ def create_instance(self, *args, **kwargs): pool_name="test_pool", instance_name="test_instance", requirements=requirements, + ssh_key=SSHKey(public=""), ) pool = await services_pools.get_pool(session, project, "test_pool") From a3a3cf3664bdf9c2e688272d5f4d3fb8447cf64b Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 9 Feb 2024 13:56:36 +0300 Subject: [PATCH 36/47] Always load job --- .../_internal/core/backends/aws/compute.py | 2 +- .../_internal/core/backends/base/compute.py | 21 +------------------ .../core/backends/datacrunch/compute.py | 2 +- .../_internal/core/backends/gcp/compute.py | 2 +- src/dstack/_internal/core/models/instances.py | 18 ++++++++++++++++ src/dstack/_internal/server/models.py | 2 +- src/dstack/_internal/server/services/runs.py | 6 ++---- 7 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index d019b08fd..2de538cea 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -10,7 +10,6 @@ from dstack._internal.core.backends.aws.config import AWSConfig from dstack._internal.core.backends.base.compute import ( Compute, - InstanceConfiguration, get_gateway_user_data, get_instance_name, get_user_data, @@ -21,6 +20,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, + InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, LaunchedGatewayInfo, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 48c6a51fb..f82addcac 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -7,39 +7,20 @@ import git import requests import yaml -from pydantic import BaseModel from dstack._internal import settings -from dstack._internal.core.models.configurations import RegistryAuth from dstack._internal.core.models.instances import ( + InstanceConfiguration, InstanceOfferWithAvailability, LaunchedGatewayInfo, LaunchedInstanceInfo, - SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run -from dstack._internal.server.services.docker import DockerImage from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) -class DockerConfig(BaseModel): - registry_auth: Optional[RegistryAuth] - image: Optional[DockerImage] - - -class InstanceConfiguration(BaseModel): - project_name: str - instance_name: str # unique in pool - ssh_keys: List[SSHKey] - job_docker_config: Optional[DockerConfig] - user: str # dstack user name - - def get_public_keys(self) -> List[str]: - return [ssh_key.public.strip() for ssh_key in self.ssh_keys] - - class Compute(ABC): @abstractmethod def get_offers( diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index a26fe1fec..9102bcdfa 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -2,7 +2,6 @@ from dstack._internal.core.backends.base import Compute from dstack._internal.core.backends.base.compute import ( - InstanceConfiguration, get_shim_commands, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -12,6 +11,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, + InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, LaunchedInstanceInfo, diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 65cb89c4f..9e0021e40 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -8,7 +8,6 @@ import dstack._internal.core.backends.gcp.resources as gcp_resources from dstack._internal.core.backends.base.compute import ( Compute, - InstanceConfiguration, get_gateway_user_data, get_instance_name, get_user_data, @@ -19,6 +18,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, + InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, InstanceType, diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 4c7d32e22..6c319d5fb 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import RegistryAuth +from dstack._internal.server.services.docker import DockerImage from dstack._internal.utils.common import pretty_resources @@ -65,6 +67,22 @@ class SSHKey(BaseModel): private: Optional[str] = None +class DockerConfig(BaseModel): + registry_auth: Optional[RegistryAuth] + image: Optional[DockerImage] + + +class InstanceConfiguration(BaseModel): + project_name: str + instance_name: str # unique in pool + ssh_keys: List[SSHKey] + job_docker_config: Optional[DockerConfig] + user: str # dstack user name + + def get_public_keys(self) -> List[str]: + return [ssh_key.public.strip() for ssh_key in self.ssh_keys] + + class LaunchedInstanceInfo(BaseModel): instance_id: str region: str diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 2196bf0fd..18977a159 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -302,7 +302,7 @@ class InstanceModel(BaseModel): # current job job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id")) - job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance") + job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="immediate") last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) # + # job_id: Optional[FK] (current job) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 5c0addb02..30dbe8b27 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -13,13 +13,11 @@ import dstack._internal.server.services.gateways as gateways import dstack._internal.utils.common as common_utils from dstack._internal.core.backends.base import Backend -from dstack._internal.core.backends.base.compute import ( - DockerConfig, - InstanceConfiguration, -) from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError from dstack._internal.core.models.instances import ( + DockerConfig, InstanceAvailability, + InstanceConfiguration, InstanceOfferWithAvailability, LaunchedInstanceInfo, SSHKey, From 95f34ce9ae34b3159d7b820cac7c4afe09bb37f1 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 9 Feb 2024 13:28:54 +0100 Subject: [PATCH 37/47] TODO and small fix --- src/dstack/_internal/cli/commands/pool.py | 2 +- src/dstack/_internal/server/services/pools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index af7f128f1..211b3ccd9 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -191,7 +191,7 @@ def _add(self, args: argparse.Namespace) -> None: requirements = Requirements( resources=resources, max_price=args.max_price, - spot=(args.spot_policy == SpotPolicy.SPOT), + spot=(args.spot_policy == SpotPolicy.SPOT), # TODO(egor-s): None if SpotPolicy.AUTO ) profile = load_profile(Path.cwd(), args.profile) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 0ec66f0ab..9af84cc27 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -240,7 +240,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: price=offer.price, ) if instance_model.job is not None: - instance.job_name = instance_model.job.name + instance.job_name = instance_model.job.job_name instance.job_status = instance_model.job.status return instance From b01ec508d7d11ee3b46f664b11f7da9449ef271a Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 9 Feb 2024 17:35:38 +0300 Subject: [PATCH 38/47] Fix job_name --- src/dstack/_internal/server/services/pools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 0ec66f0ab..9af84cc27 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -240,7 +240,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: price=offer.price, ) if instance_model.job is not None: - instance.job_name = instance_model.job.name + instance.job_name = instance_model.job.job_name instance.job_status = instance_model.job.status return instance From 2076ad560fcc5bdc6d3b361051718656329dcd89 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 9 Feb 2024 16:19:00 +0100 Subject: [PATCH 39/47] `dstack show` now works with no pool provided --- src/dstack/_internal/cli/commands/pool.py | 9 +++++--- src/dstack/_internal/core/models/pools.py | 7 ++++++- src/dstack/_internal/server/routers/pools.py | 12 +++++++---- src/dstack/_internal/server/schemas/pools.py | 2 +- src/dstack/_internal/server/services/pools.py | 21 +++++++++++++------ src/dstack/api/server/_pools.py | 8 +++---- .../_internal/server/services/test_pools.py | 4 ++-- 7 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 211b3ccd9..89d47286c 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -83,7 +83,9 @@ def _register(self) -> None: formatter_class=self._parser.formatter_class, ) show_parser.add_argument( - "--pool", dest="pool_name", help="The name of the pool", required=True + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", ) show_parser.set_defaults(subfunc=self._show) @@ -175,8 +177,9 @@ def _set_default(self, args: argparse.Namespace) -> None: console.print(f"Failed to set default pool {args.pool_name!r}", style="error") def _show(self, args: argparse.Namespace) -> None: - instances = self.api.client.pool.show(self.api.project, args.pool_name) - print_instance_table(instances) + resp = self.api.client.pool.show(self.api.project, args.pool_name) + console.print(f"[bold]Pool name[/] {resp.name}") + print_instance_table(resp.instances) def _add(self, args: argparse.Namespace) -> None: super()._command(args) diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index e53e5b848..977c97a3e 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -1,5 +1,5 @@ import datetime -from typing import Optional +from typing import List, Optional from pydantic import BaseModel @@ -25,3 +25,8 @@ class Instance(BaseModel): # type: ignore[misc] hostname: str status: InstanceStatus price: float + + +class PoolInstances(BaseModel): # type: ignore[misc] + name: str + instances: List[Instance] diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 273586a85..ca185e4f4 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Tuple +from typing import List, Tuple from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -6,6 +6,7 @@ import dstack._internal.core.models.pools as models import dstack._internal.server.schemas.pools as schemas import dstack._internal.server.services.pools as pools +from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest @@ -94,12 +95,15 @@ async def create_pool( @router.post("/show") # type: ignore[misc] async def show_pool( - body: schemas.CreatePoolRequest, + body: schemas.ShowPoolRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Sequence[models.Instance]: +) -> models.PoolInstances: _, project = user_project - return await pools.show_pool(session, project, pool_name=body.name) + instances = await pools.show_pool(session, project, pool_name=body.name) + if instances is None: + raise ResourceNotExistsError("Pool is not found") + return instances @router.post("/add_remote") # type: ignore[misc] diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index 6ae3b2d2f..b171b772a 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -13,7 +13,7 @@ class CreatePoolRequest(BaseModel): # type: ignore[misc] class ShowPoolRequest(BaseModel): # type: ignore[misc] - name: str + name: Optional[str] class RemoveInstanceRequest(BaseModel): # type: ignore[misc] diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 9af84cc27..c0ff38350 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -21,7 +21,7 @@ InstanceType, Resources, ) -from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.pools import Instance, Pool, PoolInstances from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, Requirements @@ -247,11 +247,20 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: async def show_pool( - session: AsyncSession, project: ProjectModel, pool_name: str -) -> Sequence[Instance]: - """Show active instances in the pool. If the pool doesn't exist, return an empty list.""" - pool_instances = await get_pool_instances(session, project, pool_name) - return [instance_model_to_instance(i) for i in pool_instances if not i.deleted] + session: AsyncSession, project: ProjectModel, pool_name: Optional[str] +) -> Optional[PoolInstances]: + """Show active instances in the pool (specified or default). Return None if the pool is not found.""" + if pool_name is None: + pool = project.default_pool + else: + pool = await get_pool(session, project, pool_name) + + if pool is None: + return None + return PoolInstances( + name=pool.name, + instances=[instance_model_to_instance(i) for i in pool.instances if not i.deleted], + ) async def get_pool_instances( diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 4f9b9c455..9c1c8ce4f 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -3,7 +3,7 @@ from pydantic import parse_obj_as import dstack._internal.server.schemas.pools as schemas_pools -from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.pools import Pool, PoolInstances from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest @@ -24,11 +24,11 @@ def create(self, project_name: str, pool_name: str) -> None: body = schemas_pools.CreatePoolRequest(name=pool_name) self._request(f"/api/project/{project_name}/pool/create", body=body.json()) - def show(self, project_name: str, pool_name: str) -> List[Instance]: + def show(self, project_name: str, pool_name: Optional[str]) -> PoolInstances: body = schemas_pools.ShowPoolRequest(name=pool_name) resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) - result: List[Instance] = parse_obj_as(List[Instance], resp.json()) - return result + pool: PoolInstances = parse_obj_as(PoolInstances, resp.json()) + return pool def remove( self, project_name: str, pool_name: Optional[str], instance_name: str, force: bool diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index 0b8e54ee8..4f4fb5607 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -134,8 +134,8 @@ async def test_show_pool(session: AsyncSession, test_db): session.add(im) await session.commit() - instances = await services_pools.show_pool(session, project, POOL_NAME) - assert len(instances) == 1 + pool_instances = await services_pools.show_pool(session, project, POOL_NAME) + assert len(pool_instances.instances) == 1 @pytest.mark.asyncio From e0f2240e158ec5ada58ceaaabe6766b4d1977a6c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 9 Feb 2024 16:20:48 +0100 Subject: [PATCH 40/47] Improve dstack pool show formatting --- src/dstack/_internal/cli/commands/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 89d47286c..eeb98a87d 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -178,7 +178,7 @@ def _set_default(self, args: argparse.Namespace) -> None: def _show(self, args: argparse.Namespace) -> None: resp = self.api.client.pool.show(self.api.project, args.pool_name) - console.print(f"[bold]Pool name[/] {resp.name}") + console.print(f" [bold]Pool name[/] {resp.name}\n") print_instance_table(resp.instances) def _add(self, args: argparse.Namespace) -> None: From c9ce8da67f7f10e89440d26882aa1fd905529082 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 9 Feb 2024 18:37:00 +0100 Subject: [PATCH 41/47] Ask confirmation on `dstack pool remove` --- src/dstack/_internal/cli/commands/pool.py | 24 ++++++++++++++++--- src/dstack/_internal/server/schemas/pools.py | 2 +- src/dstack/_internal/server/services/pools.py | 12 ++-------- src/dstack/api/server/_pools.py | 4 +--- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index eeb98a87d..0c0ff0323 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -25,7 +25,7 @@ TerminationPolicy, ) from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE -from dstack._internal.core.models.runs import Requirements +from dstack._internal.core.models.runs import InstanceStatus, Requirements from dstack._internal.utils.common import pretty_date from dstack._internal.utils.logging import get_logger from dstack.api._public.resources import Resources @@ -136,6 +136,9 @@ def _register(self) -> None: action="store_true", help="The name of the instance", ) + remove_parser.add_argument( + "-y", "--yes", help="Don't ask for confirmation", action="store_true" + ) remove_parser.set_defaults(subfunc=self._remove) # pool set-default @@ -164,10 +167,25 @@ def _delete(self, args: argparse.Namespace) -> None: console.print(f"Pool {args.pool_name!r} removed") def _remove(self, args: argparse.Namespace) -> None: - # TODO(egor-s): ask for confirmation + pool = self.api.client.pool.show(self.api.project, args.pool_name) + pool.instances = [i for i in pool.instances if i.instance_id == args.instance_name] + if not pool.instances: + raise CLIError(f"Instance {args.instance_name!r} not found in pool {pool.name!r}") + + console.print(f" [bold]Pool name[/] {pool.name}\n") + print_instance_table(pool.instances) + + if not args.force and any(i.status == InstanceStatus.BUSY for i in pool.instances): + # TODO(egor-s): implement this logic in the server too + raise CLIError("Can't remove busy instance. Use `--force` to remove anyway") + + if not args.yes and not confirm_ask(f"Remove instance {args.instance_name!r}?"): + console.print("\nExiting...") + return + with console.status("Removing instance..."): self.api.client.pool.remove( - self.api.project, args.pool_name, args.instance_name, args.force + self.api.project, pool.name, args.instance_name, args.force ) console.print(f"Instance {args.instance_name!r} removed") diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index b171b772a..f170bb7cb 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -17,7 +17,7 @@ class ShowPoolRequest(BaseModel): # type: ignore[misc] class RemoveInstanceRequest(BaseModel): # type: ignore[misc] - pool_name: Optional[str] + pool_name: str instance_name: str force: bool = False diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index c0ff38350..3821c4a4e 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -155,19 +155,11 @@ async def set_default_pool(session: AsyncSession, project: ProjectModel, pool_na async def remove_instance( session: AsyncSession, project: ProjectModel, - pool_name: Optional[str], + pool_name: str, instance_name: str, force: bool, ) -> None: - pool = ( - await session.scalars( - select(PoolModel).where( - PoolModel.name == pool_name, - PoolModel.project == project, - PoolModel.deleted == False, - ) - ) - ).one_or_none() + pool = await get_pool(session, project, pool_name) if pool is None: logger.warning("Couldn't find pool") diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 9c1c8ce4f..28c448283 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -30,9 +30,7 @@ def show(self, project_name: str, pool_name: Optional[str]) -> PoolInstances: pool: PoolInstances = parse_obj_as(PoolInstances, resp.json()) return pool - def remove( - self, project_name: str, pool_name: Optional[str], instance_name: str, force: bool - ) -> None: + def remove(self, project_name: str, pool_name: str, instance_name: str, force: bool) -> None: body = schemas_pools.RemoveInstanceRequest( pool_name=pool_name, instance_name=instance_name, force=force ) From ddad6eed89a8254f551a7da18518db66ee061f94 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 12 Feb 2024 15:14:50 +0500 Subject: [PATCH 42/47] Fix done jobs being aborted Fixes #885. The problem was that after reading the logs, the run status wasn't done yet at the moment the CLI aborted it. Added waiting time for run to finish. --- src/dstack/_internal/cli/commands/run.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index e2de69822..4ee276fc1 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -238,10 +238,16 @@ def _command(self, args: argparse.Namespace): else: console.print("[error]Failed to attach, exiting...[/]") - run.refresh() - if run.status.is_finished(): - _print_fail_message(run) - abort_at_exit = False + # After reading the logs, the run may not be marked as finished immediately. + # Give the run some time to transit into a finished state before aborting it. + for _ in range(5): + run.refresh() + if run.status.is_finished(): + if run.status == RunStatus.FAILED: + _print_fail_message(run) + abort_at_exit = False + break + time.sleep(1) except KeyboardInterrupt: try: if not confirm_ask("\nStop the run before detaching?"): From 4c0a45d4514decdfbbadab5da54025d9379b151a Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 12 Feb 2024 13:22:01 +0300 Subject: [PATCH 43/47] Remove mypy and comments --- .pre-commit-config.yaml | 10 +--------- src/dstack/_internal/cli/commands/pool.py | 2 +- src/dstack/_internal/core/models/pools.py | 6 +++--- .../server/background/tasks/process_pools.py | 2 +- src/dstack/_internal/server/routers/pools.py | 14 +++++++------- src/dstack/_internal/server/schemas/pools.py | 10 +++++----- src/dstack/_internal/server/services/pools.py | 4 ++-- src/dstack/api/server/_pools.py | 2 +- 8 files changed, 21 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46a2d6e25..ee577a901 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,16 +6,8 @@ repos: name: ruff common args: ['--fix'] - id: ruff-format - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 - hooks: - - id: mypy - args: ['--strict', '--follow-imports=skip', '--ignore-missing-imports', '--python-version=3.8'] - files: '.*pools?\.py' - exclude: 'versions|src/tests' - additional_dependencies: [types-PyYAML, types-requests, pydantic<2, sqlalchemy] - repo: https://github.com/golangci/golangci-lint - rev: v1.56.0 + rev: v1.56.1 hooks: - id: golangci-lint-full entry: bash -c 'cd runner && golangci-lint run -D depguard --presets import,module,unused "$@"' diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 0c0ff0323..3dd7d1063 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -34,7 +34,7 @@ logger = get_logger(__name__) -class PoolCommand(APIBaseCommand): # type: ignore[misc] +class PoolCommand(APIBaseCommand): NAME = "pool" DESCRIPTION = "Pool management" diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 977c97a3e..c4c987607 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.runs import InstanceStatus, JobStatus -class Pool(BaseModel): # type: ignore[misc] +class Pool(BaseModel): name: str default: bool created_at: datetime.datetime @@ -16,7 +16,7 @@ class Pool(BaseModel): # type: ignore[misc] available_instances: int -class Instance(BaseModel): # type: ignore[misc] +class Instance(BaseModel): backend: BackendType instance_type: InstanceType instance_id: str # TODO: rename to name @@ -27,6 +27,6 @@ class Instance(BaseModel): # type: ignore[misc] price: float -class PoolInstances(BaseModel): # type: ignore[misc] +class PoolInstances(BaseModel): name: str instances: List[Instance] diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 4e3faa7a8..057b9b6b1 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -92,7 +92,7 @@ async def check_shim(instance_id: UUID) -> None: await session.commit() -@runner_ssh_tunnel(ports=[client.REMOTE_SHIM_PORT], retries=1) # type: ignore[misc] +@runner_ssh_tunnel(ports=[client.REMOTE_SHIM_PORT], retries=1) def instance_healthcheck(*, ports: Dict[int, int]) -> bool: shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT]) resp = shim_client.healthcheck() diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index ca185e4f4..2ea53616d 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -20,7 +20,7 @@ router = APIRouter(prefix="/api/project/{project_name}/pool", tags=["pool"]) -@router.post("/list") # type: ignore[misc] +@router.post("/list") async def list_pool( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), @@ -29,7 +29,7 @@ async def list_pool( return await pools.list_project_pool(session=session, project=project) -@router.post("/remove") # type: ignore[misc] +@router.post("/remove") async def remove_instance( body: schemas.RemoveInstanceRequest, session: AsyncSession = Depends(get_session), @@ -41,7 +41,7 @@ async def remove_instance( ) -@router.post("/set-default") # type: ignore[misc] +@router.post("/set-default") async def set_default_pool( body: schemas.SetDefaultPoolRequest, session: AsyncSession = Depends(get_session), @@ -51,7 +51,7 @@ async def set_default_pool( return await pools.set_default_pool(session, project_model, body.pool_name) -@router.post("/delete") # type: ignore[misc] +@router.post("/delete") async def delete_pool( body: schemas.DeletePoolRequest, session: AsyncSession = Depends(get_session), @@ -83,7 +83,7 @@ async def delete_pool( await pools.delete_pool(session, project_model, pool_name) -@router.post("/create") # type: ignore[misc] +@router.post("/create") async def create_pool( body: schemas.CreatePoolRequest, session: AsyncSession = Depends(get_session), @@ -93,7 +93,7 @@ async def create_pool( await pools.create_pool_model(session=session, project=project, name=body.name) -@router.post("/show") # type: ignore[misc] +@router.post("/show") async def show_pool( body: schemas.ShowPoolRequest, session: AsyncSession = Depends(get_session), @@ -106,7 +106,7 @@ async def show_pool( return instances -@router.post("/add_remote") # type: ignore[misc] +@router.post("/add_remote") async def add_instance( body: AddRemoteInstanceRequest, session: AsyncSession = Depends(get_session), diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py index f170bb7cb..f4eccc6f1 100644 --- a/src/dstack/_internal/server/schemas/pools.py +++ b/src/dstack/_internal/server/schemas/pools.py @@ -3,24 +3,24 @@ from pydantic import BaseModel -class DeletePoolRequest(BaseModel): # type: ignore[misc] +class DeletePoolRequest(BaseModel): name: str force: bool -class CreatePoolRequest(BaseModel): # type: ignore[misc] +class CreatePoolRequest(BaseModel): name: str -class ShowPoolRequest(BaseModel): # type: ignore[misc] +class ShowPoolRequest(BaseModel): name: Optional[str] -class RemoveInstanceRequest(BaseModel): # type: ignore[misc] +class RemoveInstanceRequest(BaseModel): pool_name: str instance_name: str force: bool = False -class SetDefaultPoolRequest(BaseModel): # type: ignore[misc] +class SetDefaultPoolRequest(BaseModel): pool_name: str diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 3821c4a4e..cd367c814 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -130,7 +130,7 @@ async def list_project_pool_models( .where(PoolModel.project_id == project.id, PoolModel.deleted == False) .options(joinedload(PoolModel.instances)) ) - return pools.unique().all() # type: ignore[no-any-return] + return pools.unique().all() async def set_default_pool(session: AsyncSession, project: ProjectModel, pool_name: str) -> bool: @@ -212,7 +212,7 @@ async def list_deleted_pools( PoolModel.project_id == project_model.id, PoolModel.deleted == True ) ) - return pools.all() # type: ignore[no-any-return] + return pools.all() def instance_model_to_instance(instance_model: InstanceModel) -> Instance: diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 28c448283..5e63f301a 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -10,7 +10,7 @@ from dstack.api.server._group import APIClientGroup -class PoolAPIClient(APIClientGroup): # type: ignore[misc] +class PoolAPIClient(APIClientGroup): def list(self, project_name: str) -> List[Pool]: resp = self._request(f"/api/project/{project_name}/pool/list") result: List[Pool] = parse_obj_as(List[Pool], resp.json()) From 16cabf688fd78101bf3bf5fd035c8363053711a3 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 12 Feb 2024 13:37:36 +0300 Subject: [PATCH 44/47] Rename url pool/set-default to pool/set_default --- src/dstack/_internal/server/routers/pools.py | 2 +- src/dstack/api/server/_pools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 2ea53616d..3b8aa026c 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -41,7 +41,7 @@ async def remove_instance( ) -@router.post("/set-default") +@router.post("/set_default") async def set_default_pool( body: schemas.SetDefaultPoolRequest, session: AsyncSession = Depends(get_session), diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py index 5e63f301a..fc5cdf7d2 100644 --- a/src/dstack/api/server/_pools.py +++ b/src/dstack/api/server/_pools.py @@ -38,7 +38,7 @@ def remove(self, project_name: str, pool_name: str, instance_name: str, force: b def set_default(self, project_name: str, pool_name: str) -> bool: body = schemas_pools.SetDefaultPoolRequest(pool_name=pool_name) - result = self._request(f"/api/project/{project_name}/pool/set-default", body=body.json()) + result = self._request(f"/api/project/{project_name}/pool/set_default", body=body.json()) return bool(result.json()) def add_remote( From 2a10473076331d430ae44bb0f53aad55851c8f46 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 12 Feb 2024 12:18:57 +0100 Subject: [PATCH 45/47] Replace instance.instance_id with instance.name --- src/dstack/_internal/cli/commands/pool.py | 4 ++-- src/dstack/_internal/core/models/pools.py | 2 +- src/dstack/_internal/server/services/pools.py | 2 +- src/tests/_internal/server/services/test_pools.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 0c0ff0323..867196e72 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -168,7 +168,7 @@ def _delete(self, args: argparse.Namespace) -> None: def _remove(self, args: argparse.Namespace) -> None: pool = self.api.client.pool.show(self.api.project, args.pool_name) - pool.instances = [i for i in pool.instances if i.instance_id == args.instance_name] + pool.instances = [i for i in pool.instances if i.name == args.instance_name] if not pool.instances: raise CLIError(f"Instance {args.instance_name!r} not found in pool {pool.name!r}") @@ -296,7 +296,7 @@ def print_instance_table(instances: Sequence[Instance]) -> None: for instance in instances: style = "success" if instance.status.is_available() else "warning" row = [ - instance.instance_id, + instance.name, instance.backend, instance.instance_type.resources.pretty_format(), f"[{style}]{instance.status.value}[/]", diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 977c97a3e..e9e93b449 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -19,7 +19,7 @@ class Pool(BaseModel): # type: ignore[misc] class Instance(BaseModel): # type: ignore[misc] backend: BackendType instance_type: InstanceType - instance_id: str # TODO: rename to name + name: str job_name: Optional[str] = None job_status: Optional[JobStatus] = None hostname: str diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 3821c4a4e..1cb3d314c 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -225,7 +225,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: instance = Instance( backend=offer.backend, - instance_id=jpd.instance_id, + name=instance_model.name, instance_type=jpd.instance_type, hostname=jpd.hostname, status=instance_model.status, diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py index 4f4fb5607..b47ce8e11 100644 --- a/src/tests/_internal/server/services/test_pools.py +++ b/src/tests/_internal/server/services/test_pools.py @@ -76,7 +76,7 @@ def test_convert_instance(): instance_type=InstanceType( name="instance", resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]) ), - instance_id="test_instance", + name="test_instance", hostname="hostname_test", status=InstanceStatus.PENDING, price=1.0, From b61ad109f74efb244830e347f0d7cdfdc307ad0c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 12 Feb 2024 12:20:50 +0100 Subject: [PATCH 46/47] dstack pool remove: name as positional argument --- src/dstack/_internal/cli/commands/pool.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 867196e72..c33b7c8e2 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -123,14 +123,15 @@ def _register(self) -> None: help="Remove instance from the pool", formatter_class=self._parser.formatter_class, ) + remove_parser.add_argument( + "instance_name", + help="The name of the instance", + ) remove_parser.add_argument( "--pool", dest="pool_name", help="The name of the pool. If not set, the default pool will be used", ) - remove_parser.add_argument( - "--name", dest="instance_name", help="The name of the instance", required=True - ) remove_parser.add_argument( "--force", action="store_true", From e2208de9ae1fd12ef9756cfdbbc69c65b13d67c5 Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Mon, 12 Feb 2024 14:52:30 +0300 Subject: [PATCH 47/47] fix review --- .../server/background/tasks/process_pools.py | 2 +- ...add_pools.py => 27d3e55759fa_add_pools.py} | 8 ++-- .../5395b4ae6c3b_add_pools_fix_optional.py | 41 ------------------- src/dstack/_internal/server/models.py | 15 ------- src/dstack/_internal/server/services/pools.py | 1 - src/dstack/_internal/server/services/runs.py | 1 - src/dstack/_internal/server/testing/common.py | 1 - 7 files changed, 5 insertions(+), 64 deletions(-) rename src/dstack/_internal/server/migrations/versions/{b55bd09bf186_add_pools.py => 27d3e55759fa_add_pools.py} (96%) delete mode 100644 src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py index 057b9b6b1..affea4aaf 100644 --- a/src/dstack/_internal/server/background/tasks/process_pools.py +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -98,7 +98,7 @@ def instance_healthcheck(*, ports: Dict[int, int]) -> bool: resp = shim_client.healthcheck() if resp is None: return False # shim is not available yet - return bool(resp.service == "dstack-shim") + return resp.service == "dstack-shim" async def terminate(instance_id: UUID) -> None: diff --git a/src/dstack/_internal/server/migrations/versions/b55bd09bf186_add_pools.py b/src/dstack/_internal/server/migrations/versions/27d3e55759fa_add_pools.py similarity index 96% rename from src/dstack/_internal/server/migrations/versions/b55bd09bf186_add_pools.py rename to src/dstack/_internal/server/migrations/versions/27d3e55759fa_add_pools.py index 73c107793..8869d3491 100644 --- a/src/dstack/_internal/server/migrations/versions/b55bd09bf186_add_pools.py +++ b/src/dstack/_internal/server/migrations/versions/27d3e55759fa_add_pools.py @@ -1,8 +1,8 @@ """add pools -Revision ID: b55bd09bf186 +Revision ID: 27d3e55759fa Revises: d3e8af4786fa -Create Date: 2024-02-06 08:44:44.235928 +Create Date: 2024-02-12 14:27:52.035476 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "b55bd09bf186" +revision = "27d3e55759fa" down_revision = "d3e8af4786fa" branch_labels = None depends_on = None @@ -66,7 +66,7 @@ def upgrade() -> None: sa.Column("started_at", sa.DateTime(), nullable=True), sa.Column("finished_at", sa.DateTime(), nullable=True), sa.Column("termination_policy", sa.String(length=50), nullable=True), - sa.Column("termination_idle_time", sa.String(length=50), nullable=True), + sa.Column("termination_idle_time", sa.Integer(), nullable=False), sa.Column( "backend", sa.Enum( diff --git a/src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py b/src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py deleted file mode 100644 index 87a48e367..000000000 --- a/src/dstack/_internal/server/migrations/versions/5395b4ae6c3b_add_pools_fix_optional.py +++ /dev/null @@ -1,41 +0,0 @@ -"""add pools fix optional - -Revision ID: 5395b4ae6c3b -Revises: b55bd09bf186 -Create Date: 2024-02-08 11:08:31.426042 - -""" -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision = "5395b4ae6c3b" -down_revision = "b55bd09bf186" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("instances", schema=None) as batch_op: - batch_op.alter_column( - "termination_idle_time", - existing_type=sa.VARCHAR(length=50), - type_=sa.Integer(), - nullable=False, - ) - - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("instances", schema=None) as batch_op: - batch_op.alter_column( - "termination_idle_time", - existing_type=sa.Integer(), - type_=sa.VARCHAR(length=50), - nullable=True, - ) - - # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 18977a159..8833c8292 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -278,7 +278,6 @@ class InstanceModel(BaseModel): pool: Mapped["PoolModel"] = relationship(back_populates="instances", single_parent=True) status: Mapped[InstanceStatus] = mapped_column(Enum(InstanceStatus)) - status_message: Mapped[Optional[str]] = mapped_column(String(50)) # VM started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) @@ -298,21 +297,7 @@ class InstanceModel(BaseModel): offer: Mapped[str] = mapped_column(String(4000)) - resource_spec_data: Mapped[Optional[str]] = mapped_column(String(4000)) - # current job job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id")) job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="immediate") last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - - # + # job_id: Optional[FK] (current job) - # ip address - # ssh creds: user, port, dockerized - # real resources + spot (exact) / instance offer - # + backend + backend data - # + region - # + price (for querying) - # + # termination policy - # creation policy - # job_provisioning_data=job_provisioning_data.json(), - # TODO: instance provisioning diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index cd367c814..107426dae 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -363,7 +363,6 @@ async def add_remote( status=InstanceStatus.PENDING, job_provisioning_data=local.json(), offer=offer.json(), - resource_spec_data=resources.json(), termination_policy=profile.termination_policy, termination_idle_time=profile.termination_idle_time, ) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 30dbe8b27..a44aa52b6 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -278,7 +278,6 @@ async def create_instance( job_provisioning_data=job_provisioning_data.json(), # TODO: instance provisioning offer=cast(InstanceOfferWithAvailability, instance_offer).json(), - resource_spec_data=requirements.resources.json(), termination_policy=profile.termination_policy, termination_idle_time=profile.termination_idle_time, ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index ac46b7bea..ef376d34a 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -321,7 +321,6 @@ async def create_instance( status=status, job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "ssh_proxy": null, "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 2, "memory_mib": 12000, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', - resource_spec_data=resources.json(), price=1, region="eu-west", backend=BackendType.DATACRUNCH,