Skip to content

Commit

Permalink
StepFunctions: Fix Heartbeat Callback Locking (#10663)
Browse files Browse the repository at this point in the history
  • Loading branch information
MEPalma committed Apr 18, 2024
1 parent d858f00 commit 0a255e2
Show file tree
Hide file tree
Showing 16 changed files with 1,484 additions and 112 deletions.
12 changes: 9 additions & 3 deletions localstack/services/stepfunctions/asl/eval/callback/callback.py
@@ -1,6 +1,6 @@
import abc
from collections import OrderedDict
from threading import Event
from threading import Event, Lock
from typing import Final, Optional

from localstack.aws.api.stepfunctions import ActivityDoesNotExist, Arn
Expand Down Expand Up @@ -51,19 +51,25 @@ class CallbackConsumerLeft(CallbackConsumerError):


class HeartbeatEndpoint:
_mutex: Final[Lock]
_next_heartbeat_event: Final[Event]
_heartbeat_seconds: Final[int]

def __init__(self, heartbeat_seconds: int):
self._mutex = Lock()
self._next_heartbeat_event = Event()
self._heartbeat_seconds = heartbeat_seconds

def clear_and_wait(self) -> bool:
self._next_heartbeat_event.clear()
with self._mutex:
if self._next_heartbeat_event.is_set():
self._next_heartbeat_event.clear()
return True
return self._next_heartbeat_event.wait(timeout=self._heartbeat_seconds)

def notify(self):
self._next_heartbeat_event.set()
with self._mutex:
self._next_heartbeat_event.set()


class HeartbeatTimeoutError(TimeoutError):
Expand Down
9 changes: 9 additions & 0 deletions localstack/testing/snapshots/transformer_utility.py
Expand Up @@ -629,6 +629,15 @@ def sfn_map_run_arn(map_run_arn: LongArn, index: int) -> list[RegexTransformer]:
RegexTransformer(arn_parts[1], f"<MapRunArnPart1_{index}idx>"),
]

@staticmethod
def sfn_sqs_integration():
return [
*TransformerUtility.sqs_api(),
# Transform MD5OfMessageBody value bindings as in StepFunctions these are not deterministic
# about the input message.
TransformerUtility.key_value("MD5OfMessageBody"),
]

@staticmethod
def stepfunctions_api():
return [
Expand Down
2 changes: 1 addition & 1 deletion tests/aws/services/stepfunctions/utils.py
Expand Up @@ -332,7 +332,7 @@ def create_and_record_events(
definition,
execution_input,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sqs_api())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformers_list(
[
JsonpathTransformer(
Expand Down
170 changes: 160 additions & 10 deletions tests/aws/services/stepfunctions/v2/callback/test_callback.py
Expand Up @@ -3,6 +3,7 @@

from localstack_snapshot.snapshots.transformer import JsonpathTransformer, RegexTransformer

from localstack.services.stepfunctions.asl.eval.count_down_latch import CountDownLatch
from localstack.testing.pytest import markers
from localstack.utils.strings import short_uid
from localstack.utils.sync import retry
Expand All @@ -13,10 +14,37 @@
from tests.aws.services.stepfunctions.templates.timeouts.timeout_templates import (
TimeoutTemplates as TT,
)
from tests.aws.services.stepfunctions.utils import create, create_and_record_execution
from tests.aws.services.stepfunctions.utils import (
await_execution_terminated,
create,
create_and_record_execution,
)
from tests.aws.test_notifications import PUBLICATION_RETRIES, PUBLICATION_TIMEOUT


def _handle_sqs_task_token_with_heartbeats_and_success(aws_client, queue_url) -> None:
# Handle the state machine task token published in the sqs queue, by submitting 10 heartbeat
# notifications and a task success notification. Snapshot the response of each call.

# Read the expected sqs message and extract the body.
def _get_message_body():
receive_message_response = aws_client.sqs.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=1
)
return receive_message_response["Messages"][0]["Body"]

message_body_str = retry(_get_message_body, retries=100, sleep=1)
message_body = json.loads(message_body_str)

# Send the heartbeat notifications.
task_token = message_body["TaskToken"]
for i in range(10):
aws_client.stepfunctions.send_task_heartbeat(taskToken=task_token)

# Send the task success notification.
aws_client.stepfunctions.send_task_success(taskToken=task_token, output=message_body_str)


@markers.snapshot.skip_snapshot_verify(
paths=[
"$..loggingConfiguration",
Expand All @@ -26,7 +54,6 @@
]
)
class TestCallback:
@markers.snapshot.skip_snapshot_verify(paths=["$..MD5OfMessageBody"])
@markers.aws.needs_fixing
def test_sqs_wait_for_task_token(
self,
Expand All @@ -37,7 +64,7 @@ def test_sqs_wait_for_task_token(
sqs_send_task_success_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sqs_api())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
Expand Down Expand Up @@ -67,7 +94,6 @@ def test_sqs_wait_for_task_token(
exec_input,
)

@markers.snapshot.skip_snapshot_verify(paths=["$..MD5OfMessageBody"])
@markers.aws.needs_fixing
def test_sqs_wait_for_task_token_timeout(
self,
Expand All @@ -78,7 +104,7 @@ def test_sqs_wait_for_task_token_timeout(
sqs_send_task_success_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sqs_api())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
Expand Down Expand Up @@ -106,7 +132,6 @@ def test_sqs_wait_for_task_token_timeout(
exec_input,
)

@markers.snapshot.skip_snapshot_verify(paths=["$..MD5OfMessageBody"])
@markers.aws.needs_fixing
def test_sqs_failure_in_wait_for_task_token(
self,
Expand All @@ -117,7 +142,7 @@ def test_sqs_failure_in_wait_for_task_token(
sqs_send_task_failure_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sqs_api())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
Expand Down Expand Up @@ -147,7 +172,6 @@ def test_sqs_failure_in_wait_for_task_token(
exec_input,
)

@markers.snapshot.skip_snapshot_verify(paths=["$..MD5OfMessageBody"])
@markers.aws.needs_fixing
def test_sqs_wait_for_task_tok_with_heartbeat(
self,
Expand All @@ -158,7 +182,7 @@ def test_sqs_wait_for_task_tok_with_heartbeat(
sqs_send_heartbeat_and_task_success_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sqs_api())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
Expand Down Expand Up @@ -208,7 +232,7 @@ def test_sns_publish_wait_for_task_token(
replace_reference=True,
)
)
sfn_snapshot.add_transformer(sfn_snapshot.transform.sqs_api())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(sfn_snapshot.transform.sns_api())

topic_info = sns_create_topic()
Expand Down Expand Up @@ -471,3 +495,129 @@ def test_start_execution_sync_delegate_timeout(
definition,
exec_input,
)

@markers.aws.validated
def test_multiple_heartbeat_notifications(
self,
aws_client,
create_iam_role_for_sfn,
create_state_machine,
sqs_create_queue,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
replacement="task_token",
replace_reference=True,
)
)

queue_name = f"queue-{short_uid()}"
queue_url = sqs_create_queue(QueueName=queue_name)
sfn_snapshot.add_transformer(RegexTransformer(queue_url, "sqs_queue_url"))
sfn_snapshot.add_transformer(RegexTransformer(queue_name, "sqs_queue_name"))

task_token_consumer_thread = threading.Thread(
target=_handle_sqs_task_token_with_heartbeats_and_success, args=(aws_client, queue_url)
)
task_token_consumer_thread.start()

template = CT.load_sfn_template(
TT.SERVICE_SQS_SEND_AND_WAIT_FOR_TASK_TOKEN_WITH_HEARTBEAT_PATH
)
definition = json.dumps(template)

exec_input = json.dumps(
{"QueueUrl": queue_url, "Message": "txt", "HeartbeatSecondsPath": 120}
)
create_and_record_execution(
aws_client.stepfunctions,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
definition,
exec_input,
)

task_token_consumer_thread.join(timeout=300)

@markers.aws.validated
def test_multiple_executions_and_heartbeat_notifications(
self,
aws_client,
create_iam_role_for_sfn,
create_state_machine,
sqs_create_queue,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
replacement="a_task_token",
replace_reference=False,
)
)
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..MessageId",
replacement="a_message_id",
replace_reference=False,
)
)

queue_name = f"queue-{short_uid()}"
queue_url = sqs_create_queue(QueueName=queue_name)
sfn_snapshot.add_transformer(RegexTransformer(queue_url, "sqs_queue_url"))
sfn_snapshot.add_transformer(RegexTransformer(queue_name, "sqs_queue_name"))

sfn_role_arn = create_iam_role_for_sfn()

template = CT.load_sfn_template(
TT.SERVICE_SQS_SEND_AND_WAIT_FOR_TASK_TOKEN_WITH_HEARTBEAT_PATH
)
definition = json.dumps(template)

creation_response = create_state_machine(
name=f"state_machine_{short_uid()}", definition=definition, roleArn=sfn_role_arn
)
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sm_create_arn(creation_response, 0))
state_machine_arn = creation_response["stateMachineArn"]

exec_input = json.dumps(
{"QueueUrl": queue_url, "Message": "txt", "HeartbeatSecondsPath": 120}
)

# Launch multiple execution of the same state machine.
execution_count = 6
execution_arns = list()
for _ in range(execution_count):
execution_arn = aws_client.stepfunctions.start_execution(
stateMachineArn=state_machine_arn, input=exec_input
)["executionArn"]
execution_arns.append(execution_arn)

# Launch one sqs task token handler per each execution, and await for all the terminate handling the task.
task_token_handler_latch = CountDownLatch(execution_count)

def _sqs_task_token_handler():
_handle_sqs_task_token_with_heartbeats_and_success(aws_client, queue_url)
task_token_handler_latch.count_down()

for _ in range(execution_count):
inner_handler_thread = threading.Thread(target=_sqs_task_token_handler, args=())
inner_handler_thread.start()

task_token_handler_latch.wait()

# For each execution, await terminate and record the event executions.
for i, execution_arn in enumerate(execution_arns):
await_execution_terminated(
stepfunctions_client=aws_client.stepfunctions, execution_arn=execution_arn
)
execution_history = aws_client.stepfunctions.get_execution_history(
executionArn=execution_arn
)
sfn_snapshot.match(f"execution_history_{i}", execution_history)

0 comments on commit 0a255e2

Please sign in to comment.