From 94ab4e1d303b75f865b1c32b0fab0351e2e4d145 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 27 Jul 2023 13:36:55 +0500 Subject: [PATCH 1/2] Support custom run names --- cli/dstack/_internal/api/repos.py | 12 ++-- cli/dstack/_internal/backend/base/__init__.py | 11 +++- cli/dstack/_internal/backend/base/jobs.py | 11 +++- cli/dstack/_internal/backend/base/runners.py | 2 +- cli/dstack/_internal/backend/base/runs.py | 57 +++++++++++++------ cli/dstack/_internal/backend/local/compute.py | 3 +- cli/dstack/_internal/backend/local/runners.py | 12 ++++ .../_internal/cli/commands/ls/__init__.py | 1 - .../_internal/cli/commands/run/__init__.py | 31 ++++++++-- cli/dstack/_internal/cli/main.py | 1 - cli/dstack/_internal/cli/utils/config.py | 6 +- cli/dstack/_internal/core/repo/remote.py | 11 +++- cli/dstack/_internal/hub/models/__init__.py | 6 ++ cli/dstack/_internal/hub/routers/artifacts.py | 5 +- .../_internal/hub/routers/configurations.py | 5 +- cli/dstack/_internal/hub/routers/jobs.py | 13 ++--- cli/dstack/_internal/hub/routers/link.py | 7 +-- cli/dstack/_internal/hub/routers/logs.py | 17 ++++-- cli/dstack/_internal/hub/routers/repos.py | 15 +++-- cli/dstack/_internal/hub/routers/runners.py | 17 ++---- cli/dstack/_internal/hub/routers/runs.py | 49 ++++++++++------ cli/dstack/_internal/hub/routers/secrets.py | 13 ++--- cli/dstack/_internal/hub/routers/tags.py | 13 ++--- cli/dstack/_internal/hub/routers/util.py | 17 +++++- cli/dstack/api/hub/_api_client.py | 26 ++++++++- cli/dstack/api/hub/_client.py | 12 ++-- cli/requirements.txt | 1 + docs/docs/reference/cli/run.md | 31 ++++++---- setup.py | 1 + 29 files changed, 273 insertions(+), 133 deletions(-) diff --git a/cli/dstack/_internal/api/repos.py b/cli/dstack/_internal/api/repos.py index e1598d6e0..4e8a39334 100644 --- a/cli/dstack/_internal/api/repos.py +++ b/cli/dstack/_internal/api/repos.py @@ -46,7 +46,7 @@ def get_local_repo_credentials( if identity_file is not None: # must fail if key is invalid try: # user provided ssh key - return test_remote_repo_credentials( + return check_remote_repo_credentials( repo_data, RepoProtocol.SSH, identity_file=identity_file ) except GitCommandError: @@ -54,7 +54,7 @@ def get_local_repo_credentials( if oauth_token is not None: try: # user provided oauth token - return test_remote_repo_credentials( + return check_remote_repo_credentials( repo_data, RepoProtocol.HTTPS, oauth_token=oauth_token ) except GitCommandError: @@ -63,7 +63,7 @@ def get_local_repo_credentials( identities = get_host_config(original_hostname or repo_data.repo_host_name).get("identityfile") if identities: # must fail if key is invalid try: # key from ssh config - return test_remote_repo_credentials( + return check_remote_repo_credentials( repo_data, RepoProtocol.SSH, identity_file=identities[0] ) except GitCommandError: @@ -75,7 +75,7 @@ def get_local_repo_credentials( oauth_token = gh_hosts.get(repo_data.repo_host_name, {}).get("oauth_token") if oauth_token is not None: try: # token from gh config - return test_remote_repo_credentials( + return check_remote_repo_credentials( repo_data, RepoProtocol.HTTPS, oauth_token=oauth_token ) except GitCommandError: @@ -83,14 +83,14 @@ def get_local_repo_credentials( if os.path.exists(default_ssh_key): try: # default user key - return test_remote_repo_credentials( + return check_remote_repo_credentials( repo_data, RepoProtocol.SSH, identity_file=default_ssh_key ) except GitCommandError: pass -def test_remote_repo_credentials( +def check_remote_repo_credentials( repo_data: RemoteRepoData, protocol: RepoProtocol, *, diff --git a/cli/dstack/_internal/backend/base/__init__.py b/cli/dstack/_internal/backend/base/__init__.py index f29d6ae31..c2ab08a3a 100644 --- a/cli/dstack/_internal/backend/base/__init__.py +++ b/cli/dstack/_internal/backend/base/__init__.py @@ -88,6 +88,10 @@ def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[J def delete_job_head(self, repo_id: str, job_id: str): pass + @abstractmethod + def delete_run_jobs(self, repo_id: str, run_name: str): + pass + @abstractmethod def list_run_heads( self, @@ -262,8 +266,8 @@ def logging(self) -> Logging: def predict_instance_type(self, job: Job) -> Optional[InstanceType]: return base_jobs.predict_job_instance(self.compute(), job) - def create_run(self, repo_id: str) -> str: - return base_runs.create_run(self.storage()) + def create_run(self, repo_id: str, run_name: Optional[str]) -> str: + return base_runs.create_run(self.storage(), run_name) def create_job(self, job: Job): base_jobs.create_job(self.storage(), job) @@ -293,6 +297,9 @@ def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[J def delete_job_head(self, repo_id: str, job_id: str): base_jobs.delete_job_head(self.storage(), repo_id, job_id) + def delete_run_jobs(self, repo_id: str, run_name: str): + base_jobs.delete_jobs(self.storage(), repo_id, run_name) + def list_run_heads( self, repo_id: str, diff --git a/cli/dstack/_internal/backend/base/jobs.py b/cli/dstack/_internal/backend/base/jobs.py index b473b800f..219a30d45 100644 --- a/cli/dstack/_internal/backend/base/jobs.py +++ b/cli/dstack/_internal/backend/base/jobs.py @@ -96,6 +96,13 @@ def delete_job_head(storage: Storage, repo_id: str, job_id: str): storage.delete_object(job_head_key) +def delete_jobs(storage: Storage, repo_id: str, run_name: str): + job_key_run_prefix = _get_jobs_filenames_prefix(repo_id, run_name) + jobs_keys = storage.list_objects(job_key_run_prefix) + for job_key in jobs_keys: + storage.delete_object(job_key) + + def predict_job_instance( compute: Compute, job: Job, @@ -176,9 +183,9 @@ def stop_job( if new_status is not None and new_status != job.status: job.status = new_status update_job(storage, job) - if new_status.is_finished(): + if new_status in [JobStatus.TERMINATED, JobStatus.ABORTED]: if runner is not None: - runners.stop_runner(compute, runner) + runners.terminate_runner(compute, runner) def update_job_submission(job: Job): diff --git a/cli/dstack/_internal/backend/base/runners.py b/cli/dstack/_internal/backend/base/runners.py index 4fceb6603..340d85183 100644 --- a/cli/dstack/_internal/backend/base/runners.py +++ b/cli/dstack/_internal/backend/base/runners.py @@ -33,7 +33,7 @@ def delete_runner(storage: Storage, runner: Runner): storage.delete_object(_get_runner_filename(runner.runner_id)) -def stop_runner(compute: Compute, runner: Runner): +def terminate_runner(compute: Compute, runner: Runner): if runner.request_id: if runner.resources.spot: compute.cancel_spot_request(runner) diff --git a/cli/dstack/_internal/backend/base/runs.py b/cli/dstack/_internal/backend/base/runs.py index 68214524e..98083fd94 100644 --- a/cli/dstack/_internal/backend/base/runs.py +++ b/cli/dstack/_internal/backend/base/runs.py @@ -1,4 +1,5 @@ -from typing import List +import re +from typing import List, Optional import yaml @@ -7,29 +8,20 @@ from dstack._internal.backend.base.storage import Storage from dstack._internal.core.app import AppHead from dstack._internal.core.artifact import ArtifactHead +from dstack._internal.core.error import BackendValueError from dstack._internal.core.job import JobErrorCode, JobHead, JobStatus from dstack._internal.core.run import RequestStatus, RunHead, generate_remote_run_name_prefix -def create_run( - storage: Storage, -) -> str: - name = generate_remote_run_name_prefix() - run_name_index = _next_run_name_index(storage, name) - run_name = f"{name}-{run_name_index}" - return run_name - - -def _next_run_name_index(storage: Storage, run_name: str) -> int: - count = 0 - key = f"run_names/{run_name}.yaml" +def create_run(storage: Storage, run_name: Optional[str]) -> str: + if run_name is None: + return _generate_random_run_name(storage) + _validate_run_name(run_name) + key = _get_run_name_filename(run_name) obj = storage.get_object(key) if obj is None: storage.put_object(key=key, content=yaml.dump({"count": 1})) - return 1 - count = yaml.load(obj, yaml.FullLoader)["count"] - storage.put_object(key=key, content=yaml.dump({"count": count + 1})) - return count + 1 + return run_name def get_run_heads( @@ -55,6 +47,33 @@ def get_run_heads( return run_heads +def _generate_random_run_name(storage: Storage): + name = generate_remote_run_name_prefix() + run_name_index = _next_run_name_index(storage, name) + return f"{name}-{run_name_index}" + + +def _next_run_name_index(storage: Storage, run_name: str) -> int: + count = 0 + key = _get_run_name_filename(run_name) + obj = storage.get_object(key) + if obj is None: + storage.put_object(key=key, content=yaml.dump({"count": 1})) + return 1 + count = yaml.load(obj, yaml.FullLoader)["count"] + storage.put_object(key=key, content=yaml.dump({"count": count + 1})) + return count + 1 + + +def _validate_run_name(run_name: str): + if re.match(r"^[a-zA-Z0-9_-]{5,100}$", run_name) is None: + raise BackendValueError( + "Invalid run name. " + "Run name may include alphanumeric characters, dashes, and underscores, " + "and its length should be between 5 and 100 characters." + ) + + def _create_run( storage: Storage, compute: Compute, @@ -183,3 +202,7 @@ def _update_run( jobs.update_job(storage, job) run.status = job_head.status run.job_heads.append(job_head) + + +def _get_run_name_filename(run_name: str) -> str: + return f"run_names/{run_name}.yaml" diff --git a/cli/dstack/_internal/backend/local/compute.py b/cli/dstack/_internal/backend/local/compute.py index 16901e9f0..cc50475e4 100644 --- a/cli/dstack/_internal/backend/local/compute.py +++ b/cli/dstack/_internal/backend/local/compute.py @@ -34,6 +34,7 @@ def restart_instance(self, job: Job): def terminate_instance(self, runner: Runner): runners.stop_process(runner.request_id) + runners.remove_container(runner.job.run_name) def cancel_spot_request(self, runner: Runner): - runners.stop_process(runner.request_id) + pass diff --git a/cli/dstack/_internal/backend/local/runners.py b/cli/dstack/_internal/backend/local/runners.py index 5b053d6e6..915a0eaa7 100644 --- a/cli/dstack/_internal/backend/local/runners.py +++ b/cli/dstack/_internal/backend/local/runners.py @@ -7,12 +7,15 @@ from typing import Optional import cpuinfo +import docker.errors import psutil import requests import yaml from psutil import NoSuchProcess from tqdm import tqdm +import docker + from dstack import version from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.local.config import LocalConfig @@ -212,3 +215,12 @@ def is_running(request_id: str) -> bool: return True except NoSuchProcess: return False + + +def remove_container(container_name: str): + client = docker.from_env() + try: + container = client.containers.get(container_name) + except docker.errors.NotFound: + return + container.remove() diff --git a/cli/dstack/_internal/cli/commands/ls/__init__.py b/cli/dstack/_internal/cli/commands/ls/__init__.py index ab2867d0b..e1e4c23fb 100644 --- a/cli/dstack/_internal/cli/commands/ls/__init__.py +++ b/cli/dstack/_internal/cli/commands/ls/__init__.py @@ -32,7 +32,6 @@ def register(self): nargs="?", default="", ) - self._parser.add_argument( "-r", "--recursive", help="Show all files recursively", action="store_true" ) diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index 196d3f32e..ebd63cb0d 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -19,6 +19,7 @@ from dstack._internal.cli.utils.watcher import Watcher from dstack._internal.configurators.ports import PortUsedError from dstack._internal.core.error import RepoNotInitializedError +from dstack.api.hub import HubClient class RunCommand(BasicCommand): @@ -26,7 +27,7 @@ class RunCommand(BasicCommand): DESCRIPTION = "Run a configuration" def __init__(self, parser): - super().__init__(parser, store_help=True) + super().__init__(parser, store_help=False) def register(self): self._parser.add_argument( @@ -43,6 +44,11 @@ def register(self): type=str, dest="file_name", ) + self._parser.add_argument( + "-n", + "--name", + help="The name of the run. If not specified, a random name is assigned.", + ) self._parser.add_argument( "-y", "--yes", @@ -55,6 +61,11 @@ def register(self): help="Do not poll logs and run status", action="store_true", ) + self._parser.add_argument( + "--reload", + action="store_true", + help="Enable auto-reload", + ) add_project_argument(self._parser) self._parser.add_argument( "--profile", @@ -69,11 +80,6 @@ def register(self): nargs=argparse.ZERO_OR_MORE, help="Run arguments", ) - self._parser.add_argument( - "--reload", - action="store_true", - help="Enable auto-reload", - ) @check_init def _command(self, args: Namespace): @@ -99,6 +105,9 @@ def _command(self, args: Namespace): ): raise RepoNotInitializedError("No credentials", project_name=project_name) + if args.name: + _check_run_name(hub_client, args.name) + if not config.repo_user_config.ssh_key_path: ssh_key_pub = None else: @@ -126,6 +135,7 @@ def _command(self, args: Namespace): run_name, jobs = hub_client.run_configuration( configurator=configurator, ssh_key_pub=ssh_key_pub, + run_name=args.name, run_args=run_args, ) runs = list_runs_hub(hub_client, run_name=run_name) @@ -145,3 +155,12 @@ def _command(self, args: Namespace): if watcher.is_alive(): watcher.stop() watcher.join() + + +def _check_run_name(hub_client: HubClient, run_name: str): + runs = list_runs_hub(hub_client, run_name=run_name) + if len(runs) == 0: + return + if not Confirm.ask(f"[red]Run {run_name} already exist. Override?[/]"): + exit(0) + hub_client.delete_run(run_name) diff --git a/cli/dstack/_internal/cli/main.py b/cli/dstack/_internal/cli/main.py index 818e5ee8b..b70435854 100644 --- a/cli/dstack/_internal/cli/main.py +++ b/cli/dstack/_internal/cli/main.py @@ -42,7 +42,6 @@ def main(): subparsers = parser.add_subparsers(metavar="COMMAND") cli_initialize(parser=subparsers) - if len(sys.argv) < 2: parser.print_help() exit(0) diff --git a/cli/dstack/_internal/cli/utils/config.py b/cli/dstack/_internal/cli/utils/config.py index 834badb68..8e62f40f1 100644 --- a/cli/dstack/_internal/cli/utils/config.py +++ b/cli/dstack/_internal/cli/utils/config.py @@ -9,6 +9,7 @@ from dstack._internal.cli.errors import CLIError from dstack._internal.cli.profiles import load_profiles from dstack._internal.core.error import RepoNotInitializedError +from dstack._internal.core.repo.remote import RepoError from dstack._internal.core.userconfig import RepoUserConfig from dstack._internal.utils.common import get_dstack_dir from dstack.api.hub import HubClient, HubClientConfig @@ -165,7 +166,10 @@ def get_hub_client(project_name: Optional[str] = None) -> HubClient: f"No default project is configured. Call `dstack start` or `dstack config`." ) repo_config = _read_repo_config_or_error_with_project_name(project_name) - repo = load_repo(repo_config) + try: + repo = load_repo(repo_config) + except RepoError as e: + raise CLIError(e.message) hub_client_config = HubClientConfig(url=project_config.url, token=project_config.token) hub_client = HubClient(config=hub_client_config, project=project_config.name, repo=repo) return hub_client diff --git a/cli/dstack/_internal/core/repo/remote.py b/cli/dstack/_internal/core/repo/remote.py index df3e38a4b..e3b40e88f 100644 --- a/cli/dstack/_internal/core/repo/remote.py +++ b/cli/dstack/_internal/core/repo/remote.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field from typing_extensions import Literal +from dstack._internal.core.error import DstackError from dstack._internal.core.repo import RepoProtocol from dstack._internal.core.repo.base import Repo, RepoData, RepoInfo, RepoRef from dstack._internal.utils.common import PathLike @@ -17,6 +18,10 @@ from dstack._internal.utils.ssh import get_host_config, make_ssh_command_for_git +class RepoError(DstackError): + pass + + class RemoteRepoCredentials(BaseModel): protocol: RepoProtocol private_key: Optional[str] @@ -109,7 +114,7 @@ def __init__( repo = git.Repo(self.local_repo_dir) tracking_branch = repo.active_branch.tracking_branch() if tracking_branch is None: - raise ValueError("No remote branch is configured") + raise RepoError("No remote branch is configured") self.repo_url = repo.remote(tracking_branch.remote_name).url repo_data = RemoteRepoData.from_url(self.repo_url, parse_ssh_config=True) repo_data.repo_branch = tracking_branch.remote_head @@ -120,7 +125,7 @@ def __init__( elif self.repo_url is not None: repo_data = RemoteRepoData.from_url(self.repo_url, parse_ssh_config=True) elif repo_data is None: - raise ValueError("No remote repo data provided") + raise RepoError("No remote repo data provided") if repo_ref is None: repo_ref = RepoRef(repo_id=slugify(repo_data.repo_name, repo_data.path("/"))) @@ -162,7 +167,7 @@ def timeout(self): if not self.warned and now > self.start_time + self.warning_time: print( "Provisioning is taking longer than usual, possibly because of having too many or large local " - "files that haven’t been pushed to Git. Tip: Exclude unnecessary files from provisioning " + "files that haven't been pushed to Git. Tip: Exclude unnecessary files from provisioning " "by using the `.gitignore` file." ) self.warned = True diff --git a/cli/dstack/_internal/hub/models/__init__.py b/cli/dstack/_internal/hub/models/__init__.py index 019d8cf48..184d50772 100644 --- a/cli/dstack/_internal/hub/models/__init__.py +++ b/cli/dstack/_internal/hub/models/__init__.py @@ -6,6 +6,7 @@ from dstack._internal.core.job import Job from dstack._internal.core.repo import RemoteRepoCredentials, RepoSpec +from dstack._internal.core.repo.base import RepoRef from dstack._internal.core.repo.head import RepoHead from dstack._internal.core.run import RunHead from dstack._internal.core.secret import Secret @@ -286,6 +287,11 @@ class RunsGetPlan(BaseModel): jobs: List[Job] +class RunsCreate(BaseModel): + repo_ref: RepoRef + run_name: Optional[str] + + class RunsList(BaseModel): repo_id: str run_name: Optional[str] diff --git a/cli/dstack/_internal/hub/routers/artifacts.py b/cli/dstack/_internal/hub/routers/artifacts.py index b13e195f5..ca12bbf1f 100644 --- a/cli/dstack/_internal/hub/routers/artifacts.py +++ b/cli/dstack/_internal/hub/routers/artifacts.py @@ -4,9 +4,8 @@ from dstack._internal.core.artifact import Artifact from dstack._internal.hub.models import ArtifactsList -from dstack._internal.hub.routers.util import get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter( prefix="/api/project", tags=["artifacts"], dependencies=[Depends(ProjectMember())] @@ -17,7 +16,7 @@ async def list_artifacts(project_name: str, body: ArtifactsList) -> List[Artifact]: project = await get_project(project_name=project_name) backend = await get_backend(project) - artifacts = await run_async( + artifacts = await call_backend( backend.list_run_artifact_files, body.repo_id, body.run_name, diff --git a/cli/dstack/_internal/hub/routers/configurations.py b/cli/dstack/_internal/hub/routers/configurations.py index 80b28ddbc..0f9f3f6d6 100644 --- a/cli/dstack/_internal/hub/routers/configurations.py +++ b/cli/dstack/_internal/hub/routers/configurations.py @@ -2,9 +2,8 @@ from dstack._internal.core.repo import RepoRef from dstack._internal.hub.db.models import User -from dstack._internal.hub.routers.util import get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, get_backend, get_project from dstack._internal.hub.security.permissions import Authenticated, ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter( prefix="/api/project", tags=["configurations"], dependencies=[Depends(ProjectMember())] @@ -20,6 +19,6 @@ async def delete_configuration_cache( ): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async( + await call_backend( backend.delete_configuration_cache, repo_ref.repo_id, user.name, configuration_path ) diff --git a/cli/dstack/_internal/hub/routers/jobs.py b/cli/dstack/_internal/hub/routers/jobs.py index 97343437d..1e53babca 100644 --- a/cli/dstack/_internal/hub/routers/jobs.py +++ b/cli/dstack/_internal/hub/routers/jobs.py @@ -5,9 +5,8 @@ from dstack._internal.core.job import Job, JobHead from dstack._internal.hub.db.models import User from dstack._internal.hub.models import JobHeadList, JobsGet, JobsList -from dstack._internal.hub.routers.util import get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, get_backend, get_project from dstack._internal.hub.security.permissions import Authenticated, ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter(prefix="/api/project", tags=["jobs"], dependencies=[Depends(ProjectMember())]) @@ -17,7 +16,7 @@ async def create_job(project_name: str, job: Job, user: User = Depends(Authentic project = await get_project(project_name=project_name) backend = await get_backend(project) job.hub_user_name = user.name - await run_async(backend.create_job, job) + await call_backend(backend.create_job, job) return job @@ -25,7 +24,7 @@ async def create_job(project_name: str, job: Job, user: User = Depends(Authentic async def get_job(project_name: str, body: JobsGet) -> Job: project = await get_project(project_name=project_name) backend = await get_backend(project) - job = await run_async(backend.get_job, body.repo_id, body.job_id) + job = await call_backend(backend.get_job, body.repo_id, body.job_id) return job @@ -33,7 +32,7 @@ async def get_job(project_name: str, body: JobsGet) -> Job: async def list_job(project_name: str, body: JobsList) -> List[Job]: project = await get_project(project_name=project_name) backend = await get_backend(project) - jobs = await run_async(backend.list_jobs, body.repo_id, body.run_name) + jobs = await call_backend(backend.list_jobs, body.repo_id, body.run_name) return jobs @@ -41,7 +40,7 @@ async def list_job(project_name: str, body: JobsList) -> List[Job]: async def list_job_heads(project_name: str, body: JobHeadList) -> List[JobHead]: project = await get_project(project_name=project_name) backend = await get_backend(project) - job_heads = await run_async(backend.list_job_heads, body.repo_id, body.run_name) + job_heads = await call_backend(backend.list_job_heads, body.repo_id, body.run_name) return job_heads @@ -49,4 +48,4 @@ async def list_job_heads(project_name: str, body: JobHeadList) -> List[JobHead]: async def delete_job(project_name: str, body: JobsGet): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.delete_job_head, body.repo_id, body.job_id) + await call_backend(backend.delete_job_head, body.repo_id, body.job_id) diff --git a/cli/dstack/_internal/hub/routers/link.py b/cli/dstack/_internal/hub/routers/link.py index 9b815da4d..04810d6a4 100644 --- a/cli/dstack/_internal/hub/routers/link.py +++ b/cli/dstack/_internal/hub/routers/link.py @@ -4,9 +4,8 @@ from dstack._internal.backend.local import LocalBackend from dstack._internal.hub.models import StorageLink -from dstack._internal.hub.routers.util import get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter(prefix="/api/project", tags=["link"], dependencies=[Depends(ProjectMember())]) @@ -31,7 +30,7 @@ async def link_upload( token=token.credentials, ) ) - url = await run_async(backend.get_signed_upload_url, body.object_key) + url = await call_backend(backend.get_signed_upload_url, body.object_key) return url @@ -56,5 +55,5 @@ async def link_download( token=token.credentials, ) ) - url = await run_async(backend.get_signed_download_url, body.object_key) + url = await call_backend(backend.get_signed_download_url, body.object_key) return url diff --git a/cli/dstack/_internal/hub/routers/logs.py b/cli/dstack/_internal/hub/routers/logs.py index 149fc3887..88594290f 100644 --- a/cli/dstack/_internal/hub/routers/logs.py +++ b/cli/dstack/_internal/hub/routers/logs.py @@ -1,14 +1,14 @@ import itertools -from datetime import timedelta +from datetime import datetime, timedelta, timezone from typing import List from fastapi import APIRouter, Depends +from dstack._internal.core.job import Job from dstack._internal.core.log_event import LogEvent from dstack._internal.hub.models import PollLogs -from dstack._internal.hub.routers.util import get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async from dstack._internal.utils.common import get_current_datetime router = APIRouter(prefix="/api/project", tags=["logs"], dependencies=[Depends(ProjectMember())]) @@ -20,10 +20,19 @@ async def poll_logs(project_name: str, body: PollLogs) -> List[LogEvent]: project = await get_project(project_name=project_name) backend = await get_backend(project) + jobs = await call_backend(backend.list_jobs, body.repo_id, body.run_name) + if len(jobs) == 0: + return None + start_time = body.start_time if start_time is None: start_time = get_current_datetime() - timedelta(days=30) - logs_generator = await run_async( + # logs older than job_start_time may contain logs for overridden runs + job: Job = jobs[0] + job_start_time = datetime.fromtimestamp(job.created_at / 1000, timezone.utc) + start_time = max(job_start_time, start_time) + + logs_generator = await call_backend( backend.poll_logs, body.repo_id, body.run_name, diff --git a/cli/dstack/_internal/hub/routers/repos.py b/cli/dstack/_internal/hub/routers/repos.py index 526469780..c28919f12 100644 --- a/cli/dstack/_internal/hub/routers/repos.py +++ b/cli/dstack/_internal/hub/routers/repos.py @@ -4,9 +4,8 @@ from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoRef from dstack._internal.hub.models import RepoHeadGet, ReposDelete, ReposUpdate, SaveRepoCredentials -from dstack._internal.hub.routers.util import error_detail, get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter(prefix="/api/project", tags=["repos"], dependencies=[Depends(ProjectMember())]) @@ -15,7 +14,7 @@ async def list_repo_heads(project_name: str) -> List[RepoHead]: project = await get_project(project_name=project_name) backend = await get_backend(project) - repo_heads = await run_async(backend.list_repo_heads) + repo_heads = await call_backend(backend.list_repo_heads) return repo_heads @@ -23,7 +22,7 @@ async def list_repo_heads(project_name: str) -> List[RepoHead]: async def get_repo_head(project_name: str, body: RepoHeadGet) -> RepoHead: project = await get_project(project_name=project_name) backend = await get_backend(project) - repo_heads = await run_async(backend.list_repo_heads) + repo_heads = await call_backend(backend.list_repo_heads) for repo_head in repo_heads: if repo_head.repo_id == body.repo_id: return repo_head @@ -37,7 +36,7 @@ async def get_repo_head(project_name: str, body: RepoHeadGet) -> RepoHead: async def save_repo_credentials(project_name: str, body: SaveRepoCredentials): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.save_repo_credentials, body.repo_id, body.repo_credentials) + await call_backend(backend.save_repo_credentials, body.repo_id, body.repo_credentials) @router.post( @@ -46,7 +45,7 @@ async def save_repo_credentials(project_name: str, body: SaveRepoCredentials): async def get_repo_credentials(project_name: str, repo_ref: RepoRef) -> RemoteRepoCredentials: project = await get_project(project_name=project_name) backend = await get_backend(project) - repo_credentials = await run_async(backend.get_repo_credentials, repo_ref.repo_id) + repo_credentials = await call_backend(backend.get_repo_credentials, repo_ref.repo_id) if repo_credentials is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -59,7 +58,7 @@ async def get_repo_credentials(project_name: str, repo_ref: RepoRef) -> RemoteRe async def update_repo(project_name: str, body: ReposUpdate): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.update_repo_last_run_at, body.repo_spec, body.last_run_at) + await call_backend(backend.update_repo_last_run_at, body.repo_spec, body.last_run_at) @router.post("/{project_name}/repos/delete") @@ -67,4 +66,4 @@ async def delete_repos(project_name: str, body: ReposDelete): project = await get_project(project_name=project_name) backend = await get_backend(project) for repo_id in body.repo_ids: - await run_async(backend.delete_repo, repo_id) + await call_backend(backend.delete_repo, repo_id) diff --git a/cli/dstack/_internal/hub/routers/runners.py b/cli/dstack/_internal/hub/routers/runners.py index 6fc5ba03f..80fd22d3e 100644 --- a/cli/dstack/_internal/hub/routers/runners.py +++ b/cli/dstack/_internal/hub/routers/runners.py @@ -1,12 +1,11 @@ from fastapi import APIRouter, Depends, HTTPException, status from dstack._internal.core.build import BuildNotFoundError -from dstack._internal.core.error import BackendValueError, NoMatchingInstanceError +from dstack._internal.core.error import NoMatchingInstanceError from dstack._internal.core.job import Job, JobStatus from dstack._internal.hub.models import StopRunners -from dstack._internal.hub.routers.util import error_detail, get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter( prefix="/api/project", tags=["runners"], dependencies=[Depends(ProjectMember())] @@ -21,7 +20,7 @@ async def run(project_name: str, job: Job): if job.retry_policy.retry: failed_to_start_job_new_status = JobStatus.PENDING try: - await run_async(backend.run_job, job, failed_to_start_job_new_status) + await call_backend(backend.run_job, job, failed_to_start_job_new_status) except NoMatchingInstanceError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -40,17 +39,11 @@ async def run(project_name: str, job: Job): async def restart(project_name: str, job: Job): project = await get_project(project_name=project_name) backend = await get_backend(project) - try: - await run_async(backend.restart_job, job) - except BackendValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=error_detail(e.message, code=e.code), - ) + await call_backend(backend.restart_job, job) @router.post("/{project_name}/runners/stop") async def stop(project_name: str, body: StopRunners): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.stop_job, body.repo_id, body.job_id, body.terminate, body.abort) + await call_backend(backend.stop_job, body.repo_id, body.job_id, body.terminate, body.abort) diff --git a/cli/dstack/_internal/hub/routers/runs.py b/cli/dstack/_internal/hub/routers/runs.py index fc01af416..5726af0ed 100644 --- a/cli/dstack/_internal/hub/routers/runs.py +++ b/cli/dstack/_internal/hub/routers/runs.py @@ -5,17 +5,22 @@ from dstack._internal.backend.base import Backend from dstack._internal.core.build import BuildNotFoundError -from dstack._internal.core.error import NoMatchingInstanceError -from dstack._internal.core.job import Job, JobStatus +from dstack._internal.core.error import BackendValueError, NoMatchingInstanceError +from dstack._internal.core.job import JobStatus from dstack._internal.core.plan import JobPlan, RunPlan -from dstack._internal.core.repo import RepoRef from dstack._internal.core.run import RunHead from dstack._internal.hub.db.models import User -from dstack._internal.hub.models import RunInfo, RunsDelete, RunsGetPlan, RunsList, RunsStop +from dstack._internal.hub.models import ( + RunInfo, + RunsCreate, + RunsDelete, + RunsGetPlan, + RunsList, + RunsStop, +) from dstack._internal.hub.repository.projects import ProjectManager -from dstack._internal.hub.routers.util import error_detail, get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project from dstack._internal.hub.security.permissions import Authenticated, ProjectMember -from dstack._internal.hub.utils.common import run_async root_router = APIRouter(prefix="/api/runs", tags=["runs"], dependencies=[Depends(Authenticated())]) project_router = APIRouter( @@ -29,9 +34,9 @@ async def list_all_runs() -> List[RunInfo]: run_infos = [] for project in projects: backend = await get_backend(project) - repo_heads = await run_async(backend.list_repo_heads) + repo_heads = await call_backend(backend.list_repo_heads) for repo_head in repo_heads: - run_heads = await run_async( + run_heads = await call_backend( backend.list_run_heads, repo_head.repo_id, None, @@ -51,7 +56,7 @@ async def get_run_plan( backend = await get_backend(project) job_plans = [] for job in body.jobs: - instance_type = await run_async(backend.predict_instance_type, job) + instance_type = await call_backend(backend.predict_instance_type, job) if instance_type is None: msg = f"No instance type matching requirements ({job.requirements.pretty_format()})." if backend.name == "local": @@ -77,10 +82,10 @@ async def get_run_plan( response_model=str, response_class=PlainTextResponse, ) -async def create_run(project_name: str, repo_ref: RepoRef) -> str: +async def create_run(project_name: str, body: RunsCreate) -> str: project = await get_project(project_name=project_name) backend = await get_backend(project) - run_name = await run_async(backend.create_run, repo_ref.repo_id) + run_name = await call_backend(backend.create_run, body.repo_ref.repo_id, body.run_name) return run_name @@ -90,7 +95,7 @@ async def create_run(project_name: str, repo_ref: RepoRef) -> str: async def list_runs(project_name: str, body: RunsList) -> List[RunHead]: project = await get_project(project_name=project_name) backend = await get_backend(project) - run_heads = await run_async( + run_heads = await call_backend( backend.list_run_heads, body.repo_id, body.run_name, @@ -112,7 +117,7 @@ async def stop_runs(project_name: str, body: RunsStop): run_heads.append(run_head) for run_head in run_heads: for job_head in run_head.job_heads: - await run_async( + await call_backend( backend.stop_job, body.repo_id, job_head.job_id, @@ -134,23 +139,31 @@ async def delete_runs(project_name: str, body: RunsDelete): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=[ - error_detail( - f"Run {run_name} is not finished", code="cannot_delete_unfinished_run" - ) + error_detail(f"Run {run_name} is not finished", code=BackendValueError.code) ], ) run_heads.append(run_head) for run_head in run_heads: for job_head in run_head.job_heads: - await run_async( + if job_head.status == JobStatus.STOPPED: + # Force termination of a stopped run + await call_backend( + backend.stop_job, + body.repo_id, + job_head.job_id, + True, + True, + ) + await call_backend( backend.delete_job_head, body.repo_id, job_head.job_id, ) + await call_backend(backend.delete_run_jobs, body.repo_id, run_head.run_name) async def _get_run_head(backend: Backend, repo_id: str, run_name: str) -> RunHead: - run_head = await run_async( + run_head = await call_backend( backend.get_run_head, repo_id, run_name, diff --git a/cli/dstack/_internal/hub/routers/secrets.py b/cli/dstack/_internal/hub/routers/secrets.py index e843b90ab..ac9e21d92 100644 --- a/cli/dstack/_internal/hub/routers/secrets.py +++ b/cli/dstack/_internal/hub/routers/secrets.py @@ -5,9 +5,8 @@ from dstack._internal.core.repo import RepoRef from dstack._internal.core.secret import Secret from dstack._internal.hub.models import SecretAddUpdate -from dstack._internal.hub.routers.util import error_detail, get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter( prefix="/api/project", tags=["secrets"], dependencies=[Depends(ProjectMember())] @@ -18,7 +17,7 @@ async def list_secrets(project_name: str, repo_ref: RepoRef) -> List[str]: project = await get_project(project_name=project_name) backend = await get_backend(project) - secrets_names = await run_async(backend.list_secret_names, repo_ref.repo_id) + secrets_names = await call_backend(backend.list_secret_names, repo_ref.repo_id) return secrets_names @@ -26,21 +25,21 @@ async def list_secrets(project_name: str, repo_ref: RepoRef) -> List[str]: async def add_secret(project_name: str, body: SecretAddUpdate): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.add_secret, body.repo_id, body.secret) + await call_backend(backend.add_secret, body.repo_id, body.secret) @router.post("/{project_name}/secrets/update") async def update_secret(project_name: str, body: SecretAddUpdate): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.update_secret, body.repo_id, body.secret) + await call_backend(backend.update_secret, body.repo_id, body.secret) @router.post("/{project_name}/secrets/{secret_name}/get") async def get_secret(project_name: str, secret_name: str, repo_ref: RepoRef) -> Secret: project = await get_project(project_name=project_name) backend = await get_backend(project) - secret = await run_async(backend.get_secret, repo_ref.repo_id, secret_name) + secret = await call_backend(backend.get_secret, repo_ref.repo_id, secret_name) if secret is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=error_detail("Secret not found") @@ -52,4 +51,4 @@ async def get_secret(project_name: str, secret_name: str, repo_ref: RepoRef) -> async def delete_secret(project_name: str, secret_name: str, repo_ref: RepoRef): project = await get_project(project_name=project_name) backend = await get_backend(project) - await run_async(backend.delete_secret, repo_ref.repo_id, secret_name) + await call_backend(backend.delete_secret, repo_ref.repo_id, secret_name) diff --git a/cli/dstack/_internal/hub/routers/tags.py b/cli/dstack/_internal/hub/routers/tags.py index 72cb4c7ac..64faae5f6 100644 --- a/cli/dstack/_internal/hub/routers/tags.py +++ b/cli/dstack/_internal/hub/routers/tags.py @@ -5,9 +5,8 @@ from dstack._internal.core.repo import RepoRef from dstack._internal.core.tag import TagHead from dstack._internal.hub.models import AddTagPath, AddTagRun -from dstack._internal.hub.routers.util import error_detail, get_backend, get_project +from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project from dstack._internal.hub.security.permissions import ProjectMember -from dstack._internal.hub.utils.common import run_async router = APIRouter(prefix="/api/project", tags=["tags"], dependencies=[Depends(ProjectMember())]) @@ -18,7 +17,7 @@ async def list_heads_tags(project_name: str, repo_ref: RepoRef) -> List[TagHead]: project = await get_project(project_name=project_name) backend = await get_backend(project) - tags = await run_async(backend.list_tag_heads, repo_ref.repo_id) + tags = await call_backend(backend.list_tag_heads, repo_ref.repo_id) return tags @@ -29,7 +28,7 @@ async def list_heads_tags(project_name: str, repo_ref: RepoRef) -> List[TagHead] async def get_tag(project_name: str, tag_name: str, repo_ref: RepoRef) -> TagHead: project = await get_project(project_name=project_name) backend = await get_backend(project) - tag = await run_async(backend.get_tag_head, repo_ref.repo_id, tag_name) + tag = await call_backend(backend.get_tag_head, repo_ref.repo_id, tag_name) if tag is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=error_detail("Tag not found") @@ -41,8 +40,8 @@ async def get_tag(project_name: str, tag_name: str, repo_ref: RepoRef) -> TagHea async def delete_tag(project_name: str, tag_name: str, repo_ref: RepoRef): project = await get_project(project_name=project_name) backend = await get_backend(project) - tag = await run_async(backend.get_tag_head, repo_ref.repo_id, tag_name) - await run_async(backend.delete_tag_head, repo_ref.repo_id, tag) + tag = await call_backend(backend.get_tag_head, repo_ref.repo_id, tag_name) + await call_backend(backend.delete_tag_head, repo_ref.repo_id, tag) @router.post("/{project_name}/tags/add/run") @@ -50,7 +49,7 @@ async def add_tag_from_run(project_name: str, body: AddTagRun): project = await get_project(project_name=project_name) backend = await get_backend(project) # todo pass error to CLI if tag already exists - await run_async( + await call_backend( backend.add_tag_from_run, body.repo_id, body.tag_name, body.run_name, body.run_jobs ) diff --git a/cli/dstack/_internal/hub/routers/util.py b/cli/dstack/_internal/hub/routers/util.py index 0691210e4..5b9d90c6e 100644 --- a/cli/dstack/_internal/hub/routers/util.py +++ b/cli/dstack/_internal/hub/routers/util.py @@ -3,12 +3,17 @@ from fastapi import HTTPException, status from dstack._internal.backend.base import Backend -from dstack._internal.core.error import BackendAuthError, BackendNotAvailableError +from dstack._internal.core.error import ( + BackendAuthError, + BackendNotAvailableError, + BackendValueError, +) from dstack._internal.hub.models import Project from dstack._internal.hub.repository.projects import ProjectManager from dstack._internal.hub.services.backends import cache as backends_cache from dstack._internal.hub.services.backends import get_configurator from dstack._internal.hub.services.backends.base import Configurator +from dstack._internal.hub.utils.common import run_async async def get_project(project_name: str) -> Project: @@ -57,3 +62,13 @@ def error_detail(msg: str, code: Optional[str] = None, **kwargs) -> Dict: "code": code, **kwargs, } + + +async def call_backend(func, *args): + try: + return await run_async(func, *args) + except BackendValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=error_detail(e.message, code=e.code), + ) diff --git a/cli/dstack/api/hub/_api_client.py b/cli/dstack/api/hub/_api_client.py index c641ba326..1e47781c1 100644 --- a/cli/dstack/api/hub/_api_client.py +++ b/cli/dstack/api/hub/_api_client.py @@ -27,6 +27,8 @@ PollLogs, ProjectInfo, ReposUpdate, + RunsCreate, + RunsDelete, RunsGetPlan, RunsList, SaveRepoCredentials, @@ -91,7 +93,7 @@ def get_run_plan(self, jobs: List[Job]) -> RunPlan: raise HubClientError(body["detail"]["msg"]) resp.raise_for_status() - def create_run(self) -> str: + def create_run(self, run_name: str) -> str: url = _project_url( url=self.url, project=self.project, @@ -102,7 +104,7 @@ def create_run(self) -> str: host=self.url, url=url, headers=self._headers(), - data=self.repo.repo_ref.json(), + data=RunsCreate(repo_ref=self.repo.repo_ref, run_name=run_name).json(), ) if resp.ok: return resp.text @@ -377,6 +379,26 @@ def list_run_heads( return [RunHead.parse_obj(run) for run in body] resp.raise_for_status() + def delete_runs(self, run_names: List[str]): + url = _project_url( + url=self.url, + project=self.project, + additional_path=f"/runs/delete", + ) + resp = _make_hub_request( + requests.post, + host=self.url, + url=url, + headers=self._headers(), + data=RunsDelete( + repo_id=self.repo.repo_id, + run_names=run_names, + ).json(), + ) + if resp.ok: + return + resp.raise_for_status() + def update_repo_last_run_at(self, last_run_at: int): url = _project_url( url=self.url, diff --git a/cli/dstack/api/hub/_client.py b/cli/dstack/api/hub/_client.py index db84adca1..8318c6f41 100644 --- a/cli/dstack/api/hub/_client.py +++ b/cli/dstack/api/hub/_client.py @@ -67,8 +67,8 @@ def get_project_backend_type(self) -> str: def _get_project_info(self) -> ProjectInfo: return self._api_client.get_project_info() - def create_run(self) -> str: - return self._api_client.create_run() + def create_run(self, run_name: Optional[str]) -> str: + return self._api_client.create_run(run_name) def create_job(self, job: Job): self._api_client.create_job(job=job) @@ -112,7 +112,7 @@ def delete_job_heads(self, run_name: Optional[str]): job_heads.append(job_head) else: if run_name: - sys.exit("The run is not finished yet. Stop the run first.") + raise HubClientError("The run is not finished yet. Stop the run first.") for job_head in job_heads: self.delete_job_head(job_head.job_id) @@ -128,6 +128,9 @@ def list_run_heads( include_request_heads=include_request_heads, ) + def delete_run(self, run_name: str): + self._api_client.delete_runs([run_name]) + def poll_logs( self, run_name: str, @@ -275,9 +278,10 @@ def run_configuration( self, configurator: "configurators.JobConfigurator", ssh_key_pub: str, + run_name: Optional[str] = None, run_args: Optional[List[str]] = None, ) -> Tuple[str, List[Job]]: - run_name = self.create_run() + run_name = self.create_run(run_name) configurator = copy.deepcopy(configurator) configurator.inject_context( {"run": {"name": run_name, "args": configurator.join_run_args(run_args)}} diff --git a/cli/requirements.txt b/cli/requirements.txt index 25648b967..992c6d941 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -28,6 +28,7 @@ psutil cryptography filelock watchfiles +docker # AWS boto3 diff --git a/docs/docs/reference/cli/run.md b/docs/docs/reference/cli/run.md index 36ca1b149..04a7d815c 100644 --- a/docs/docs/reference/cli/run.md +++ b/docs/docs/reference/cli/run.md @@ -8,18 +8,24 @@ This command runs a given configuration. ```shell $ dstack run --help -Usage: dstack run [--project PROJECT] [--profile PROFILE] [-d] [--reload] WORKING_DIR [ARGS ...] +Usage: dstack run [-h] [-f FILE] [-n NAME] [-y] [-d] [--project PROJECT] [--profile PROFILE] + [--reload] + WORKING_DIR [ARGS ...] Positional Arguments: - WORKING_DIR The working directory of the run - ARGS Run arguments + WORKING_DIR The working directory of the run + ARGS Run arguments Options: - --f FILE The path to the run configuration file. Defaults to WORKING_DIR/.dstack.yml. - --project PROJECT The name of the project - --profile PROFILE The name of the profile - -d, --detach Do not poll logs and run status - --reload Enable auto-reload + -h, --help Show this help message and exit + -f, --file FILE The path to the run configuration file. Defaults to + WORKING_DIR/.dstack.yml. + -n, --name NAME The name of the run. If not specified, a random name is assigned. + -y, --yes Do not ask for plan confirmation + -d, --detach Do not poll logs and run status + --reload Enable auto-reload + --project PROJECT The name of the project + --profile PROFILE The name of the profile ``` @@ -35,11 +41,12 @@ The following arguments are required: The following arguments are optional: - `-f FILE`, `--f FILE` – (Optional) The path to the run configuration file. Defaults to `WORKING_DIR/.dstack.yml`. +- `-n NAME`, `--name NAME` - (Optional) The name of the run. If not specified, a random name is assigned. +- `-y`, `--yes` - (Optional) Do not ask for plan confirmation +- `-d`, `--detach` – (Optional) Run in the detached mode to disable logs and run status polling. By default, the run is in the attached mode, so the logs are printed in real-time. +- `--reload` – (Optional) Enable auto-reload - `--project PROJECT` – (Optional) The name of the project - `--profile PROJECT` – (Optional) The name of the profile -- `--reload` – (Optional) Enable auto-reload -- `-d`, `--detach` – (Optional) Run in the detached mode. Means, the command doesn't - poll logs and run status. [//]: # (- `-t TAG`, `--tag TAG` – (Optional) A tag name. Warning, if the tag exists, it will be overridden.) - `-p PORT [PORT ...]`, `--ports PORT [PORT ...]` – (Optional) Requests ports or define mappings for them (`LOCAL_PORT:CONTAINER_PORT`) @@ -69,4 +76,4 @@ Build policies: [//]: # (Tags should be dropped) !!! info "NOTE:" - By default, it runs it in the attached mode, so you'll see the output in real-time. + By default, the run is in the attached mode, so you'll see the output in real-time. diff --git a/setup.py b/setup.py index 52307cc36..ec95e60d7 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def get_long_description(): "grpcio>=1.50", # indirect "filelock", "watchfiles", + "docker>=6.0.0", ] AWS_DEPS = [ From 9863e2578c0760d0cba61bee141a9f1b42881842 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 27 Jul 2023 14:06:43 +0500 Subject: [PATCH 2/2] Fix dstack rm --- cli/dstack/_internal/backend/local/compute.py | 3 +++ cli/dstack/_internal/backend/local/runners.py | 14 ++++++++++---- .../_internal/cli/commands/build/__init__.py | 2 +- .../_internal/cli/commands/rm/__init__.py | 18 ++++-------------- .../_internal/cli/commands/run/__init__.py | 3 --- cli/dstack/_internal/hub/routers/runs.py | 2 +- cli/dstack/api/hub/_api_client.py | 2 ++ 7 files changed, 21 insertions(+), 23 deletions(-) diff --git a/cli/dstack/_internal/backend/local/compute.py b/cli/dstack/_internal/backend/local/compute.py index cc50475e4..71e05430b 100644 --- a/cli/dstack/_internal/backend/local/compute.py +++ b/cli/dstack/_internal/backend/local/compute.py @@ -3,6 +3,7 @@ from dstack._internal.backend.base.compute import Compute, choose_instance_type from dstack._internal.backend.local import runners from dstack._internal.backend.local.config import LocalConfig +from dstack._internal.core.error import BackendValueError from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead @@ -29,6 +30,8 @@ def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanc return LaunchedInstanceInfo(request_id=pid, location=None) def restart_instance(self, job: Job): + if runners.get_container(job.run_name) is None: + raise BackendValueError("Container not found") pid = runners.start_runner_process(self.backend_config, job.runner_id) return LaunchedInstanceInfo(request_id=pid, location=None) diff --git a/cli/dstack/_internal/backend/local/runners.py b/cli/dstack/_internal/backend/local/runners.py index 915a0eaa7..99a50845b 100644 --- a/cli/dstack/_internal/backend/local/runners.py +++ b/cli/dstack/_internal/backend/local/runners.py @@ -11,6 +11,7 @@ import psutil import requests import yaml +from docker.models.containers import Container from psutil import NoSuchProcess from tqdm import tqdm @@ -217,10 +218,15 @@ def is_running(request_id: str) -> bool: return False -def remove_container(container_name: str): +def get_container(container_name: str) -> Optional[Container]: client = docker.from_env() try: - container = client.containers.get(container_name) + return client.containers.get(container_name) except docker.errors.NotFound: - return - container.remove() + return None + + +def remove_container(container_name: str): + container = get_container(container_name) + if container is not None: + container.remove() diff --git a/cli/dstack/_internal/cli/commands/build/__init__.py b/cli/dstack/_internal/cli/commands/build/__init__.py index dd7050fd3..b52bfb640 100644 --- a/cli/dstack/_internal/cli/commands/build/__init__.py +++ b/cli/dstack/_internal/cli/commands/build/__init__.py @@ -39,13 +39,13 @@ def register(self): type=str, dest="file_name", ) - add_project_argument(self._parser) self._parser.add_argument( "-y", "--yes", help="Do not ask for plan confirmation", action="store_true", ) + add_project_argument(self._parser) self._parser.add_argument( "--profile", metavar="PROFILE", diff --git a/cli/dstack/_internal/cli/commands/rm/__init__.py b/cli/dstack/_internal/cli/commands/rm/__init__.py index 3ab3332f5..e6f884d98 100644 --- a/cli/dstack/_internal/cli/commands/rm/__init__.py +++ b/cli/dstack/_internal/cli/commands/rm/__init__.py @@ -37,22 +37,12 @@ def _command(self, args: Namespace): and (args.yes or Confirm.ask(f"[red]Delete the run '{args.run_name}'?[/]")) ) or (args.all and (args.yes or Confirm.ask("[red]Delete all runs?[/]"))): hub_client = get_hub_client(project_name=args.project) - deleted_run = False - job_heads = hub_client.list_job_heads(args.run_name) - if job_heads: - finished_job_heads = [] - for job_head in job_heads: - if job_head.status.is_finished(): - finished_job_heads.append(job_head) - elif args.run_name: - console.print("The run is not finished yet. Stop the run first.") - exit(1) - for job_head in finished_job_heads: - hub_client.delete_job_head(job_head.job_id) - deleted_run = True - if args.run_name and not deleted_run: + run_heads = hub_client.list_run_heads(args.run_name) + if len(run_heads) == 0 and args.run_name: console.print(f"Cannot find the run '{args.run_name}'") exit(1) + for run_head in run_heads: + hub_client.delete_run(run_head.run_name) console.print(f"[grey58]OK[/]") else: if not args.run_name and not args.all: diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index ebd63cb0d..fe33f106b 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -84,9 +84,6 @@ def register(self): @check_init def _command(self, args: Namespace): configurator = load_configuration(args.working_dir, args.file_name, args.profile_name) - if args.help: - configurator.get_parser(parser=copy.deepcopy(self._parser)).print_help() - exit(0) project_name = None if args.project: diff --git a/cli/dstack/_internal/hub/routers/runs.py b/cli/dstack/_internal/hub/routers/runs.py index 5726af0ed..765be6ea5 100644 --- a/cli/dstack/_internal/hub/routers/runs.py +++ b/cli/dstack/_internal/hub/routers/runs.py @@ -173,5 +173,5 @@ async def _get_run_head(backend: Backend, repo_id: str, run_name: str) -> RunHea return run_head raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=[error_detail(f"Run {run_name} not found", code="run_not_found")], + detail=[error_detail(f"Run {run_name} not found", code=BackendValueError.code)], ) diff --git a/cli/dstack/api/hub/_api_client.py b/cli/dstack/api/hub/_api_client.py index 1e47781c1..4a5b591a5 100644 --- a/cli/dstack/api/hub/_api_client.py +++ b/cli/dstack/api/hub/_api_client.py @@ -718,6 +718,8 @@ def _make_hub_request(request_func, host, *args, **kwargs) -> requests.Response: body = resp.json() detail = body.get("detail") if detail is not None: + if isinstance(detail, list): + detail = detail[0] if detail.get("code") == BackendNotAvailableError.code: raise HubClientError(detail["msg"]) elif detail.get("code") == BackendValueError.code: