Skip to content

Commit

Permalink
StepFunctions: Improve Handling of Empty SendTaskFailure Calls (#10750)
Browse files Browse the repository at this point in the history
  • Loading branch information
MEPalma committed May 1, 2024
1 parent 5bfa1de commit 0da9865
Show file tree
Hide file tree
Showing 11 changed files with 656 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def _extract_error_cause(failure_event: FailureEvent) -> dict:
f"Internal Error: invalid event details declaration in FailureEvent: '{failure_event}'."
)
spec_event_details: dict = list(failure_event.event_details.values())[0]
error = spec_event_details["error"]
cause = spec_event_details.get("cause") or ""
# If no cause or error fields are given, AWS binds an empty string; otherwise it attaches the value.
error = spec_event_details.get("error", "")
cause = spec_event_details.get("cause", "")
# Stepfunctions renames these fields to capital in this scenario.
return {
"Error": error,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Final
from typing import Final, Optional

from localstack.services.stepfunctions.asl.component.common.error_name.error_name import ErrorName

ILLEGAL_CUSTOM_ERROR_PREFIX: Final[str] = "States."


class CustomErrorName(ErrorName):
"""
States MAY report errors with other names, which MUST NOT begin with the prefix "States.".
"""

_ILLEGAL_PREFIX: Final[str] = "States."

def __init__(self, error_name: str):
if error_name.startswith(CustomErrorName._ILLEGAL_PREFIX):
def __init__(self, error_name: Optional[str]):
if error_name is not None and error_name.startswith(ILLEGAL_CUSTOM_ERROR_PREFIX):
raise ValueError(
f"Custom Error Names MUST NOT begin with the prefix 'States.', got '{error_name}'."
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import abc
from typing import Final
from typing import Final, Optional

from localstack.services.stepfunctions.asl.component.component import Component


class ErrorName(Component, abc.ABC):
def __init__(self, error_name: str):
self.error_name: Final[str] = error_name
error_name: Final[Optional[str]]

def matches(self, error_name: str) -> bool:
def __init__(self, error_name: Optional[str]):
self.error_name = error_name

def matches(self, error_name: Optional[str]) -> bool:
return self.error_name == error_name

def __eq__(self, other):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import json
import time
from typing import Optional

from localstack.aws.api.stepfunctions import (
HistoryEventExecutionDataDetails,
Expand Down Expand Up @@ -120,10 +121,10 @@ def _get_callback_outcome_failure_event(
self, env: Environment, ex: CallbackOutcomeFailureError
) -> FailureEvent:
callback_outcome_failure: CallbackOutcomeFailure = ex.callback_outcome_failure
error: str = callback_outcome_failure.error
error: Optional[str] = callback_outcome_failure.error
return FailureEvent(
env=env,
error_name=CustomErrorName(error_name=callback_outcome_failure.error),
error_name=CustomErrorName(error_name=error),
event_type=HistoryEventType.TaskFailed,
event_details=EventDetails(
taskFailedEventDetails=TaskFailedEventDetails(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def __init__(self, callback_id: CallbackId, output: str):


class CallbackOutcomeFailure(CallbackOutcome):
error: Final[str]
cause: Final[str]
error: Final[Optional[str]]
cause: Final[Optional[str]]

def __init__(self, callback_id: CallbackId, error: str, cause: str):
def __init__(self, callback_id: CallbackId, error: Optional[str], cause: Optional[str]):
super().__init__(callback_id=callback_id)
self.error = error
self.cause = cause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ class CallbackTemplates(TemplateLoader):
SQS_HEARTBEAT_SUCCESS_ON_TASK_TOKEN: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_hearbeat_success_on_task_token.json5"
)
SQS_PARALLEL_WAIT_FOR_TASK_TOKEN: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_parallel_wait_for_task_token.json5"
)
SQS_WAIT_FOR_TASK_TOKEN_CATCH: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token_catch.json5"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"Comment": "SQS_PARALLEL_WAIT_FOR_TASK_TOKEN",
"StartAt": "ParallelJob",
"States": {
"ParallelJob": {
"Type": "Parallel",
"Branches": [
{
"StartAt": "SendMessageWithWait",
"States": {
"SendMessageWithWait": {
"Type": "Task",
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
"Parameters": {
"QueueUrl.$": "$.QueueUrl",
"MessageBody": {
"Context.$": "$",
"TaskToken.$": "$$.Task.Token"
}
},
"End": true
},
}
}
],
"Catch": [
{
"ErrorEquals": [
"States.Runtime"
],
"ResultPath": "$.states_runtime_error",
"Next": "CaughtRuntimeError"
},
{
"ErrorEquals": [
"States.ALL"
],
"ResultPath": "$.states_all_error",
"Next": "CaughtStatesALL"
}
],
"End": true
},
"CaughtRuntimeError": {
"Type": "Pass",
"End": true
},
"CaughtStatesALL": {
"Type": "Pass",
"End": true
},
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"Comment": "SQS_WAIT_FOR_TASK_TOKEN_CATCH",
"StartAt": "SendMessageWithWait",
"States": {
"SendMessageWithWait": {
"Type": "Task",
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
"Parameters": {
"QueueUrl.$": "$.QueueUrl",
"MessageBody": {
"Context.$": "$",
"TaskToken.$": "$$.Task.Token"
}
},
"Catch": [
{
"ErrorEquals": [
"States.Runtime"
],
"ResultPath": "$.states_runtime_error",
"Next": "CaughtRuntimeError"
},
{
"ErrorEquals": [
"States.ALL"
],
"ResultPath": "$.states_all_error",
"Next": "CaughtStatesALL"
}
],
"End": true
},
"CaughtRuntimeError": {
"Type": "Pass",
"End": true
},
"CaughtStatesALL": {
"Type": "Pass",
"End": true
}
}
}
75 changes: 75 additions & 0 deletions tests/aws/services/stepfunctions/v2/callback/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import threading

import pytest
from localstack_snapshot.snapshots.transformer import JsonpathTransformer, RegexTransformer

from localstack.services.stepfunctions.asl.eval.count_down_latch import CountDownLatch
from localstack.testing.aws.util import is_aws_cloud
from localstack.testing.pytest import markers
from localstack.utils.strings import short_uid
from localstack.utils.sync import retry
Expand Down Expand Up @@ -697,3 +699,76 @@ def test_sqs_wait_for_task_token_no_token_parameter(
definition,
exec_input,
)

@markers.aws.validated
@pytest.mark.parametrize(
"template",
[CT.SQS_PARALLEL_WAIT_FOR_TASK_TOKEN, CT.SQS_WAIT_FOR_TASK_TOKEN_CATCH],
ids=["SQS_PARALLEL_WAIT_FOR_TASK_TOKEN", "SQS_WAIT_FOR_TASK_TOKEN_CATCH"],
)
def test_sqs_failure_in_wait_for_task_tok_no_error_field(
self,
aws_client,
create_iam_role_for_sfn,
create_state_machine,
sqs_create_queue,
sfn_snapshot,
template,
request,
):
if (
not is_aws_cloud()
and request.node.name
== "test_sqs_failure_in_wait_for_task_tok_no_error_field[SQS_PARALLEL_WAIT_FOR_TASK_TOKEN]"
):
# TODO: The conditions in which TaskStateAborted error events are logged requires further investigations.
# These appear to be logged for Task state workers but only within Parallel states. The behaviour with
# other 'Abort' errors should also be investigated.
pytest.skip("Investigate occurrence logic of 'TaskStateAborted' errors")

sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
replacement="task_token",
replace_reference=True,
)
)

queue_name = f"queue-{short_uid()}"
queue_url = sqs_create_queue(QueueName=queue_name)
sfn_snapshot.add_transformer(RegexTransformer(queue_url, "<sqs_queue_url>"))
sfn_snapshot.add_transformer(RegexTransformer(queue_name, "<sqs_queue_name>"))

def _empty_send_task_failure_on_sqs_message():
def _get_message_body():
receive_message_response = aws_client.sqs.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=1
)
return receive_message_response["Messages"][0]["Body"]

message_body_str = retry(_get_message_body, retries=60, sleep=1)
message_body = json.loads(message_body_str)
task_token = message_body["TaskToken"]
aws_client.stepfunctions.send_task_failure(taskToken=task_token)

thread_send_task_failure = threading.Thread(
target=_empty_send_task_failure_on_sqs_message,
args=(),
name="Thread_empty_send_task_failure_on_sqs_message",
)
thread_send_task_failure.daemon = True
thread_send_task_failure.start()

template = CT.load_sfn_template(template)
definition = json.dumps(template)

exec_input = json.dumps({"QueueUrl": queue_url, "Message": "test_message_txt"})
create_and_record_execution(
aws_client.stepfunctions,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
definition,
exec_input,
)

0 comments on commit 0da9865

Please sign in to comment.