Skip to content

Commit

Permalink
Add RetryAfterError
Browse files Browse the repository at this point in the history
  • Loading branch information
goodoldneon committed May 24, 2024
1 parent 220f1c8 commit a19376d
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 5 deletions.
3 changes: 2 additions & 1 deletion inngest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from ._internal.client_lib import Inngest
from ._internal.errors import NonRetriableError, StepError
from ._internal.errors import NonRetriableError, RetryAfterError, StepError
from ._internal.event_lib import Event
from ._internal.execution import Output
from ._internal.function import Context, Function
Expand Down Expand Up @@ -36,6 +36,7 @@
"NonRetriableError",
"Output",
"RateLimit",
"RetryAfterError",
"Step",
"StepError",
"StepSync",
Expand Down
5 changes: 5 additions & 0 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def from_call_result(
if call_res.is_retriable is False:
headers[const.HeaderKey.NO_RETRY.value] = "true"

if call_res.retry_after is not None:
headers[
const.HeaderKey.RETRY_AFTER.value
] = transforms.to_iso_utc(call_res.retry_after)

return cls(
body=transforms.prep_body(d),
headers=headers,
Expand Down
2 changes: 2 additions & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ErrorCode(enum.Enum):
OUTPUT_UNSERIALIZABLE = "output_unserializable"
QUERY_PARAM_MISSING = "query_param_missing"
REGISTRATION_FAILED = "registration_failed"
RETRY_AFTER_ERROR = "retry_after_error"
SERVER_KIND_MISMATCH = "server_kind_mismatch"
SIGNING_KEY_UNSPECIFIED = "signing_key_unspecified"
SIG_VERIFICATION_FAILED = "sig_verification_failed"
Expand All @@ -69,6 +70,7 @@ class HeaderKey(enum.Enum):
EXPECTED_SERVER_KIND = "X-Inngest-Expected-Server-Kind"
FRAMEWORK = "X-Inngest-Framework"
NO_RETRY = "X-Inngest-No-Retry"
RETRY_AFTER = "Retry-After"
SDK = "X-Inngest-SDK"
SERVER_KIND = "X-Inngest-Server-Kind"
SERVER_TIMING = "Server-Timing"
Expand Down
25 changes: 25 additions & 0 deletions inngest/_internal/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import typing

import pydantic
Expand Down Expand Up @@ -136,6 +137,30 @@ def __init__(
self.cause = cause


class RetryAfterError(Error):
"""
Raise this retry after a time duration or datetime.
"""

code = const.ErrorCode.RETRY_AFTER_ERROR

def __init__(
self,
message: typing.Optional[str],
retry_after: typing.Union[int, datetime.timedelta, datetime.datetime],
) -> None:
super().__init__(message)

if isinstance(retry_after, int):
retry_after = datetime.datetime.now() + datetime.timedelta(
milliseconds=retry_after
)
elif isinstance(retry_after, datetime.timedelta):
retry_after = datetime.datetime.now() + retry_after

self.retry_after: datetime.datetime = retry_after


class StepError(Error):
"""
Wraps a userland error. This is necessary because the Executor sends
Expand Down
7 changes: 7 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import enum
import typing

Expand Down Expand Up @@ -38,6 +39,7 @@ class CallError(types.BaseModel):
message: str
name: str
original_error: object = pydantic.Field(exclude=True)
retry_after: typing.Optional[datetime.datetime]
stack: typing.Optional[str]
step_id: typing.Optional[str]

Expand All @@ -60,12 +62,17 @@ def from_error(
name = type(err).__name__
stack = transforms.get_traceback(err)

retry_after = None
if isinstance(err, errors.RetryAfterError):
retry_after = err.retry_after

return cls(
code=code,
is_retriable=is_retriable,
message=message,
name=name,
original_error=err,
retry_after=retry_after,
stack=stack,
step_id=step_id,
)
Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/step_lib/step_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ async def run(
op=execution.Opcode.STEP_RUN,
)
)
except errors.NonRetriableError as err:
# NonRetriableErrors should bubble up to the function level
except (errors.NonRetriableError, errors.RetryAfterError) as err:
# Bubble up these error types to the function level
raise err
except Exception as err:
transforms.remove_first_traceback_frame(err)
Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/step_lib/step_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def run(
op=execution.Opcode.STEP_RUN,
)
)
except errors.NonRetriableError as err:
# NonRetriableErrors should bubble up to the function level
except (errors.NonRetriableError, errors.RetryAfterError) as err:
# Bubble up these error types to the function level
raise err
except Exception as err:
transforms.remove_first_traceback_frame(err)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_function/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
on_failure,
parallel_steps,
pydantic_output,
retry_after_error,
sleep_until,
step_callback_args,
step_callback_kwargs,
Expand Down Expand Up @@ -64,6 +65,7 @@
on_failure,
parallel_steps,
pydantic_output,
retry_after_error,
sleep_until,
step_callback_args,
step_callback_kwargs,
Expand Down
161 changes: 161 additions & 0 deletions tests/test_function/cases/retry_after_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import datetime
import typing

import inngest
import tests.helper

from . import base

_TEST_NAME = "retry_after_error"


class _State(base.BaseState):
fn_level_raise_1_time: typing.Optional[datetime.datetime] = None
fn_level_retry_1_time: typing.Optional[datetime.datetime] = None
fn_level_raise_2_time: typing.Optional[datetime.datetime] = None
fn_level_retry_2_time: typing.Optional[datetime.datetime] = None
step_level_raise_time: typing.Optional[datetime.datetime] = None
step_level_retry_time: typing.Optional[datetime.datetime] = None


def create(
client: inngest.Inngest,
framework: str,
is_sync: bool,
) -> base.Case:
test_name = base.create_test_name(_TEST_NAME, is_sync)
event_name = base.create_event_name(framework, test_name)
fn_id = base.create_fn_id(test_name)
state = _State()

@client.create_function(
fn_id=fn_id,
retries=2,
trigger=inngest.TriggerEvent(event=event_name),
)
def fn_sync(
ctx: inngest.Context,
step: inngest.StepSync,
) -> None:
state.run_id = ctx.run_id

if state.fn_level_raise_1_time is None:
# Raise a RetryAfterError and track what time we raised it
state.fn_level_raise_1_time = datetime.datetime.now()
raise inngest.RetryAfterError("fn-1", 1000)

if state.fn_level_retry_1_time is None:
# Track the time we retried
state.fn_level_retry_1_time = datetime.datetime.now()

def step_fn() -> None:
if state.step_level_raise_time is None:
# Raise a RetryAfterError and track what time we raised it
state.step_level_raise_time = datetime.datetime.now()
raise inngest.RetryAfterError(
"step", datetime.timedelta(seconds=1)
)

if state.step_level_retry_time is None:
# Track the time we retried
state.step_level_retry_time = datetime.datetime.now()

step.run("step_1", step_fn)

if state.fn_level_raise_2_time is None:
# Raise a RetryAfterError and track what time we raised it
state.fn_level_raise_2_time = datetime.datetime.now()
raise inngest.RetryAfterError(
"fn-2", datetime.datetime.now() + datetime.timedelta(seconds=1)
)

if state.fn_level_retry_2_time is None:
# Track the time we retried
state.fn_level_retry_2_time = datetime.datetime.now()

@client.create_function(
fn_id=fn_id,
retries=1,
trigger=inngest.TriggerEvent(event=event_name),
)
async def fn_async(
ctx: inngest.Context,
step: inngest.Step,
) -> None:
state.run_id = ctx.run_id

if state.fn_level_raise_1_time is None:
# Raise a RetryAfterError and track what time we raised it
state.fn_level_raise_1_time = datetime.datetime.now()
raise inngest.RetryAfterError("fn-1", 1000)

if state.fn_level_retry_1_time is None:
# Track the time we retried
state.fn_level_retry_1_time = datetime.datetime.now()

async def step_fn() -> None:
if state.step_level_raise_time is None:
# Raise a RetryAfterError and track what time we raised it
state.step_level_raise_time = datetime.datetime.now()
raise inngest.RetryAfterError(
"step", datetime.timedelta(seconds=1)
)

if state.step_level_retry_time is None:
# Track the time we retried
state.step_level_retry_time = datetime.datetime.now()

await step.run("step_1", step_fn)

if state.fn_level_raise_2_time is None:
# Raise a RetryAfterError and track what time we raised it
state.fn_level_raise_2_time = datetime.datetime.now()
raise inngest.RetryAfterError(
"fn-2", datetime.datetime.now() + datetime.timedelta(seconds=1)
)

if state.fn_level_retry_2_time is None:
# Track the time we retried
state.fn_level_retry_2_time = datetime.datetime.now()

def run_test(self: base.TestClass) -> None:
self.client.send_sync(inngest.Event(name=event_name))
run_id = state.wait_for_run_id()

tests.helper.client.wait_for_run_status(
run_id,
tests.helper.RunStatus.COMPLETED,
)

assert state.fn_level_raise_1_time is not None
assert state.fn_level_retry_1_time is not None
assert state.step_level_raise_time is not None
assert state.step_level_retry_time is not None
assert state.fn_level_raise_2_time is not None
assert state.fn_level_retry_2_time is not None

assert (
state.fn_level_retry_1_time - state.fn_level_raise_1_time
< datetime.timedelta(seconds=2)
)

assert (
state.step_level_retry_time - state.step_level_raise_time
< datetime.timedelta(seconds=2)
)

assert (
state.fn_level_retry_2_time - state.fn_level_raise_2_time
< datetime.timedelta(seconds=2)
)

if is_sync:
fn = fn_sync
else:
fn = fn_async

return base.Case(
fn=fn,
run_test=run_test,
name=test_name,
)

0 comments on commit a19376d

Please sign in to comment.