Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cli/dstack/_internal/backend/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def secrets_manager(self) -> AWSSecretsManager:
def logging(self) -> AWSLogging:
return self._logging

def create_run(self, repo_id: str) -> str:
def create_run(self, repo_id: str, run_name: Optional[str]) -> str:
self._logging.create_log_groups_if_not_exist(
aws_utils.get_logs_client(self._session), self.backend_config.bucket_name, repo_id
)
return base_runs.create_run(self._storage)
return base_runs.create_run(self._storage, run_name)

def _check_credentials(self):
try:
Expand Down
7 changes: 5 additions & 2 deletions cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ def predict_instance_type(self, job: Job) -> Optional[InstanceType]:
pass

@abstractmethod
def create_run(self, repo_id: str) -> str:
def create_run(self, repo_id: str, run_name: Optional[str]) -> str:
pass

@abstractmethod
def create_job(self, job: Job):
def create_job(
self,
job: Job,
):
pass

def submit_job(self, job: Job, failed_to_start_job_new_status: JobStatus = JobStatus.FAILED):
Expand Down
1 change: 0 additions & 1 deletion cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dstack._internal.core.error import BackendError, BackendValueError, NoMatchingInstanceError
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import (
ConfigurationType,
Job,
JobErrorCode,
JobHead,
Expand Down
4 changes: 2 additions & 2 deletions cli/dstack/_internal/backend/lambdalabs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def secrets_manager(self) -> AWSSecretsManager:
def logging(self) -> AWSLogging:
return self._logging

def create_run(self, repo_id: str) -> str:
def create_run(self, repo_id: str, run_name: Optional[str]) -> str:
self._logging.create_log_groups_if_not_exist(
aws_utils.get_logs_client(self._session),
self.backend_config.storage_config.bucket,
repo_id,
)
return base_runs.create_run(self._storage)
return base_runs.create_run(self._storage, run_name)
1 change: 1 addition & 0 deletions cli/dstack/_internal/cli/commands/build/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _command(self, args: argparse.Namespace):
configurator=configurator,
ssh_key_pub=ssh_key_pub,
run_args=run_args,
run_plan=run_plan,
)
runs = list_runs_hub(hub_client, run_name=run_name)
run = runs[0]
Expand Down
1 change: 1 addition & 0 deletions cli/dstack/_internal/cli/commands/run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _command(self, args: Namespace):
ssh_key_pub=ssh_key_pub,
run_name=args.name,
run_args=run_args,
run_plan=run_plan,
)
runs = list_runs_hub(hub_client, run_name=run_name)
run = runs[0]
Expand Down
17 changes: 12 additions & 5 deletions cli/dstack/_internal/configurators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dstack._internal.core.build import BuildPolicy
from dstack._internal.core.configuration import BaseConfiguration, PythonVersion
from dstack._internal.core.error import DstackError
from dstack._internal.core.plan import RunPlan
from dstack._internal.core.profile import Profile, parse_duration, parse_max_duration
from dstack._internal.core.repo import Repo
from dstack._internal.utils.common import get_milliseconds_since_epoch
Expand Down Expand Up @@ -121,7 +122,12 @@ def interpolate(obj):
self.conf = type(self.conf).parse_obj(conf)

def get_jobs(
self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str
self,
repo: Repo,
run_name: str,
repo_code_filename: str,
ssh_key_pub: str,
run_plan: Optional[RunPlan] = None,
) -> List[job.Job]:
self.run_name = run_name
self.ssh_key_pub = ssh_key_pub
Expand All @@ -138,7 +144,7 @@ def get_jobs(
status=job.JobStatus.SUBMITTED,
created_at=created_at,
submitted_at=created_at,
image_name=self.image_name(),
image_name=self.image_name(run_plan),
registry_auth=self.registry_auth(),
entrypoint=self.entrypoint(),
build_commands=self.build_commands(),
Expand Down Expand Up @@ -201,11 +207,12 @@ def entrypoint(self) -> Optional[List[str]]:
def home_dir(self) -> Optional[str]:
return self.conf.home_dir

def image_name(self) -> str:
def image_name(self, run_plan: Optional[RunPlan]) -> Optional[str]:
if self.conf.image is not None:
return self.conf.image
if self.profile.resources and self.profile.resources.gpu:
return f"dstackai/base:py{self.python()}-{version.base_image}-cuda-11.8"
if run_plan is not None:
if len(run_plan.job_plans[0].instance_type.resources.gpus) > 0:
return f"dstackai/base:py{self.python()}-{version.base_image}-cuda-11.8"
return f"dstackai/base:py{self.python()}-{version.base_image}"

def cache_specs(self) -> List[job.CacheSpec]:
Expand Down
10 changes: 8 additions & 2 deletions cli/dstack/_internal/configurators/dev_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dstack._internal.configurators.extensions.vscode import VSCodeDesktop
from dstack._internal.configurators.ports import get_map_to_port
from dstack._internal.core.configuration import DevEnvironmentConfiguration
from dstack._internal.core.plan import RunPlan
from dstack._internal.core.repo import Repo

DEFAULT_MAX_DURATION_SECONDS = 6 * 3600
Expand All @@ -20,14 +21,19 @@ class DevEnvironmentConfigurator(JobConfigurator):
ide: Optional[IDEExtension]

def get_jobs(
self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str
self,
repo: Repo,
run_name: str,
repo_code_filename: str,
ssh_key_pub: str,
run_plan: Optional[RunPlan] = None,
) -> List[job.Job]:
self.ide = VSCodeDesktop(
extensions=["ms-python.python", "ms-toolsai.jupyter"], run_name=run_name
)
self.sshd = SSHd(ssh_key_pub)
self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port)
return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub)
return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub, run_plan)

def build_commands(self) -> List[str]:
if len(self.conf.build) == 0:
Expand Down
10 changes: 8 additions & 2 deletions cli/dstack/_internal/configurators/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dstack._internal.configurators.ports import get_map_to_port
from dstack._internal.core import job as job
from dstack._internal.core.configuration import TaskConfiguration
from dstack._internal.core.plan import RunPlan
from dstack._internal.core.repo import Repo

DEFAULT_MAX_DURATION_SECONDS = 72 * 3600
Expand All @@ -15,11 +16,16 @@ class TaskConfigurator(JobConfigurator):
sshd: Optional[SSHd]

def get_jobs(
self, repo: Repo, run_name: str, repo_code_filename: str, ssh_key_pub: str
self,
repo: Repo,
run_name: str,
repo_code_filename: str,
ssh_key_pub: str,
run_plan: Optional[RunPlan] = None,
) -> List[job.Job]:
self.sshd = SSHd(ssh_key_pub)
self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port)
return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub)
return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub, run_plan)

def optional_build_commands(self) -> List[str]:
return [] # not needed
Expand Down
4 changes: 4 additions & 0 deletions cli/dstack/_internal/core/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class ProfileGPU(ForbidExtra):
]
_validate_mem = validator("memory", pre=True, allow_reuse=True)(parse_memory)

@validator("name")
def _validate_name(name: str):
return name.upper()


class ProfileResources(ForbidExtra):
gpu: Optional[Union[int, ProfileGPU]]
Expand Down
2 changes: 2 additions & 0 deletions cli/dstack/api/hub/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def run_configuration(
ssh_key_pub: str,
run_name: Optional[str] = None,
run_args: Optional[List[str]] = None,
run_plan: Optional[RunPlan] = None,
) -> Tuple[str, List[Job]]:
run_name = self.create_run(run_name)
configurator = copy.deepcopy(configurator)
Expand All @@ -295,6 +296,7 @@ def run_configuration(
run_name=run_name,
repo_code_filename=repo_code_filename,
ssh_key_pub=ssh_key_pub,
run_plan=run_plan,
)
for job in jobs:
self.submit_job(job)
Expand Down