Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions cli/dstack/_internal/configurators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dstack._internal.core.build import BuildPolicy
from dstack._internal.core.configuration import BaseConfiguration, PythonVersion
from dstack._internal.core.error import DstackError
from dstack._internal.core.profile import Profile
from dstack._internal.core.profile import Profile, parse_duration, parse_max_duration
from dstack._internal.core.repo import Repo
from dstack._internal.utils.common import get_milliseconds_since_epoch
from dstack._internal.utils.interpolator import VariablesInterpolator
Expand Down Expand Up @@ -60,6 +60,8 @@ def get_parser(
retry_group.add_argument("--no-retry", action="store_true")
retry_group.add_argument("--retry-limit", type=str)

parser.add_argument("--max-duration", type=str)

build_policy = parser.add_mutually_exclusive_group()
build_policy.add_argument(
"--build", action="store_const", dest="build_policy", const=BuildPolicy.BUILD
Expand All @@ -83,11 +85,14 @@ def apply_args(self, args: argparse.Namespace):
self.profile.retry_policy.retry = False
elif args.retry_limit:
self.profile.retry_policy.retry = True
self.profile.retry_policy.limit = args.retry_limit
self.profile.retry_policy.limit = parse_duration(args.retry_limit)

if args.build_policy is not None:
self.build_policy = args.build_policy

if args.max_duration:
self.profile.max_duration = parse_max_duration(args.max_duration)

def inject_context(
self, namespaces: Dict[str, Dict[str, str]], skip: Optional[List[str]] = None
):
Expand Down Expand Up @@ -142,6 +147,7 @@ def get_jobs(
dep_specs=self.dep_specs(),
spot_policy=self.spot_policy(),
retry_policy=self.retry_policy(),
max_duration=self.max_duration(),
build_policy=self.build_policy,
requirements=self.requirements(),
ssh_key_pub=ssh_key_pub,
Expand All @@ -164,6 +170,10 @@ def artifact_specs(self) -> List[job.ArtifactSpec]:
def dep_specs(self) -> List[job.DepSpec]:
pass

@abstractmethod
def default_max_duration(self) -> int:
pass

def build_commands(self) -> List[str]:
return self.conf.build

Expand All @@ -186,12 +196,6 @@ def image_name(self) -> str:
return f"dstackai/miniforge:py{self.python()}-{version.miniforge_image}-cuda-11.4"
return f"dstackai/miniforge:py{self.python()}-{version.miniforge_image}"

def spot_policy(self) -> job.SpotPolicy:
return self.profile.spot_policy or job.SpotPolicy.AUTO

def retry_policy(self) -> job.RetryPolicy:
return job.RetryPolicy.parse_obj(self.profile.retry_policy)

def cache_specs(self) -> List[job.CacheSpec]:
return [
job.CacheSpec(path=validate_local_path(path, self.home_dir(), self.working_dir))
Expand Down Expand Up @@ -254,6 +258,19 @@ def join_run_args(cls, args: List[str]) -> str:
(arg if " " not in arg else '"%s"' % arg.replace('"', '\\"')) for arg in args
)

def spot_policy(self) -> job.SpotPolicy:
return self.profile.spot_policy or job.SpotPolicy.AUTO

def retry_policy(self) -> job.RetryPolicy:
return job.RetryPolicy.parse_obj(self.profile.retry_policy)

def max_duration(self) -> Optional[int]:
if self.profile.max_duration is None:
return self.default_max_duration()
if self.profile.max_duration < 0:
return None
return self.profile.max_duration


def validate_local_path(path: str, home: Optional[str], working_dir: str) -> str:
if path == "~" or path.startswith("~/"):
Expand Down
11 changes: 8 additions & 3 deletions cli/dstack/_internal/configurators/dev_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dstack._internal.core.configuration import DevEnvironmentConfiguration
from dstack._internal.core.repo import Repo

DEFAULT_MAX_DURATION_SECONDS = 6 * 3600

require_sshd = require(["sshd"])
install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"'

Expand Down Expand Up @@ -59,16 +61,19 @@ def optional_build_commands(self) -> List[str]:
commands.append(install_ipykernel)
return commands

def default_max_duration(self) -> int:
return DEFAULT_MAX_DURATION_SECONDS

def artifact_specs(self) -> List[job.ArtifactSpec]:
return [] # not available

def dep_specs(self) -> List[job.DepSpec]:
return [] # not available

def spot_policy(self) -> job.SpotPolicy:
return self.profile.spot_policy or job.SpotPolicy.ONDEMAND

def app_specs(self) -> List[job.AppSpec]:
specs = super().app_specs()
self.sshd.add_app(specs)
return specs

def spot_policy(self) -> job.SpotPolicy:
return self.profile.spot_policy or job.SpotPolicy.ONDEMAND
5 changes: 5 additions & 0 deletions cli/dstack/_internal/configurators/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dstack._internal.core.configuration import TaskConfiguration
from dstack._internal.core.repo import Repo

DEFAULT_MAX_DURATION_SECONDS = 72 * 3600


class TaskConfigurator(JobConfigurator):
conf: TaskConfiguration
Expand All @@ -29,6 +31,9 @@ def commands(self) -> List[str]:
def optional_build_commands(self) -> List[str]:
return [] # not needed

def default_max_duration(self) -> int:
return DEFAULT_MAX_DURATION_SECONDS

def artifact_specs(self) -> List[job.ArtifactSpec]:
specs = []
for a in self.conf.artifacts:
Expand Down
5 changes: 5 additions & 0 deletions cli/dstack/_internal/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class Job(JobHead):
requirements: Optional[Requirements]
spot_policy: Optional[SpotPolicy]
retry_policy: Optional[RetryPolicy]
max_duration: Optional[int]
dep_specs: Optional[List[DepSpec]]
master_job: Optional[JobRef]
app_specs: Optional[List[AppSpec]]
Expand Down Expand Up @@ -287,6 +288,7 @@ def serialize(self) -> dict:
"host_name": self.host_name or "",
"spot_policy": self.spot_policy.value if self.spot_policy else None,
"retry_policy": self.retry_policy.dict() if self.retry_policy else None,
"max_duration": self.max_duration or None,
"requirements": self.requirements.serialize() if self.requirements else {},
"deps": deps,
"master_job_id": self.master_job.get_id() if self.master_job else "",
Expand Down Expand Up @@ -435,6 +437,9 @@ def unserialize(job_data: dict):
host_name=job_data.get("host_name") or None,
spot_policy=SpotPolicy(spot_policy) if spot_policy else None,
retry_policy=retry_policy,
max_duration=int(job_data.get("max_duration"))
if job_data.get("max_duration")
else None,
requirements=requirements,
dep_specs=dep_specs or None,
master_job=master_job,
Expand Down
16 changes: 12 additions & 4 deletions cli/dstack/_internal/core/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def parse_memory(v: Optional[Union[int, str]]) -> Optional[int]:
return int(v)


def duration(v: Union[int, str]) -> int:
def parse_duration(v: Union[int, str]) -> int:
if isinstance(v, int):
return v
regex = re.compile(r"(?P<amount>\d+) *(?P<unit>[smhdw])$")
re_match = regex.match(duration)
re_match = regex.match(v)
if not re_match:
raise ValueError(f"Cannot parse the duration {duration}")
raise ValueError(f"Cannot parse the duration {v}")
amount, unit = int(re_match.group("amount")), re_match.group("unit")
multiplier = {
"s": 1,
Expand All @@ -47,6 +47,12 @@ def duration(v: Union[int, str]) -> int:
return amount * multiplier


def parse_max_duration(v: Union[int, str]) -> int:
if v == "off":
return -1
return parse_duration(v)


class ProfileGPU(ForbidExtra):
name: Optional[str]
count: int = 1
Expand All @@ -71,7 +77,7 @@ def _validate_gpu(cls, v: Optional[Union[int, ProfileGPU]]) -> Optional[ProfileG
class ProfileRetryPolicy(ForbidExtra):
retry: bool = False
limit: Union[int, str] = DEFAULT_RETRY_LIMIT
_validate_limit = validator("limit", pre=True, allow_reuse=True)(duration)
_validate_limit = validator("limit", pre=True, allow_reuse=True)(parse_duration)


class Profile(ForbidExtra):
Expand All @@ -80,7 +86,9 @@ class Profile(ForbidExtra):
resources: ProfileResources = ProfileResources()
spot_policy: Optional[SpotPolicy]
retry_policy: ProfileRetryPolicy = ProfileRetryPolicy()
max_duration: Optional[Union[int, str]]
default: bool = False
_validate_limit = validator("max_duration", pre=True, allow_reuse=True)(parse_max_duration)


class ProfilesConfig(ForbidExtra):
Expand Down
5 changes: 3 additions & 2 deletions docs/docs/reference/profiles.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ Below is a full reference of all available properties.
- `resources` - (Optional) The minimum required resources
- `memory` - (Optional) The minimum size of RAM memory (e.g., `"16GB"`).
- `gpu` - (Optional) The minimum number of GPUs, their model name and memory
- `name` - (Optional) The name of the GPU model (e.g. `"K80"`, `"V100"`, `"A100"`, etc)
- `name` - (Optional) The name of the GPU model (e.g., `"K80"`, `"V100"`, `"A100"`, etc)
- `count` - (Optional) The minimum number of GPUs. Defaults to `1`.
- `memory` (Optional) The minimum size of GPU memory (e.g., `"16GB"`)
- `shm_size` (Optional) The size of shared memory (e.g., `"8GB"`). If you are using parallel communicating
processes (e.g., dataloaders in PyTorch), you may need to configure this.
- `spot_policy` - (Optional) The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. `spot` provisions a spot instance. `on-demand` provisions a on-demand instance. `auto` first tries to provision a spot instance and then tries on-demand if spot is not available. Defaults to `on-demand` for dev environments and to `auto` for tasks.
- `retry_policy` - (Optional) The policy for re-submitting the run.
- `retry` - (Optional) Whether to retry the run on failure or not. Default to `false`
- `limit` - (Optional) The maximum period of retrying the run, e.g. 4h or 1d. Defaults to 1h if `retry` is `true`.
- `limit` - (Optional) The maximum period of retrying the run, e.g., `4h` or `1d`. Defaults to `1h` if `retry` is `true`.
- `max_duration` - (Optional) The maximum duration of a run (e.g., `2h`, `1d`, etc). After it elapses, the run is forced to stop. Protects from running idle instances. Defaults to `6h` for dev environments and to `72h` for tasks. Use `max_duration: off` to disable maximum run duration.

[//]: # (TODO: Add examples)

Expand Down
21 changes: 15 additions & 6 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/dstackai/dstack/runner/internal/backend/base"
"io"
"os"
"path"
"path/filepath"
"strconv"
"time"

"github.com/dstackai/dstack/runner/internal/backend/base"

"github.com/dstackai/dstack/runner/internal/models"
"github.com/dustin/go-humanize"

Expand Down Expand Up @@ -46,7 +47,6 @@ type Executor struct {
artifactsOut []base.Artifacter
artifactsFUSE []base.Artifacter
repo *repo.Manager
portID string
streamLogs *stream.Server
stoppedCh chan struct{}
}
Expand Down Expand Up @@ -142,15 +142,24 @@ func (ex *Executor) Run(ctx context.Context) error {
if err != nil {
return err
}
job, err := ex.backend.RefetchJob(runCtx)
if err != nil {
return gerrors.Wrap(err)
}
if stopped {
log.Info(runCtx, "Stopped")
ex.Stop()
log.Info(runCtx, "Waiting job end")
errRun := <-erCh
job, err := ex.backend.RefetchJob(runCtx)
if err != nil {
return gerrors.Wrap(err)
}
job.Status = states.Stopped
_ = ex.backend.UpdateState(runCtx)
return errRun
}
if job.MaxDurationExceeded() {
log.Info(runCtx, "Job max duration exceeded. Stopping...")
ex.Stop()
log.Info(runCtx, "Waiting job end")
errRun := <-erCh
job.Status = states.Stopped
_ = ex.backend.UpdateState(runCtx)
return errRun
Expand Down
10 changes: 10 additions & 0 deletions runner/internal/models/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package models
import (
"fmt"
"strings"
"time"
)

type Resource struct {
Expand Down Expand Up @@ -54,6 +55,7 @@ type Job struct {
RunnerID string `yaml:"runner_id"`
SpotPolicy string `yaml:"spot_policy"`
RetryPolicy RetryPolicy `yaml:"retry_policy"`
MaxDuration uint64 `yaml:"max_duration,omitempty"`
Status string `yaml:"status"`
ErrorCode string `yaml:"error_code,omitempty"`
ContainerExitCode string `yaml:"container_exit_code,omitempty"`
Expand Down Expand Up @@ -198,3 +200,11 @@ func (j *Job) GetInstanceType() string {
func (j *Job) SecretsPrefix() string {
return fmt.Sprintf("secrets/%s/l;", j.RepoId)
}

func (j *Job) MaxDurationExceeded() bool {
if j.MaxDuration == 0 {
return false
}
now := uint64(time.Now().Unix())
return now > j.SubmittedAt/1000+j.MaxDuration
}