Skip to content

Commit

Permalink
StepFunctions: Fix TaskToken Creation Logic (#10707)
Browse files Browse the repository at this point in the history
  • Loading branch information
MEPalma committed Apr 23, 2024
1 parent ba58014 commit bcee501
Show file tree
Hide file tree
Showing 12 changed files with 590 additions and 15 deletions.
Expand Up @@ -19,12 +19,7 @@ def from_raw(cls, string_dollar: str, string_path_context_obj: str):
return cls(field=field, path_context_obj=path_context_obj)

def _eval_val(self, env: Environment) -> Any:
if self.path_context_obj.endswith("Task.Token"):
task_token = env.context_object_manager.update_task_token()
env.callback_pool_manager.add(task_token)
value = task_token
else:
value = JSONPathUtils.extract_json(
self.path_context_obj, env.context_object_manager.context_object
)
value = JSONPathUtils.extract_json(
self.path_context_obj, env.context_object_manager.context_object
)
return value
Expand Up @@ -140,6 +140,19 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
return self._get_callback_outcome_failure_event(env=env, ex=ex)
return super()._from_error(env=env, ex=ex)

def _eval_body(self, env: Environment) -> None:
# Generate a TaskToken uuid within the context object, if this task resources has a waitForTaskToken condition.
# This logic provisions a TaskToken callback uuid to support waitForTaskToken workflows as described in :
# https://docs.aws.amazon.com/step-functions/latest/dg/connect-to-resource.html#connect-wait-token
if self._is_condition() and self.resource.condition == ResourceCondition.WaitForTaskToken:
task_token = env.context_object_manager.update_task_token()
env.callback_pool_manager.add(task_token)

super()._eval_body(env=env)

# Ensure the TaskToken field is reset, as this is only available during waitForTaskToken task evaluations.
env.context_object_manager.context_object.pop("Task", None)

def _after_eval_execution(
self,
env: Environment,
Expand Down
Expand Up @@ -83,7 +83,3 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
if isinstance(ex, TimeoutError):
return self._get_timed_out_failure_event(env)
return super()._from_error(env=env, ex=ex)

def _eval_body(self, env: Environment) -> None:
super(StateTask, self)._eval_body(env=env)
env.context_object_manager.context_object["Task"] = None
@@ -1,4 +1,4 @@
from typing import Any, Final, Optional, TypedDict
from typing import Any, Final, NotRequired, Optional, TypedDict

from localstack.utils.strings import long_uid

Expand Down Expand Up @@ -41,7 +41,7 @@ class ContextObject(TypedDict):
Execution: Execution
State: Optional[State]
StateMachine: StateMachine
Task: Optional[Task] # Null if the Parameters field is outside a task state.
Task: NotRequired[Task] # Null if the Parameters field is outside a task state.
Map: Optional[Map] # Only available when processing a Map state.


Expand All @@ -60,3 +60,4 @@ def update_task_token(self) -> str:
class ContextObjectInitData(TypedDict):
Execution: Execution
StateMachine: StateMachine
Task: Optional[Task]
7 changes: 7 additions & 0 deletions localstack/services/stepfunctions/asl/eval/environment.py
Expand Up @@ -15,6 +15,7 @@
ContextObject,
ContextObjectInitData,
ContextObjectManager,
Task,
)
from localstack.services.stepfunctions.asl.eval.event.event_history import (
EventHistory,
Expand Down Expand Up @@ -70,12 +71,17 @@ def __init__(
self.aws_execution_details = aws_execution_details
self.callback_pool_manager = CallbackPoolManager(activity_store=activity_store)
self.map_run_record_pool_manager = MapRunRecordPoolManager()

self.context_object_manager = ContextObjectManager(
context_object=ContextObject(
Execution=context_object_init["Execution"],
StateMachine=context_object_init["StateMachine"],
)
)
task: Optional[Task] = context_object_init.get("Task")
if task is not None:
self.context_object_manager.context_object["Task"] = task

self.activity_store = activity_store

self._frames = list()
Expand All @@ -90,6 +96,7 @@ def as_frame_of(cls, env: Environment, event_history_frame_cache: EventHistoryCo
context_object_init = ContextObjectInitData(
Execution=env.context_object_manager.context_object["Execution"],
StateMachine=env.context_object_manager.context_object["StateMachine"],
Task=env.context_object_manager.context_object.get("Task"),
)
frame = cls(
aws_execution_details=env.aws_execution_details,
Expand Down
Expand Up @@ -33,6 +33,5 @@
"context_object.$": "$$",
}
}

}
}
Expand Up @@ -25,9 +25,15 @@ class CallbackTemplates(TemplateLoader):
SQS_WAIT_FOR_TASK_TOKEN: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token.json5"
)
SQS_WAIT_FOR_TASK_TOKEN_CALL_CHAIN: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token_call_chain.json5"
)
SQS_WAIT_FOR_TASK_TOKEN_WITH_TIMEOUT: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token_with_timeout.json5"
)
SQS_WAIT_FOR_TASK_TOKEN_NO_TOKEN_PARAMETER: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token_no_token_parameter.json5"
)
SQS_HEARTBEAT_SUCCESS_ON_TASK_TOKEN: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_hearbeat_success_on_task_token.json5"
)
@@ -0,0 +1,32 @@
{
"Comment": "SQS_WAIT_FOR_TASK_TOKEN_CALL_CHAIN",
"StartAt": "State1",
"States": {
"State1": {
"Type": "Task",
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
"Parameters": {
"QueueUrl.$": "$.QueueUrl",
"MessageBody": {
"Message.$": "$.Message",
"TaskToken.$": "$$.Task.Token"
}
},
ResultPath: "$.State1Output",
"Next": "State2"
},
"State2": {
"Type": "Task",
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
"Parameters": {
"QueueUrl.$": "$.QueueUrl",
"MessageBody": {
"Message.$": "$.Message",
"TaskToken.$": "$$.Task.Token"
}
},
ResultPath: "$.State2Output",
"End": true
}
}
}
@@ -0,0 +1,18 @@
{
"Comment": "SQS_WAIT_FOR_TASK_TOKEN_NO_TOKEN_PARAMETER",
"StartAt": "State1",
"States": {
"State1": {
"Type": "Task",
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
"TimeoutSeconds": 5,
"Parameters": {
"QueueUrl.$": "$.QueueUrl",
"MessageBody": {
"Context.$": "$",
}
},
"End": true
},
}
}
76 changes: 76 additions & 0 deletions tests/aws/services/stepfunctions/v2/callback/test_callback.py
Expand Up @@ -621,3 +621,79 @@ def _sqs_task_token_handler():
executionArn=execution_arn
)
sfn_snapshot.match(f"execution_history_{i}", execution_history)

@markers.aws.validated
def test_sqs_wait_for_task_token_call_chain(
self,
aws_client,
create_iam_role_for_sfn,
create_state_machine,
sqs_create_queue,
sqs_send_task_success_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..MessageId",
replacement="message_id",
replace_reference=True,
)
)
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..TaskToken",
replacement="task_token",
replace_reference=True,
)
)
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())

queue_name = f"queue-{short_uid()}"
queue_url = sqs_create_queue(QueueName=queue_name)
sfn_snapshot.add_transformer(RegexTransformer(queue_url, "sqs_queue_url"))
sfn_snapshot.add_transformer(RegexTransformer(queue_name, "sqs_queue_name"))

sqs_send_task_success_state_machine(queue_url)

template = CT.load_sfn_template(CT.SQS_WAIT_FOR_TASK_TOKEN_CALL_CHAIN)
definition = json.dumps(template)

exec_input = json.dumps({"QueueUrl": queue_url, "Message": "HelloWorld"})
create_and_record_execution(
aws_client.stepfunctions,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
definition,
exec_input,
)

@markers.aws.validated
def test_sqs_wait_for_task_token_no_token_parameter(
self,
aws_client,
create_iam_role_for_sfn,
create_state_machine,
sqs_create_queue,
sqs_send_task_success_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())

queue_name = f"queue-{short_uid()}"
queue_url = sqs_create_queue(QueueName=queue_name)
sfn_snapshot.add_transformer(RegexTransformer(queue_url, "sqs_queue_url"))
sfn_snapshot.add_transformer(RegexTransformer(queue_name, "sqs_queue_name"))

template = CT.load_sfn_template(CT.SQS_WAIT_FOR_TASK_TOKEN_NO_TOKEN_PARAMETER)
definition = json.dumps(template)

exec_input = json.dumps({"QueueUrl": queue_url})
create_and_record_execution(
aws_client.stepfunctions,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
definition,
exec_input,
)

0 comments on commit bcee501

Please sign in to comment.