Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions src/dstack/_internal/cli/services/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dstack._internal.core.models.profiles import (
CreationPolicy,
Profile,
ProfileRetryPolicy,
ProfileRetry,
SpotPolicy,
TerminationPolicy,
parse_duration,
Expand Down Expand Up @@ -120,10 +120,8 @@ def register_profile_args(parser: argparse.ArgumentParser):

retry_group = parser.add_argument_group("Retry policy")
retry_group_exc = retry_group.add_mutually_exclusive_group()
retry_group_exc.add_argument("--retry", action="store_const", dest="retry_policy", const=True)
retry_group_exc.add_argument(
"--no-retry", action="store_const", dest="retry_policy", const=False
)
retry_group_exc.add_argument("--retry", action="store_const", dest="retry", const=True)
retry_group_exc.add_argument("--no-retry", action="store_const", dest="retry", const=False)
retry_group_exc.add_argument(
"--retry-duration", type=retry_duration, dest="retry_duration", metavar="DURATION"
)
Expand Down Expand Up @@ -161,15 +159,12 @@ def apply_profile_args(
if args.spot_policy is not None:
profile_settings.spot_policy = args.spot_policy

if args.retry_policy is not None:
if not profile_settings.retry_policy:
profile_settings.retry_policy = ProfileRetryPolicy()
profile_settings.retry_policy.retry = args.retry_policy
if args.retry is not None:
profile_settings.retry = args.retry
elif args.retry_duration is not None:
if not profile_settings.retry_policy:
profile_settings.retry_policy = ProfileRetryPolicy()
profile_settings.retry_policy.retry = True
profile_settings.retry_policy.duration = args.retry_duration
profile_settings.retry = ProfileRetry(
duration=args.retry_duration,
)


def max_duration(v: str) -> int:
Expand Down
14 changes: 8 additions & 6 deletions src/dstack/_internal/core/models/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[st
return parse_duration(v)


# Deprecated in favor of ProfileRetry().
# TODO: Remove when no longer referenced.
class ProfileRetryPolicy(CoreModel):
retry: Annotated[bool, Field(description="Whether to retry the run on failure or not")] = False
duration: Annotated[
Expand All @@ -95,14 +97,15 @@ class RetryEvent(str, Enum):

class ProfileRetry(CoreModel):
on_events: Annotated[
List[RetryEvent],
Optional[List[RetryEvent]],
Field(
description=(
"The list of events that should be handled with retry."
" Supported events are `no-capacity`, `interruption`, and `error`"
" Supported events are `no-capacity`, `interruption`, and `error`."
" Omit to retry on all events"
)
),
]
] = None
duration: Annotated[
Optional[Union[int, str]],
Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"),
Expand All @@ -112,7 +115,8 @@ class ProfileRetry(CoreModel):

@root_validator
def _validate_fields(cls, values):
if "on_events" in values and len(values["on_events"]) == 0:
on_events = values.get("on_events", None)
if on_events is not None and len(values["on_events"]) == 0:
raise ValueError("`on_events` cannot be empty")
return values

Expand Down Expand Up @@ -249,8 +253,6 @@ class ProfileParams(CoreModel):
description="Deprecated in favor of `idle_duration`",
),
] = None
# The policy for resubmitting the run. Deprecated in favor of `retry`
retry_policy: Optional[ProfileRetryPolicy] = None

_validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)(
parse_max_duration
Expand Down
4 changes: 0 additions & 4 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
CreationPolicy,
Profile,
ProfileParams,
ProfileRetryPolicy,
RetryEvent,
SpotPolicy,
UtilizationPolicy,
Expand Down Expand Up @@ -204,9 +203,6 @@ class JobSpec(CoreModel):
retry: Optional[Retry]
volumes: Optional[List[MountPoint]] = None
ssh_key: Optional[JobSSHKey] = None
# For backward compatibility with 0.18.x when retry_policy was required.
# TODO: remove in 0.19
retry_policy: ProfileRetryPolicy = ProfileRetryPolicy(retry=False)
working_dir: Optional[str]


Expand Down
19 changes: 7 additions & 12 deletions src/dstack/_internal/core/services/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,7 @@
def get_retry(profile: Profile) -> Optional[Retry]:
profile_retry = profile.retry
if profile_retry is None:
# Handle retry_policy before retry was introduced
# TODO: Remove once retry_policy no longer supported
profile_retry_policy = profile.retry_policy
if profile_retry_policy is None:
return None
if not profile_retry_policy.retry:
return None
duration = profile_retry_policy.duration or DEFAULT_RETRY_DURATION
return Retry(
on_events=[RetryEvent.NO_CAPACITY, RetryEvent.INTERRUPTION, RetryEvent.ERROR],
duration=duration,
)
return None
if isinstance(profile_retry, bool):
if profile_retry:
return Retry(
Expand All @@ -32,6 +21,12 @@ def get_retry(profile: Profile) -> Optional[Retry]:
)
return None
profile_retry = profile_retry.copy()
if profile_retry.on_events is None:
profile_retry.on_events = [
RetryEvent.NO_CAPACITY,
RetryEvent.INTERRUPTION,
RetryEvent.ERROR,
]
if profile_retry.duration is None:
profile_retry.duration = DEFAULT_RETRY_DURATION
return Retry.parse_obj(profile_retry)
Expand Down
1 change: 0 additions & 1 deletion src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,6 @@ def get_plan(
reservation=reservation,
spot_policy=spot_policy,
retry=None,
retry_policy=retry_policy,
utilization_policy=utilization_policy,
max_duration=max_duration,
stop_duration=stop_duration,
Expand Down
12 changes: 6 additions & 6 deletions src/tests/_internal/cli/services/configurators/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
apply_profile_args,
register_profile_args,
)
from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy, SpotPolicy
from dstack._internal.core.models.profiles import Profile, ProfileRetry, SpotPolicy


class TestProfileArgs:
Expand Down Expand Up @@ -51,21 +51,21 @@ def test_spot_policy_on_demand(self):
assert profile.dict() == modified.dict()

def test_retry(self):
profile = Profile(name="test")
profile.retry_policy = ProfileRetryPolicy(retry=True)
profile = Profile(name="test", retry=None)
modified, _ = apply_args(profile, ["--retry"])
profile.retry = True
assert profile.dict() == modified.dict()

def test_no_retry(self):
profile = Profile(name="test", retry_policy=ProfileRetryPolicy(retry=True, duration=3600))
profile = Profile(name="test", retry=None)
modified, _ = apply_args(profile, ["--no-retry"])
profile.retry_policy.retry = False
profile.retry = False
assert profile.dict() == modified.dict()

def test_retry_duration(self):
profile = Profile(name="test")
modified, _ = apply_args(profile, ["--retry-duration", "1h"])
profile.retry_policy = ProfileRetryPolicy(retry=True, duration=3600)
profile.retry = ProfileRetry(on_events=None, duration="1h")
assert profile.dict() == modified.dict()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
InstanceType,
Resources,
)
from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy
from dstack._internal.core.models.profiles import Profile
from dstack._internal.core.models.runs import (
JobProvisioningData,
JobStatus,
Expand Down Expand Up @@ -372,7 +372,6 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession):
repo_id=repo.name,
profile=Profile(
name="default",
retry_policy=ProfileRetryPolicy(retry=True, duration=3600),
),
),
)
Expand Down
2 changes: 0 additions & 2 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
"instance_types": None,
"spot_policy": None,
"retry": None,
"retry_policy": None,
"max_duration": None,
"stop_duration": None,
"max_price": None,
Expand Down Expand Up @@ -482,7 +481,6 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
"instance_types": None,
"spot_policy": None,
"retry": None,
"retry_policy": None,
"max_duration": None,
"stop_duration": None,
"max_price": None,
Expand Down
6 changes: 0 additions & 6 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def get_dev_env_run_plan_dict(
"stop_duration": None,
"max_price": None,
"retry": None,
"retry_policy": None,
"spot_policy": "spot",
"idle_duration": None,
"termination_idle_time": 300,
Expand All @@ -141,7 +140,6 @@ def get_dev_env_run_plan_dict(
"max_price": None,
"name": "string",
"retry": None,
"retry_policy": None,
"spot_policy": "spot",
"idle_duration": None,
"termination_idle_time": 300,
Expand Down Expand Up @@ -206,7 +204,6 @@ def get_dev_env_run_plan_dict(
"retry": None,
"volumes": volumes,
"ssh_key": None,
"retry_policy": {"retry": False, "duration": None},
"working_dir": ".",
},
"offers": [json.loads(o.json()) for o in offers],
Expand Down Expand Up @@ -277,7 +274,6 @@ def get_dev_env_run_dict(
"stop_duration": None,
"max_price": None,
"retry": None,
"retry_policy": None,
"spot_policy": "spot",
"idle_duration": None,
"termination_idle_time": 300,
Expand All @@ -298,7 +294,6 @@ def get_dev_env_run_dict(
"max_price": None,
"name": "string",
"retry": None,
"retry_policy": None,
"spot_policy": "spot",
"idle_duration": None,
"termination_idle_time": 300,
Expand Down Expand Up @@ -363,7 +358,6 @@ def get_dev_env_run_dict(
"retry": None,
"volumes": [],
"ssh_key": None,
"retry_policy": {"retry": False, "duration": None},
"working_dir": ".",
},
"job_submissions": [
Expand Down