diff --git a/cli/dstack/_internal/backend/aws/__init__.py b/cli/dstack/_internal/backend/aws/__init__.py index 34d2c05bd..f54cb1b05 100644 --- a/cli/dstack/_internal/backend/aws/__init__.py +++ b/cli/dstack/_internal/backend/aws/__init__.py @@ -71,11 +71,11 @@ def secrets_manager(self) -> AWSSecretsManager: def logging(self) -> AWSLogging: return self._logging - def create_run(self, repo_id: str) -> str: + def create_run(self, repo_id: str, run_name: Optional[str]) -> str: self._logging.create_log_groups_if_not_exist( aws_utils.get_logs_client(self._session), self.backend_config.bucket_name, repo_id ) - return base_runs.create_run(self._storage) + return base_runs.create_run(self._storage, run_name) def _check_credentials(self): try: diff --git a/cli/dstack/_internal/backend/base/__init__.py b/cli/dstack/_internal/backend/base/__init__.py index c2ab08a3a..d67ec8195 100644 --- a/cli/dstack/_internal/backend/base/__init__.py +++ b/cli/dstack/_internal/backend/base/__init__.py @@ -45,11 +45,14 @@ def predict_instance_type(self, job: Job) -> Optional[InstanceType]: pass @abstractmethod - def create_run(self, repo_id: str) -> str: + def create_run(self, repo_id: str, run_name: Optional[str]) -> str: pass @abstractmethod - def create_job(self, job: Job): + def create_job( + self, + job: Job, + ): pass def submit_job(self, job: Job, failed_to_start_job_new_status: JobStatus = JobStatus.FAILED): diff --git a/cli/dstack/_internal/backend/base/jobs.py b/cli/dstack/_internal/backend/base/jobs.py index 219a30d45..1af9049cc 100644 --- a/cli/dstack/_internal/backend/base/jobs.py +++ b/cli/dstack/_internal/backend/base/jobs.py @@ -8,7 +8,6 @@ from dstack._internal.core.error import BackendError, BackendValueError, NoMatchingInstanceError from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import ( - ConfigurationType, Job, JobErrorCode, JobHead, diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index 2531bc358..247912fdd 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -60,10 +60,10 @@ def secrets_manager(self) -> AWSSecretsManager: def logging(self) -> AWSLogging: return self._logging - def create_run(self, repo_id: str) -> str: + def create_run(self, repo_id: str, run_name: Optional[str]) -> str: self._logging.create_log_groups_if_not_exist( aws_utils.get_logs_client(self._session), self.backend_config.storage_config.bucket, repo_id, ) - return base_runs.create_run(self._storage) + return base_runs.create_run(self._storage, run_name) diff --git a/cli/dstack/_internal/cli/commands/build/__init__.py b/cli/dstack/_internal/cli/commands/build/__init__.py index b52bfb640..1ab3abe8d 100644 --- a/cli/dstack/_internal/cli/commands/build/__init__.py +++ b/cli/dstack/_internal/cli/commands/build/__init__.py @@ -105,6 +105,7 @@ def _command(self, args: argparse.Namespace): configurator=configurator, ssh_key_pub=ssh_key_pub, run_args=run_args, + run_plan=run_plan, ) runs = list_runs_hub(hub_client, run_name=run_name) run = runs[0] diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index fe33f106b..c0111e333 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -134,6 +134,7 @@ def _command(self, args: Namespace): ssh_key_pub=ssh_key_pub, run_name=args.name, run_args=run_args, + run_plan=run_plan, ) runs = list_runs_hub(hub_client, run_name=run_name) run = runs[0] diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index bd1b47c1d..6b7e152e1 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -14,6 +14,7 @@ from dstack._internal.core.build import BuildPolicy from dstack._internal.core.configuration import BaseConfiguration, PythonVersion from dstack._internal.core.error import DstackError +from dstack._internal.core.plan import RunPlan from dstack._internal.core.profile import Profile, parse_duration, parse_max_duration from dstack._internal.core.repo import Repo from dstack._internal.utils.common import get_milliseconds_since_epoch @@ -121,7 +122,12 @@ def interpolate(obj): self.conf = type(self.conf).parse_obj(conf) def get_jobs( - self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str + self, + repo: Repo, + run_name: str, + repo_code_filename: str, + ssh_key_pub: str, + run_plan: Optional[RunPlan] = None, ) -> List[job.Job]: self.run_name = run_name self.ssh_key_pub = ssh_key_pub @@ -138,7 +144,7 @@ def get_jobs( status=job.JobStatus.SUBMITTED, created_at=created_at, submitted_at=created_at, - image_name=self.image_name(), + image_name=self.image_name(run_plan), registry_auth=self.registry_auth(), entrypoint=self.entrypoint(), build_commands=self.build_commands(), @@ -201,11 +207,12 @@ def entrypoint(self) -> Optional[List[str]]: def home_dir(self) -> Optional[str]: return self.conf.home_dir - def image_name(self) -> str: + def image_name(self, run_plan: Optional[RunPlan]) -> Optional[str]: if self.conf.image is not None: return self.conf.image - if self.profile.resources and self.profile.resources.gpu: - return f"dstackai/base:py{self.python()}-{version.base_image}-cuda-11.8" + if run_plan is not None: + if len(run_plan.job_plans[0].instance_type.resources.gpus) > 0: + return f"dstackai/base:py{self.python()}-{version.base_image}-cuda-11.8" return f"dstackai/base:py{self.python()}-{version.base_image}" def cache_specs(self) -> List[job.CacheSpec]: diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index edab1e2f9..00e89cb9e 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -7,6 +7,7 @@ from dstack._internal.configurators.extensions.vscode import VSCodeDesktop from dstack._internal.configurators.ports import get_map_to_port from dstack._internal.core.configuration import DevEnvironmentConfiguration +from dstack._internal.core.plan import RunPlan from dstack._internal.core.repo import Repo DEFAULT_MAX_DURATION_SECONDS = 6 * 3600 @@ -20,14 +21,19 @@ class DevEnvironmentConfigurator(JobConfigurator): ide: Optional[IDEExtension] def get_jobs( - self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str + self, + repo: Repo, + run_name: str, + repo_code_filename: str, + ssh_key_pub: str, + run_plan: Optional[RunPlan] = None, ) -> List[job.Job]: self.ide = VSCodeDesktop( extensions=["ms-python.python", "ms-toolsai.jupyter"], run_name=run_name ) self.sshd = SSHd(ssh_key_pub) self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port) - return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub) + return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub, run_plan) def build_commands(self) -> List[str]: if len(self.conf.build) == 0: diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py index d33fb4de8..a9da79cd5 100644 --- a/cli/dstack/_internal/configurators/task.py +++ b/cli/dstack/_internal/configurators/task.py @@ -5,6 +5,7 @@ from dstack._internal.configurators.ports import get_map_to_port from dstack._internal.core import job as job from dstack._internal.core.configuration import TaskConfiguration +from dstack._internal.core.plan import RunPlan from dstack._internal.core.repo import Repo DEFAULT_MAX_DURATION_SECONDS = 72 * 3600 @@ -15,11 +16,16 @@ class TaskConfigurator(JobConfigurator): sshd: Optional[SSHd] def get_jobs( - self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str + self, + repo: Repo, + run_name: str, + repo_code_filename: str, + ssh_key_pub: str, + run_plan: Optional[RunPlan] = None, ) -> List[job.Job]: self.sshd = SSHd(ssh_key_pub) self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port) - return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub) + return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub, run_plan) def optional_build_commands(self) -> List[str]: return [] # not needed diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py index f9f719428..7eeb8c101 100644 --- a/cli/dstack/_internal/core/profile.py +++ b/cli/dstack/_internal/core/profile.py @@ -63,6 +63,10 @@ class ProfileGPU(ForbidExtra): ] _validate_mem = validator("memory", pre=True, allow_reuse=True)(parse_memory) + @validator("name") + def _validate_name(name: str): + return name.upper() + class ProfileResources(ForbidExtra): gpu: Optional[Union[int, ProfileGPU]] diff --git a/cli/dstack/api/hub/_client.py b/cli/dstack/api/hub/_client.py index 8318c6f41..33df23687 100644 --- a/cli/dstack/api/hub/_client.py +++ b/cli/dstack/api/hub/_client.py @@ -280,6 +280,7 @@ def run_configuration( ssh_key_pub: str, run_name: Optional[str] = None, run_args: Optional[List[str]] = None, + run_plan: Optional[RunPlan] = None, ) -> Tuple[str, List[Job]]: run_name = self.create_run(run_name) configurator = copy.deepcopy(configurator) @@ -295,6 +296,7 @@ def run_configuration( run_name=run_name, repo_code_filename=repo_code_filename, ssh_key_pub=ssh_key_pub, + run_plan=run_plan, ) for job in jobs: self.submit_job(job)