Skip to content

Commit

Permalink
Add retry-with-delay to task metadata
Browse files Browse the repository at this point in the history
Signed-off-by: mucahitkantepe <mucahitkantepe@gmail.com>
  • Loading branch information
mucahitkantepe committed Apr 22, 2024
1 parent b0cd1e8 commit e1eb20b
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 9 deletions.
11 changes: 9 additions & 2 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class TaskMetadata(object):
deprecated (str): Can be used to provide a warning message for deprecated task. Absence or empty str indicates
that the task is active and not deprecated
retries (int): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times.
retry_delay (Optional[Union[datetime.timedelta, int]]): specifies the delay in between retries of this task execution (approximately).
timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task
should be executed for. The execution will be terminated if the runtime exceeds the given timeout
(approximately)
Expand All @@ -130,6 +131,7 @@ class TaskMetadata(object):
interruptible: Optional[bool] = None
deprecated: str = ""
retries: int = 0
retry_delay: Optional[Union[datetime.timedelta, int]] = None
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None

Expand All @@ -138,7 +140,12 @@ def __post_init__(self):
if isinstance(self.timeout, int):
self.timeout = datetime.timedelta(seconds=self.timeout)
elif not isinstance(self.timeout, datetime.timedelta):
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
raise ValueError("retry_delay should be duration represented as either a datetime.timedelta or int seconds")
if self.retry_delay:
if isinstance(self.retry_delay, int):
self.retry_delay = datetime.timedelta(seconds=self.retry_delay)
elif not isinstance(self.retry_delay, datetime.timedelta):
raise ValueError("retry_delay should be duration represented as either a datetime.timedelta or int seconds")
if self.cache and not self.cache_version:
raise ValueError("Caching is enabled ``cache=True`` but ``cache_version`` is not set.")
if self.cache_serialize and not self.cache:
Expand All @@ -150,7 +157,7 @@ def __post_init__(self):

@property
def retry_strategy(self) -> _literal_models.RetryStrategy:
return _literal_models.RetryStrategy(self.retries)
return _literal_models.RetryStrategy(self.retries, self.retry_delay)

def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
"""
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def task(
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
retries: int = ...,
retry_delay: Union[_datetime.timedelta, int] = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
Expand Down Expand Up @@ -130,6 +131,7 @@ def task(
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
retries: int = ...,
retry_delay: Union[_datetime.timedelta, int] = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
Expand Down Expand Up @@ -167,6 +169,7 @@ def task(
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
retries: int = 0,
retry_delay: Union[_datetime.timedelta, int] = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
timeout: Union[_datetime.timedelta, int] = 0,
Expand Down Expand Up @@ -239,6 +242,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str:
this version if the function body/business logic has changed, but the signature hasn't.
:param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache.
:param retries: Number of times to retry this task during a workflow execution.
:param retry_delay: specifies the delay in between retries of this task execution (approximately).
:param interruptible: [Optional] Boolean that indicates that this task can be interrupted and/or scheduled on nodes
with lower QoS guarantees. This will directly reduce the `$`/`execution cost` associated,
at the cost of performance penalties due to potential interruptions. Requires additional
Expand Down Expand Up @@ -322,6 +326,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
cache_version=cache_version,
cache_ignore_input_vars=cache_ignore_input_vars,
retries=retries,
retry_delay=retry_delay,
interruptible=interruptible,
deprecated=deprecated,
timeout=timeout,
Expand Down
17 changes: 14 additions & 3 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@


class RetryStrategy(_common.FlyteIdlEntity):
def __init__(self, retries):
def __init__(self, retries, retry_delay):
"""
:param int retries: Number of retries to attempt on recoverable failures. If retries is 0, then
only one attempt will be made.
:param retry_delay: The delay in between retries of this task execution. If it is 0, then
the task will be retried immediately.
"""
self._retries = retries
self._retry_delay = retry_delay

@property
def retries(self):
Expand All @@ -30,19 +33,27 @@ def retries(self):
"""
return self._retries

@property
def retry_delay(self):
"""
The delay in between retries of this task execution. If it is 0, then the task will be retried immediately.
:rtype: Union[datetime.timedelta, int]
"""
return self._retries

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.literals_pb2.RetryStrategy
"""
return _literals_pb2.RetryStrategy(retries=self.retries)
return _literals_pb2.RetryStrategy(retries=self.retries, retry_delay=self.retry_delay)

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.literals_pb2.RetryStrategy pb2_object:
:rtype: RetryStrategy
"""
return cls(retries=pb2_object.retries)
return cls(retries=pb2_object.retries, retry_delay=pb2_object.retry_delay)


class Primitive(_common.FlyteIdlEntity):
Expand Down
5 changes: 2 additions & 3 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def __init__(
:param datetime.timedelta timeout: The amount of time to wait before timing out. This includes queuing and
scheduler latency.
:param bool interruptible: Whether or not the task is interruptible.
:param flytekit.models.literals.RetryStrategy retries: Retry strategy for this task. 0 retries means only
try once.
:param flytekit.models.literals.RetryStrategy retries: Retry strategy for this task.
:param Text discovery_version: This is the version used to create a logical version for data in the cache.
This is only used when `discoverable` is true. Data is considered discoverable if: the inputs to a given
task are the same and the discovery_version is also the same.
Expand Down Expand Up @@ -231,7 +230,7 @@ def runtime(self):
@property
def retries(self):
"""
Retry strategy for this task. 0 retries means only try once.
Retry strategy for this task.
:rtype: flytekit.models.literals.RetryStrategy
"""
return self._retries
Expand Down
3 changes: 2 additions & 1 deletion tests/flytekit/unit/models/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_task_metadata():
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
literals.RetryStrategy(3, timedelta(minutes=10)),
True,
"0.1.1b0",
"This is deprecated!",
Expand All @@ -78,6 +78,7 @@ def test_task_metadata():

assert obj.discoverable is True
assert obj.retries.retries == 3
assert obj.retries.retry_delay == timedelta(minutes=10)
assert obj.interruptible is True
assert obj.timeout == timedelta(days=1)
assert obj.runtime.flavor == "python"
Expand Down

0 comments on commit e1eb20b

Please sign in to comment.