diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index 9931dce93..9448ad033 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -14,7 +14,7 @@ from dstack._internal.core.build import BuildPolicy from dstack._internal.core.configuration import BaseConfiguration, PythonVersion from dstack._internal.core.error import DstackError -from dstack._internal.core.profile import Profile +from dstack._internal.core.profile import Profile, parse_duration, parse_max_duration from dstack._internal.core.repo import Repo from dstack._internal.utils.common import get_milliseconds_since_epoch from dstack._internal.utils.interpolator import VariablesInterpolator @@ -60,6 +60,8 @@ def get_parser( retry_group.add_argument("--no-retry", action="store_true") retry_group.add_argument("--retry-limit", type=str) + parser.add_argument("--max-duration", type=str) + build_policy = parser.add_mutually_exclusive_group() build_policy.add_argument( "--build", action="store_const", dest="build_policy", const=BuildPolicy.BUILD @@ -83,11 +85,14 @@ def apply_args(self, args: argparse.Namespace): self.profile.retry_policy.retry = False elif args.retry_limit: self.profile.retry_policy.retry = True - self.profile.retry_policy.limit = args.retry_limit + self.profile.retry_policy.limit = parse_duration(args.retry_limit) if args.build_policy is not None: self.build_policy = args.build_policy + if args.max_duration: + self.profile.max_duration = parse_max_duration(args.max_duration) + def inject_context( self, namespaces: Dict[str, Dict[str, str]], skip: Optional[List[str]] = None ): @@ -142,6 +147,7 @@ def get_jobs( dep_specs=self.dep_specs(), spot_policy=self.spot_policy(), retry_policy=self.retry_policy(), + max_duration=self.max_duration(), build_policy=self.build_policy, requirements=self.requirements(), ssh_key_pub=ssh_key_pub, @@ -164,6 +170,10 @@ def artifact_specs(self) -> List[job.ArtifactSpec]: def dep_specs(self) -> List[job.DepSpec]: pass + @abstractmethod + def default_max_duration(self) -> int: + pass + def build_commands(self) -> List[str]: return self.conf.build @@ -186,12 +196,6 @@ def image_name(self) -> str: return f"dstackai/miniforge:py{self.python()}-{version.miniforge_image}-cuda-11.4" return f"dstackai/miniforge:py{self.python()}-{version.miniforge_image}" - def spot_policy(self) -> job.SpotPolicy: - return self.profile.spot_policy or job.SpotPolicy.AUTO - - def retry_policy(self) -> job.RetryPolicy: - return job.RetryPolicy.parse_obj(self.profile.retry_policy) - def cache_specs(self) -> List[job.CacheSpec]: return [ job.CacheSpec(path=validate_local_path(path, self.home_dir(), self.working_dir)) @@ -254,6 +258,19 @@ def join_run_args(cls, args: List[str]) -> str: (arg if " " not in arg else '"%s"' % arg.replace('"', '\\"')) for arg in args ) + def spot_policy(self) -> job.SpotPolicy: + return self.profile.spot_policy or job.SpotPolicy.AUTO + + def retry_policy(self) -> job.RetryPolicy: + return job.RetryPolicy.parse_obj(self.profile.retry_policy) + + def max_duration(self) -> Optional[int]: + if self.profile.max_duration is None: + return self.default_max_duration() + if self.profile.max_duration < 0: + return None + return self.profile.max_duration + def validate_local_path(path: str, home: Optional[str], working_dir: str) -> str: if path == "~" or path.startswith("~/"): diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index 2725ed4da..8f6109834 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -10,6 +10,8 @@ from dstack._internal.core.configuration import DevEnvironmentConfiguration from dstack._internal.core.repo import Repo +DEFAULT_MAX_DURATION_SECONDS = 6 * 3600 + require_sshd = require(["sshd"]) install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"' @@ -59,16 +61,19 @@ def optional_build_commands(self) -> List[str]: commands.append(install_ipykernel) return commands + def default_max_duration(self) -> int: + return DEFAULT_MAX_DURATION_SECONDS + def artifact_specs(self) -> List[job.ArtifactSpec]: return [] # not available def dep_specs(self) -> List[job.DepSpec]: return [] # not available - def spot_policy(self) -> job.SpotPolicy: - return self.profile.spot_policy or job.SpotPolicy.ONDEMAND - def app_specs(self) -> List[job.AppSpec]: specs = super().app_specs() self.sshd.add_app(specs) return specs + + def spot_policy(self) -> job.SpotPolicy: + return self.profile.spot_policy or job.SpotPolicy.ONDEMAND diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py index 25c26dc00..8609cd975 100644 --- a/cli/dstack/_internal/configurators/task.py +++ b/cli/dstack/_internal/configurators/task.py @@ -7,6 +7,8 @@ from dstack._internal.core.configuration import TaskConfiguration from dstack._internal.core.repo import Repo +DEFAULT_MAX_DURATION_SECONDS = 72 * 3600 + class TaskConfigurator(JobConfigurator): conf: TaskConfiguration @@ -29,6 +31,9 @@ def commands(self) -> List[str]: def optional_build_commands(self) -> List[str]: return [] # not needed + def default_max_duration(self) -> int: + return DEFAULT_MAX_DURATION_SECONDS + def artifact_specs(self) -> List[job.ArtifactSpec]: specs = [] for a in self.conf.artifacts: diff --git a/cli/dstack/_internal/core/job.py b/cli/dstack/_internal/core/job.py index 4081542e1..daab51158 100644 --- a/cli/dstack/_internal/core/job.py +++ b/cli/dstack/_internal/core/job.py @@ -206,6 +206,7 @@ class Job(JobHead): requirements: Optional[Requirements] spot_policy: Optional[SpotPolicy] retry_policy: Optional[RetryPolicy] + max_duration: Optional[int] dep_specs: Optional[List[DepSpec]] master_job: Optional[JobRef] app_specs: Optional[List[AppSpec]] @@ -287,6 +288,7 @@ def serialize(self) -> dict: "host_name": self.host_name or "", "spot_policy": self.spot_policy.value if self.spot_policy else None, "retry_policy": self.retry_policy.dict() if self.retry_policy else None, + "max_duration": self.max_duration or None, "requirements": self.requirements.serialize() if self.requirements else {}, "deps": deps, "master_job_id": self.master_job.get_id() if self.master_job else "", @@ -435,6 +437,9 @@ def unserialize(job_data: dict): host_name=job_data.get("host_name") or None, spot_policy=SpotPolicy(spot_policy) if spot_policy else None, retry_policy=retry_policy, + max_duration=int(job_data.get("max_duration")) + if job_data.get("max_duration") + else None, requirements=requirements, dep_specs=dep_specs or None, master_job=master_job, diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py index 377adf3c7..66d494e73 100644 --- a/cli/dstack/_internal/core/profile.py +++ b/cli/dstack/_internal/core/profile.py @@ -29,13 +29,13 @@ def parse_memory(v: Optional[Union[int, str]]) -> Optional[int]: return int(v) -def duration(v: Union[int, str]) -> int: +def parse_duration(v: Union[int, str]) -> int: if isinstance(v, int): return v regex = re.compile(r"(?P\d+) *(?P[smhdw])$") - re_match = regex.match(duration) + re_match = regex.match(v) if not re_match: - raise ValueError(f"Cannot parse the duration {duration}") + raise ValueError(f"Cannot parse the duration {v}") amount, unit = int(re_match.group("amount")), re_match.group("unit") multiplier = { "s": 1, @@ -47,6 +47,12 @@ def duration(v: Union[int, str]) -> int: return amount * multiplier +def parse_max_duration(v: Union[int, str]) -> int: + if v == "off": + return -1 + return parse_duration(v) + + class ProfileGPU(ForbidExtra): name: Optional[str] count: int = 1 @@ -71,7 +77,7 @@ def _validate_gpu(cls, v: Optional[Union[int, ProfileGPU]]) -> Optional[ProfileG class ProfileRetryPolicy(ForbidExtra): retry: bool = False limit: Union[int, str] = DEFAULT_RETRY_LIMIT - _validate_limit = validator("limit", pre=True, allow_reuse=True)(duration) + _validate_limit = validator("limit", pre=True, allow_reuse=True)(parse_duration) class Profile(ForbidExtra): @@ -80,7 +86,9 @@ class Profile(ForbidExtra): resources: ProfileResources = ProfileResources() spot_policy: Optional[SpotPolicy] retry_policy: ProfileRetryPolicy = ProfileRetryPolicy() + max_duration: Optional[Union[int, str]] default: bool = False + _validate_limit = validator("max_duration", pre=True, allow_reuse=True)(parse_max_duration) class ProfilesConfig(ForbidExtra): diff --git a/docs/docs/reference/profiles.yml.md b/docs/docs/reference/profiles.yml.md index d3187678c..d1f9de226 100644 --- a/docs/docs/reference/profiles.yml.md +++ b/docs/docs/reference/profiles.yml.md @@ -11,7 +11,7 @@ Below is a full reference of all available properties. - `resources` - (Optional) The minimum required resources - `memory` - (Optional) The minimum size of RAM memory (e.g., `"16GB"`). - `gpu` - (Optional) The minimum number of GPUs, their model name and memory - - `name` - (Optional) The name of the GPU model (e.g. `"K80"`, `"V100"`, `"A100"`, etc) + - `name` - (Optional) The name of the GPU model (e.g., `"K80"`, `"V100"`, `"A100"`, etc) - `count` - (Optional) The minimum number of GPUs. Defaults to `1`. - `memory` (Optional) The minimum size of GPU memory (e.g., `"16GB"`) - `shm_size` (Optional) The size of shared memory (e.g., `"8GB"`). If you are using parallel communicating @@ -19,7 +19,8 @@ Below is a full reference of all available properties. - `spot_policy` - (Optional) The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. `spot` provisions a spot instance. `on-demand` provisions a on-demand instance. `auto` first tries to provision a spot instance and then tries on-demand if spot is not available. Defaults to `on-demand` for dev environments and to `auto` for tasks. - `retry_policy` - (Optional) The policy for re-submitting the run. - `retry` - (Optional) Whether to retry the run on failure or not. Default to `false` - - `limit` - (Optional) The maximum period of retrying the run, e.g. 4h or 1d. Defaults to 1h if `retry` is `true`. + - `limit` - (Optional) The maximum period of retrying the run, e.g., `4h` or `1d`. Defaults to `1h` if `retry` is `true`. + - `max_duration` - (Optional) The maximum duration of a run (e.g., `2h`, `1d`, etc). After it elapses, the run is forced to stop. Protects from running idle instances. Defaults to `6h` for dev environments and to `72h` for tasks. Use `max_duration: off` to disable maximum run duration. [//]: # (TODO: Add examples) diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index fe359ee69..f102a0208 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/dstackai/dstack/runner/internal/backend/base" "io" "os" "path" @@ -14,6 +13,8 @@ import ( "strconv" "time" + "github.com/dstackai/dstack/runner/internal/backend/base" + "github.com/dstackai/dstack/runner/internal/models" "github.com/dustin/go-humanize" @@ -46,7 +47,6 @@ type Executor struct { artifactsOut []base.Artifacter artifactsFUSE []base.Artifacter repo *repo.Manager - portID string streamLogs *stream.Server stoppedCh chan struct{} } @@ -142,15 +142,24 @@ func (ex *Executor) Run(ctx context.Context) error { if err != nil { return err } + job, err := ex.backend.RefetchJob(runCtx) + if err != nil { + return gerrors.Wrap(err) + } if stopped { log.Info(runCtx, "Stopped") ex.Stop() log.Info(runCtx, "Waiting job end") errRun := <-erCh - job, err := ex.backend.RefetchJob(runCtx) - if err != nil { - return gerrors.Wrap(err) - } + job.Status = states.Stopped + _ = ex.backend.UpdateState(runCtx) + return errRun + } + if job.MaxDurationExceeded() { + log.Info(runCtx, "Job max duration exceeded. Stopping...") + ex.Stop() + log.Info(runCtx, "Waiting job end") + errRun := <-erCh job.Status = states.Stopped _ = ex.backend.UpdateState(runCtx) return errRun diff --git a/runner/internal/models/backend.go b/runner/internal/models/backend.go index cf25ab807..bd3112a98 100644 --- a/runner/internal/models/backend.go +++ b/runner/internal/models/backend.go @@ -3,6 +3,7 @@ package models import ( "fmt" "strings" + "time" ) type Resource struct { @@ -54,6 +55,7 @@ type Job struct { RunnerID string `yaml:"runner_id"` SpotPolicy string `yaml:"spot_policy"` RetryPolicy RetryPolicy `yaml:"retry_policy"` + MaxDuration uint64 `yaml:"max_duration,omitempty"` Status string `yaml:"status"` ErrorCode string `yaml:"error_code,omitempty"` ContainerExitCode string `yaml:"container_exit_code,omitempty"` @@ -198,3 +200,11 @@ func (j *Job) GetInstanceType() string { func (j *Job) SecretsPrefix() string { return fmt.Sprintf("secrets/%s/l;", j.RepoId) } + +func (j *Job) MaxDurationExceeded() bool { + if j.MaxDuration == 0 { + return false + } + now := uint64(time.Now().Unix()) + return now > j.SubmittedAt/1000+j.MaxDuration +}