From 9b13e28ce5885065ff7298fafc0d176fd588080c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Wed, 5 Jul 2023 16:09:24 +0400 Subject: [PATCH 01/26] Fix method return type hints --- cli/dstack/_internal/backend/azure/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/dstack/_internal/backend/azure/compute.py b/cli/dstack/_internal/backend/azure/compute.py index 8205d2800..8f37bc18e 100644 --- a/cli/dstack/_internal/backend/azure/compute.py +++ b/cli/dstack/_internal/backend/azure/compute.py @@ -179,7 +179,7 @@ def _vm_type_available(vm_resource: ResourceSku) -> bool: return False -def _get_gpu_name_memory(vm_name: str) -> Tuple[int, str]: +def _get_gpu_name_memory(vm_name: str) -> Tuple[str, int]: if re.match(r"^Standard_NC\d+ads_A100_v4$", vm_name): return "A100", 80 * 1024 if re.match(r"^Standard_NC\d+as_T4_v3$", vm_name): From 5a32b9920ff3adf45964ab04dea41a01ef7bb796 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Wed, 5 Jul 2023 16:54:58 +0400 Subject: [PATCH 02/26] Implement configuration pydantic models --- cli/dstack/_internal/core/configuration.py | 61 ++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 cli/dstack/_internal/core/configuration.py diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py new file mode 100644 index 000000000..2609828d1 --- /dev/null +++ b/cli/dstack/_internal/core/configuration.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Union + +from pydantic import BaseModel, Extra, Field, validator +from typing_extensions import Annotated, Literal + +PythonVersions = Literal["3.7", "3.8", "3.9", "3.10", "3.11"] + + +class ForbidExtra(BaseModel): + class Config: + extra = Extra.forbid + + +class RegistryAuth(ForbidExtra): + username: Optional[str] + password: str + + +class Artifact(ForbidExtra): + path: str + mount: bool = False + + +class BaseConfiguration(ForbidExtra): + image: Optional[str] + registry_auth: Optional[RegistryAuth] + python: Optional[PythonVersions] + ports: List[Union[str, int]] = [] + env: List[str] = [] + build: List[str] = [] + cache: List[str] = [] + + @validator("python", pre=True) + def convert_python(cls, v) -> str: + if isinstance(v, float): + v = str(v) + if v == "3.1": + v = "3.10" + return v + + +class DevEnvironmentConfiguration(BaseConfiguration): + type: Literal["dev-environment"] = "dev-environment" + ide: Literal["vscode"] + init: List[str] = [] + + +class TaskConfiguration(BaseConfiguration): + type: Literal["task"] = "task" + commands: List[str] + artifacts: List[Artifact] = [] + + +class DstackConfiguration(BaseModel): + __root__: Annotated[ + Union[DevEnvironmentConfiguration, TaskConfiguration], Field(discriminator="type") + ] + + +def parse(data: dict) -> Union[DevEnvironmentConfiguration, TaskConfiguration]: + return DstackConfiguration.parse_obj(data).__root__ From 7bb48642f45a6294ddde8bbffeee394928019a33 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 6 Jul 2023 18:10:56 +0400 Subject: [PATCH 03/26] Outline configurator interface --- .../_internal/cli/commands/run/__init__.py | 2 + .../_internal/configurators/__init__.py | 238 ++++++++++++++++++ .../configurators/dev_environment.py | 6 + cli/dstack/_internal/configurators/task.py | 6 + cli/dstack/_internal/core/configuration.py | 7 +- cli/dstack/_internal/core/job.py | 9 +- cli/dstack/_internal/core/profile.py | 76 ++++++ 7 files changed, 338 insertions(+), 6 deletions(-) create mode 100644 cli/dstack/_internal/configurators/__init__.py create mode 100644 cli/dstack/_internal/configurators/dev_environment.py create mode 100644 cli/dstack/_internal/configurators/task.py create mode 100644 cli/dstack/_internal/core/profile.py diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index fbda9d4bd..b3189a55c 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -117,6 +117,8 @@ def _command(self, args: Namespace): ) if args.project: project_name = args.project + if args.help: + pass # todo watcher = Watcher(os.getcwd()) try: if args.reload: diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py new file mode 100644 index 000000000..c0a315ef0 --- /dev/null +++ b/cli/dstack/_internal/configurators/__init__.py @@ -0,0 +1,238 @@ +import argparse +import json +import sys +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from rich_argparse import RichHelpFormatter + +import dstack._internal.core.job as job +import dstack._internal.providers.ports as ports +import dstack.version as version +from dstack._internal.core.configuration import BaseConfiguration +from dstack._internal.core.error import DstackError +from dstack._internal.core.profile import Profile +from dstack._internal.core.repo import Repo +from dstack._internal.utils.common import get_milliseconds_since_epoch +from dstack._internal.utils.interpolator import VariablesInterpolator + + +class JobConfigurator(ABC): + def __init__( + self, + working_dir: str, + configuration_path: str, + configuration: BaseConfiguration, + profile: Profile, + ): + self.configuration_path = configuration_path + self.working_dir = working_dir + self.conf = configuration + self.profile = profile + self.build_policy = "use-build" + + def print_help(self, prog: str = "dstack run"): + parser = self.get_parser(prog) + parser.print_help() + + def get_parser(self, prog: Optional[str] = None) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog=prog, formatter_class=RichHelpFormatter) + + spot_group = parser.add_mutually_exclusive_group() + spot_group.add_argument( + "--spot", action="store_const", dest="spot_policy", const=job.SpotPolicy.SPOT + ) + spot_group.add_argument( + "--on-demand", action="store_const", dest="spot_policy", const=job.SpotPolicy.ONDEMAND + ) + spot_group.add_argument( + "--spot-auto", action="store_const", dest="spot_policy", const=job.SpotPolicy.AUTO + ) + spot_group.add_argument("--spot-policy", type=job.SpotPolicy, dest="spot_policy") + + retry_group = parser.add_mutually_exclusive_group() + retry_group.add_argument("--retry", action="store_true") + retry_group.add_argument("--no-retry", action="store_true") + retry_group.add_argument("--retry-limit", type=str) + + build_policy = parser.add_mutually_exclusive_group() + for value in job.BuildPolicy: + build_policy.add_argument( + f"--{value}", action="store_const", dest="build_policy", const=value + ) + + return parser + + def apply_args(self, args: argparse.Namespace): + if args.spot_policy is not None: + self.profile.spot_policy = args.spot_policy + if args.retry: + self.profile.retry_policy.retry = True + elif args.no_retry: + self.profile.retry_policy.retry = False + elif args.retry_limit: + self.profile.retry_policy.retry = True + self.profile.retry_policy.limit = args.retry_limit + if args.build_policy is not None: + self.build_policy = args.build_policy + + def inject_context( + self, namespaces: Dict[str, Dict[str, str]], skip: Optional[List[str]] = None + ): + if skip is None: + skip = ["secrets"] + vi = VariablesInterpolator(namespaces, skip=skip) + + def interpolate(obj): + if isinstance(obj, str): + return vi.interpolate(obj) + if isinstance(obj, dict): + return {k: interpolate(v) for k, v in obj.items()} + if isinstance(obj, list): + return [interpolate(i) for i in obj] + return obj + + conf = json.loads(self.conf.json()) + conf = interpolate(conf) + self.conf = type(self.conf).parse_obj(conf) + + def get_jobs( + self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str + ) -> List[job.Job]: + created_at = get_milliseconds_since_epoch() + configured_job = job.Job( + job_id=f"{run_name},,0", + runner_id=uuid.uuid4().hex, + repo_ref=repo.repo_ref, + repo_data=repo.repo_data, + repo_code_filename=repo_code_filename, + run_name=run_name, + configuration_type=job.ConfigurationType(self.conf.type), + configuration_path=self.configuration_path, + status=job.JobStatus.SUBMITTED, + created_at=created_at, + submitted_at=created_at, + image_name=self.image_name(), + registry_auth=self.registry_auth(), + entrypoint=self.entrypoint(), + build_commands=self.build_commands(), + optional_build_commands=self.optional_build_commands(), + commands=self.commands(), + working_dir=self.working_dir, + home_dir=self.home_dir(), + env=self.env(), + artifact_specs=self.artifact_specs(), + cache_specs=self.cache_specs(), + app_specs=self.app_specs(), + dep_specs=self.dep_specs(), + spot_policy=self.spot_policy(), + retry_policy=self.retry_policy(), + build_policy=self.build_policy, + requirements=self.requirements(), + ssh_key_pub=ssh_key_pub, + ) + return [configured_job] + + @abstractmethod + def commands(self) -> List[str]: + pass + + @abstractmethod + def build_commands(self) -> List[str]: + pass + + @abstractmethod + def optional_build_commands(self) -> List[str]: + pass + + @abstractmethod + def artifact_specs(self) -> List[job.ArtifactSpec]: + pass + + @abstractmethod + def dep_specs(self): + pass + + def entrypoint(self) -> Optional[List[str]]: + if self.conf.image is None or self.commands(): + return ["/bin/sh", "-i", "-c"] + return None + + def home_dir(self) -> Optional[str]: + return "/root" if self.conf.image is None else None + + def image_name(self) -> str: + if self.conf.image is not None: + return self.conf.image + if self.profile.resources and self.profile.resources.gpu: + return f"dstackai/miniforge:py{self.python()}-{version.miniforge_image}-cuda-11.4" + return f"dstackai/miniforge:py{self.python()}-{version.miniforge_image}" + + def spot_policy(self) -> job.SpotPolicy: + return self.profile.spot_policy or job.SpotPolicy.AUTO + + def retry_policy(self) -> job.RetryPolicy: + return job.RetryPolicy.parse_obj(self.profile.retry_policy.dict()) + + def cache_specs(self) -> List[job.CacheSpec]: + return [ + job.CacheSpec(path=validate_local_path(path, self.home_dir(), self.working_dir)) + for path in self.conf.cache + ] + + def registry_auth(self) -> Optional[job.RegistryAuth]: + if self.conf.registry_auth is None: + return None + return job.RegistryAuth.parse_obj(self.conf.registry_auth.dict()) + + def app_specs(self) -> List[job.AppSpec]: + specs = [] + for i, pm in enumerate(ports.filter_reserved_ports(self.ports())): + specs.append( + job.AppSpec( + port=pm.port, + map_to_port=pm.map_to_port, + app_name=f"app_{i}", + ) + ) + return specs + + def python(self) -> str: + if self.conf.python is not None: + return self.conf.python + version_info = sys.version_info + return f"{version_info.major}.{version_info.minor}" # todo check is in supported + + def ports(self) -> Dict[int, ports.PortMapping]: + mapping = [ports.PortMapping(p) for p in self.conf.ports] + ports.unique_ports_constraint([pm.port for pm in mapping]) + ports.unique_ports_constraint( + [pm.map_to_port for pm in mapping if pm.map_to_port is not None], + error="Mapped port {} is already in use", + ) + return {pm.port: pm for pm in mapping} + + def env(self) -> Dict[str, str]: + return self.conf.env + + def requirements(self) -> job.Requirements: + return job.Requirements.parse_obj(self.profile.resources.dict()) + + +def validate_local_path(path: str, home: Optional[str], working_dir: str) -> str: + if path == "~" or path.startswith("~/"): + if home is None: + raise HomeDirUnsetError("home_dir is not defined, local path can't start with ~") + path = home if path == "~" else f"{home}/{path[len('~/'):]}" + while path.startswith("./"): + path = path[len("./") :] + if not path.startswith("/"): + path = "/".join( + ["/workflow", path] if working_dir == "." else ["/workflow", working_dir, path] + ) + return path + + +class HomeDirUnsetError(DstackError): + pass diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py new file mode 100644 index 000000000..9f3875c39 --- /dev/null +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -0,0 +1,6 @@ +from dstack._internal.configurators import JobConfigurator +from dstack._internal.core.configuration import DevEnvironmentConfiguration + + +class DevEnvironmentConfigurator(JobConfigurator): + conf: DevEnvironmentConfiguration diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py new file mode 100644 index 000000000..a1f06531d --- /dev/null +++ b/cli/dstack/_internal/configurators/task.py @@ -0,0 +1,6 @@ +from dstack._internal.configurators import JobConfigurator +from dstack._internal.core.configuration import TaskConfiguration + + +class TaskConfigurator(JobConfigurator): + conf: TaskConfiguration diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py index 2609828d1..c63abd8aa 100644 --- a/cli/dstack/_internal/core/configuration.py +++ b/cli/dstack/_internal/core/configuration.py @@ -1,8 +1,9 @@ -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel, Extra, Field, validator from typing_extensions import Annotated, Literal +# todo use Enum PythonVersions = Literal["3.7", "3.8", "3.9", "3.10", "3.11"] @@ -22,11 +23,13 @@ class Artifact(ForbidExtra): class BaseConfiguration(ForbidExtra): + type: Literal["none"] image: Optional[str] + # todo entrypoint registry_auth: Optional[RegistryAuth] python: Optional[PythonVersions] ports: List[Union[str, int]] = [] - env: List[str] = [] + env: Dict[str, str] = {} build: List[str] = [] cache: List[str] = [] diff --git a/cli/dstack/_internal/core/job.py b/cli/dstack/_internal/core/job.py index b25405981..8571becd4 100644 --- a/cli/dstack/_internal/core/job.py +++ b/cli/dstack/_internal/core/job.py @@ -18,6 +18,7 @@ RepoRef, ) +# todo use Enum BuildPolicy = ["use-build", "build", "force-build", "build-only"] @@ -131,7 +132,7 @@ def pretty_repr(self) -> str: class JobHead(JobRef): job_id: str repo_ref: RepoRef - hub_user_name: str + hub_user_name: str = "" run_name: str workflow_name: Optional[str] provider_name: str @@ -176,8 +177,8 @@ class Job(JobHead): ) repo_code_filename: Optional[str] = None run_name: str - workflow_name: Optional[str] - provider_name: str + workflow_name: Optional[str] # deprecated + provider_name: Optional[str] # deprecated configuration_type: Optional[ConfigurationType] configuration_path: Optional[str] status: JobStatus @@ -210,7 +211,7 @@ class Job(JobHead): build_policy: Optional[str] build_commands: Optional[List[str]] optional_build_commands: Optional[List[str]] - run_env: Optional[Dict[str, str]] + run_env: Optional[Dict[str, str]] # deprecated @root_validator(pre=True) def preprocess_data(cls, data): diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py new file mode 100644 index 000000000..de2f9bb84 --- /dev/null +++ b/cli/dstack/_internal/core/profile.py @@ -0,0 +1,76 @@ +import re +from typing import Optional, Union + +from pydantic import validator + +from dstack._internal.core.configuration import ForbidExtra +from dstack._internal.core.job import SpotPolicy + +DEFAULT_CPU = 2 +DEFAULT_MEM = "8GB" +DEFAULT_RETRY_LIMIT = 3600 + + +def mem_size(v: Optional[Union[int, str]]) -> Optional[int]: + if isinstance(v, str): + m = re.fullmatch(r"(\d+) *([gm]b)?", v.strip().lower()) + if not m: + raise ValueError(f"Invalid memory size: {v}") + v = int(m.group(1)) + if m.group(2) == "gb": + v = v * 1024 # todo + return v + + +def duration(v: Union[int, str]) -> int: + if isinstance(v, int): + return v + regex = re.compile(r"(?P\d+) *(?P[smhdw])$") + re_match = regex.match(duration) + if not re_match: + raise ValueError(f"Cannot parse the duration {duration}") + amount, unit = int(re_match.group("amount")), re_match.group("unit") + multiplier = { + "s": 1, + "m": 60, + "h": 3600, + "d": 24 * 3600, + "w": 7 * 24 * 3600, + }[unit] + return amount * multiplier + + +class ProfileGPU(ForbidExtra): + name: Optional[str] + count: int = 1 + memory: Optional[Union[int, str]] + _validate_mem = validator("memory", pre=True, allow_reuse=True)(mem_size) + + +class ProfileResources(ForbidExtra): + gpu: Optional[Union[int, ProfileGPU]] + memory: Union[int, str] = mem_size(DEFAULT_MEM) + shm_size: Optional[Union[int, str]] + cpu: int = DEFAULT_CPU + _validate_mem = validator("memory", "shm_size", pre=True, allow_reuse=True)(mem_size) + + @validator("gpu", pre=True) + def _validate_gpu(cls, v: Optional[Union[int, ProfileGPU]]) -> Optional[ProfileGPU]: + if isinstance(v, int): + v = ProfileGPU(count=v) + return v + + +class ProfileRetryPolicy(ForbidExtra): + retry: bool = False + limit: Union[int, str] = DEFAULT_RETRY_LIMIT + _validate_limit = validator("limit", pre=True, allow_reuse=True)(duration) + + +class Profile(ForbidExtra): + name: str + project: Optional[str] + resources: ProfileResources = ProfileResources() + spot_policy: Optional[SpotPolicy] + retry_policy: ProfileRetryPolicy = ProfileRetryPolicy() + default: bool = False From 6f174e28835e6752195e24823de0f709d8a3543c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Thu, 6 Jul 2023 19:00:50 +0400 Subject: [PATCH 04/26] Implement configurators --- .../_internal/configurators/__init__.py | 19 ++++--- .../configurators/dev_environment.py | 51 +++++++++++++++++++ cli/dstack/_internal/configurators/task.py | 20 ++++++++ 3 files changed, 84 insertions(+), 6 deletions(-) diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index c0a315ef0..cc4993ad1 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -31,6 +31,9 @@ def __init__( self.conf = configuration self.profile = profile self.build_policy = "use-build" + # context + self.run_name: Optional[str] = None + self.ssh_key_pub: Optional[str] = None def print_help(self, prog: str = "dstack run"): parser = self.get_parser(prog) @@ -67,6 +70,7 @@ def get_parser(self, prog: Optional[str] = None) -> argparse.ArgumentParser: def apply_args(self, args: argparse.Namespace): if args.spot_policy is not None: self.profile.spot_policy = args.spot_policy + if args.retry: self.profile.retry_policy.retry = True elif args.no_retry: @@ -74,6 +78,7 @@ def apply_args(self, args: argparse.Namespace): elif args.retry_limit: self.profile.retry_policy.retry = True self.profile.retry_policy.limit = args.retry_limit + if args.build_policy is not None: self.build_policy = args.build_policy @@ -100,6 +105,9 @@ def interpolate(obj): def get_jobs( self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str ) -> List[job.Job]: + self.run_name = run_name + self.ssh_key_pub = ssh_key_pub + created_at = get_milliseconds_since_epoch() configured_job = job.Job( job_id=f"{run_name},,0", @@ -138,10 +146,6 @@ def get_jobs( def commands(self) -> List[str]: pass - @abstractmethod - def build_commands(self) -> List[str]: - pass - @abstractmethod def optional_build_commands(self) -> List[str]: pass @@ -151,9 +155,12 @@ def artifact_specs(self) -> List[job.ArtifactSpec]: pass @abstractmethod - def dep_specs(self): + def dep_specs(self) -> List[job.DepSpec]: pass + def build_commands(self) -> List[str]: + return self.conf.build + def entrypoint(self) -> Optional[List[str]]: if self.conf.image is None or self.commands(): return ["/bin/sh", "-i", "-c"] @@ -202,7 +209,7 @@ def python(self) -> str: if self.conf.python is not None: return self.conf.python version_info = sys.version_info - return f"{version_info.major}.{version_info.minor}" # todo check is in supported + return f"{version_info.major}.{version_info.minor}" # todo check if is in supported def ports(self) -> Dict[int, ports.PortMapping]: mapping = [ports.PortMapping(p) for p in self.conf.ports] diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index 9f3875c39..5a02e24ed 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -1,6 +1,57 @@ +from typing import List + from dstack._internal.configurators import JobConfigurator +from dstack._internal.core import job as job from dstack._internal.core.configuration import DevEnvironmentConfiguration +from dstack._internal.providers.extensions import OpenSSHExtension, VSCodeDesktopServer + +vscode_extensions = ["ms-python.python", "ms-toolsai.jupyter"] +pip_packages = ["ipykernel"] class DevEnvironmentConfigurator(JobConfigurator): conf: DevEnvironmentConfiguration + + # todo handle NoVSCodeVersionError + + def commands(self) -> List[str]: + commands = [] + # todo magic script + OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) + VSCodeDesktopServer.patch_commands(commands, vscode_extensions=vscode_extensions) + commands.append("pip install -q --no-cache-dir " + " ".join(pip_packages)) + commands.extend(self.conf.init) + commands.extend( + [ + "echo ''", + f"echo To open in VS Code Desktop, use one of these links:", + f"echo ''", + f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}/workflow'", + "echo ''", + f"echo 'To connect via SSH, use: `ssh {self.run_name}`'", + "echo ''", + "echo -n 'To exit, press Ctrl+C.'", + "cat", # idle + ] + ) + return commands + + def optional_build_commands(self) -> List[str]: + commands = [] + VSCodeDesktopServer.patch_setup(commands, vscode_extensions=vscode_extensions) + commands.append("pip install -q --no-cache-dir " + " ".join(pip_packages)) + return commands + + def artifact_specs(self) -> List[job.ArtifactSpec]: + return [] # not available + + def dep_specs(self) -> List[job.DepSpec]: + return [] # not available + + def spot_policy(self) -> job.SpotPolicy: + return self.profile.spot_policy or job.SpotPolicy.ONDEMAND + + def app_specs(self) -> List[job.AppSpec]: + specs = super().app_specs() + VSCodeDesktopServer.patch_apps(specs) + return specs diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py index a1f06531d..736300973 100644 --- a/cli/dstack/_internal/configurators/task.py +++ b/cli/dstack/_internal/configurators/task.py @@ -1,6 +1,26 @@ +from typing import List + from dstack._internal.configurators import JobConfigurator +from dstack._internal.core import job as job from dstack._internal.core.configuration import TaskConfiguration class TaskConfigurator(JobConfigurator): conf: TaskConfiguration + + def commands(self) -> List[str]: + commands = [] + commands.extend(self.conf.commands) + return commands + + def optional_build_commands(self) -> List[str]: + return [] # not needed + + def artifact_specs(self) -> List[job.ArtifactSpec]: + specs = [] + for a in self.conf.artifacts: + specs.append(job.ArtifactSpec(artifact_path=a.path, mount=a.mount)) + return specs + + def dep_specs(self) -> List[job.DepSpec]: + return [] # not available yet From 28c28453de479bb904a25439d73ad3fe8b5c09bf Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 7 Jul 2023 15:35:57 +0400 Subject: [PATCH 05/26] Implement run command with configurators --- .../_internal/cli/commands/run/__init__.py | 48 ++++--- cli/dstack/_internal/cli/configuration.py | 52 ++++++++ cli/dstack/_internal/cli/profiles.py | 17 ++- .../_internal/configurators/__init__.py | 30 ++++- .../configurators/dev_environment.py | 2 +- cli/dstack/_internal/core/profile.py | 14 +- cli/dstack/api/hub/_client.py | 122 ++++-------------- 7 files changed, 147 insertions(+), 138 deletions(-) create mode 100644 cli/dstack/_internal/cli/configuration.py diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index b3189a55c..232ea5db9 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -18,11 +18,11 @@ from dstack._internal.api.runs import list_runs_hub from dstack._internal.backend.base.logs import fix_urls from dstack._internal.cli.commands import BasicCommand -from dstack._internal.cli.commands.run import configurations from dstack._internal.cli.commands.run.ssh_tunnel import PortsLock, run_ssh_tunnel from dstack._internal.cli.commands.run.watcher import LocalCopier, SSHCopier, Watcher from dstack._internal.cli.common import add_project_argument, check_init, console, print_runs from dstack._internal.cli.config import config, get_hub_client +from dstack._internal.cli.configuration import load_configuration from dstack._internal.core.error import RepoNotInitializedError from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import Job, JobErrorCode, JobHead, JobStatus @@ -107,18 +107,16 @@ def register(self): @check_init def _command(self, args: Namespace): - ( - configuration_path, - provider_name, - provider_data, - project_name, - ) = configurations.parse_configuration_file( - args.working_dir, args.file_name, args.profile_name - ) + configurator = load_configuration(args.working_dir, args.file_name, args.profile_name) + # if args.help: # todo + # configurator.print_help(prog="dstack run") + + project_name = None if args.project: project_name = args.project - if args.help: - pass # todo + elif configurator.profile.project: + project_name = configurator.profile.project + watcher = Watcher(os.getcwd()) try: if args.reload: @@ -131,30 +129,28 @@ def _command(self, args: Namespace): raise RepoNotInitializedError("No credentials", project_name=project_name) if not config.repo_user_config.ssh_key_path: - ssh_pub_key = None + ssh_key_pub = None else: - ssh_pub_key = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) + ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) - run_plan = hub_client.get_run_plan( - configuration_path=configuration_path, - provider_name=provider_name, - provider_data=provider_data, - args=args, + # should we pass args.args here? + configurator_args, run_args = configurator.get_parser().parse_known_args( + args.args + args.unknown ) + configurator.apply_args(configurator_args) + + run_plan = hub_client.get_run_plan(configurator) console.print("dstack will execute the following plan:\n") - _print_run_plan(configuration_path, run_plan) + _print_run_plan(configurator.configuration_path, run_plan) if not args.yes and not Confirm.ask("Continue?"): console.print("\nExiting...") exit(0) console.print("\nProvisioning...\n") - run_name, jobs = hub_client.run_provider( - configuration_path=configuration_path, - provider_name=provider_name, - provider_data=provider_data, - ssh_pub_key=ssh_pub_key, - tag_name=args.tag_name, - args=args, + run_name, jobs = hub_client.run_configuration( + configurator=configurator, + ssh_key_pub=ssh_key_pub, + run_args=run_args, ) runs = list_runs_hub(hub_client, run_name=run_name) run = runs[0] diff --git a/cli/dstack/_internal/cli/configuration.py b/cli/dstack/_internal/cli/configuration.py new file mode 100644 index 000000000..3397cae9c --- /dev/null +++ b/cli/dstack/_internal/cli/configuration.py @@ -0,0 +1,52 @@ +from pathlib import Path +from typing import Optional + +import yaml + +from dstack._internal.cli.profiles import load_profiles +from dstack._internal.configurators import JobConfigurator +from dstack._internal.configurators.dev_environment import DevEnvironmentConfigurator +from dstack._internal.configurators.task import TaskConfigurator +from dstack._internal.core.configuration import ( + DevEnvironmentConfiguration, + TaskConfiguration, + parse, +) +from dstack._internal.core.profile import Profile + + +def load_configuration( + working_dir: str, configuration_path: Optional[str], profile_name: Optional[str] +) -> JobConfigurator: + configuration_path = resolve_configuration_path(configuration_path, working_dir) + configuration = parse(yaml.safe_load(configuration_path.read_text())) + # todo handle validation errors + profiles = load_profiles() + if profile_name: + try: + profile = profiles[profile_name] + except KeyError: + exit(f"Error: No profile `{profile_name}` found") + else: + profile = profiles.get("default", Profile(name="default")) + + if isinstance(configuration, DevEnvironmentConfiguration): + return DevEnvironmentConfigurator( + working_dir, str(configuration_path), configuration, profile + ) + elif isinstance(configuration, TaskConfiguration): + return TaskConfigurator(working_dir, str(configuration_path), configuration, profile) + exit(f"Unsupported configuration {type(configuration)}") + + +def resolve_configuration_path(file_name: str, working_dir: str) -> Path: + root = Path.cwd() + configuration_path = root / file_name if file_name else root / working_dir / ".dstack.yml" + if not file_name and not configuration_path.exists(): + configuration_path = root / working_dir / ".dstack.yaml" + if not configuration_path.exists(): + exit(f"Error: No such configuration file {configuration_path}") + try: + return configuration_path.relative_to(root) + except ValueError: + exit(f"Configuration file is outside the repository {root}") diff --git a/cli/dstack/_internal/cli/profiles.py b/cli/dstack/_internal/cli/profiles.py index c1c716795..96439dbec 100644 --- a/cli/dstack/_internal/cli/profiles.py +++ b/cli/dstack/_internal/cli/profiles.py @@ -1,13 +1,15 @@ import json from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict import jsonschema import pkg_resources import yaml +from dstack._internal.core.profile import Profile -def load_profiles() -> Optional[Dict[str, Dict[str, Any]]]: + +def load_profiles() -> Dict[str, Profile]: # NOTE: This only supports local profiles profiles_path = Path(".dstack") / "profiles.yml" if not profiles_path.exists(): @@ -22,8 +24,11 @@ def load_profiles() -> Optional[Dict[str, Dict[str, Any]]]: ) jsonschema.validate(profiles, schema) for profile in profiles["profiles"]: - if profile.get("default"): - profiles["default"] = profile - profiles[profile["name"]] = profile - del profiles["profiles"] + profile = Profile.parse_obj(profile) + if profile.default: + profiles[ + "default" + ] = profile # we can't have Profile(name="default"), we use the latest default=True + profiles[profile.name] = profile + del profiles["profiles"] # we can't have Profile(name="profiles") return profiles diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index cc4993ad1..6e6455a42 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -38,6 +38,7 @@ def __init__( def print_help(self, prog: str = "dstack run"): parser = self.get_parser(prog) parser.print_help() + exit(0) def get_parser(self, prog: Optional[str] = None) -> argparse.ArgumentParser: parser = argparse.ArgumentParser(prog=prog, formatter_class=RichHelpFormatter) @@ -162,7 +163,10 @@ def build_commands(self) -> List[str]: return self.conf.build def entrypoint(self) -> Optional[List[str]]: - if self.conf.image is None or self.commands(): + # todo custom entrypoint + if self.conf.image is None: # dstackai/miniforge + return ["/bin/bash", "-i", "-c"] + if self.commands(): # custom docker image with commands return ["/bin/sh", "-i", "-c"] return None @@ -180,7 +184,7 @@ def spot_policy(self) -> job.SpotPolicy: return self.profile.spot_policy or job.SpotPolicy.AUTO def retry_policy(self) -> job.RetryPolicy: - return job.RetryPolicy.parse_obj(self.profile.retry_policy.dict()) + return job.RetryPolicy.parse_obj(self.profile.retry_policy) def cache_specs(self) -> List[job.CacheSpec]: return [ @@ -191,7 +195,7 @@ def cache_specs(self) -> List[job.CacheSpec]: def registry_auth(self) -> Optional[job.RegistryAuth]: if self.conf.registry_auth is None: return None - return job.RegistryAuth.parse_obj(self.conf.registry_auth.dict()) + return job.RegistryAuth.parse_obj(self.conf.registry_auth) def app_specs(self) -> List[job.AppSpec]: specs = [] @@ -224,7 +228,25 @@ def env(self) -> Dict[str, str]: return self.conf.env def requirements(self) -> job.Requirements: - return job.Requirements.parse_obj(self.profile.resources.dict()) + r = job.Requirements( + cpus=self.profile.resources.cpu, + memory_mib=self.profile.resources.memory, + gpus=None, + shm_size_mib=self.profile.resources.shm_size, + ) + if self.profile.resources.gpu: + r.gpus = job.GpusRequirements( + count=self.profile.resources.gpu.count, + memory_mib=self.profile.resources.gpu.memory, + name=self.profile.resources.gpu.name, + ) + return r + + @classmethod + def join_run_args(cls, args: List[str]) -> str: + return " ".join( + (arg if " " not in arg else '"%s"' % arg.replace('"', '\\"')) for arg in args + ) def validate_local_path(path: str, home: Optional[str], working_dir: str) -> str: diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index 5a02e24ed..2601c72df 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -53,5 +53,5 @@ def spot_policy(self) -> job.SpotPolicy: def app_specs(self) -> List[job.AppSpec]: specs = super().app_specs() - VSCodeDesktopServer.patch_apps(specs) + OpenSSHExtension.patch_apps(specs) return specs diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py index de2f9bb84..ec1cea9d6 100644 --- a/cli/dstack/_internal/core/profile.py +++ b/cli/dstack/_internal/core/profile.py @@ -12,14 +12,22 @@ def mem_size(v: Optional[Union[int, str]]) -> Optional[int]: + """ + Converts human-readable sizes (MB and GB) to megabytes + >>> mem_size("512MB") + 512 + >>> mem_size("1 GB") + 1024 + """ + dec_bin = 1000 / 1024 if isinstance(v, str): m = re.fullmatch(r"(\d+) *([gm]b)?", v.strip().lower()) if not m: raise ValueError(f"Invalid memory size: {v}") - v = int(m.group(1)) + v = int(m.group(1)) * (dec_bin**2) if m.group(2) == "gb": - v = v * 1024 # todo - return v + v = v * 1000 + return int(v) def duration(v: Union[int, str]) -> int: diff --git a/cli/dstack/api/hub/_client.py b/cli/dstack/api/hub/_client.py index 95cb87588..e554bcb3d 100644 --- a/cli/dstack/api/hub/_client.py +++ b/cli/dstack/api/hub/_client.py @@ -1,17 +1,16 @@ -import argparse +import copy import sys import tempfile import time import urllib.parse from datetime import datetime from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Generator, List, Optional, Tuple -from dstack._internal import providers +import dstack._internal.configurators as configurators from dstack._internal.api.repos import get_local_repo_credentials from dstack._internal.backend.base import artifacts as base_artifacts from dstack._internal.core.artifact import Artifact -from dstack._internal.core.error import NameNotFoundError from dstack._internal.core.job import Job, JobHead, JobStatus from dstack._internal.core.log_event import LogEvent from dstack._internal.core.plan import RunPlan @@ -21,7 +20,6 @@ from dstack._internal.core.secret import Secret from dstack._internal.core.tag import TagHead from dstack._internal.hub.models import ProjectInfo -from dstack._internal.utils.common import merge_workflow_data from dstack.api.hub._api_client import HubAPIClient from dstack.api.hub._config import HubClientConfig from dstack.api.hub._storage import HUBStorage @@ -260,114 +258,42 @@ def delete_secret(self, secret_name: str): def delete_configuration_cache(self, configuration_path: str): self._api_client.delete_configuration_cache(configuration_path=configuration_path) - def get_run_plan( - self, - configuration_path: str, - provider_name: str, - provider_data: Optional[Dict[str, Any]] = None, - args: Optional[argparse.Namespace] = None, - ) -> RunPlan: - if provider_name not in providers.get_provider_names(): - raise NameNotFoundError(f"No provider '{provider_name}' is found") - provider = providers.load_provider(provider_name) - provider.load( - hub_client=self, - args=args, - workflow_name=None, - provider_data=provider_data or {}, - run_name="dry-run", - ssh_key_pub="", + def get_run_plan(self, configurator: "configurators.JobConfigurator") -> RunPlan: + """ + :param configurator: args must be already applied + :return: run plan + """ + jobs = configurator.get_jobs( + repo=self.repo, run_name="dry-run", repo_code_filename="", ssh_key_pub="" ) - jobs = provider.get_jobs(repo=self.repo, configuration_path=configuration_path) run_plan = self._api_client.get_run_plan(jobs) return run_plan - def run_provider( + def run_configuration( self, - provider_name: str, - provider_data: Optional[Dict[str, Any]] = None, - configuration_path: Optional[str] = None, - tag_name: Optional[str] = None, - ssh_pub_key: Optional[str] = None, - args: Optional[argparse.Namespace] = None, + configurator: "configurators.JobConfigurator", + ssh_key_pub: str, + run_args: Optional[List[str]] = None, ) -> Tuple[str, List[Job]]: - """Runs provider by name - :return: run_name, jobs - """ - if provider_name not in providers.get_provider_names(): - raise NameNotFoundError(f"No provider '{provider_name}' is found") - provider = providers.load_provider(provider_name) - run_name = self.create_run() - provider.load( - hub_client=self, - args=args, - workflow_name=None, - provider_data=provider_data or {}, - run_name=run_name, - ssh_key_pub=ssh_pub_key, - ) # todo validate data - if tag_name: - tag_head = self.get_tag_head(tag_name) - if tag_head: - self.delete_tag_head(tag_head) - - repo_code_filename = self._upload_code_file() - jobs = provider.get_jobs( - repo=self.repo, - configuration_path=configuration_path, - repo_code_filename=repo_code_filename, - tag_name=tag_name, + configurator = copy.deepcopy(configurator) + configurator.inject_context( + {"run": {"name": run_name, "args": configurator.join_run_args(run_args)}} ) - for job in jobs: - self.submit_job(job) - if tag_name: - self.add_tag_from_run(tag_name, run_name, jobs) - self.update_repo_last_run_at(last_run_at=int(round(time.time() * 1000))) - return run_name, jobs # todo return run_head - def run_workflow( - self, - workflow_name: str, - workflow_data: Optional[Dict[str, Any]] = None, - tag_name: Optional[str] = None, - ssh_pub_key: Optional[str] = None, - args: Optional[argparse.Namespace] = None, - ) -> Tuple[str, List[Job]]: - """Runs workflow by name - :return: run_name, jobs - """ - workflow = self.repo.get_workflows(credentials=self.get_repo_credentials()).get( - workflow_name - ) - if workflow is None: - raise NameNotFoundError(f"No workflow '{workflow_name}' is found") - provider = providers.load_provider(workflow["provider"]) - - run_name = self.create_run() - provider.load( - self, - args, - workflow_name, - merge_workflow_data(workflow, workflow_data), - run_name, - ssh_pub_key, - ) - if tag_name: - tag_head = self.get_tag_head(tag_name) - if tag_head: - self.delete_tag_head(tag_head) + # todo handle tag_name & dependencies repo_code_filename = self._upload_code_file() - jobs = provider.get_jobs( - repo=self.repo, repo_code_filename=repo_code_filename, tag_name=tag_name + jobs = configurator.get_jobs( + repo=self.repo, + run_name=run_name, + repo_code_filename=repo_code_filename, + ssh_key_pub=ssh_key_pub, ) for job in jobs: self.submit_job(job) - if tag_name: - self.add_tag_from_run(tag_name, self.run_name, jobs) self.update_repo_last_run_at(last_run_at=int(round(time.time() * 1000))) - return run_name, jobs # todo return run_head + return run_name, jobs def _upload_code_file(self) -> str: with tempfile.NamedTemporaryFile("w+b") as f: From d6c8abd0096fab0b0af6653434b12c00ab090576 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 7 Jul 2023 16:35:23 +0400 Subject: [PATCH 06/26] Drop legacy configurations loader --- .../_internal/cli/commands/build/__init__.py | 55 ++++----- .../_internal/cli/commands/prune/__init__.py | 6 +- .../_internal/cli/commands/run/__init__.py | 2 +- .../cli/commands/run/configurations.py | 112 ------------------ 4 files changed, 26 insertions(+), 149 deletions(-) delete mode 100644 cli/dstack/_internal/cli/commands/run/configurations.py diff --git a/cli/dstack/_internal/cli/commands/build/__init__.py b/cli/dstack/_internal/cli/commands/build/__init__.py index 99ee3d591..7b8e9db57 100644 --- a/cli/dstack/_internal/cli/commands/build/__init__.py +++ b/cli/dstack/_internal/cli/commands/build/__init__.py @@ -7,16 +7,11 @@ from dstack._internal.api.runs import list_runs_hub from dstack._internal.cli.commands import BasicCommand -from dstack._internal.cli.commands.run import ( - _poll_run, - _print_run_plan, - _read_ssh_key_pub, - configurations, -) -from dstack._internal.cli.common import add_project_argument, check_init, console, print_runs +from dstack._internal.cli.commands.run import _poll_run, _print_run_plan, _read_ssh_key_pub +from dstack._internal.cli.common import add_project_argument, check_init, console from dstack._internal.cli.config import config, get_hub_client +from dstack._internal.cli.configuration import load_configuration from dstack._internal.core.error import RepoNotInitializedError -from dstack._internal.core.job import JobStatus class BuildCommand(BasicCommand): @@ -25,18 +20,15 @@ class BuildCommand(BasicCommand): @check_init def _command(self, args: argparse.Namespace): - ( - configuration_path, - provider_name, - provider_data, - project_name, - ) = configurations.parse_configuration_file( - args.working_dir, args.file_name, args.profile_name - ) - provider_data["build_policy"] = "build-only" + configurator = load_configuration(args.working_dir, args.file_name, args.profile_name) + configurator.build_policy = "build-only" + project_name = None if args.project: project_name = args.project + elif configurator.profile.project: + project_name = configurator.profile.project + try: hub_client = get_hub_client(project_name=project_name) if ( @@ -46,29 +38,28 @@ def _command(self, args: argparse.Namespace): raise RepoNotInitializedError("No credentials", project_name=project_name) if not config.repo_user_config.ssh_key_path: - ssh_pub_key = None + ssh_key_pub = None else: - ssh_pub_key = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) + ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) - run_plan = hub_client.get_run_plan( - configuration_path=configuration_path, - provider_name=provider_name, - provider_data=provider_data, - args=args, + # should we pass args.args here? + configurator_args, run_args = configurator.get_parser().parse_known_args( + args.args + args.unknown ) + configurator.apply_args(configurator_args) + + run_plan = hub_client.get_run_plan(configurator) console.print("dstack will execute the following plan:\n") - _print_run_plan(configuration_path, run_plan) + _print_run_plan(configurator.configuration_path, run_plan) if not args.yes and not Confirm.ask("Continue?"): console.print("\nExiting...") exit(0) console.print("\nProvisioning...\n") - run_name, jobs = hub_client.run_provider( - configuration_path=configuration_path, - provider_name=provider_name, - provider_data=provider_data, - ssh_pub_key=ssh_pub_key, - args=args, + run_name, jobs = hub_client.run_configuration( + configurator=configurator, + ssh_key_pub=ssh_key_pub, + run_args=run_args, ) runs = list_runs_hub(hub_client, run_name=run_name) run = runs[0] @@ -80,7 +71,7 @@ def _command(self, args: argparse.Namespace): watcher=None, ) except ValidationError as e: - sys.exit( + sys.exit( # todo replace with pydantic f"There a syntax error in one of the files inside the {os.getcwd()}/.dstack/workflows directory:\n\n{e}" ) diff --git a/cli/dstack/_internal/cli/commands/prune/__init__.py b/cli/dstack/_internal/cli/commands/prune/__init__.py index f85906d97..90e1d4c72 100644 --- a/cli/dstack/_internal/cli/commands/prune/__init__.py +++ b/cli/dstack/_internal/cli/commands/prune/__init__.py @@ -3,9 +3,9 @@ from rich_argparse import RichHelpFormatter from dstack._internal.cli.commands import BasicCommand -from dstack._internal.cli.commands.run import configurations from dstack._internal.cli.common import add_project_argument, check_init, console from dstack._internal.cli.config import get_hub_client +from dstack._internal.cli.configuration import resolve_configuration_path from dstack.api.hub import HubClient @@ -46,8 +46,6 @@ def _command(self, args: argparse.Namespace): @staticmethod def prune_cache(args: argparse.Namespace, hub_client: HubClient): - configuration_path = str( - configurations.get_configuration_path(args.working_dir, args.file_name) - ) + configuration_path = str(resolve_configuration_path(args.file_name, args.working_dir)) hub_client.delete_configuration_cache(configuration_path=configuration_path) console.print(f"[grey58]Cache pruned[/]") diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index 232ea5db9..5dcf6dff1 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -163,7 +163,7 @@ def _command(self, args: Namespace): watcher=watcher, ) except ValidationError as e: - sys.exit( + sys.exit( # todo replace with pydantic f"There a syntax error in one of the files inside the {os.getcwd()}/.dstack/workflows directory:\n\n{e}" ) finally: diff --git a/cli/dstack/_internal/cli/commands/run/configurations.py b/cli/dstack/_internal/cli/commands/run/configurations.py deleted file mode 100644 index 763521f49..000000000 --- a/cli/dstack/_internal/cli/commands/run/configurations.py +++ /dev/null @@ -1,112 +0,0 @@ -import json -import os -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -import jsonschema -import pkg_resources -import yaml - -from dstack._internal.cli.common import console -from dstack._internal.cli.profiles import load_profiles -from dstack._internal.providers.extensions import NoVSCodeVersionError, VSCodeDesktopServer - - -def _init_base_provider_data(configuration_data: Dict[str, Any], provider_data: Dict[str, Any]): - if "cache" in configuration_data: - provider_data["cache"] = configuration_data["cache"] - if "ports" in configuration_data: - provider_data["ports"] = configuration_data["ports"] - if "python" in configuration_data: - provider_data["python"] = configuration_data["python"] - if "env" in configuration_data: - provider_data["env"] = configuration_data["env"] - provider_data["build"] = configuration_data.get("build") or [] - - -def _parse_dev_environment_configuration_data( - configuration_data: Dict[str, Any] -) -> Tuple[str, Dict[str, Any]]: - provider_name = "ssh" - provider_data = { - "configuration_type": "dev-environment", - "optional_build": [], - "commands": [], - } - _init_base_provider_data(configuration_data, provider_data) - try: - extensions = ["ms-python.python", "ms-toolsai.jupyter"] - VSCodeDesktopServer.patch_setup( - provider_data["optional_build"], vscode_extensions=extensions - ) - VSCodeDesktopServer.patch_commands(provider_data["commands"], vscode_extensions=extensions) - except NoVSCodeVersionError as e: - console.print( - "[grey58]Unable to detect the VS Code version and pre-install extensions. Fix by opening [" - "sea_green3]Command Palette[/sea_green3], executing [sea_green3]Shell Command: Install 'code' command in " - "PATH[/sea_green3], and restarting terminal.[/]\n" - ) - for key in ["optional_build", "commands"]: - provider_data[key].append("pip install -q --no-cache-dir ipykernel") - provider_data["commands"].extend(configuration_data.get("init") or []) - return provider_name, provider_data - - -def _parse_task_configuration_data( - configuration_data: Dict[str, Any] -) -> Tuple[str, Dict[str, Any]]: - # TODO: Support the `docker` provider - provider_name = "bash" - provider_data = { - "configuration_type": "task", - "commands": [], - } - _init_base_provider_data(configuration_data, provider_data) - provider_data["commands"].extend(configuration_data["commands"]) - return provider_name, provider_data - - -def parse_configuration_file( - working_dir: str, file_name: Optional[str], profile_name: Optional[str] -) -> Tuple[str, str, Dict[str, Any], Optional[str]]: - configuration_path = get_configuration_path(working_dir, file_name) - with configuration_path.open("r") as f: - configuration_data = yaml.load(f, yaml.FullLoader) - schema = json.loads( - pkg_resources.resource_string("dstack._internal", "schemas/configuration.json") - ) - jsonschema.validate(configuration_data, schema) - configuration_type = configuration_data["type"] - if configuration_type == "dev-environment": - (provider_name, provider_data) = _parse_dev_environment_configuration_data( - configuration_data - ) - elif configuration_type == "task": - (provider_name, provider_data) = _parse_task_configuration_data(configuration_data) - else: - exit(f"Unsupported configuration type: {configuration_type}") - profiles = load_profiles() - if profile_name: - if profile_name in profiles: - profile = profiles[profile_name] - else: - exit(f"Error: No profile `{profile_name}` found") - else: - profile = profiles.get("default", {}) - if "resources" in profile: - provider_data["resources"] = profile["resources"] - provider_data["spot_policy"] = profile.get("spot_policy") - provider_data["retry_policy"] = profile.get("retry_policy") - project_name = profile.get("project") - if not Path(os.getcwd()).samefile(Path(working_dir)): - provider_data["working_dir"] = str(Path(working_dir)) - return str(configuration_path), provider_name, provider_data, project_name - - -def get_configuration_path(working_dir: str, file_name: str) -> Path: - configuration_path = Path(file_name) if file_name else Path(working_dir) / ".dstack.yml" - if not file_name and not configuration_path.exists(): - configuration_path = Path(working_dir) / ".dstack.yaml" - if not configuration_path.exists(): - exit(f"Error: No such configuration file {configuration_path}") - return configuration_path From cbaf84fba66a137a8064c8d236d853885111f0c6 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 7 Jul 2023 17:31:01 +0400 Subject: [PATCH 07/26] Print relevant help for dstack run --- cli/dstack/_internal/cli/commands/__init__.py | 13 ++++++------- cli/dstack/_internal/cli/commands/run/__init__.py | 8 +++++--- cli/dstack/_internal/configurators/__init__.py | 12 +++++------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/cli/dstack/_internal/cli/commands/__init__.py b/cli/dstack/_internal/cli/commands/__init__.py index cd8c953f7..a70cb8922 100644 --- a/cli/dstack/_internal/cli/commands/__init__.py +++ b/cli/dstack/_internal/cli/commands/__init__.py @@ -12,19 +12,18 @@ class BasicCommand(object): DESCRIPTION = "describe the command" SUBCOMMANDS = [] - def __init__(self, parser: _SubParsersAction): + def __init__(self, parser: _SubParsersAction, store_help: bool = False): kwargs = {} if self.description: kwargs["help"] = self.description - self._parser = parser.add_parser( + self._parser: argparse.ArgumentParser = parser.add_parser( self.name, add_help=False, formatter_class=RichHelpFormatter, **kwargs ) + help_kwargs = dict(action="help", default=argparse.SUPPRESS) + if store_help: + help_kwargs = dict(action="store_true") self._parser.add_argument( - "-h", - "--help", - action="help", - default=argparse.SUPPRESS, - help="Show this help message and exit", + "-h", "--help", help="Show this help message and exit", **help_kwargs ) self._parser.set_defaults(func=self.__command) diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index 5dcf6dff1..73dc77b74 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -1,4 +1,5 @@ import argparse +import copy import os import sys import threading @@ -48,7 +49,7 @@ class RunCommand(BasicCommand): DESCRIPTION = "Run a configuration" def __init__(self, parser): - super(RunCommand, self).__init__(parser) + super().__init__(parser, store_help=True) def register(self): self._parser.add_argument( @@ -108,8 +109,9 @@ def register(self): @check_init def _command(self, args: Namespace): configurator = load_configuration(args.working_dir, args.file_name, args.profile_name) - # if args.help: # todo - # configurator.print_help(prog="dstack run") + if args.help: + configurator.get_parser(parser=copy.deepcopy(self._parser)).print_help() + exit(0) project_name = None if args.project: diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index 6e6455a42..b3d1f8056 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -35,13 +35,11 @@ def __init__( self.run_name: Optional[str] = None self.ssh_key_pub: Optional[str] = None - def print_help(self, prog: str = "dstack run"): - parser = self.get_parser(prog) - parser.print_help() - exit(0) - - def get_parser(self, prog: Optional[str] = None) -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(prog=prog, formatter_class=RichHelpFormatter) + def get_parser( + self, prog: Optional[str] = None, parser: Optional[argparse.ArgumentParser] = None + ) -> argparse.ArgumentParser: + if parser is None: + parser = argparse.ArgumentParser(prog=prog, formatter_class=RichHelpFormatter) spot_group = parser.add_mutually_exclusive_group() spot_group.add_argument( From 139e0eacc7f041100def5e13ca852837d8afb6e5 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Fri, 7 Jul 2023 18:05:12 +0400 Subject: [PATCH 08/26] Handle pydantic validation errors --- .../_internal/cli/commands/build/__init__.py | 79 ++++++++----------- .../_internal/cli/commands/run/__init__.py | 6 -- cli/dstack/_internal/cli/config.py | 5 +- cli/dstack/_internal/cli/configuration.py | 15 ++-- cli/dstack/_internal/cli/profiles.py | 28 ++----- cli/dstack/_internal/core/profile.py | 18 ++++- 6 files changed, 67 insertions(+), 84 deletions(-) diff --git a/cli/dstack/_internal/cli/commands/build/__init__.py b/cli/dstack/_internal/cli/commands/build/__init__.py index 7b8e9db57..f5785b491 100644 --- a/cli/dstack/_internal/cli/commands/build/__init__.py +++ b/cli/dstack/_internal/cli/commands/build/__init__.py @@ -1,8 +1,5 @@ import argparse -import os -import sys -from jsonschema import ValidationError from rich.prompt import Confirm from dstack._internal.api.runs import list_runs_hub @@ -29,51 +26,45 @@ def _command(self, args: argparse.Namespace): elif configurator.profile.project: project_name = configurator.profile.project - try: - hub_client = get_hub_client(project_name=project_name) - if ( - hub_client.repo.repo_data.repo_type != "local" - and not hub_client.get_repo_credentials() - ): - raise RepoNotInitializedError("No credentials", project_name=project_name) + hub_client = get_hub_client(project_name=project_name) + if ( + hub_client.repo.repo_data.repo_type != "local" + and not hub_client.get_repo_credentials() + ): + raise RepoNotInitializedError("No credentials", project_name=project_name) - if not config.repo_user_config.ssh_key_path: - ssh_key_pub = None - else: - ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) + if not config.repo_user_config.ssh_key_path: + ssh_key_pub = None + else: + ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) - # should we pass args.args here? - configurator_args, run_args = configurator.get_parser().parse_known_args( - args.args + args.unknown - ) - configurator.apply_args(configurator_args) + configurator_args, run_args = configurator.get_parser().parse_known_args( + args.args + args.unknown + ) + configurator.apply_args(configurator_args) - run_plan = hub_client.get_run_plan(configurator) - console.print("dstack will execute the following plan:\n") - _print_run_plan(configurator.configuration_path, run_plan) - if not args.yes and not Confirm.ask("Continue?"): - console.print("\nExiting...") - exit(0) - console.print("\nProvisioning...\n") + run_plan = hub_client.get_run_plan(configurator) + console.print("dstack will execute the following plan:\n") + _print_run_plan(configurator.configuration_path, run_plan) + if not args.yes and not Confirm.ask("Continue?"): + console.print("\nExiting...") + exit(0) + console.print("\nProvisioning...\n") - run_name, jobs = hub_client.run_configuration( - configurator=configurator, - ssh_key_pub=ssh_key_pub, - run_args=run_args, - ) - runs = list_runs_hub(hub_client, run_name=run_name) - run = runs[0] - _poll_run( - hub_client, - run, - jobs, - ssh_key=config.repo_user_config.ssh_key_path, - watcher=None, - ) - except ValidationError as e: - sys.exit( # todo replace with pydantic - f"There a syntax error in one of the files inside the {os.getcwd()}/.dstack/workflows directory:\n\n{e}" - ) + run_name, jobs = hub_client.run_configuration( + configurator=configurator, + ssh_key_pub=ssh_key_pub, + run_args=run_args, + ) + runs = list_runs_hub(hub_client, run_name=run_name) + run = runs[0] + _poll_run( + hub_client, + run, + jobs, + ssh_key=config.repo_user_config.ssh_key_path, + watcher=None, + ) def __init__(self, parser): super().__init__(parser) diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index 73dc77b74..cd5b4c14e 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -10,7 +10,6 @@ import websocket from cursor import cursor -from jsonschema import ValidationError from rich.progress import Progress, SpinnerColumn, TextColumn from rich.prompt import Confirm from rich.table import Table @@ -135,7 +134,6 @@ def _command(self, args: Namespace): else: ssh_key_pub = _read_ssh_key_pub(config.repo_user_config.ssh_key_path) - # should we pass args.args here? configurator_args, run_args = configurator.get_parser().parse_known_args( args.args + args.unknown ) @@ -164,10 +162,6 @@ def _command(self, args: Namespace): ssh_key=config.repo_user_config.ssh_key_path, watcher=watcher, ) - except ValidationError as e: - sys.exit( # todo replace with pydantic - f"There a syntax error in one of the files inside the {os.getcwd()}/.dstack/workflows directory:\n\n{e}" - ) finally: if watcher.is_alive(): watcher.stop() diff --git a/cli/dstack/_internal/cli/config.py b/cli/dstack/_internal/cli/config.py index dfb4019c6..834badb68 100644 --- a/cli/dstack/_internal/cli/config.py +++ b/cli/dstack/_internal/cli/config.py @@ -149,10 +149,7 @@ def get_default_project_config(self) -> Optional[CLIProjectConfig]: def get_hub_client(project_name: Optional[str] = None) -> HubClient: if project_name is None: - profiles = load_profiles() - if "default" in profiles: - if "project" in profiles["default"]: - project_name = profiles["default"]["project"] + project_name = load_profiles().default().project cli_config_manager = CLIConfigManager() project_config = cli_config_manager.get_default_project_config() if project_name is not None: diff --git a/cli/dstack/_internal/cli/configuration.py b/cli/dstack/_internal/cli/configuration.py index 3397cae9c..f7e60029f 100644 --- a/cli/dstack/_internal/cli/configuration.py +++ b/cli/dstack/_internal/cli/configuration.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Optional +import pydantic import yaml from dstack._internal.cli.profiles import load_profiles @@ -12,23 +13,25 @@ TaskConfiguration, parse, ) -from dstack._internal.core.profile import Profile def load_configuration( working_dir: str, configuration_path: Optional[str], profile_name: Optional[str] ) -> JobConfigurator: configuration_path = resolve_configuration_path(configuration_path, working_dir) - configuration = parse(yaml.safe_load(configuration_path.read_text())) - # todo handle validation errors - profiles = load_profiles() + try: + configuration = parse(yaml.safe_load(configuration_path.read_text())) + profiles = load_profiles() + except pydantic.ValidationError as e: + exit(e) + if profile_name: try: - profile = profiles[profile_name] + profile = profiles.get(profile_name) except KeyError: exit(f"Error: No profile `{profile_name}` found") else: - profile = profiles.get("default", Profile(name="default")) + profile = profiles.default() if isinstance(configuration, DevEnvironmentConfiguration): return DevEnvironmentConfigurator( diff --git a/cli/dstack/_internal/cli/profiles.py b/cli/dstack/_internal/cli/profiles.py index 96439dbec..f38ec15cd 100644 --- a/cli/dstack/_internal/cli/profiles.py +++ b/cli/dstack/_internal/cli/profiles.py @@ -1,34 +1,16 @@ -import json from pathlib import Path -from typing import Dict -import jsonschema -import pkg_resources import yaml -from dstack._internal.core.profile import Profile +from dstack._internal.core.profile import ProfilesConfig -def load_profiles() -> Dict[str, Profile]: +def load_profiles() -> ProfilesConfig: # NOTE: This only supports local profiles profiles_path = Path(".dstack") / "profiles.yml" if not profiles_path.exists(): profiles_path = Path(".dstack") / "profiles.yaml" if not profiles_path.exists(): - return {} - else: - with profiles_path.open("r") as f: - profiles = yaml.load(f, yaml.FullLoader) - schema = json.loads( - pkg_resources.resource_string("dstack._internal", "schemas/profiles.json") - ) - jsonschema.validate(profiles, schema) - for profile in profiles["profiles"]: - profile = Profile.parse_obj(profile) - if profile.default: - profiles[ - "default" - ] = profile # we can't have Profile(name="default"), we use the latest default=True - profiles[profile.name] = profile - del profiles["profiles"] # we can't have Profile(name="profiles") - return profiles + return ProfilesConfig(profiles=[]) + + return ProfilesConfig.parse_obj(yaml.safe_load(profiles_path.read_text())) diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py index ec1cea9d6..78ca87692 100644 --- a/cli/dstack/_internal/core/profile.py +++ b/cli/dstack/_internal/core/profile.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Union +from typing import List, Optional, Union from pydantic import validator @@ -82,3 +82,19 @@ class Profile(ForbidExtra): spot_policy: Optional[SpotPolicy] retry_policy: ProfileRetryPolicy = ProfileRetryPolicy() default: bool = False + + +class ProfilesConfig(ForbidExtra): + profiles: List[Profile] + + def default(self) -> Profile: + for p in self.profiles: + if p.default: + return p + return Profile(name="default") + + def get(self, name: str) -> Profile: + for p in self.profiles: + if p.name == name: + return p + raise KeyError(name) From a08aa68c6b7c9b64484efd562defd3ad91ecb3a5 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 14:53:28 +0400 Subject: [PATCH 09/26] Allow configuring entrypoint and home dir. Fail if sshd is missing in dev environment --- .../_internal/configurators/__init__.py | 6 ++- .../configurators/dev_environment.py | 46 +++++++++++++++---- .../configurators/extensions/__init__.py | 3 ++ .../configurators/extensions/shell.py | 13 ++++++ .../_internal/configurators/extensions/ssh.py | 37 +++++++++++++++ cli/dstack/_internal/core/configuration.py | 9 ++-- 6 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 cli/dstack/_internal/configurators/extensions/__init__.py create mode 100644 cli/dstack/_internal/configurators/extensions/shell.py create mode 100644 cli/dstack/_internal/configurators/extensions/ssh.py diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index b3d1f8056..d285b02bb 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -1,5 +1,6 @@ import argparse import json +import shlex import sys import uuid from abc import ABC, abstractmethod @@ -161,7 +162,8 @@ def build_commands(self) -> List[str]: return self.conf.build def entrypoint(self) -> Optional[List[str]]: - # todo custom entrypoint + if self.conf.entrypoint is not None: + return shlex.split(self.conf.entrypoint) if self.conf.image is None: # dstackai/miniforge return ["/bin/bash", "-i", "-c"] if self.commands(): # custom docker image with commands @@ -169,7 +171,7 @@ def entrypoint(self) -> Optional[List[str]]: return None def home_dir(self) -> Optional[str]: - return "/root" if self.conf.image is None else None + return self.conf.home_dir def image_name(self) -> str: if self.conf.image is not None: diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index 2601c72df..f586bf3b2 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -1,25 +1,51 @@ -from typing import List +from typing import List, Optional +import dstack._internal.core.job as job from dstack._internal.configurators import JobConfigurator -from dstack._internal.core import job as job +from dstack._internal.configurators.extensions.shell import require +from dstack._internal.configurators.extensions.ssh import SSHd from dstack._internal.core.configuration import DevEnvironmentConfiguration -from dstack._internal.providers.extensions import OpenSSHExtension, VSCodeDesktopServer +from dstack._internal.core.profile import Profile +from dstack._internal.core.repo import Repo +from dstack._internal.providers.extensions import VSCodeDesktopServer +from dstack._internal.providers.ports import get_map_to_port +require_sshd = require(["sshd"]) +install_ipykernel = ( + f'pip install -q --no-cache-dir ipykernel || echo "no pip, ipykernel was not installed"' +) vscode_extensions = ["ms-python.python", "ms-toolsai.jupyter"] -pip_packages = ["ipykernel"] class DevEnvironmentConfigurator(JobConfigurator): conf: DevEnvironmentConfiguration # todo handle NoVSCodeVersionError + def __init__( + self, + working_dir: str, + configuration_path: str, + configuration: DevEnvironmentConfiguration, + profile: Profile, + ): + super().__init__(working_dir, configuration_path, configuration, profile) + self.sshd: Optional[SSHd] = None + + def get_jobs( + self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str + ) -> List[job.Job]: + self.sshd = SSHd(ssh_key_pub) + self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port) + return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub) def commands(self) -> List[str]: commands = [] - # todo magic script - OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) + if self.conf.image: + require_sshd(commands) + self.sshd.set_permissions(commands) + self.sshd.start(commands) VSCodeDesktopServer.patch_commands(commands, vscode_extensions=vscode_extensions) - commands.append("pip install -q --no-cache-dir " + " ".join(pip_packages)) + commands.append(install_ipykernel) commands.extend(self.conf.init) commands.extend( [ @@ -38,8 +64,10 @@ def commands(self) -> List[str]: def optional_build_commands(self) -> List[str]: commands = [] + if self.conf.image: + require_sshd(commands) VSCodeDesktopServer.patch_setup(commands, vscode_extensions=vscode_extensions) - commands.append("pip install -q --no-cache-dir " + " ".join(pip_packages)) + commands.append(install_ipykernel) return commands def artifact_specs(self) -> List[job.ArtifactSpec]: @@ -53,5 +81,5 @@ def spot_policy(self) -> job.SpotPolicy: def app_specs(self) -> List[job.AppSpec]: specs = super().app_specs() - OpenSSHExtension.patch_apps(specs) + self.sshd.add_app(specs) return specs diff --git a/cli/dstack/_internal/configurators/extensions/__init__.py b/cli/dstack/_internal/configurators/extensions/__init__.py new file mode 100644 index 000000000..a36596e71 --- /dev/null +++ b/cli/dstack/_internal/configurators/extensions/__init__.py @@ -0,0 +1,3 @@ +from typing import Callable, List + +CommandsExtension = Callable[[List[str]], None] diff --git a/cli/dstack/_internal/configurators/extensions/shell.py b/cli/dstack/_internal/configurators/extensions/shell.py new file mode 100644 index 000000000..e97b85d8a --- /dev/null +++ b/cli/dstack/_internal/configurators/extensions/shell.py @@ -0,0 +1,13 @@ +from typing import List + +from dstack._internal.configurators.extensions import CommandsExtension + + +def require(executables: List[str]) -> CommandsExtension: + def wrapper(commands: List[str]): + for exe in executables: + commands.append( + f'((command -v {exe} > /dev/null) || (echo "{exe} is required" && exit 1))' + ) + + return wrapper diff --git a/cli/dstack/_internal/configurators/extensions/ssh.py b/cli/dstack/_internal/configurators/extensions/ssh.py new file mode 100644 index 000000000..3a571adbb --- /dev/null +++ b/cli/dstack/_internal/configurators/extensions/ssh.py @@ -0,0 +1,37 @@ +from typing import List, Optional + +from dstack._internal.core.app import AppSpec + + +class SSHd: + def __init__(self, key_pub: str, *, port: int = 10022): + self.key_pub = key_pub + self.port = port + self.map_to_port: Optional[int] = None + + def set_permissions(self, commands: List[str]): + commands.extend( + [ + f'sed -i "s/.*PasswordAuthentication.*/PasswordAuthentication no/g" /etc/ssh/sshd_config', + f"mkdir -p /run/sshd ~/.ssh", + f"chmod 700 ~/.ssh", + f"touch ~/.ssh/authorized_keys", + f"chmod 600 ~/.ssh/authorized_keys", + f"rm -rf /etc/ssh/ssh_host_*", + ] + ) + + def start(self, commands: List[str]): + commands.extend( + [ + f'echo "{self.key_pub}" >> ~/.ssh/authorized_keys', + f"env >> ~/.ssh/environment", + f"ssh-keygen -A > /dev/null", + f"/usr/sbin/sshd -p {self.port} -o PermitUserEnvironment=yes", + ] + ) + + def add_app(self, apps: List[AppSpec]): + apps.append( + AppSpec(port=self.port, map_to_port=self.map_to_port, app_name="openssh-server") + ) diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py index c63abd8aa..d1a97dbed 100644 --- a/cli/dstack/_internal/core/configuration.py +++ b/cli/dstack/_internal/core/configuration.py @@ -25,7 +25,8 @@ class Artifact(ForbidExtra): class BaseConfiguration(ForbidExtra): type: Literal["none"] image: Optional[str] - # todo entrypoint + entrypoint: Optional[str] + home_dir: str = "/root" registry_auth: Optional[RegistryAuth] python: Optional[PythonVersions] ports: List[Union[str, int]] = [] @@ -33,8 +34,10 @@ class BaseConfiguration(ForbidExtra): build: List[str] = [] cache: List[str] = [] - @validator("python", pre=True) - def convert_python(cls, v) -> str: + @validator("python", pre=True, always=True) + def convert_python(cls, v, values) -> Optional[str]: + if v is not None and values.get("image"): + raise KeyError("`image` and `python` are mutually exclusive fields") if isinstance(v, float): v = str(v) if v == "3.1": From 87b8730fe879778284fc1d62e69a8f5a83fb77f3 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 16:32:06 +0400 Subject: [PATCH 10/26] Reimplement vscode extension --- .../configurators/dev_environment.py | 25 +++---- .../configurators/extensions/__init__.py | 15 ++++ .../configurators/extensions/vscode.py | 74 +++++++++++++++++++ 3 files changed, 101 insertions(+), 13 deletions(-) create mode 100644 cli/dstack/_internal/configurators/extensions/vscode.py diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index f586bf3b2..cb5505a59 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -2,25 +2,22 @@ import dstack._internal.core.job as job from dstack._internal.configurators import JobConfigurator +from dstack._internal.configurators.extensions import IDEExtension from dstack._internal.configurators.extensions.shell import require from dstack._internal.configurators.extensions.ssh import SSHd +from dstack._internal.configurators.extensions.vscode import VSCodeDesktop from dstack._internal.core.configuration import DevEnvironmentConfiguration from dstack._internal.core.profile import Profile from dstack._internal.core.repo import Repo -from dstack._internal.providers.extensions import VSCodeDesktopServer from dstack._internal.providers.ports import get_map_to_port require_sshd = require(["sshd"]) -install_ipykernel = ( - f'pip install -q --no-cache-dir ipykernel || echo "no pip, ipykernel was not installed"' -) -vscode_extensions = ["ms-python.python", "ms-toolsai.jupyter"] +install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"' class DevEnvironmentConfigurator(JobConfigurator): conf: DevEnvironmentConfiguration - # todo handle NoVSCodeVersionError def __init__( self, working_dir: str, @@ -30,10 +27,14 @@ def __init__( ): super().__init__(working_dir, configuration_path, configuration, profile) self.sshd: Optional[SSHd] = None + self.ide: Optional[IDEExtension] = None def get_jobs( self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str ) -> List[job.Job]: + self.ide = VSCodeDesktop( + extensions=["ms-python.python", "ms-toolsai.jupyter"], run_name=run_name + ) self.sshd = SSHd(ssh_key_pub) self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port) return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub) @@ -44,16 +45,14 @@ def commands(self) -> List[str]: require_sshd(commands) self.sshd.set_permissions(commands) self.sshd.start(commands) - VSCodeDesktopServer.patch_commands(commands, vscode_extensions=vscode_extensions) + self.ide.install_if_not_found(commands) commands.append(install_ipykernel) commands.extend(self.conf.init) + commands.append("echo ''") + + self.ide.print_readme(commands) commands.extend( [ - "echo ''", - f"echo To open in VS Code Desktop, use one of these links:", - f"echo ''", - f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}/workflow'", - "echo ''", f"echo 'To connect via SSH, use: `ssh {self.run_name}`'", "echo ''", "echo -n 'To exit, press Ctrl+C.'", @@ -66,7 +65,7 @@ def optional_build_commands(self) -> List[str]: commands = [] if self.conf.image: require_sshd(commands) - VSCodeDesktopServer.patch_setup(commands, vscode_extensions=vscode_extensions) + self.ide.install(commands) commands.append(install_ipykernel) return commands diff --git a/cli/dstack/_internal/configurators/extensions/__init__.py b/cli/dstack/_internal/configurators/extensions/__init__.py index a36596e71..49e1835b6 100644 --- a/cli/dstack/_internal/configurators/extensions/__init__.py +++ b/cli/dstack/_internal/configurators/extensions/__init__.py @@ -1,3 +1,18 @@ +from abc import ABC, abstractmethod from typing import Callable, List CommandsExtension = Callable[[List[str]], None] + + +class IDEExtension(ABC): + @abstractmethod + def install(self, commands: List[str]): + pass + + @abstractmethod + def install_if_not_found(self, commands: List[str]): + pass + + @abstractmethod + def print_readme(self, commands: List[str]): + pass diff --git a/cli/dstack/_internal/configurators/extensions/vscode.py b/cli/dstack/_internal/configurators/extensions/vscode.py new file mode 100644 index 000000000..b5a75cee2 --- /dev/null +++ b/cli/dstack/_internal/configurators/extensions/vscode.py @@ -0,0 +1,74 @@ +import subprocess +from typing import List, Optional + +from dstack._internal.cli.common import console +from dstack._internal.configurators.extensions import IDEExtension + + +class VSCodeDesktop(IDEExtension): + def __init__( + self, extensions: List[str], version: Optional[str] = None, run_name: Optional[str] = None + ): + self.extensions = extensions + if version is None: + version = self.detect_code_version() + if version is None: + console.print( + "[grey58]Unable to detect the VS Code version and pre-install extensions. " + "Fix by opening [sea_green3]Command Palette[/sea_green3], executing [sea_green3]Shell Command: " + "Install 'code' command in PATH[/sea_green3], and restarting terminal.[/]\n" + ) + self.version = version + self.run_name = run_name + + @classmethod + def detect_code_version(cls, exe: str = "code") -> Optional[str]: + try: + run = subprocess.run([exe, "--version"], capture_output=True) + except FileNotFoundError: + return None + if run.returncode == 0: + return run.stdout.decode().split("\n")[1].strip() + return None + + def install(self, commands: List[str]): + if self.version is None: + return + url = ( + f"https://update.code.visualstudio.com/commit:{self.version}/server-linux-$arch/stable" + ) + archive = "vscode-server-linux-$arch.tar.gz" + target = f'~/.vscode-server/bin/"{self.version}"' + commands.extend( + [ + f'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', + f"mkdir -p /tmp", + f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', + f"mkdir -vp {target}", + f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', + f'rm "/tmp/{archive}"', + ] + ) + if self.extensions: + extensions = " ".join(f'--install-extension "{name}"' for name in self.extensions) + commands.append(f'PATH="$PATH":{target}/bin code-server {extensions}') + + def install_if_not_found(self, commands: List[str]): + if self.version is None: + return + install_commands = [] + self.install(install_commands) + install_commands = " && ".join(install_commands) + commands.append( + f'if [ ! -d ~/.vscode-server/bin/"{self.version}" ]; then {install_commands}; fi' + ) + + def print_readme(self, commands: List[str]): + commands.extend( + [ + f"echo To open in VS Code Desktop, use link below:", + f"echo ''", + f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}/workflow'", + f"echo ''", + ] + ) From 70175ea54b42231cfb33852150d7d5dec4d7779a Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 17:18:30 +0400 Subject: [PATCH 11/26] Add ssh support to tasks configuration --- .../configurators/dev_environment.py | 14 ++------- cli/dstack/_internal/configurators/task.py | 30 +++++++++++++++++-- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index cb5505a59..c310263ac 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -7,7 +7,6 @@ from dstack._internal.configurators.extensions.ssh import SSHd from dstack._internal.configurators.extensions.vscode import VSCodeDesktop from dstack._internal.core.configuration import DevEnvironmentConfiguration -from dstack._internal.core.profile import Profile from dstack._internal.core.repo import Repo from dstack._internal.providers.ports import get_map_to_port @@ -17,17 +16,8 @@ class DevEnvironmentConfigurator(JobConfigurator): conf: DevEnvironmentConfiguration - - def __init__( - self, - working_dir: str, - configuration_path: str, - configuration: DevEnvironmentConfiguration, - profile: Profile, - ): - super().__init__(working_dir, configuration_path, configuration, profile) - self.sshd: Optional[SSHd] = None - self.ide: Optional[IDEExtension] = None + sshd: Optional[SSHd] + ide: Optional[IDEExtension] def get_jobs( self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py index 736300973..428baddae 100644 --- a/cli/dstack/_internal/configurators/task.py +++ b/cli/dstack/_internal/configurators/task.py @@ -1,15 +1,28 @@ -from typing import List +from typing import List, Optional -from dstack._internal.configurators import JobConfigurator +from dstack._internal.configurators import JobConfigurator, validate_local_path +from dstack._internal.configurators.extensions.ssh import SSHd from dstack._internal.core import job as job from dstack._internal.core.configuration import TaskConfiguration +from dstack._internal.core.repo import Repo +from dstack._internal.providers.ports import get_map_to_port class TaskConfigurator(JobConfigurator): conf: TaskConfiguration + sshd: Optional[SSHd] + + def get_jobs( + self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str + ) -> List[job.Job]: + self.sshd = SSHd(ssh_key_pub) + self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port) + return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub) def commands(self) -> List[str]: commands = [] + if self.conf.image is None: + self.sshd.start(commands) commands.extend(self.conf.commands) return commands @@ -19,8 +32,19 @@ def optional_build_commands(self) -> List[str]: def artifact_specs(self) -> List[job.ArtifactSpec]: specs = [] for a in self.conf.artifacts: - specs.append(job.ArtifactSpec(artifact_path=a.path, mount=a.mount)) + specs.append( + job.ArtifactSpec( + artifact_path=validate_local_path(a.path, self.home_dir(), self.working_dir), + mount=a.mount, + ) + ) return specs def dep_specs(self) -> List[job.DepSpec]: return [] # not available yet + + def app_specs(self) -> List[job.AppSpec]: + specs = super().app_specs() + if self.conf.image is None: + self.sshd.add_app(specs) + return specs From 8bc8141c325cd83e62662a416467018a6600b316 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 18:14:59 +0400 Subject: [PATCH 12/26] Replace literals with enums --- cli/dstack/_internal/backend/base/build.py | 8 ++++---- .../_internal/cli/commands/run/__init__.py | 1 - .../_internal/configurators/__init__.py | 11 +++++----- cli/dstack/_internal/core/build.py | 7 +++++++ cli/dstack/_internal/core/configuration.py | 16 +++++++++++---- cli/dstack/_internal/core/job.py | 20 +++++-------------- 6 files changed, 34 insertions(+), 29 deletions(-) diff --git a/cli/dstack/_internal/backend/base/build.py b/cli/dstack/_internal/backend/base/build.py index c3821badf..94669584a 100644 --- a/cli/dstack/_internal/backend/base/build.py +++ b/cli/dstack/_internal/backend/base/build.py @@ -5,7 +5,7 @@ import cpuinfo from dstack._internal.backend.base.storage import Storage -from dstack._internal.core.build import BuildNotFoundError, BuildPlan, DockerPlatform +from dstack._internal.core.build import BuildNotFoundError, BuildPlan, BuildPolicy, DockerPlatform from dstack._internal.core.job import Job from dstack._internal.utils.escape import escape_head @@ -13,7 +13,7 @@ def predict_build_plan( storage: Storage, job: Job, platform: Optional[DockerPlatform] ) -> BuildPlan: - if job.build_policy in ["force-build", "build-only"]: + if job.build_policy in [BuildPolicy.FORCE_BUILD, BuildPolicy.BUILD_ONLY]: return BuildPlan.yes if platform is None: @@ -22,11 +22,11 @@ def predict_build_plan( return BuildPlan.use if job.build_commands: - if job.build_policy == "use-build": + if job.build_policy == BuildPolicy.USE_BUILD: raise BuildNotFoundError("Build not found. Run `dstack build` or add `--build` flag") return BuildPlan.yes - if job.optional_build_commands and job.build_policy == "build": + if job.optional_build_commands and job.build_policy == BuildPolicy.BUILD: return BuildPlan.yes return BuildPlan.no diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index cd5b4c14e..fede51beb 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -420,7 +420,6 @@ def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: Ports for run in _poll_run_head(hub_client, run_name, loop_statuses=[JobStatus.BUILDING]): pass app_ports = ports_lock.release() - # TODO replace long delay with starting ssh-server in the beginning for delay in range(0, 60 * 10 + 1, POLL_PROVISION_RATE_SECS): # retry time.sleep(POLL_PROVISION_RATE_SECS if delay else 0) # skip first sleep if run_ssh_tunnel(run_name, app_ports): diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index d285b02bb..99c735e14 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -11,7 +11,8 @@ import dstack._internal.core.job as job import dstack._internal.providers.ports as ports import dstack.version as version -from dstack._internal.core.configuration import BaseConfiguration +from dstack._internal.core.build import BuildPolicy +from dstack._internal.core.configuration import BaseConfiguration, PythonVersion from dstack._internal.core.error import DstackError from dstack._internal.core.profile import Profile from dstack._internal.core.repo import Repo @@ -31,7 +32,7 @@ def __init__( self.working_dir = working_dir self.conf = configuration self.profile = profile - self.build_policy = "use-build" + self.build_policy = BuildPolicy.USE_BUILD # context self.run_name: Optional[str] = None self.ssh_key_pub: Optional[str] = None @@ -60,7 +61,7 @@ def get_parser( retry_group.add_argument("--retry-limit", type=str) build_policy = parser.add_mutually_exclusive_group() - for value in job.BuildPolicy: + for value in BuildPolicy: build_policy.add_argument( f"--{value}", action="store_const", dest="build_policy", const=value ) @@ -211,9 +212,9 @@ def app_specs(self) -> List[job.AppSpec]: def python(self) -> str: if self.conf.python is not None: - return self.conf.python + return self.conf.python.value version_info = sys.version_info - return f"{version_info.major}.{version_info.minor}" # todo check if is in supported + return PythonVersion(f"{version_info.major}.{version_info.minor}").value def ports(self) -> Dict[int, ports.PortMapping]: mapping = [ports.PortMapping(p) for p in self.conf.ports] diff --git a/cli/dstack/_internal/core/build.py b/cli/dstack/_internal/core/build.py index ff712fbad..406529b0d 100644 --- a/cli/dstack/_internal/core/build.py +++ b/cli/dstack/_internal/core/build.py @@ -3,6 +3,13 @@ from dstack._internal.core.error import DstackError +class BuildPolicy(str, Enum): + USE_BUILD = "use-build" + BUILD = "build" + FORCE_BUILD = "force-build" + BUILD_ONLY = "build-only" + + class DockerPlatform(str, Enum): amd64 = "amd64" arm64 = "arm64" diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py index d1a97dbed..2ee05d115 100644 --- a/cli/dstack/_internal/core/configuration.py +++ b/cli/dstack/_internal/core/configuration.py @@ -1,10 +1,16 @@ +from enum import Enum from typing import Dict, List, Optional, Union from pydantic import BaseModel, Extra, Field, validator from typing_extensions import Annotated, Literal -# todo use Enum -PythonVersions = Literal["3.7", "3.8", "3.9", "3.10", "3.11"] + +class PythonVersion(str, Enum): + PY37 = "3.7" + PY38 = "3.8" + PY39 = "3.9" + PY310 = "3.10" + PY311 = "3.11" class ForbidExtra(BaseModel): @@ -28,20 +34,22 @@ class BaseConfiguration(ForbidExtra): entrypoint: Optional[str] home_dir: str = "/root" registry_auth: Optional[RegistryAuth] - python: Optional[PythonVersions] + python: Optional[PythonVersion] ports: List[Union[str, int]] = [] env: Dict[str, str] = {} build: List[str] = [] cache: List[str] = [] @validator("python", pre=True, always=True) - def convert_python(cls, v, values) -> Optional[str]: + def convert_python(cls, v, values) -> Optional[PythonVersion]: if v is not None and values.get("image"): raise KeyError("`image` and `python` are mutually exclusive fields") if isinstance(v, float): v = str(v) if v == "3.1": v = "3.10" + if isinstance(v, str): + return PythonVersion(v) return v diff --git a/cli/dstack/_internal/core/job.py b/cli/dstack/_internal/core/job.py index 8571becd4..5a594c8bf 100644 --- a/cli/dstack/_internal/core/job.py +++ b/cli/dstack/_internal/core/job.py @@ -2,10 +2,11 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import BaseModel, Field, root_validator from dstack._internal.core.app import AppSpec from dstack._internal.core.artifact import ArtifactSpec +from dstack._internal.core.build import BuildPolicy from dstack._internal.core.cache import CacheSpec from dstack._internal.core.dependents import DepSpec from dstack._internal.core.repo import ( @@ -18,9 +19,6 @@ RepoRef, ) -# todo use Enum -BuildPolicy = ["use-build", "build", "force-build", "build-only"] - class GpusRequirements(BaseModel): count: Optional[int] = None @@ -208,7 +206,7 @@ class Job(JobHead): location: Optional[str] tag_name: Optional[str] ssh_key_pub: Optional[str] - build_policy: Optional[str] + build_policy: BuildPolicy = BuildPolicy.USE_BUILD build_commands: Optional[List[str]] optional_build_commands: Optional[List[str]] run_env: Optional[Dict[str, str]] # deprecated @@ -228,14 +226,6 @@ def preprocess_data(cls, data): ) return data - @validator("build_policy") - def default_build_policy(cls, v: Optional[str]) -> str: - if not v: - return BuildPolicy[0] - if v not in BuildPolicy: - raise KeyError(f"Unknown build policy: {v}") - return v - def get_instance_spot_type(self) -> str: if self.requirements and self.requirements.spot: return "spot" @@ -311,7 +301,7 @@ def serialize(self) -> dict: "ssh_key_pub": self.ssh_key_pub or "", "repo_code_filename": self.repo_code_filename, "instance_type": self.instance_type, - "build_policy": self.build_policy, + "build_policy": self.build_policy.value, "build_commands": self.build_commands or [], "optional_build_commands": self.optional_build_commands or [], "run_env": self.run_env or {}, @@ -447,7 +437,7 @@ def unserialize(job_data: dict): tag_name=job_data.get("tag_name") or None, ssh_key_pub=job_data.get("ssh_key_pub") or None, instance_type=job_data.get("instance_type") or None, - build_policy=job_data.get("build_policy") or None, + build_policy=job_data.get("build_policy") or BuildPolicy.USE_BUILD, build_commands=job_data.get("build_commands") or None, optional_build_commands=job_data.get("optional_build_commands") or None, run_env=job_data.get("run_env") or None, From e2ea26f9284253b8d0ac1fe19b51d5d538473051 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 18:18:23 +0400 Subject: [PATCH 13/26] Pin pydantic v1 --- cli/requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/requirements.txt b/cli/requirements.txt index 9481ba633..8e6d709d3 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -11,7 +11,7 @@ rich-argparse fastapi starlette uvicorn -pydantic +pydantic==1.10.10 sqlalchemy[asyncio] py-cpuinfo websocket-client diff --git a/setup.py b/setup.py index 082538342..f9b9b23ee 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_long_description(): "fastapi", "starlette>=0.26.0", "uvicorn", - "pydantic", + "pydantic==1.10.10", "sqlalchemy[asyncio]>=2.0.0", "websocket-client", "cursor", From da99c0a721a423282f1323ddddc333f913067f14 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 19:13:13 +0400 Subject: [PATCH 14/26] Make pydantic version flexible --- cli/requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/requirements.txt b/cli/requirements.txt index 8e6d709d3..25648b967 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -11,7 +11,7 @@ rich-argparse fastapi starlette uvicorn -pydantic==1.10.10 +pydantic<=1.10.10 sqlalchemy[asyncio] py-cpuinfo websocket-client diff --git a/setup.py b/setup.py index f9b9b23ee..83cb7d73e 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_long_description(): "fastapi", "starlette>=0.26.0", "uvicorn", - "pydantic==1.10.10", + "pydantic<=1.10.10", "sqlalchemy[asyncio]>=2.0.0", "websocket-client", "cursor", From e9e0dfba3cf87e1c8d5e9213dcb23b5fb97ffdb7 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 19:41:51 +0400 Subject: [PATCH 15/26] Fix entrypoint tests --- cli/tests/providers/docker/test_entrypoint.py | 89 +++++-------------- 1 file changed, 24 insertions(+), 65 deletions(-) diff --git a/cli/tests/providers/docker/test_entrypoint.py b/cli/tests/providers/docker/test_entrypoint.py index 5bb1b87a5..902e3649f 100644 --- a/cli/tests/providers/docker/test_entrypoint.py +++ b/cli/tests/providers/docker/test_entrypoint.py @@ -1,83 +1,42 @@ import unittest -from argparse import Namespace from typing import List, Optional -from unittest import mock +from dstack._internal.configurators.task import TaskConfiguration, TaskConfigurator +from dstack._internal.core.job import Job +from dstack._internal.core.profile import Profile from dstack._internal.core.repo import RemoteRepo -from dstack._internal.providers.docker.main import DockerProvider -def create_provider_data( - commands: Optional[List[str]] = None, entrypoint: Optional[str] = None -) -> dict: - return { - "image": "ubuntu:20.04", - "commands": commands, - "entrypoint": entrypoint, - "configuration_type": "task", - } - - -args = Namespace(args=[], unknown=[], detach=True) +def configure_job(commands: List[str], entrypoint: Optional[str]) -> Job: + conf = TaskConfiguration( + image="ubuntu:20.04", + commands=commands, + entrypoint=entrypoint, + ) + repo = RemoteRepo(repo_url="https://github.com/dstackai/dstack-playground.git") + configurator = TaskConfigurator(".", ".dstack.yaml", conf, Profile(name="default")) + return configurator.get_jobs(repo, "run-name-1", "code.tar", "key.pub")[0] class TestEntrypoint(unittest.TestCase): - def setUp(self) -> None: - self.hub_client = mock.Mock() - self.hub_client.configure_mock( - repo=RemoteRepo(repo_url="https://github.com/dstackai/dstack-playground.git") - ) - def test_no_commands(self): - provider = DockerProvider() - provider.load( - self.hub_client, args, "dummy-workflow", create_provider_data(), "dummy-run-1" - ) - for job in provider.get_jobs(self.hub_client.repo, "", ""): - data = job.serialize() - self.assertListEqual(data["commands"], []) - self.assertEqual(data["entrypoint"], None) + job = configure_job([], None) + self.assertListEqual(job.commands, []) + self.assertEqual(job.entrypoint, None) def test_no_entrypoint(self): commands = ["echo 123", "whoami"] - provider = DockerProvider() - provider.load( - self.hub_client, - args, - "dummy-workflow", - create_provider_data(commands=commands), - "dummy-run-1", - ) - for job in provider.get_jobs(self.hub_client.repo, "", ""): - data = job.serialize() - self.assertListEqual(data["commands"], commands) - self.assertListEqual(data["entrypoint"], ["/bin/sh", "-i", "-c"]) + job = configure_job(commands, None) + self.assertListEqual(job.commands, commands) + self.assertListEqual(job.entrypoint, ["/bin/sh", "-i", "-c"]) def test_only_entrypoint(self): - provider = DockerProvider() - provider.load( - self.hub_client, - args, - "dummy-workflow", - create_provider_data(entrypoint="/bin/bash -ic"), - "dummy-run-1", - ) - for job in provider.get_jobs(self.hub_client.repo, "", ""): - data = job.serialize() - self.assertListEqual(data["commands"], []) - self.assertListEqual(data["entrypoint"], ["/bin/bash", "-ic"]) + job = configure_job([], "/bin/bash -ic") + self.assertListEqual(job.commands, []) + self.assertListEqual(job.entrypoint, ["/bin/bash", "-ic"]) def test_entrypoint_override(self): commands = ["echo 123", "whoami"] - provider = DockerProvider() - provider.load( - self.hub_client, - args, - "dummy-workflow", - create_provider_data(commands=commands, entrypoint="/bin/bash -ic"), - "dummy-run-1", - ) - for job in provider.get_jobs(self.hub_client.repo, "", ""): - data = job.serialize() - self.assertListEqual(data["commands"], commands) - self.assertListEqual(data["entrypoint"], ["/bin/bash", "-ic"]) + job = configure_job(commands, "/bin/bash -ic") + self.assertListEqual(job.commands, commands) + self.assertListEqual(job.entrypoint, ["/bin/bash", "-ic"]) From 6286009b0585e0ab4b99c0ffdae8cf357a1e24a5 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 19:57:32 +0400 Subject: [PATCH 16/26] Move ports.py --- cli/dstack/_internal/cli/commands/run/ssh_tunnel.py | 2 +- cli/dstack/_internal/configurators/__init__.py | 2 +- cli/dstack/_internal/configurators/dev_environment.py | 2 +- cli/dstack/_internal/{providers => configurators}/ports.py | 0 cli/dstack/_internal/configurators/task.py | 2 +- cli/dstack/_internal/providers/__init__.py | 2 +- cli/dstack/_internal/providers/bash/main.py | 2 +- cli/dstack/_internal/providers/code/main.py | 2 +- cli/dstack/_internal/providers/docker/main.py | 2 +- cli/dstack/_internal/providers/lab/main.py | 2 +- cli/dstack/_internal/providers/notebook/main.py | 2 +- cli/dstack/_internal/providers/ssh/main.py | 2 +- 12 files changed, 11 insertions(+), 11 deletions(-) rename cli/dstack/_internal/{providers => configurators}/ports.py (100%) diff --git a/cli/dstack/_internal/cli/commands/run/ssh_tunnel.py b/cli/dstack/_internal/cli/commands/run/ssh_tunnel.py index a696be459..72454b6af 100644 --- a/cli/dstack/_internal/cli/commands/run/ssh_tunnel.py +++ b/cli/dstack/_internal/cli/commands/run/ssh_tunnel.py @@ -3,7 +3,7 @@ import subprocess from typing import Dict, List, Optional -from dstack._internal.providers.ports import PortUsedError +from dstack._internal.configurators.ports import PortUsedError class PortsLock: diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index 99c735e14..ea4f93d5a 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -8,8 +8,8 @@ from rich_argparse import RichHelpFormatter +import dstack._internal.configurators.ports as ports import dstack._internal.core.job as job -import dstack._internal.providers.ports as ports import dstack.version as version from dstack._internal.core.build import BuildPolicy from dstack._internal.core.configuration import BaseConfiguration, PythonVersion diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index c310263ac..2725ed4da 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -6,9 +6,9 @@ from dstack._internal.configurators.extensions.shell import require from dstack._internal.configurators.extensions.ssh import SSHd from dstack._internal.configurators.extensions.vscode import VSCodeDesktop +from dstack._internal.configurators.ports import get_map_to_port from dstack._internal.core.configuration import DevEnvironmentConfiguration from dstack._internal.core.repo import Repo -from dstack._internal.providers.ports import get_map_to_port require_sshd = require(["sshd"]) install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"' diff --git a/cli/dstack/_internal/providers/ports.py b/cli/dstack/_internal/configurators/ports.py similarity index 100% rename from cli/dstack/_internal/providers/ports.py rename to cli/dstack/_internal/configurators/ports.py diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py index 428baddae..25c26dc00 100644 --- a/cli/dstack/_internal/configurators/task.py +++ b/cli/dstack/_internal/configurators/task.py @@ -2,10 +2,10 @@ from dstack._internal.configurators import JobConfigurator, validate_local_path from dstack._internal.configurators.extensions.ssh import SSHd +from dstack._internal.configurators.ports import get_map_to_port from dstack._internal.core import job as job from dstack._internal.core.configuration import TaskConfiguration from dstack._internal.core.repo import Repo -from dstack._internal.providers.ports import get_map_to_port class TaskConfigurator(JobConfigurator): diff --git a/cli/dstack/_internal/providers/__init__.py b/cli/dstack/_internal/providers/__init__.py index 37d215753..a072372a4 100644 --- a/cli/dstack/_internal/providers/__init__.py +++ b/cli/dstack/_internal/providers/__init__.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union import dstack.api.hub as hub +from dstack._internal.configurators.ports import PortMapping, merge_ports from dstack._internal.core.cache import CacheSpec from dstack._internal.core.error import RepoNotInitializedError from dstack._internal.core.job import ( @@ -25,7 +26,6 @@ SpotPolicy, ) from dstack._internal.core.repo.base import Repo -from dstack._internal.providers.ports import PortMapping, merge_ports from dstack._internal.utils.common import get_milliseconds_since_epoch, parse_pretty_duration from dstack._internal.utils.interpolator import VariablesInterpolator diff --git a/cli/dstack/_internal/providers/bash/main.py b/cli/dstack/_internal/providers/bash/main.py index d4066b3a5..72bf585e3 100644 --- a/cli/dstack/_internal/providers/bash/main.py +++ b/cli/dstack/_internal/providers/bash/main.py @@ -5,11 +5,11 @@ import dstack.api.hub as hub from dstack import version +from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port from dstack._internal.core.app import AppSpec from dstack._internal.core.job import JobSpec from dstack._internal.providers import Provider from dstack._internal.providers.extensions import OpenSSHExtension -from dstack._internal.providers.ports import filter_reserved_ports, get_map_to_port class BashProvider(Provider): diff --git a/cli/dstack/_internal/providers/code/main.py b/cli/dstack/_internal/providers/code/main.py index 5b0a95f15..2f43dac99 100644 --- a/cli/dstack/_internal/providers/code/main.py +++ b/cli/dstack/_internal/providers/code/main.py @@ -6,11 +6,11 @@ import dstack.api.hub as hub from dstack import version +from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port from dstack._internal.core.app import AppSpec from dstack._internal.core.job import JobSpec from dstack._internal.providers import Provider from dstack._internal.providers.extensions import OpenSSHExtension -from dstack._internal.providers.ports import filter_reserved_ports, get_map_to_port class CodeProvider(Provider): diff --git a/cli/dstack/_internal/providers/docker/main.py b/cli/dstack/_internal/providers/docker/main.py index 0bbe25875..6704366ba 100644 --- a/cli/dstack/_internal/providers/docker/main.py +++ b/cli/dstack/_internal/providers/docker/main.py @@ -4,10 +4,10 @@ from rich_argparse import RichHelpFormatter import dstack.api.hub as hub +from dstack._internal.configurators.ports import filter_reserved_ports from dstack._internal.core.app import AppSpec from dstack._internal.core.job import JobSpec from dstack._internal.providers import Provider -from dstack._internal.providers.ports import filter_reserved_ports class DockerProvider(Provider): diff --git a/cli/dstack/_internal/providers/lab/main.py b/cli/dstack/_internal/providers/lab/main.py index 9eaad7d4d..878d9b768 100644 --- a/cli/dstack/_internal/providers/lab/main.py +++ b/cli/dstack/_internal/providers/lab/main.py @@ -6,11 +6,11 @@ import dstack.api.hub as hub from dstack import version +from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port from dstack._internal.core.app import AppSpec from dstack._internal.core.job import JobSpec from dstack._internal.providers import Provider from dstack._internal.providers.extensions import OpenSSHExtension -from dstack._internal.providers.ports import filter_reserved_ports, get_map_to_port class LabProvider(Provider): diff --git a/cli/dstack/_internal/providers/notebook/main.py b/cli/dstack/_internal/providers/notebook/main.py index 3a6cf0120..00acdc4d1 100644 --- a/cli/dstack/_internal/providers/notebook/main.py +++ b/cli/dstack/_internal/providers/notebook/main.py @@ -6,11 +6,11 @@ import dstack.api.hub as hub from dstack import version +from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port from dstack._internal.core.app import AppSpec from dstack._internal.core.job import JobSpec from dstack._internal.providers import Provider from dstack._internal.providers.extensions import OpenSSHExtension -from dstack._internal.providers.ports import filter_reserved_ports, get_map_to_port class NotebookProvider(Provider): diff --git a/cli/dstack/_internal/providers/ssh/main.py b/cli/dstack/_internal/providers/ssh/main.py index 94e613ccf..83fdb02a3 100644 --- a/cli/dstack/_internal/providers/ssh/main.py +++ b/cli/dstack/_internal/providers/ssh/main.py @@ -5,11 +5,11 @@ import dstack.api.hub as hub from dstack import version +from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port from dstack._internal.core.app import AppSpec from dstack._internal.core.job import JobSpec from dstack._internal.providers import Provider from dstack._internal.providers.extensions import OpenSSHExtension -from dstack._internal.providers.ports import filter_reserved_ports, get_map_to_port class SSHProvider(Provider): From b55066db1e03ca0163a26a3c6f9cf89c12b5ff6e Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 20:04:07 +0400 Subject: [PATCH 17/26] Test configurators instead of providers --- .../{providers => configurators}/__init__.py | 0 .../test_entrypoint.py | 2 +- cli/tests/configurators/test_local_path.py | 30 +++++++++++++ cli/tests/providers/docker/__init__.py | 0 cli/tests/providers/test_local_path.py | 44 ------------------- 5 files changed, 31 insertions(+), 45 deletions(-) rename cli/tests/{providers => configurators}/__init__.py (100%) rename cli/tests/{providers/docker => configurators}/test_entrypoint.py (94%) create mode 100644 cli/tests/configurators/test_local_path.py delete mode 100644 cli/tests/providers/docker/__init__.py delete mode 100644 cli/tests/providers/test_local_path.py diff --git a/cli/tests/providers/__init__.py b/cli/tests/configurators/__init__.py similarity index 100% rename from cli/tests/providers/__init__.py rename to cli/tests/configurators/__init__.py diff --git a/cli/tests/providers/docker/test_entrypoint.py b/cli/tests/configurators/test_entrypoint.py similarity index 94% rename from cli/tests/providers/docker/test_entrypoint.py rename to cli/tests/configurators/test_entrypoint.py index 902e3649f..07f0a14fa 100644 --- a/cli/tests/providers/docker/test_entrypoint.py +++ b/cli/tests/configurators/test_entrypoint.py @@ -14,7 +14,7 @@ def configure_job(commands: List[str], entrypoint: Optional[str]) -> Job: entrypoint=entrypoint, ) repo = RemoteRepo(repo_url="https://github.com/dstackai/dstack-playground.git") - configurator = TaskConfigurator(".", ".dstack.yaml", conf, Profile(name="default")) + configurator = TaskConfigurator("docker", ".dstack.yaml", conf, Profile(name="default")) return configurator.get_jobs(repo, "run-name-1", "code.tar", "key.pub")[0] diff --git a/cli/tests/configurators/test_local_path.py b/cli/tests/configurators/test_local_path.py new file mode 100644 index 000000000..90a6cc5a7 --- /dev/null +++ b/cli/tests/configurators/test_local_path.py @@ -0,0 +1,30 @@ +import unittest + +from dstack._internal.configurators import HomeDirUnsetError, validate_local_path + + +class TestLocalPath(unittest.TestCase): + def test_absolute(self): + path = "/root/.cache" + self.assertEqual(path, validate_local_path(path, None, ".")) + + def test_relative(self): + path = ".cache/pip" + self.assertEqual("/workflow/" + path, validate_local_path(path, None, ".")) + + def test_relative_dot(self): + self.assertEqual("/workflow/cache/pip", validate_local_path("./cache/pip", None, ".")) + + def test_relative_dot_twice(self): + self.assertEqual("/workflow/cache/pip", validate_local_path("././cache/pip", None, ".")) + + def test_home(self): + home = "/root" + self.assertEqual(home, validate_local_path("~", home, ".")) + + def test_startswith_home(self): + self.assertEqual("/root/.cache", validate_local_path("~/.cache", "/root", ".")) + + def test_missing_home(self): + with self.assertRaises(HomeDirUnsetError): + validate_local_path("~/.cache", None, ".") diff --git a/cli/tests/providers/docker/__init__.py b/cli/tests/providers/docker/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/tests/providers/test_local_path.py b/cli/tests/providers/test_local_path.py deleted file mode 100644 index c17fab794..000000000 --- a/cli/tests/providers/test_local_path.py +++ /dev/null @@ -1,44 +0,0 @@ -import unittest - -from dstack._internal.providers import Provider - - -def make_provider(**kwargs): - p = Provider("test") - for k, v in kwargs.items(): - setattr(p, k, v) - return p - - -class TestLocalPath(unittest.TestCase): - def test_absolute(self): - p = make_provider() - path = "/root/.cache" - self.assertEqual(path, p._validate_local_path(path)) - - def test_relative(self): - p = make_provider() - path = ".cache/pip" - self.assertEqual(path, p._validate_local_path(path)) - - def test_relative_dot(self): - p = make_provider() - self.assertEqual("cache/pip", p._validate_local_path("./cache/pip")) - - def test_relative_dot_twice(self): - p = make_provider() - self.assertEqual("cache/pip", p._validate_local_path("././cache/pip")) - - def test_home(self): - home = "/root" - p = make_provider(home_dir=home) - self.assertEqual(home, p._validate_local_path("~")) - - def test_startswith_home(self): - p = make_provider(home_dir="/root") - self.assertEqual("/root/.cache", p._validate_local_path("~/.cache")) - - def test_missing_home(self): - p = make_provider() - with self.assertRaises(KeyError): - p._validate_local_path("~/.cache") From 237d4af65edecadaac22f6eba69c3cf73fb2c3ba Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Mon, 10 Jul 2023 20:05:57 +0400 Subject: [PATCH 18/26] Drop providers --- cli/dstack/_internal/providers/__init__.py | 556 ------------------ .../_internal/providers/bash/__init__.py | 0 cli/dstack/_internal/providers/bash/main.py | 108 ---- .../_internal/providers/code/__init__.py | 0 cli/dstack/_internal/providers/code/main.py | 156 ----- .../_internal/providers/docker/__init__.py | 0 cli/dstack/_internal/providers/docker/main.py | 97 --- cli/dstack/_internal/providers/extensions.py | 120 ---- .../_internal/providers/lab/__init__.py | 0 cli/dstack/_internal/providers/lab/main.py | 145 ----- .../_internal/providers/notebook/__init__.py | 0 .../_internal/providers/notebook/main.py | 142 ----- .../_internal/providers/ssh/__init__.py | 0 cli/dstack/_internal/providers/ssh/main.py | 122 ---- 14 files changed, 1446 deletions(-) delete mode 100644 cli/dstack/_internal/providers/__init__.py delete mode 100644 cli/dstack/_internal/providers/bash/__init__.py delete mode 100644 cli/dstack/_internal/providers/bash/main.py delete mode 100644 cli/dstack/_internal/providers/code/__init__.py delete mode 100644 cli/dstack/_internal/providers/code/main.py delete mode 100644 cli/dstack/_internal/providers/docker/__init__.py delete mode 100644 cli/dstack/_internal/providers/docker/main.py delete mode 100644 cli/dstack/_internal/providers/extensions.py delete mode 100644 cli/dstack/_internal/providers/lab/__init__.py delete mode 100644 cli/dstack/_internal/providers/lab/main.py delete mode 100644 cli/dstack/_internal/providers/notebook/__init__.py delete mode 100644 cli/dstack/_internal/providers/notebook/main.py delete mode 100644 cli/dstack/_internal/providers/ssh/__init__.py delete mode 100644 cli/dstack/_internal/providers/ssh/main.py diff --git a/cli/dstack/_internal/providers/__init__.py b/cli/dstack/_internal/providers/__init__.py deleted file mode 100644 index a072372a4..000000000 --- a/cli/dstack/_internal/providers/__init__.py +++ /dev/null @@ -1,556 +0,0 @@ -import argparse -import importlib -import shlex -import sys -import uuid -from abc import abstractmethod -from argparse import ArgumentParser, Namespace -from pkgutil import iter_modules -from typing import Any, Dict, List, Optional, Union - -import dstack.api.hub as hub -from dstack._internal.configurators.ports import PortMapping, merge_ports -from dstack._internal.core.cache import CacheSpec -from dstack._internal.core.error import RepoNotInitializedError -from dstack._internal.core.job import ( - ArtifactSpec, - BuildPolicy, - ConfigurationType, - DepSpec, - GpusRequirements, - Job, - JobSpec, - JobStatus, - Requirements, - RetryPolicy, - SpotPolicy, -) -from dstack._internal.core.repo.base import Repo -from dstack._internal.utils.common import get_milliseconds_since_epoch, parse_pretty_duration -from dstack._internal.utils.interpolator import VariablesInterpolator - -DEFAULT_CPU = 2 -DEFAULT_MEM = "8GB" - -DEFAULT_RETRY_LIMIT = 3600 - - -class Provider: - def __init__(self, provider_name: str): - self.provider_name: str = provider_name - self.provider_data: Optional[Dict[str, Any]] = None - self.provider_args: Optional[List[str]] = None - self.workflow_name: Optional[str] = None - self.run_as_provider: Optional[bool] = None - self.run_name: Optional[str] = None - self.dep_specs: Optional[List[DepSpec]] = None - self.cache_specs: List[CacheSpec] = [] - self.ssh_key_pub: Optional[str] = None - self.openssh_server: bool = False - self.loaded = False - self.home_dir: Optional[str] = None - self.ports: Dict[int, PortMapping] = {} - self.build_policy: Optional[str] = None - self.build_commands: List[str] = [] - self.optional_build_commands: List[str] = [] - self.commands: List[str] = [] - - # TODO: This is a dirty hack - def _safe_python_version(self, name: str): - python_version: str - v = self.provider_data.get(name) - if isinstance(v, str): - python_version = v - elif v == 3.1: - python_version = "3.10" - elif v: - python_version = str(v) - else: - version_info = sys.version_info - python_version = f"{version_info.major}.{version_info.minor}" - supported_python_versions = ["3.7", "3.8", "3.9", "3.10", "3.11"] - if python_version not in supported_python_versions: - sys.exit( - f"Python version `{python_version}` is not supported. " - f"Supported versions: {str(supported_python_versions)}." - ) - return python_version - - def _inject_context(self): - args = [] - for arg in self.provider_data.get("run_args", []): - if " " in arg: - arg = '"%s"' % arg.replace('"', '\\"') - args.append(arg) - - self.provider_data = self._inject_context_recursively( - VariablesInterpolator( - {"run": {"name": self.run_name, "args": " ".join(args)}}, skip=["secrets"] - ), - self.provider_data, - ) - - @staticmethod - def _inject_context_recursively(interpolator: VariablesInterpolator, obj: Any) -> Any: - if isinstance(obj, str): - return interpolator.interpolate(obj) - elif isinstance(obj, dict): - d = {} - for k in obj: - d[k] = Provider._inject_context_recursively(interpolator, obj[k]) - return d - elif isinstance(obj, list): - return [Provider._inject_context_recursively(interpolator, item) for item in obj] - else: - return obj - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - if getattr(args, "help", False): - self.help(workflow_name) - exit() # todo: find a better place for this - - self.provider_args = [] if args is None else args.args + args.unknown - self.workflow_name = workflow_name - self.provider_data = provider_data - self.run_as_provider = not workflow_name - self.configuration_type = ConfigurationType(self.provider_data["configuration_type"]) - self.run_name = run_name - self.ssh_key_pub = ssh_key_pub - self.openssh_server = self.provider_data.get("ssh", self.openssh_server) - self.build_policy = self.provider_data.get("build_policy") - - self.parse_args() - self.ports = self.provider_data.get("ports") or {} - if self.ssh_key_pub is None: - if self.openssh_server or ( - hub_client.get_project_backend_type() != "local" and not args.detach - ): - raise RepoNotInitializedError( - "No valid SSH identity", project_name=hub_client.project - ) - - self._inject_context() - self.build_commands = self._get_list_data("build") or [] - self.optional_build_commands = self._get_list_data("optional_build") or [] - self.commands = self._get_list_data("commands") or [] - self.dep_specs = self._dep_specs(hub_client) - self.cache_specs = self._cache_specs() - self.loaded = True - - @abstractmethod - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - return None - - def help(self, workflow_name: Optional[str]): - parser = self._create_parser(workflow_name) - if parser: - parser.print_help() - - @abstractmethod - def create_job_specs(self) -> List[JobSpec]: - pass - - @staticmethod - def _add_base_args(parser: ArgumentParser): - parser.add_argument("-r", "--requirements", metavar="PATH", type=str) - parser.add_argument("-e", "--env", action="append") - parser.add_argument("-a", "--artifact", metavar="PATH", dest="artifacts", action="append") - parser.add_argument("--dep", metavar="(:TAG | WORKFLOW)", dest="deps", action="append") - parser.add_argument("-w", "--working-dir", metavar="PATH", type=str) - spot_group = parser.add_mutually_exclusive_group() - spot_group.add_argument("-i", "--interruptible", action="store_true") - spot_group.add_argument("--spot", action="store_true") - spot_group.add_argument("--on-demand", action="store_true") - spot_group.add_argument("--spot-auto", action="store_true") - spot_group.add_argument("--spot-policy", type=str) - - retry_group = parser.add_mutually_exclusive_group() - retry_group.add_argument("--retry", action="store_true") - retry_group.add_argument("--no-retry", action="store_true") - retry_group.add_argument("--retry-limit", type=str) - - parser.add_argument("--cpu", metavar="NUM", type=int) - parser.add_argument("--memory", metavar="SIZE", type=str) - parser.add_argument("--gpu", metavar="NUM", type=int) - parser.add_argument("--gpu-name", metavar="NAME", type=str) - parser.add_argument("--gpu-memory", metavar="SIZE", type=str) - parser.add_argument("--shm-size", metavar="SIZE", type=str) - parser.add_argument( - "-p", "--port", metavar="PORTS", type=PortMapping, nargs=argparse.ONE_OR_MORE - ) - build_policy = parser.add_mutually_exclusive_group() - for value in BuildPolicy: - build_policy.add_argument( - f"--{value}", action="store_const", dest="build_policy", const=value - ) - - def _parse_base_args(self, args: Namespace, unknown_args): - if args.requirements: - self.provider_data["requirements"] = args.requirements - if args.artifacts: - self.provider_data["artifacts"] = args.artifacts - if args.deps: - self.provider_data["deps"] = args.deps - if args.working_dir: - self.provider_data["working_dir"] = args.working_dir - if args.env: - env = self.provider_data.get("env") or [] - env.extend(args.env) - self.provider_data["env"] = env - - if args.spot_policy: - self.provider_data["spot_policy"] = args.spot_policy - if args.interruptible or args.spot: - self.provider_data["spot_policy"] = SpotPolicy.SPOT.value - if args.on_demand: - self.provider_data["spot_policy"] = SpotPolicy.ONDEMAND.value - if args.spot_auto: - self.provider_data["spot_policy"] = SpotPolicy.AUTO.value - - if args.retry: - self.provider_data["retry_policy"] = {"retry": True} - if args.no_retry: - self.provider_data["retry_policy"] = {"retry": False} - if args.retry_limit: - self.provider_data["retry_policy"] = {"retry": True, "limit": args.retry_limit} - - resources = self.provider_data.get("resources") or {} - self.provider_data["resources"] = resources - if args.cpu: - resources["cpu"] = args.cpu - if args.memory: - resources["memory"] = args.memory - if args.gpu or args.gpu_name or args.gpu_memory: - gpu = ( - self.provider_data["resources"].get("gpu") or {} - if self.provider_data.get("resources") - else {} - ) - if type(gpu) is int: - gpu = {"count": gpu} - resources["gpu"] = gpu - if args.gpu: - gpu["count"] = args.gpu - if args.gpu_memory: - gpu["memory"] = args.gpu_memory - if args.gpu_name: - gpu["name"] = args.gpu_name - if args.shm_size: - resources["shm_size"] = args.shm_size - self.provider_data["ports"] = merge_ports( - [PortMapping(i) for i in self.provider_data.get("ports") or []], args.port or [] - ) - if args.build_policy: - self.build_policy = args.build_policy - if unknown_args: - self.provider_data["run_args"] = unknown_args - - def parse_args(self): - pass - - def get_jobs( - self, - repo: Repo, - configuration_path: Optional[str] = None, - repo_code_filename: Optional[str] = None, - tag_name: Optional[str] = None, - ) -> List[Job]: - if not self.loaded: - raise Exception("The provider is not loaded") - created_at = get_milliseconds_since_epoch() - job_specs = self.create_job_specs() - jobs = [] - for i, job_spec in enumerate(job_specs): - job = Job( - job_id=f"{self.run_name},{self.workflow_name or ''},{i}", - repo_ref=repo.repo_ref, - hub_user_name="", # HUB will fill it later - repo_data=repo.repo_data, - run_name=self.run_name, - workflow_name=self.workflow_name or None, - provider_name=self.provider_name, - configuration_type=self.configuration_type, - configuration_path=configuration_path, - status=JobStatus.SUBMITTED, - created_at=created_at, - submitted_at=created_at, - image_name=job_spec.image_name, - registry_auth=job_spec.registry_auth, - commands=job_spec.commands, - entrypoint=job_spec.entrypoint, - env=job_spec.env, - home_dir=self.home_dir, - working_dir=job_spec.working_dir, - artifact_specs=job_spec.artifact_specs, - cache_specs=self.cache_specs, - host_name=None, - spot_policy=self._spot_policy(), - retry_policy=self._retry_policy(), - requirements=job_spec.requirements, - dep_specs=self.dep_specs, - master_job=job_spec.master_job, - app_specs=job_spec.app_specs, - runner_id=uuid.uuid4().hex, - request_id=None, - tag_name=tag_name, - ssh_key_pub=self.ssh_key_pub, - repo_code_filename=repo_code_filename, - build_policy=self.build_policy, - build_commands=job_spec.build_commands, - optional_build_commands=self.optional_build_commands, - run_env=job_spec.run_env, - ) - jobs.append(job) - return jobs - - def _dep_specs(self, hub_client: "hub.HubClient") -> Optional[List[DepSpec]]: - if self.provider_data.get("deps"): - return [self._parse_dep_spec(dep, hub_client) for dep in self.provider_data["deps"]] - else: - return None - - def _validate_local_path(self, path: str) -> str: - if path == "~" or path.startswith("~/"): - if not self.home_dir: - raise KeyError("home_dir is not defined, local path can't start with ~") - home = self.home_dir.rstrip("/") - path = home if path == "~" else f"{home}/{path[len('~/'):]}" - while path.startswith("./"): - path = path[len("./") :] - if not path.startswith("/"): - pass # todo: use self.working_dir - return path - - def _artifact_specs(self) -> Optional[List[ArtifactSpec]]: - artifact_specs = [] - for item in self.provider_data.get("artifacts", []): - if isinstance(item, str): - item = {"artifact_path": item} - else: - item["artifact_path"] = item.pop("path") - item["artifact_path"] = self._validate_local_path(item["artifact_path"]) - artifact_specs.append(ArtifactSpec(**item)) - return artifact_specs or None - - def _cache_specs(self) -> List[CacheSpec]: - cache_specs = [] - for item in self.provider_data.get("cache", []): - if isinstance(item, str): - item = {"path": item} - item["path"] = self._validate_local_path(item["path"]) - cache_specs.append(CacheSpec(**item)) - return cache_specs - - @staticmethod - def _parse_dep_spec(dep: Union[dict, str], hub_client) -> DepSpec: - if isinstance(dep, str): - mount = False - if dep.startswith(":"): - tag_dep = True - dep = dep[1:] - else: - tag_dep = False - else: - mount = dep.get("mount") is True - tag_dep = dep.get("tag") is not None - dep = dep.get("tag") or dep.get("workflow") - t = dep.split("/") - if len(t) == 1: - if tag_dep: - return Provider._tag_dep(hub_client, t[0], mount) - else: - return Provider._workflow_dep(hub_client, t[0], mount) - elif len(t) == 3: - # This doesn't allow to refer to projects from other repos - if tag_dep: - return Provider._tag_dep(hub_client, t[2], mount) - else: - return Provider._workflow_dep(hub_client, t[2], mount) - else: - sys.exit(f"Invalid dep format: {dep}") - - @staticmethod - def _tag_dep(hub_client: "hub.HubClient", tag_name: str, mount: bool) -> DepSpec: - tag_head = hub_client.get_tag_head(tag_name) - if tag_head: - return DepSpec( - repo_ref=hub_client.repo.repo_ref, run_name=tag_head.run_name, mount=mount - ) - else: - sys.exit(f"Cannot find the tag '{tag_name}' in the '{hub_client.repo.repo_id}' repo") - - @staticmethod - def _workflow_dep(hub_client: "hub.HubClient", workflow_name: str, mount: bool) -> DepSpec: - job_heads = sorted( - hub_client.list_job_heads(), - key=lambda j: j.submitted_at, - reverse=True, - ) - run_name = next( - iter( - [ - job_head.run_name - for job_head in job_heads - if job_head.workflow_name == workflow_name - and job_head.status == JobStatus.DONE - ] - ), - None, - ) - if run_name: - return DepSpec(repo_ref=hub_client.repo.repo_ref, run_name=run_name, mount=mount) - else: - sys.exit( - f"Cannot find any successful workflow with the name '{workflow_name}' " - f"in the '{hub_client.repo.repo_id}' repo" - ) - - def _env(self) -> Optional[Dict[str, str]]: - if self.provider_data.get("env"): - env = {} - for e in self.provider_data.get("env"): - if "=" in e: - tokens = e.split("=", maxsplit=1) - env[tokens[0]] = tokens[1] - else: - env[e] = "" - return env - else: - return None - - def _get_list_data(self, name: str) -> Optional[List[str]]: - v = self.provider_data.get(name) - if isinstance(v, str): - return v.split("\n") - else: - return v - - def _get_entrypoint(self) -> Optional[List[str]]: - v = self.provider_data.get("entrypoint") - if isinstance(v, str): - return shlex.split(v) - return v - - def _spot_policy(self) -> SpotPolicy: - spot_policy = self.provider_data.get("spot_policy") - if spot_policy is not None: - return SpotPolicy(spot_policy) - if self.configuration_type is ConfigurationType.DEV_ENVIRONMENT: - return SpotPolicy.ONDEMAND - return SpotPolicy.AUTO - - def _retry_policy(self) -> RetryPolicy: - retry_policy = self.provider_data.get("retry_policy") - if retry_policy is None: - return RetryPolicy( - retry=False, - limit=0, - ) - if retry_policy.get("retry") is False: - return RetryPolicy( - retry=False, - limit=None, - ) - if retry_policy.get("limit"): - return RetryPolicy(retry=True, limit=parse_pretty_duration(retry_policy.get("limit"))) - return RetryPolicy( - retry=retry_policy.get("retry"), - limit=DEFAULT_RETRY_LIMIT, - ) - - def _resources(self) -> Requirements: - resources = Requirements() - cpu = self.provider_data["resources"].get("cpu", DEFAULT_CPU) - if not str(cpu).isnumeric(): - sys.exit("resources.cpu should be an integer") - cpu = int(cpu) - if cpu > 0: - resources.cpus = cpu - memory = self.provider_data["resources"].get("memory", DEFAULT_MEM) - resources.memory_mib = _str_to_mib(memory) - gpu = self.provider_data["resources"].get("gpu") - if gpu: - if str(gpu).isnumeric(): - gpu = int(self.provider_data["resources"]["gpu"]) - if gpu > 0: - resources.gpus = GpusRequirements(count=gpu) - else: - gpu_count = 0 - gpu_name = None - gpu_memory = None - if str(gpu.get("count")).isnumeric(): - gpu_count = int(gpu.get("count")) - if gpu.get("name"): - gpu_name = gpu.get("name") - if not gpu_count: - gpu_count = 1 - if gpu.get("memory"): - gpu_memory = _str_to_mib(gpu.get("memory")) - if not gpu_count: - gpu_count = 1 - if gpu_count: - resources.gpus = GpusRequirements( - count=gpu_count, name=gpu_name, memory_mib=gpu_memory - ) - for resource_name in self.provider_data["resources"]: - if resource_name.endswith("/gpu") and len(resource_name) > 4: - if not str(self.provider_data["resources"][resource_name]).isnumeric(): - sys.exit(f"resources.'{resource_name}' should be an integer") - gpu = int(self.provider_data["resources"][resource_name]) - if gpu > 0: - resources.gpus = GpusRequirements(count=gpu, name=resource_name[:-4]) - if self.provider_data["resources"].get("shm_size"): - resources.shm_size_mib = _str_to_mib(self.provider_data["resources"]["shm_size"]) - return resources - - @staticmethod - def _extend_commands_with_env(commands, env): - commands.extend([f"export {e}={env[e] if env.get(e) else ''}" for e in env]) - - -def get_provider_names() -> List[str]: - return list( - map( - lambda m: m[1], - filter( - lambda m: m.ispkg and not m[1].startswith("_"), - iter_modules(sys.modules[__name__].__path__), - ), - ) - ) - - -def _str_to_mib(s: str) -> int: - ns = s.replace(" ", "").lower() - if ns.endswith("mib"): - return int(s[:-3]) - elif ns.endswith("gib"): - return int(s[:-3]) * 1024 - elif ns.endswith("mi"): - return int(s[:-2]) - elif ns.endswith("gi"): - return int(s[:-2]) * 1024 - elif ns.endswith("mb"): - return int(int(s[:-2]) * 1000 * 1000 / 1024 / 1024) - elif ns.endswith("gb"): - return int(int(s[:-2]) * (1000 * 1000 * 1000) / 1024 / 1024) - elif ns.endswith("m"): - return int(int(s[:-1]) * 1000 * 1000 / 1024 / 1024) - elif ns.endswith("g"): - return int(int(s[:-1]) * (1000 * 1000 * 1000) / 1024 / 1024) - else: - raise Exception(f"Unknown memory unit: {s}") - - -def load_provider(provider_name) -> Provider: - return importlib.import_module( - f"dstack._internal.providers.{provider_name}.main" - ).__provider__() diff --git a/cli/dstack/_internal/providers/bash/__init__.py b/cli/dstack/_internal/providers/bash/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/dstack/_internal/providers/bash/main.py b/cli/dstack/_internal/providers/bash/main.py deleted file mode 100644 index 72bf585e3..000000000 --- a/cli/dstack/_internal/providers/bash/main.py +++ /dev/null @@ -1,108 +0,0 @@ -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Optional - -from rich_argparse import RichHelpFormatter - -import dstack.api.hub as hub -from dstack import version -from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port -from dstack._internal.core.app import AppSpec -from dstack._internal.core.job import JobSpec -from dstack._internal.providers import Provider -from dstack._internal.providers.extensions import OpenSSHExtension - - -class BashProvider(Provider): - def __init__(self): - super().__init__("bash") - self.python = None - self.env = None - self.artifact_specs = None - self.working_dir = None - self.resources = None - self.image_name = None - self.home_dir = "/root" - self.openssh_server = True - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) - self.python = self._safe_python_version("python") - self.env = self._env() - self.artifact_specs = self._artifact_specs() - self.working_dir = self.provider_data.get("working_dir") - self.resources = self._resources() - self.image_name = self._image_name() - - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - parser = ArgumentParser( - prog="dstack run " + (workflow_name or self.provider_name), - formatter_class=RichHelpFormatter, - ) - self._add_base_args(parser) - parser.add_argument("--no-ssh", action="store_false", dest="openssh_server") - if not workflow_name: - parser.add_argument("-c", "--command", type=str) - return parser - - def parse_args(self): - parser = self._create_parser(self.workflow_name) - args, unknown_args = parser.parse_known_args(self.provider_args) - self._parse_base_args(args, unknown_args) - if self.run_as_provider and args.command: - self.provider_data["commands"] = [args.command] - if not args.openssh_server: - self.openssh_server = False - - def create_job_specs(self) -> List[JobSpec]: - apps = [] - for i, pm in enumerate(filter_reserved_ports(self.ports)): - apps.append( - AppSpec( - port=pm.port, - map_to_port=pm.map_to_port, - app_name="bash" + (str(i) if len(self.ports) > 1 else ""), - ) - ) - if self.openssh_server: - OpenSSHExtension.patch_apps( - apps, map_to_port=get_map_to_port(self.ports, OpenSSHExtension.port) - ) - return [ - JobSpec( - image_name=self.image_name, - commands=self._commands(), - entrypoint=["/bin/bash", "-i", "-c"], - working_dir=self.working_dir, - artifact_specs=self.artifact_specs, - requirements=self.resources, - app_specs=apps, - build_commands=self.build_commands, - ) - ] - - def _image_name(self) -> str: - cuda_is_required = self.resources and self.resources.gpus - cuda_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}-cuda-11.4" - cpu_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}" - return cuda_image_name if cuda_is_required else cpu_image_name - - def _commands(self): - commands = [] - if self.env: - self._extend_commands_with_env(commands, self.env) - if self.openssh_server: - OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) - commands.extend(self.commands) - return commands - - -def __provider__(): - return BashProvider() diff --git a/cli/dstack/_internal/providers/code/__init__.py b/cli/dstack/_internal/providers/code/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/dstack/_internal/providers/code/main.py b/cli/dstack/_internal/providers/code/main.py deleted file mode 100644 index 2f43dac99..000000000 --- a/cli/dstack/_internal/providers/code/main.py +++ /dev/null @@ -1,156 +0,0 @@ -import uuid -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Optional - -from rich_argparse import RichHelpFormatter - -import dstack.api.hub as hub -from dstack import version -from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port -from dstack._internal.core.app import AppSpec -from dstack._internal.core.job import JobSpec -from dstack._internal.providers import Provider -from dstack._internal.providers.extensions import OpenSSHExtension - - -class CodeProvider(Provider): - code_port = 10000 - - def __init__(self): - super().__init__("code") - self.python = None - self.version = None - self.requirements = None - self.env = None - self.artifact_specs = None - self.working_dir = None - self.resources = None - self.image_name = None - self.home_dir = "/root" - self.openssh_server = True - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) - self.python = self._safe_python_version("python") - self.version = self.provider_data.get("version") or "1.78.1" - self.env = self._env() - self.artifact_specs = self._artifact_specs() - self.working_dir = self.provider_data.get("working_dir") - self.resources = self._resources() - self.image_name = self._image_name() - - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - parser = ArgumentParser( - prog="dstack run " + (workflow_name or self.provider_name), - formatter_class=RichHelpFormatter, - ) - self._add_base_args(parser) - parser.add_argument("--no-ssh", action="store_false", dest="openssh_server") - return parser - - def parse_args(self): - parser = self._create_parser(self.workflow_name) - args, unknown_args = parser.parse_known_args(self.provider_args) - self._parse_base_args(args, unknown_args) - if not args.openssh_server: - self.openssh_server = False - - def create_job_specs(self) -> List[JobSpec]: - env = {} - connection_token = uuid.uuid4().hex - env["CONNECTION_TOKEN"] = connection_token - apps = [] - for i, pm in enumerate(filter_reserved_ports(self.ports), start=1): - apps.append( - AppSpec( - port=pm.port, - map_to_port=pm.map_to_port, - app_name="code" + str(i), - ) - ) - apps.append( - AppSpec( - port=self.code_port, - map_to_port=get_map_to_port(self.ports, self.code_port), - app_name="code", - url_query_params={ - "tkn": connection_token, - }, - ) - ) - if self.openssh_server: - OpenSSHExtension.patch_apps( - apps, map_to_port=get_map_to_port(self.ports, OpenSSHExtension.port) - ) - return [ - JobSpec( - image_name=self.image_name, - commands=self._commands(), - entrypoint=["/bin/bash", "-i", "-c"], - run_env=env, - working_dir=self.working_dir, - artifact_specs=self.artifact_specs, - requirements=self.resources, - app_specs=apps, - build_commands=self._setup(), - ) - ] - - def _image_name(self) -> str: - cuda_is_required = self.resources and self.resources.gpus - cuda_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}-cuda-11.4" - cpu_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}" - return cuda_image_name if cuda_is_required else cpu_image_name - - def _setup(self) -> List[str]: - commands = [ - "pip install ipykernel -q", - "mkdir -p /tmp", - 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', - f"wget -q https://github.com/gitpod-io/openvscode-server/releases/download/" - f"openvscode-server-v{self.version}/openvscode-server-v{self.version}-linux-$arch.tar.gz -O " - f"/tmp/openvscode-server-v{self.version}-linux-$arch.tar.gz", - f"tar -xzf /tmp/openvscode-server-v{self.version}-linux-$arch.tar.gz -C /tmp", - f"/tmp/openvscode-server-v{self.version}-linux-$arch/bin/openvscode-server --install-extension ms-python.python --install-extension ms-toolsai.jupyter", - "rm /usr/bin/python2*", - ] - if self.build_commands: - commands.extend(self.build_commands) - return commands - - def _commands(self): - commands = [] - if self.env: - self._extend_commands_with_env(commands, self.env) - if self.openssh_server: - OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) - if self.openssh_server: - commands.extend( - [ - f"echo Connect from code desktop", - f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}/workflow'", - f"echo ' vscode-insiders://vscode-remote/ssh-remote+{self.run_name}/workflow'", - ] - ) - commands.extend(self.commands) - commands.extend( - [ - 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', - f"/tmp/openvscode-server-v{self.version}-linux-$arch/bin/openvscode-server" - f" --port {self.code_port} --host 0.0.0.0 --connection-token $CONNECTION_TOKEN" - f" --default-folder /workflow", - ] - ) - return commands - - -def __provider__(): - return CodeProvider() diff --git a/cli/dstack/_internal/providers/docker/__init__.py b/cli/dstack/_internal/providers/docker/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/dstack/_internal/providers/docker/main.py b/cli/dstack/_internal/providers/docker/main.py deleted file mode 100644 index 6704366ba..000000000 --- a/cli/dstack/_internal/providers/docker/main.py +++ /dev/null @@ -1,97 +0,0 @@ -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Optional - -from rich_argparse import RichHelpFormatter - -import dstack.api.hub as hub -from dstack._internal.configurators.ports import filter_reserved_ports -from dstack._internal.core.app import AppSpec -from dstack._internal.core.job import JobSpec -from dstack._internal.providers import Provider - - -class DockerProvider(Provider): - def __init__(self): - super().__init__("docker") - self.image_name = None - self.registry_auth = None - self.entrypoint = None - self.artifact_specs = None - self.env = None - self.working_dir = None - self.resources = None - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) - self.image_name = self.provider_data["image"] - self.registry_auth = self.provider_data.get("registry_auth") - self.entrypoint = self._get_entrypoint() - if self.commands and self.entrypoint is None: # commands not empty - self.entrypoint = ["/bin/sh", "-i", "-c"] - self.artifact_specs = self._artifact_specs() - self.env = self.provider_data.get("env") - self.home_dir = self.provider_data.get("home_dir") - self.working_dir = self.provider_data.get("working_dir") - self.resources = self._resources() - - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - parser = ArgumentParser( - prog="dstack run " + (workflow_name or self.provider_name), - formatter_class=RichHelpFormatter, - ) - self._add_base_args(parser) - if not workflow_name: - parser.add_argument("image", metavar="IMAGE", type=str) - parser.add_argument("-c", "--command", type=str) - parser.add_argument("-e", "--entrypoint", type=str) - return parser - - def parse_args(self): - parser = self._create_parser(self.workflow_name) - args, unknown_args = parser.parse_known_args(self.provider_args) - self._parse_base_args(args, unknown_args) - if self.run_as_provider: - self.provider_data["image"] = args.image - if args.command: - self.provider_data["commands"] = [args.command] - if args.entrypoint: - self.provider_data["entrypoint"] = args.entrypoint - - def create_job_specs(self) -> List[JobSpec]: - apps = [] - for i, pm in enumerate(filter_reserved_ports(self.ports)): - apps.append( - AppSpec( - port=pm.port, - map_to_port=pm.map_to_port, - app_name="docker" + (str(i) if len(self.ports) > 1 else ""), - ) - ) - commands = [] - commands.extend(self.commands or []) - return [ - JobSpec( - image_name=self.image_name, - registry_auth=self.registry_auth, - commands=commands, - entrypoint=self.entrypoint, - env=self.env, - working_dir=self.working_dir, - artifact_specs=self.artifact_specs, - requirements=self.resources, - app_specs=apps, - build_commands=self.build_commands, - ) - ] - - -def __provider__(): - return DockerProvider() diff --git a/cli/dstack/_internal/providers/extensions.py b/cli/dstack/_internal/providers/extensions.py deleted file mode 100644 index 45b6b5a0e..000000000 --- a/cli/dstack/_internal/providers/extensions.py +++ /dev/null @@ -1,120 +0,0 @@ -import subprocess -from abc import ABC, abstractmethod -from typing import List, Optional - -import requests - -from dstack._internal.core.app import AppSpec -from dstack._internal.core.error import DstackError - - -class ProviderExtension(ABC): - @classmethod - @abstractmethod - def patch_setup(cls, commands: List[str], **kwargs): - pass - - @classmethod - @abstractmethod - def patch_commands(cls, commands: List[str], **kwargs): - pass - - @classmethod - @abstractmethod - def patch_apps(cls, apps: List[AppSpec], **kwargs): - pass - - -class OpenSSHExtension(ProviderExtension): - port = 10022 - - @classmethod - def patch_setup(cls, commands: List[str], **kwargs): - pass - - @classmethod - def patch_commands(cls, commands: List[str], *, ssh_key_pub: str = None, **kwargs): - assert ssh_key_pub is not None, "No SSH key provided" - commands.extend( - [ - f'echo "{ssh_key_pub}" >> ~/.ssh/authorized_keys', - f"env >> ~/.ssh/environment", - f"ssh-keygen -A > /dev/null", - f"/usr/sbin/sshd -p {cls.port} -o PermitUserEnvironment=yes", - ] - ) - - @classmethod - def patch_apps(cls, apps: List[AppSpec], *, map_to_port: Optional[int] = None, **kwargs): - apps.append(AppSpec(port=cls.port, map_to_port=map_to_port, app_name="openssh-server")) - - -class VSCodeDesktopServer(ProviderExtension): - @staticmethod - def get_tag_sha(tag: Optional[str] = None) -> str: - repo_api = "https://api.github.com/repos/microsoft/vscode" - if tag is None: # get latest - tag = requests.get(f"{repo_api}/releases/latest").json()["tag_name"] - obj = requests.get(f"{repo_api}/git/ref/tags/{tag}").json()["object"] - if obj["type"] == "commit": - return obj["sha"] - raise NotImplementedError() - - @staticmethod - def detect_code_sha(exe: str = "code") -> Optional[str]: - try: - run = subprocess.run([exe, "--version"], capture_output=True) - except FileNotFoundError: - return None - if run.returncode == 0: - return run.stdout.decode().split("\n")[1].strip() - return None - - @classmethod - def _vscode_server_install(cls, commit: str, extensions: Optional[List[str]]) -> List[str]: - url = f"https://update.code.visualstudio.com/commit:{commit}/server-linux-$arch/stable" - archive = "vscode-server-linux-$arch.tar.gz" - target = f'~/.vscode-server/bin/"{commit}"' - commands = [ - f'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', - f"mkdir -p /tmp", - f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', - f"mkdir -vp {target}", - f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', - f'rm "/tmp/{archive}"', - ] - if extensions: - extensions = " ".join(f'--install-extension "{name}"' for name in extensions) - commands.append(f'PATH="$PATH":{target}/bin code-server {extensions}') - return commands - - @classmethod - def patch_setup( - cls, commands: List[str], *, vscode_extensions: Optional[List[str]] = None, **kwargs - ): - commit = cls.detect_code_sha() - if commit is None: - raise NoVSCodeVersionError() - commands.extend(cls._vscode_server_install(commit, extensions=vscode_extensions)) - - @classmethod - def patch_commands( - cls, commands: List[str], *, vscode_extensions: Optional[List[str]] = None, **kwargs - ): - commit = cls.detect_code_sha() - if commit is None: - raise NoVSCodeVersionError() - install_commands = " && ".join( - cls._vscode_server_install(commit, extensions=vscode_extensions) - ) - commands.append( - f'if [ ! -d ~/.vscode-server/bin/"{commit}" ]; then {install_commands}; fi' - ) - - @classmethod - def patch_apps(cls, apps: List[AppSpec], **kwargs): - pass - - -class NoVSCodeVersionError(DstackError): - pass diff --git a/cli/dstack/_internal/providers/lab/__init__.py b/cli/dstack/_internal/providers/lab/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/dstack/_internal/providers/lab/main.py b/cli/dstack/_internal/providers/lab/main.py deleted file mode 100644 index 878d9b768..000000000 --- a/cli/dstack/_internal/providers/lab/main.py +++ /dev/null @@ -1,145 +0,0 @@ -import uuid -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Optional - -from rich_argparse import RichHelpFormatter - -import dstack.api.hub as hub -from dstack import version -from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port -from dstack._internal.core.app import AppSpec -from dstack._internal.core.job import JobSpec -from dstack._internal.providers import Provider -from dstack._internal.providers.extensions import OpenSSHExtension - - -class LabProvider(Provider): - lab_port = 10000 - - def __init__(self): - super().__init__("lab") - self.python = None - self.version = None - self.env = None - self.artifact_specs = None - self.working_dir = None - self.resources = None - self.image_name = None - self.home_dir = "/root" - self.openssh_server = True - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) - self.python = self._safe_python_version("python") - self.version = self.provider_data.get("version") - self.env = self._env() - self.artifact_specs = self._artifact_specs() - self.working_dir = self.provider_data.get("working_dir") - self.resources = self._resources() - self.image_name = self._image_name() - - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - parser = ArgumentParser( - prog="dstack run " + (workflow_name or self.provider_name), - formatter_class=RichHelpFormatter, - ) - self._add_base_args(parser) - parser.add_argument("--no-ssh", action="store_false", dest="openssh_server") - return parser - - def parse_args(self): - parser = self._create_parser(self.workflow_name) - args, unknown_args = parser.parse_known_args(self.provider_args) - self._parse_base_args(args, unknown_args) - if not args.openssh_server: - self.openssh_server = False - - def create_job_specs(self) -> List[JobSpec]: - env = {} - token = uuid.uuid4().hex - env["TOKEN"] = token - apps = [] - for i, pm in enumerate(filter_reserved_ports(self.ports), start=1): - apps.append( - AppSpec( - port=pm.port, - map_to_port=pm.map_to_port, - app_name="lab" + str(i), - ) - ) - apps.append( - AppSpec( - port=self.lab_port, - map_to_port=get_map_to_port(self.ports, self.lab_port), - app_name="lab", - url_path="lab", - url_query_params={"token": token}, - ) - ) - if self.openssh_server: - OpenSSHExtension.patch_apps( - apps, map_to_port=get_map_to_port(self.ports, OpenSSHExtension.port) - ) - return [ - JobSpec( - image_name=self.image_name, - commands=self._commands(), - entrypoint=["/bin/bash", "-i", "-c"], - run_env=env, - working_dir=self.working_dir, - artifact_specs=self.artifact_specs, - requirements=self.resources, - app_specs=apps, - build_commands=self._setup(), - ) - ] - - def _image_name(self) -> str: - cuda_is_required = self.resources and self.resources.gpus - cuda_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}-cuda-11.4" - cpu_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}" - return cuda_image_name if cuda_is_required else cpu_image_name - - def _setup(self) -> List[str]: - commands = [ - "conda install psutil -y", - "pip install jupyterlab" + (f"=={self.version}" if self.version else ""), - "pip install ipywidgets", - "jupyter labextension enable --py widgetsnbextension", - "mkdir -p /root/.jupyter", - 'echo "c.ServerApp.allow_root = True" > /root/.jupyter/jupyter_server_config.py', - "echo \"c.ServerApp.allow_origin = '*'\" >> /root/.jupyter/jupyter_server_config.py", - 'echo "c.ServerApp.open_browser = False" >> /root/.jupyter/jupyter_server_config.py', - "echo \"c.ServerApp.ip = '0.0.0.0'\" >> /root/.jupyter/jupyter_server_config.py", - ] - if self.build_commands: - commands.extend(self.build_commands) - return commands - - def _commands(self): - commands = [] - if self.env: - self._extend_commands_with_env(commands, self.env) - if self.openssh_server: - OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) - commands.extend(self.commands) - commands.extend( - [ - f'echo "c.ServerApp.port = {self.lab_port}" >> /root/.jupyter/jupyter_server_config.py', - "echo \"c.ServerApp.token = '$TOKEN'\" >> /root/.jupyter/jupyter_server_config.py", - ] - ) - commands.append(f"jupyter lab") - return commands - - -def __provider__(): - return LabProvider() diff --git a/cli/dstack/_internal/providers/notebook/__init__.py b/cli/dstack/_internal/providers/notebook/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/dstack/_internal/providers/notebook/main.py b/cli/dstack/_internal/providers/notebook/main.py deleted file mode 100644 index 00acdc4d1..000000000 --- a/cli/dstack/_internal/providers/notebook/main.py +++ /dev/null @@ -1,142 +0,0 @@ -import uuid -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Optional - -from rich_argparse import RichHelpFormatter - -import dstack.api.hub as hub -from dstack import version -from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port -from dstack._internal.core.app import AppSpec -from dstack._internal.core.job import JobSpec -from dstack._internal.providers import Provider -from dstack._internal.providers.extensions import OpenSSHExtension - - -class NotebookProvider(Provider): - notebook_port = 10000 - - def __init__(self): - super().__init__("notebook") - self.python = None - self.version = None - self.env = None - self.artifact_specs = None - self.working_dir = None - self.resources = None - self.image_name = None - self.home_dir = "/root" - self.openssh_server = True - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) - self.python = self._safe_python_version("python") - self.version = self.provider_data.get("version") - self.env = self._env() - self.artifact_specs = self._artifact_specs() - self.working_dir = self.provider_data.get("working_dir") - self.resources = self._resources() - self.image_name = self._image_name() - - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - parser = ArgumentParser( - prog="dstack run " + (workflow_name or self.provider_name), - formatter_class=RichHelpFormatter, - ) - self._add_base_args(parser) - parser.add_argument("--no-ssh", action="store_false", dest="openssh_server") - return parser - - def parse_args(self): - parser = self._create_parser(self.workflow_name) - args, unknown_args = parser.parse_known_args(self.provider_args) - self._parse_base_args(args, unknown_args) - if not args.openssh_server: - self.openssh_server = False - - def create_job_specs(self) -> List[JobSpec]: - env = {} - token = uuid.uuid4().hex - env["TOKEN"] = token - apps = [] - for i, pm in enumerate(filter_reserved_ports(self.ports), start=1): - apps.append( - AppSpec( - port=pm.port, - map_to_port=pm.map_to_port, - app_name="notebook" + str(i), - ) - ) - apps.append( - AppSpec( - port=self.notebook_port, - map_to_port=get_map_to_port(self.ports, self.notebook_port), - app_name="notebook", - url_query_params={"token": token}, - ) - ) - if self.openssh_server: - OpenSSHExtension.patch_apps( - apps, map_to_port=get_map_to_port(self.ports, OpenSSHExtension.port) - ) - return [ - JobSpec( - image_name=self.image_name, - commands=self._commands(), - entrypoint=["/bin/bash", "-i", "-c"], - run_env=env, - working_dir=self.working_dir, - artifact_specs=self.artifact_specs, - requirements=self.resources, - app_specs=apps, - build_commands=self._setup(), - ) - ] - - def _image_name(self) -> str: - cuda_is_required = self.resources and self.resources.gpus - cuda_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}-cuda-11.4" - cpu_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}" - return cuda_image_name if cuda_is_required else cpu_image_name - - def _setup(self) -> List[str]: - commands = [ - "conda install psutil -y", - "pip install jupyter" + (f"=={self.version}" if self.version else ""), - "mkdir -p /root/.jupyter", - 'echo "c.NotebookApp.allow_root = True" > /root/.jupyter/jupyter_notebook_config.py', - "echo \"c.NotebookApp.allow_origin = '*'\" >> /root/.jupyter/jupyter_notebook_config.py", - 'echo "c.NotebookApp.open_browser = False" >> /root/.jupyter/jupyter_notebook_config.py', - "echo \"c.NotebookApp.ip = '0.0.0.0'\" >> /root/.jupyter/jupyter_notebook_config.py", - ] - if self.build_commands: - commands.extend(self.build_commands) - return commands - - def _commands(self): - commands = [] - if self.env: - self._extend_commands_with_env(commands, self.env) - if self.openssh_server: - OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) - commands.extend(self.commands) - commands.extend( - [ - f'echo "c.NotebookApp.port = {self.notebook_port}" >> /root/.jupyter/jupyter_notebook_config.py', - "echo \"c.NotebookApp.token = '$TOKEN'\" >> /root/.jupyter/jupyter_notebook_config.py", - ] - ) - commands.append(f"jupyter notebook") - return commands - - -def __provider__(): - return NotebookProvider() diff --git a/cli/dstack/_internal/providers/ssh/__init__.py b/cli/dstack/_internal/providers/ssh/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cli/dstack/_internal/providers/ssh/main.py b/cli/dstack/_internal/providers/ssh/main.py deleted file mode 100644 index 83fdb02a3..000000000 --- a/cli/dstack/_internal/providers/ssh/main.py +++ /dev/null @@ -1,122 +0,0 @@ -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Optional - -from rich_argparse import RichHelpFormatter - -import dstack.api.hub as hub -from dstack import version -from dstack._internal.configurators.ports import filter_reserved_ports, get_map_to_port -from dstack._internal.core.app import AppSpec -from dstack._internal.core.job import JobSpec -from dstack._internal.providers import Provider -from dstack._internal.providers.extensions import OpenSSHExtension - - -class SSHProvider(Provider): - def __init__(self): - super().__init__("ssh") - self.python = None - self.env = None - self.artifact_specs = None - self.working_dir = None - self.resources = None - self.image_name = None - self.home_dir = "/root" - self.code = True - - def load( - self, - hub_client: "hub.HubClient", - args: Optional[Namespace], - workflow_name: Optional[str], - provider_data: Dict[str, Any], - run_name: str, - ssh_key_pub: Optional[str] = None, - ): - super().load(hub_client, args, workflow_name, provider_data, run_name, ssh_key_pub) - self.python = self._safe_python_version("python") - self.env = self._env() - self.artifact_specs = self._artifact_specs() - self.working_dir = self.provider_data.get("working_dir") - self.resources = self._resources() - self.image_name = self._image_name() - self.code = self.code or self.provider_data.get("code", self.code) - - def _create_parser(self, workflow_name: Optional[str]) -> Optional[ArgumentParser]: - parser = ArgumentParser( - prog="dstack run " + (workflow_name or self.provider_name), - formatter_class=RichHelpFormatter, - ) - self._add_base_args(parser) - parser.add_argument("--code", action="store_true", help="Print VS Code connection URI") - return parser - - def parse_args(self): - parser = self._create_parser(self.workflow_name) - args, unknown_args = parser.parse_known_args(self.provider_args) - self._parse_base_args(args, unknown_args) - if args.code: - self.code = True - - def create_job_specs(self) -> List[JobSpec]: - apps = [] - for i, pm in enumerate(filter_reserved_ports(self.ports)): - apps.append( - AppSpec( - port=pm.port, - map_to_port=pm.map_to_port, - app_name="ssh" + (str(i) if len(self.ports) > 1 else ""), - ) - ) - OpenSSHExtension.patch_apps( - apps, map_to_port=get_map_to_port(self.ports, OpenSSHExtension.port) - ) - return [ - JobSpec( - image_name=self.image_name, - commands=self._commands(), - entrypoint=["/bin/bash", "-i", "-c"], - working_dir=self.working_dir, - artifact_specs=self.artifact_specs, - requirements=self.resources, - app_specs=apps, - build_commands=self.build_commands, - ) - ] - - def _image_name(self) -> str: - cuda_is_required = self.resources and self.resources.gpus - cuda_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}-cuda-11.4" - cpu_image_name = f"dstackai/miniforge:py{self.python}-{version.miniforge_image}" - return cuda_image_name if cuda_is_required else cpu_image_name - - def _commands(self): - commands = [] - if self.env: - self._extend_commands_with_env(commands, self.env) - OpenSSHExtension.patch_commands(commands, ssh_key_pub=self.ssh_key_pub) - commands.extend(self.commands) - if self.code: - commands.extend( - [ - "echo ''", - f"echo To open in VS Code Desktop, use one of these links:", - f"echo ''", - f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}/workflow'", - # f"echo ' vscode-insiders://vscode-remote/ssh-remote+{self.run_name}/workflow'", - "echo ''", - f"echo 'To connect via SSH, use: `ssh {self.run_name}`'", - "echo ''", - "echo -n 'To exit, press Ctrl+C.'", - ] - ) - commands.extend( - [ - "cat", - ] - ) - return commands - - -def __provider__(): - return SSHProvider() From 0fb390782bb34baeebf38d120025f50a7b1028cb Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 10:56:56 +0400 Subject: [PATCH 19/26] Drop json schemas --- cli/dstack/_internal/core/repo/base.py | 4 - cli/dstack/_internal/core/repo/local.py | 6 +- cli/dstack/_internal/core/repo/remote.py | 17 +- .../_internal/schemas/configuration.json | 217 ------ cli/dstack/_internal/schemas/profiles.json | 132 ---- cli/dstack/_internal/schemas/workflows.json | 630 ------------------ cli/dstack/_internal/utils/workflows.py | 49 -- 7 files changed, 2 insertions(+), 1053 deletions(-) delete mode 100644 cli/dstack/_internal/schemas/configuration.json delete mode 100644 cli/dstack/_internal/schemas/profiles.json delete mode 100644 cli/dstack/_internal/schemas/workflows.json delete mode 100644 cli/dstack/_internal/utils/workflows.py diff --git a/cli/dstack/_internal/core/repo/base.py b/cli/dstack/_internal/core/repo/base.py index 1a4e89677..f0eff1752 100644 --- a/cli/dstack/_internal/core/repo/base.py +++ b/cli/dstack/_internal/core/repo/base.py @@ -48,7 +48,3 @@ def __init__(self, repo_ref: RepoRef, repo_data: RepoData): @property def repo_id(self) -> str: return self.repo_ref.repo_id - - @abstractmethod - def get_workflows(self, credentials=None) -> Dict[str, Dict[str, Any]]: - pass diff --git a/cli/dstack/_internal/core/repo/local.py b/cli/dstack/_internal/core/repo/local.py index 9f2d63b17..9fddd8705 100644 --- a/cli/dstack/_internal/core/repo/local.py +++ b/cli/dstack/_internal/core/repo/local.py @@ -1,6 +1,6 @@ import tarfile from pathlib import Path -from typing import Any, BinaryIO, Dict, Optional +from typing import BinaryIO, Optional from typing_extensions import Literal @@ -9,7 +9,6 @@ from dstack._internal.utils.escape import escape_head from dstack._internal.utils.hash import get_sha256, slugify from dstack._internal.utils.ignore import GitIgnore -from dstack._internal.utils.workflows import load_workflows class LocalRepoData(RepoData): @@ -53,9 +52,6 @@ def __init__( repo_ref = RepoRef(repo_id=slugify(Path(repo_data.repo_dir).name, repo_data.repo_dir)) super().__init__(repo_ref, repo_data) - def get_workflows(self, credentials=None) -> Dict[str, Dict[str, Any]]: - return load_workflows(Path(self.repo_data.repo_dir) / ".dstack") - class TarIgnore(GitIgnore): def __call__(self, tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]: diff --git a/cli/dstack/_internal/core/repo/remote.py b/cli/dstack/_internal/core/repo/remote.py index 28ae07765..df3e38a4b 100644 --- a/cli/dstack/_internal/core/repo/remote.py +++ b/cli/dstack/_internal/core/repo/remote.py @@ -3,8 +3,7 @@ import subprocess import tempfile import time -from pathlib import Path -from typing import Any, BinaryIO, Dict, Optional +from typing import BinaryIO, Optional import git import giturlparse @@ -16,7 +15,6 @@ from dstack._internal.utils.common import PathLike from dstack._internal.utils.hash import get_sha256, slugify from dstack._internal.utils.ssh import get_host_config, make_ssh_command_for_git -from dstack._internal.utils.workflows import load_workflows class RemoteRepoCredentials(BaseModel): @@ -128,19 +126,6 @@ def __init__( repo_ref = RepoRef(repo_id=slugify(repo_data.repo_name, repo_data.path("/"))) super().__init__(repo_ref, repo_data) - def get_workflows( - self, credentials: Optional[RemoteRepoCredentials] = None - ) -> Dict[str, Dict[str, Any]]: - if self.local_repo_dir is not None: - local_repo_dir = Path(self.local_repo_dir) - elif credentials is None: - raise RuntimeError("No credentials for remote only repo") - else: - temp_dir = tempfile.TemporaryDirectory() # will be removed by garbage collector - local_repo_dir = Path(temp_dir.name) - _clone_remote_repo(local_repo_dir, self.repo_data, credentials, depth=1) - return load_workflows(local_repo_dir / ".dstack") - def _clone_remote_repo( dst: PathLike, repo_data: RemoteRepoData, repo_credentials: RemoteRepoCredentials, **kwargs diff --git a/cli/dstack/_internal/schemas/configuration.json b/cli/dstack/_internal/schemas/configuration.json deleted file mode 100644 index f8a13f0ff..000000000 --- a/cli/dstack/_internal/schemas/configuration.json +++ /dev/null @@ -1,217 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-04/schema", - "definitions": { - "ports": { - "description": "The list of port numbers to expose", - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer", - "minimum": 0, - "maximum": 65536 - } - ] - } - }, - "_commands": { - "anyOf": [ - { - "type": "array", - "minItems": 1, - "items": { - "type": "string", - "minLength": 1 - } - }, - { - "type": "string", - "minLength": 1 - } - ] - }, - "cache": { - "description": "The directories to be cached between runs", - "type": "array", - "minItems": 1, - "items": { - "type": "string", - "minLength": 1 - } - }, - "python": { - "description": "The major version of Python", - "anyOf": [ - { - "type": "string", - "pattern": "^\\d+(\\.\\d+)?$" - }, - { - "type": "number" - } - ] - }, - "env": { - "description": "The list of environment variables", - "type": "array", - "items": { - "type": "string", - "minLength": 1 - }, - "minItems": 1 - }, - "image": { - "description": "The name of the Docker image", - "type": "string", - "minLength": 1 - }, - "registry_auth": { - "description": "Credentials to pull the private Docker image", - "type": "object", - "additionalProperties": false, - "properties": { - "username": { - "description": "Username", - "type": "string" - }, - "password": { - "description": "Password or access token", - "type": "string" - } - } - }, - "artifacts": { - "description": "The list of output artifacts", - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "additionalProperties": false, - "required": [ - "path" - ], - "properties": { - "path": { - "description": "The absolute or relative path to the folder that must be stored as an output artifact", - "type": "string", - "minLength": 1 - }, - "mount": { - "description": "Must be set to `true` if the artifact files must be saved in real-time", - "type": "boolean", - "enum": [ - true - ] - } - } - } - }, - "dev-environment": { - "type": "object", - "additionalProperties": false, - "properties": { - "type": { - "type": "string", - "description": "The type of the configuration", - "enum": [ - "dev-environment" - ] - }, - "ide": { - "description": "The name of the IDE to setup", - "type": "string", - "enum": [ - "vscode" - ] - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "build": { - "description": "The bash commands to build the environment", - "$ref": "#/definitions/_commands" - }, - "init": { - "description": "The bash commands to execute on start", - "$ref": "#/definitions/_commands" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "image": { - "$ref": "#/definitions/image" - }, - "registry_auth": { - "$ref": "#/definitions/registry_auth" - } - }, - "required": [ - "type", - "ide" - ] - }, - "task": { - "type": "object", - "additionalProperties": false, - "properties": { - "type": { - "type": "string", - "description": "The type of the configuration", - "enum": [ - "task" - ] - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "build": { - "description": "The bash commands to build the environment", - "$ref": "#/definitions/_commands" - }, - "commands": { - "description": "The bash commands to run the task", - "$ref": "#/definitions/_commands" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "image": { - "$ref": "#/definitions/image" - }, - "registry_auth": { - "$ref": "#/definitions/registry_auth" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - } - }, - "required": [ - "type", - "commands" - ] - } - }, - "oneOf": [ - { - "$ref": "#/definitions/dev-environment" - }, - { - "$ref": "#/definitions/task" - } - ] -} \ No newline at end of file diff --git a/cli/dstack/_internal/schemas/profiles.json b/cli/dstack/_internal/schemas/profiles.json deleted file mode 100644 index 2b1ddb6db..000000000 --- a/cli/dstack/_internal/schemas/profiles.json +++ /dev/null @@ -1,132 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-04/schema", - "type": "object", - "additionalProperties": false, - "definitions": { - "resources": { - "description": "The hardware resource requirements", - "type": "object", - "additionalProperties": false, - "properties": { - "gpu": { - "description": "The GPU requirements", - "anyOf": [ - { - "type": "integer", - "minimum": 1 - }, - { - "type": "object", - "additionalProperties": false, - "properties": { - "name": { - "description": "The name of the GPU, e.g. K80 or V100", - "type": "string", - "minLength": 1 - }, - "count": { - "description": "The minimum number of GPUs", - "type": "integer", - "minimum": 1 - }, - "memory": { - "description": "The minimum amount of GPU memory, e.g. 512MB or 16GB", - "type": "string", - "pattern": "^\\d+[MG]B$" - } - } - } - ] - }, - "memory": { - "description": "The minimum amount of RAM memory, e.g. 512MB or 16GB", - "type": "string", - "pattern": "^\\d+[MG]B$" - }, - "shm_size": { - "description": "The minimum amount of shared memory, e.g. 512MB or 16GB", - "type": "string", - "pattern": "^\\d+[MG]B$" - }, - "cpu": { - "description": "The minimum number of virtual CPU cores", - "type": "integer", - "minimum": 1 - }, - "local": { - "description": "Must be set to `true` if the workflow must run locally", - "type": "boolean", - "enum": [ - true - ] - } - } - }, - "spot_policy": { - "description": "The policy for provisioning spot or on-demand instances", - "type": "string", - "enum": [ - "spot", - "on-demand", - "auto" - ] - }, - "retry_policy": { - "description": "The policy for re-submitting the run", - "type": "object", - "additionalProperties": false, - "properties": { - "retry": { - "description": "Whether to retry the run on failure or not", - "type": "boolean" - }, - "limit": { - "description": "The maximum period of retrying the run, e.g. 1d", - "type": "string" - } - } - } - }, - "properties": { - "profiles": { - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "minProperties": 2, - "additionalProperties": false, - "required": [ - "name" - ], - "properties": { - "name": { - "type": "string", - "minLength": 1 - }, - "project": { - "type": "string", - "minLength": 1 - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "spot_policy": { - "$ref": "#/definitions/spot_policy" - }, - "retry_policy": { - "$ref": "#/definitions/retry_policy" - }, - "default": { - "type": "boolean", - "enum": [ - true - ] - } - } - } - } - }, - "required": [ - "profiles" - ] -} \ No newline at end of file diff --git a/cli/dstack/_internal/schemas/workflows.json b/cli/dstack/_internal/schemas/workflows.json deleted file mode 100644 index 819722599..000000000 --- a/cli/dstack/_internal/schemas/workflows.json +++ /dev/null @@ -1,630 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-04/schema", - "additionalProperties": false, - "definitions": { - "_commands": { - "anyOf": [ - { - "type": "array", - "minItems": 1, - "items": { - "type": "string", - "minLength": 1 - } - }, - { - "type": "string", - "minLength": 1 - } - ] - }, - "commands": { - "description": "The bash commands to run", - "$ref": "#/definitions/_commands" - }, - "setup": { - "description": "The bash commands to run before running workflow", - "$ref": "#/definitions/_commands" - }, - "env": { - "description": "The list of environment variables", - "type": "array", - "items": { - "type": "string", - "minLength": 1 - }, - "minItems": 1 - }, - "artifacts": { - "description": "The list of output artifacts", - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "additionalProperties": false, - "required": [ - "path" - ], - "properties": { - "path": { - "description": "The absolute or relative path to the folder that must be stored as an output artifact", - "type": "string", - "minLength": 1 - }, - "mount": { - "description": "Must be set to `true` if the artifact files must be saved in real-time", - "type": "boolean", - "enum": [ - true - ] - } - } - } - }, - "cache": { - "description": "The directories to be cached between workflow runs", - "type": "array", - "minItems": 1, - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": false, - "properties": { - "path": { - "type": "string" - } - }, - "required": [ - "path" - ] - } - ] - } - }, - "python": { - "description": "The major version of Python", - "anyOf": [ - { - "type": "string", - "pattern": "^\\d+(\\.\\d+)?$" - }, - { - "type": "number" - } - ] - }, - "resources": { - "description": "The hardware resource requirements", - "type": "object", - "additionalProperties": false, - "properties": { - "gpu": { - "description": "The GPU requirements", - "anyOf": [ - { - "type": "integer", - "minimum": 1 - }, - { - "type": "object", - "additionalProperties": false, - "properties": { - "name": { - "description": "The name of the GPU, e.g. K80 or V100", - "type": "string", - "minLength": 1 - }, - "count": { - "description": "The minimum number of GPUs", - "type": "integer", - "minimum": 1 - }, - "memory": { - "description": "The minimum amount of GPU memory, e.g. 512MB or 16GB", - "type": "string", - "pattern": "^\\d+[MG]B$" - } - } - } - ] - }, - "memory": { - "description": "The minimum amount of RAM memory, e.g. 512MB or 16GB", - "type": "string", - "pattern": "^\\d+[MG]B$" - }, - "shm_size": { - "description": "The minimum amount of shared memory, e.g. 512MB or 16GB", - "type": "string", - "pattern": "^\\d+[MG]B$" - }, - "cpu": { - "description": "The minimum number of virtual CPU cores", - "type": "integer", - "minimum": 1 - }, - "local": { - "description": "Must be set to `true` if the workflow must run locally", - "type": "boolean", - "enum": [ - true - ] - }, - "interruptible": { - "description": "Must be set to `true` if the workflow must use interruptible instances", - "type": "boolean", - "enum": [ - true - ] - } - } - }, - "ports": { - "description": "Port numbers to expose", - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer", - "minimum": 0, - "maximum": 65536 - } - ] - } - }, - "home_dir": { - "description": "The absolute path to the home directory inside the container", - "type": "string", - "minLength": 1 - }, - "working_dir": { - "description": "The absolute or relative path to the working directory where to run the workflow", - "type": "string", - "minLength": 1 - }, - "registry_auth": { - "description": "Credentials for pulling private container", - "type": "object", - "additionalProperties": false, - "properties": { - "username": { - "description": "Username", - "type": "string" - }, - "password": { - "description": "Password or access token", - "type": "string" - } - } - }, - "ssh_server": { - "description": "Run openssh server in the container", - "type": "boolean", - "default": false - }, - "build": { - "description": "Build image to accelerate start", - "type": "string", - "default": "use-build", - "enum": [ - "use-build", - "build", - "force-build", - "build-only" - ] - }, - "bash": { - "type": "object", - "required": [ - "commands" - ], - "additionalProperties": false, - "patternProperties": { - "^name$": { - }, - "^help$": { - }, - "^deps$": { - } - }, - "properties": { - "provider": { - "enum": [ - "bash" - ] - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "commands": { - "$ref": "#/definitions/commands" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "working_dir": { - "$ref": "#/definitions/working_dir" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "ssh": { - "$ref": "#/definitions/ssh_server" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "build": { - "$ref": "#/definitions/build" - }, - "setup": { - "$ref": "#/definitions/setup" - } - } - }, - "docker": { - "type": "object", - "required": [ - "image" - ], - "additionalProperties": false, - "patternProperties": { - "^name$": { - }, - "^help$": { - }, - "^deps$": { - } - }, - "properties": { - "provider": { - "enum": [ - "docker" - ] - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "commands": { - "$ref": "#/definitions/commands" - }, - "entrypoint": { - "description": "The Docker entrypoint", - "type": "string" - }, - "env": { - "$ref": "#/definitions/env" - }, - "image": { - "description": "The name of the Docker image to run", - "type": "string", - "minLength": 1 - }, - "registry_auth": { - "$ref": "#/definitions/registry_auth" - }, - "home_dir": { - "$ref": "#/definitions/home_dir" - }, - "working_dir": { - "$ref": "#/definitions/working_dir" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "build": { - "$ref": "#/definitions/build" - }, - "setup": { - "$ref": "#/definitions/setup" - } - } - }, - "code": { - "type": "object", - "additionalProperties": false, - "patternProperties": { - "^name$": { - }, - "^help$": { - }, - "^deps$": { - } - }, - "properties": { - "provider": { - "enum": [ - "code" - ] - }, - "version": { - "description": "The version of openvscode-server", - "anyOf": [ - { - "type": "string", - "pattern": "^\\d+(\\.\\d+)+$" - }, - { - "type": "number" - } - ] - }, - "setup": { - "$ref": "#/definitions/setup" - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "working_dir": { - "$ref": "#/definitions/working_dir" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "ssh": { - "$ref": "#/definitions/ssh_server" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "build": { - "$ref": "#/definitions/build" - } - } - }, - "lab": { - "type": "object", - "additionalProperties": false, - "patternProperties": { - "^name$": { - }, - "^help$": { - }, - "^deps$": { - } - }, - "properties": { - "provider": { - "enum": [ - "lab" - ] - }, - "setup": { - "$ref": "#/definitions/setup" - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "working_dir": { - "$ref": "#/definitions/working_dir" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "ssh": { - "$ref": "#/definitions/ssh_server" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "build": { - "$ref": "#/definitions/build" - } - } - }, - "notebook": { - "type": "object", - "additionalProperties": false, - "patternProperties": { - "^name$": { - }, - "^help$": { - }, - "^deps$": { - } - }, - "properties": { - "provider": { - "enum": [ - "notebook" - ] - }, - "setup": { - "$ref": "#/definitions/setup" - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "working_dir": { - "$ref": "#/definitions/working_dir" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "ssh": { - "$ref": "#/definitions/ssh_server" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "build": { - "$ref": "#/definitions/build" - } - } - }, - "ssh_provider": { - "type": "object", - "additionalProperties": false, - "patternProperties": { - "^name$": { - }, - "^help$": { - }, - "^deps$": { - } - }, - "properties": { - "provider": { - "enum": [ - "ssh" - ] - }, - "ports": { - "$ref": "#/definitions/ports" - }, - "setup": { - "$ref": "#/definitions/setup" - }, - "env": { - "$ref": "#/definitions/env" - }, - "python": { - "$ref": "#/definitions/python" - }, - "working_dir": { - "$ref": "#/definitions/working_dir" - }, - "artifacts": { - "$ref": "#/definitions/artifacts" - }, - "resources": { - "$ref": "#/definitions/resources" - }, - "cache": { - "$ref": "#/definitions/cache" - }, - "code": { - "description": "Print VS Code connection URI", - "type": "boolean", - "default": false - }, - "build": { - "$ref": "#/definitions/build" - } - } - } - }, - "properties": { - "workflows": { - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "required": [ - "name" - ], - "properties": { - "name": { - "type": "string", - "minLength": 1 - }, - "help": { - "type": "string", - "minLength": 1 - }, - "deps": { - "description": "The dependencies on other workflows or tags", - "type": "array", - "minItems": 1, - "items": { - "oneOf": [ - { - "type": "object", - "additionalProperties": false, - "properties": { - "workflow": { - "description": "The name of the workflow", - "type": "string", - "minLength": 1 - } - } - }, - { - "type": "object", - "additionalProperties": false, - "properties": { - "tag": { - "description": "The name of the tag", - "type": "string", - "minLength": 1 - } - } - } - ] - } - } - }, - "anyOf": [ - { - "$ref": "#/definitions/bash" - }, - { - "$ref": "#/definitions/docker" - }, - { - "$ref": "#/definitions/code" - }, - { - "$ref": "#/definitions/lab" - }, - { - "$ref": "#/definitions/notebook" - }, - { - "$ref": "#/definitions/ssh_provider" - } - ] - } - } - }, - "required": [ - "workflows" - ], - "type": "object" -} diff --git a/cli/dstack/_internal/utils/workflows.py b/cli/dstack/_internal/utils/workflows.py deleted file mode 100644 index 92f9cc0c3..000000000 --- a/cli/dstack/_internal/utils/workflows.py +++ /dev/null @@ -1,49 +0,0 @@ -import json -import logging -from itertools import groupby -from pathlib import Path - -import jsonschema -import pkg_resources -import yaml - -from dstack._internal.utils.common import PathLike - -logger = logging.getLogger(__name__) - - -def load_workflows(dstack_dir: PathLike, skip_validation_errors: bool = False) -> dict: - dstack_dir = Path(dstack_dir) - files = [] - for pathname in [dstack_dir / "workflows.yaml", dstack_dir / "workflows.yml"]: - if pathname.is_file(): - files.append(pathname) - for pathname in dstack_dir.glob("workflows/*"): - if pathname.suffix not in {".yaml", ".yml"} or not pathname.is_file(): - continue - files.append(pathname) - schema = json.loads( - pkg_resources.resource_string("dstack._internal", "schemas/workflows.json") - ) - - workflows = [] - for file in files: - with file.open("r") as f: - content = yaml.load(f, yaml.FullLoader) - try: - jsonschema.validate(content, schema) - except jsonschema.ValidationError: - logger.warning(f"Workflows validation error: {file}") - if not skip_validation_errors: - raise - continue - workflows.extend(content["workflows"] or []) - - workflows_dict = {} - workflows.sort(key=lambda item: item["name"]) - for name, group in groupby(workflows, key=lambda item: item["name"]): - group = list(group) - if len(group) > 1: - raise NameError(f"{len(group)} workflows with the same name `{name}`") - workflows_dict[name] = group[0] - return workflows_dict From 9383415813c22cb9da9adcb350e2fb7171785d67 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 10:59:18 +0400 Subject: [PATCH 20/26] Clean up utils.common --- cli/dstack/_internal/core/app.py | 2 - cli/dstack/_internal/utils/common.py | 42 +-------------------- cli/tests/utils/test_merge_workflow_data.py | 35 ----------------- 3 files changed, 1 insertion(+), 78 deletions(-) delete mode 100644 cli/tests/utils/test_merge_workflow_data.py diff --git a/cli/dstack/_internal/core/app.py b/cli/dstack/_internal/core/app.py index 7c2988af2..b1ba96e0e 100644 --- a/cli/dstack/_internal/core/app.py +++ b/cli/dstack/_internal/core/app.py @@ -2,8 +2,6 @@ from pydantic import BaseModel -from dstack._internal.utils.common import _quoted - class AppSpec(BaseModel): port: int diff --git a/cli/dstack/_internal/utils/common.py b/cli/dstack/_internal/utils/common.py index d98f0c550..e31c44758 100644 --- a/cli/dstack/_internal/utils/common.py +++ b/cli/dstack/_internal/utils/common.py @@ -1,10 +1,9 @@ -import copy import os import re import time from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Union PathLike = Union[str, os.PathLike] @@ -13,20 +12,6 @@ def get_dstack_dir() -> Path: return Path.joinpath(Path.home(), ".dstack") -def _quoted(s: Optional[str]) -> str: - if s: - return f'"{s}"' - else: - return "None" - - -def _quoted_masked(s: Optional[str]) -> str: - if s: - return f"\"{'*' * len(s)}\"" - else: - return "None" - - def pretty_date(time: Any = False): """ Get a datetime object or a int() Epoch timestamp and return a @@ -137,28 +122,3 @@ def timestamps_in_milliseconds_to_datetime(ts: int) -> datetime: def datetime_to_timestamp_in_milliseconds(dt: datetime) -> int: milliseconds = dt.microsecond // 1000 return int(dt.timestamp()) * 1000 + milliseconds - - -def format_list(items: Optional[list], *, formatter=str) -> Optional[str]: - if items is None: - return None - return "[{}]".format(", ".join(formatter(item) for item in items)) - - -def merge_workflow_data( - data: Dict[str, Any], override: Optional[Dict[str, Any]] -) -> Dict[str, Any]: - override = override or {} - result = {} - for key in data.keys() | override.keys(): - if key not in override: - result[key] = copy.deepcopy(data[key]) - elif key not in data: - result[key] = copy.deepcopy(override[key]) - else: - a, b = data[key], override[key] - if isinstance(a, dict) and isinstance(b, dict): - result[key] = merge_workflow_data(a, b) - else: - result[key] = copy.deepcopy(b) - return result diff --git a/cli/tests/utils/test_merge_workflow_data.py b/cli/tests/utils/test_merge_workflow_data.py deleted file mode 100644 index b6ac91159..000000000 --- a/cli/tests/utils/test_merge_workflow_data.py +++ /dev/null @@ -1,35 +0,0 @@ -from dstack._internal.utils.common import merge_workflow_data - - -def test_none_override(): - data = {"foo": "aaa", "bar": None, "buzz": 1.2} - r = merge_workflow_data(data, None) - assert r == data - - -def test_join(): - r = merge_workflow_data({"foo": "aaa"}, {"bar": 1.2}) - assert r == {"foo": "aaa", "bar": 1.2} - - -def test_plain_override(): - r = merge_workflow_data({"foo": "aaa", "bar": 123456}, {"bar": 1.2}) - assert r == {"foo": "aaa", "bar": 1.2} - - -def test_deep_override(): - r = merge_workflow_data( - {"foo": "aaa", "gpu": {"name": "V100", "count": 1}}, {"gpu": {"name": "A100"}} - ) - assert r == {"foo": "aaa", "gpu": {"name": "A100", "count": 1}} - - -def test_no_mutations(): - data = {"foo": {"bar": "123"}} - override = {"buzz": 123, "foo": {"addon": True}} - r = merge_workflow_data(data, override) - assert r == {"foo": {"bar": "123", "addon": True}, "buzz": 123} - r["buzz"] = 567 - r["foo"]["bar"] = None - assert data == {"foo": {"bar": "123"}} - assert override == {"buzz": 123, "foo": {"addon": True}} From 9ff4103b2f2f6dddc08546507ea711792f0ce887 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 13:41:05 +0400 Subject: [PATCH 21/26] Annotate configuration models --- cli/dstack/_internal/core/configuration.py | 62 +++++++++++++++------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py index 2ee05d115..5caac4ee4 100644 --- a/cli/dstack/_internal/core/configuration.py +++ b/cli/dstack/_internal/core/configuration.py @@ -1,9 +1,11 @@ from enum import Enum from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Extra, Field, validator +from pydantic import BaseModel, Extra, Field, conint, constr, validator from typing_extensions import Annotated, Literal +CommandsList = List[str] + class PythonVersion(str, Enum): PY37 = "3.7" @@ -19,26 +21,47 @@ class Config: class RegistryAuth(ForbidExtra): - username: Optional[str] - password: str + username: Annotated[Optional[str], Field(description="Username")] + password: Annotated[str, Field(description="Password or access token")] class Artifact(ForbidExtra): - path: str - mount: bool = False + path: Annotated[ + str, Field(description="The path to the folder that must be stored as an output artifact") + ] + mount: Annotated[ + bool, + Field( + description="Must be set to `true` if the artifact files must be saved in real-time" + ), + ] = False class BaseConfiguration(ForbidExtra): type: Literal["none"] - image: Optional[str] - entrypoint: Optional[str] - home_dir: str = "/root" - registry_auth: Optional[RegistryAuth] - python: Optional[PythonVersion] - ports: List[Union[str, int]] = [] - env: Dict[str, str] = {} - build: List[str] = [] - cache: List[str] = [] + image: Annotated[Optional[str], Field(description="The name of the Docker image to run")] + entrypoint: Annotated[Optional[str], Field(description="The Docker entrypoint")] + home_dir: Annotated[ + str, Field(description="The absolute path to the home directory inside the container") + ] = "/root" + registry_auth: Annotated[ + Optional[RegistryAuth], Field(description="Credentials for pulling a private container") + ] + python: Annotated[ + Optional[PythonVersion], + Field(description="The major version of Python\nMutually exclusive with the image"), + ] + ports: Annotated[ + List[Union[constr(regex=r"^\d+:\d+$"), conint(gt=0, le=65536)]], + Field(description="Port numbers/mapping to expose"), + ] = [] + env: Annotated[Dict[str, str], Field(description="The list of environment variables")] = {} + build: Annotated[ + CommandsList, Field(description="The bash commands to run during build stage") + ] = [] + cache: Annotated[ + List[str], Field(description="The directories to be cached between configuration runs") + ] = [] @validator("python", pre=True, always=True) def convert_python(cls, v, values) -> Optional[PythonVersion]: @@ -55,14 +78,14 @@ def convert_python(cls, v, values) -> Optional[PythonVersion]: class DevEnvironmentConfiguration(BaseConfiguration): type: Literal["dev-environment"] = "dev-environment" - ide: Literal["vscode"] - init: List[str] = [] + ide: Annotated[Literal["vscode"], Field(description="The IDE to run")] + init: Annotated[CommandsList, Field(description="The bash commands to run")] = [] class TaskConfiguration(BaseConfiguration): type: Literal["task"] = "task" - commands: List[str] - artifacts: List[Artifact] = [] + commands: Annotated[CommandsList, Field(description="The bash commands to run")] + artifacts: Annotated[List[Artifact], Field(description="The list of output artifacts")] = [] class DstackConfiguration(BaseModel): @@ -70,6 +93,9 @@ class DstackConfiguration(BaseModel): Union[DevEnvironmentConfiguration, TaskConfiguration], Field(discriminator="type") ] + class Config: + schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} + def parse(data: dict) -> Union[DevEnvironmentConfiguration, TaskConfiguration]: return DstackConfiguration.parse_obj(data).__root__ From 3d7cd1673f3bcbdb919b12c03705170ad526489c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 13:59:50 +0400 Subject: [PATCH 22/26] Generate json schema on build --- .github/workflows/build.yml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1187a8a72..19c98a7fa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -217,6 +217,26 @@ jobs: aws s3 cp dstack-runner-${{ matrix.goos }}-${{ matrix.goarch }} "s3://dstack-runner-downloads-stgn/$VERSION/binaries/dstack-runner-${{ matrix.goos }}-${{ matrix.platform }}${{ matrix.extension }}" --acl public-read aws s3 cp dstack-runner-${{ matrix.goos }}-${{ matrix.goarch }} "s3://dstack-runner-downloads-stgn/latest/binaries/dstack-runner-${{ matrix.goos }}-${{ matrix.platform }}${{ matrix.extension }}" --acl public-read + generate-json-schema: + needs: [ cli-test-master ] + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install AWS + run: pip install awscli + - name: Install dstack + run: pip install . + - name: Generate json schema + run: python -c "from dstack._internal.core.configuration import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json + - name: Upload json schema to S3 + run: | + VERSION=$((${{ github.run_number }} + 150)) + aws s3 cp configuration.json "s3://dstack-runner-downloads-stgn/$VERSION/schemas/configuration.json" --acl public-read + aws s3 cp configuration.json "s3://dstack-runner-downloads-stgn/latest/schemas/configuration.json" --acl public-read + # cli-integration-tests: # needs: [ runner-upload-master ] # runs-on: ${{ matrix.os }} From 204606a5658aeb94ea6ffbd52431af5db3a8fd3c Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 14:16:45 +0400 Subject: [PATCH 23/26] Generate json schema on release --- .github/workflows/release.yml | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 24ffee935..8c404c7f0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -344,4 +344,24 @@ jobs: run: pip install awscli - run: | VERSION=${GITHUB_REF#refs/tags/} - echo $VERSION | aws s3 cp - s3://get-dstack/cli/latest-version --acl public-read \ No newline at end of file + echo $VERSION | aws s3 cp - s3://get-dstack/cli/latest-version --acl public-read + + generate-json-schema: + needs: [ cli-test-tag ] + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install AWS + run: pip install awscli + - name: Install dstack + run: pip install . + - name: Generate json schema + run: python -c "from dstack._internal.core.configuration import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json + - name: Upload json schema to S3 + run: | + VERSION=${GITHUB_REF#refs/tags/} + aws s3 cp configuration.json "s3://dstack-runner-downloads/$VERSION/schemas/configuration.json" --acl public-read + aws s3 cp configuration.json "s3://dstack-runner-downloads/latest/schemas/configuration.json" --acl public-read From 4a645749b54df90b583f7985d05fd1b609267aa2 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 17:57:50 +0400 Subject: [PATCH 24/26] Update run & configuration docs --- docs/docs/reference/cli/run.md | 10 +++++----- docs/docs/reference/dstack.yml.md | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/docs/reference/cli/run.md b/docs/docs/reference/cli/run.md index b012afa04..9edb47f25 100644 --- a/docs/docs/reference/cli/run.md +++ b/docs/docs/reference/cli/run.md @@ -20,7 +20,6 @@ Options: --profile PROFILE The name of the profile -d, --detach Do not poll logs and run status --reload Enable auto-reload - -t, --tag TAG A tag name. Warning, if the tag exists, it will be overridden. ``` @@ -37,12 +36,13 @@ The following arguments are optional: - `-f FILE`, `--f FILE` – (Optional) The path to the run configuration file. Defaults to `WORKING_DIR/.dstack.yml`. - `--project PROJECT` – (Optional) The name of the project -- `--project PROJECT` – (Optional) The name of the profile +- `--profile PROJECT` – (Optional) The name of the profile - `--reload` – (Optional) Enable auto-reload - `-d`, `--detach` – (Optional) Run in the detached mode. Means, the command doesn't poll logs and run status. -- `-p PORT [PORT ...]`, `--port PORT [PORT ...]` – (Optional) Requests ports or define mappings for them (`APP_PORT:LOCAL_PORT`) -- `-t TAG`, `--tag TAG` – (Optional) A tag name. Warning, if the tag exists, it will be overridden. + +[//]: # (- `-p PORT [PORT ...]`, `--port PORT [PORT ...]` – (Optional) Requests ports or define mappings for them (`APP_PORT:LOCAL_PORT`)) +[//]: # (- `-t TAG`, `--tag TAG` – (Optional) A tag name. Warning, if the tag exists, it will be overridden.) - `ARGS` – (Optional) Use `ARGS` to pass custom run arguments Spot policy (the arguments are mutually exclusive): @@ -69,4 +69,4 @@ Build policies: [//]: # (Tags should be dropped) !!! info "NOTE:" - By default, it runs it in the attached mode, so you'll see the output in real-time. \ No newline at end of file + By default, it runs it in the attached mode, so you'll see the output in real-time. diff --git a/docs/docs/reference/dstack.yml.md b/docs/docs/reference/dstack.yml.md index e7a6fc15b..bf47b9423 100644 --- a/docs/docs/reference/dstack.yml.md +++ b/docs/docs/reference/dstack.yml.md @@ -13,22 +13,24 @@ types: `dev-environment` and `task`. Below is a full reference of all available properties. - `type` - (Required) The type of the configurations. Can be `dev-environment` or `task`. +- `image` - (Optional) The name of the Docker image. +- `entrypoint` - (Optional) The Docker entrypoint. - `build` - (Optional) The list of bash commands to build the environment. - `ide` - (Required if `type` is `dev-environment`). Can be `vscode`. -- `ports` - (Optional) The list of port numbers to expose -- `env` - (Optional) The list of environment variables (e.g. `PYTHONPATH=src`) - -[//]: # (- `image` - (Optional) The name of the Docker image (as an alternative or an addition to `setup`)) -- `registry_auth` - (Optional) Credentials to pull the private Docker image - - `username` - (Required) Username - - `password` - (Required) Password or access token -- `init` - (Optional, only for `dev-environment` type) The list of bash commands to execute on each run -- `commands` - (Required if `type` is `task`). The list of bash commands to run as a task -- `python` - (Optional) The major version of Python to pre-install (e.g., `"3.11""`). Defaults to the current version installed locally. -- `cache` - (Optional) The directories to be cached between runs +- `ports` - (Optional) The list of port numbers to expose. +- `env` - (Optional) The mapping or the list of environment variables (e.g. `PYTHONPATH: src` or `PYTHONPATH=src`). +- `registry_auth` - (Optional) Credentials to pull the private Docker image. + - `username` - (Required) Username. + - `password` - (Required) Password or access token. +- `init` - (Optional, only for `dev-environment` type) The list of bash commands to execute on each run. +- `commands` - (Required if `type` is `task`). The list of bash commands to run as a task. +- `python` - (Optional) The major version of Python to pre-install (e.g., `"3.11"`). Defaults to the current version installed locally. Mutually exclusive with `image`. +- `cache` - (Optional) The directories to be cached between runs. + +[//]: # (- `home_dir` - (Optional) The absolute path to the home directory inside the container) [//]: # (TODO: `artifacts` aren't documented) [//]: # (TODO: Add examples) -[//]: # (TODO: Mention here or somewhere else of how it works. What base image is used, how ports are forwarded, etc.) \ No newline at end of file +[//]: # (TODO: Mention here or somewhere else of how it works. What base image is used, how ports are forwarded, etc.) From 32b912551cb142f744582fa90f9bcbf90ec05d75 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 18:43:22 +0400 Subject: [PATCH 25/26] Support env variables as a list --- cli/dstack/_internal/core/configuration.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py index 5caac4ee4..f78a5a7f2 100644 --- a/cli/dstack/_internal/core/configuration.py +++ b/cli/dstack/_internal/core/configuration.py @@ -52,10 +52,13 @@ class BaseConfiguration(ForbidExtra): Field(description="The major version of Python\nMutually exclusive with the image"), ] ports: Annotated[ - List[Union[constr(regex=r"^\d+:\d+$"), conint(gt=0, le=65536)]], + List[Union[constr(regex=r"^[0-9]+:[0-9]+$"), conint(gt=0, le=65536)]], Field(description="Port numbers/mapping to expose"), ] = [] - env: Annotated[Dict[str, str], Field(description="The list of environment variables")] = {} + env: Annotated[ + Union[List[constr(regex=r"^[a-zA-Z_][a-zA-Z0-9_]*=.*$")], Dict[str, str]], + Field(description="The mapping or the list of environment variables"), + ] = {} build: Annotated[ CommandsList, Field(description="The bash commands to run during build stage") ] = [] @@ -75,6 +78,12 @@ def convert_python(cls, v, values) -> Optional[PythonVersion]: return PythonVersion(v) return v + @validator("env") + def convert_env(cls, v) -> Dict[str, str]: + if isinstance(v, list): + return dict(pair.split(sep="=", maxsplit=1) for pair in v) + return v + class DevEnvironmentConfiguration(BaseConfiguration): type: Literal["dev-environment"] = "dev-environment" @@ -98,4 +107,5 @@ class Config: def parse(data: dict) -> Union[DevEnvironmentConfiguration, TaskConfiguration]: - return DstackConfiguration.parse_obj(data).__root__ + conf = DstackConfiguration.parse_obj(data).__root__ + return conf From 2ee86384c1d85489e0db53fc51dba472d064e1a9 Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Tue, 11 Jul 2023 19:20:31 +0400 Subject: [PATCH 26/26] Explicit build policy args --- cli/dstack/_internal/configurators/__init__.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index ea4f93d5a..9931dce93 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -61,10 +61,15 @@ def get_parser( retry_group.add_argument("--retry-limit", type=str) build_policy = parser.add_mutually_exclusive_group() - for value in BuildPolicy: - build_policy.add_argument( - f"--{value}", action="store_const", dest="build_policy", const=value - ) + build_policy.add_argument( + "--build", action="store_const", dest="build_policy", const=BuildPolicy.BUILD + ) + build_policy.add_argument( + "--force-build", + action="store_const", + dest="build_policy", + const=BuildPolicy.FORCE_BUILD, + ) return parser