Skip to content

Commit

Permalink
StepFunctions: Add Support for MaxConcurrencyPath (#10705)
Browse files Browse the repository at this point in the history
  • Loading branch information
MEPalma committed Apr 23, 2024
1 parent 536cc28 commit cc5f570
Show file tree
Hide file tree
Showing 17 changed files with 2,777 additions and 2,006 deletions.
2 changes: 2 additions & 0 deletions localstack/services/stepfunctions/asl/antlr/ASLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ ITERATOR: '"Iterator"';

ITEMSELECTOR: '"ItemSelector"';

MAXCONCURRENCYPATH: '"MaxConcurrencyPath"';

MAXCONCURRENCY: '"MaxConcurrency"';

RESOURCE: '"Resource"';
Expand Down
4 changes: 4 additions & 0 deletions localstack/services/stepfunctions/asl/antlr/ASLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ state_stmt:
| item_selector_decl
| item_reader_decl
| max_concurrency_decl
| max_concurrency_path_decl
| timeout_seconds_decl
| timeout_seconds_path_decl
| heartbeat_seconds_decl
Expand Down Expand Up @@ -110,6 +111,8 @@ items_path_decl: ITEMSPATH COLON keyword_or_string;

max_concurrency_decl: MAXCONCURRENCY COLON INT;

max_concurrency_path_decl: MAXCONCURRENCYPATH COLON STRINGPATH;

parameters_decl: PARAMETERS COLON payload_tmpl_decl;

timeout_seconds_decl: TIMEOUTSECONDS COLON INT;
Expand Down Expand Up @@ -418,6 +421,7 @@ keyword_or_string:
| ITERATOR
| ITEMSELECTOR
| MAXCONCURRENCY
| MAXCONCURRENCYPATH
| RESOURCE
| INPUTPATH
| OUTPUTPATH
Expand Down
1,899 changes: 955 additions & 944 deletions localstack/services/stepfunctions/asl/antlr/runtime/ASLLexer.py

Large diffs are not rendered by default.

2,150 changes: 1,113 additions & 1,037 deletions localstack/services/stepfunctions/asl/antlr/runtime/ASLParser.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ def exitMax_concurrency_decl(self, ctx:ASLParser.Max_concurrency_declContext):
pass


# Enter a parse tree produced by ASLParser#max_concurrency_path_decl.
def enterMax_concurrency_path_decl(self, ctx:ASLParser.Max_concurrency_path_declContext):
pass

# Exit a parse tree produced by ASLParser#max_concurrency_path_decl.
def exitMax_concurrency_path_decl(self, ctx:ASLParser.Max_concurrency_path_declContext):
pass


# Enter a parse tree produced by ASLParser#parameters_decl.
def enterParameters_decl(self, ctx:ASLParser.Parameters_declContext):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def visitMax_concurrency_decl(self, ctx:ASLParser.Max_concurrency_declContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#max_concurrency_path_decl.
def visitMax_concurrency_path_decl(self, ctx:ASLParser.Max_concurrency_path_declContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#parameters_decl.
def visitParameters_decl(self, ctx:ASLParser.Parameters_declContext):
return self.visitChildren(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
JobPool,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.max_concurrency import (
MaxConcurrency,
DEFAULT_MAX_CONCURRENCY_VALUE,
)
from localstack.services.stepfunctions.asl.component.states import States
from localstack.services.stepfunctions.asl.eval.environment import Environment
Expand Down Expand Up @@ -123,7 +123,9 @@ def _map_run(self, env: Environment) -> None:
# TODO: add watch on map_run_record update event and adjust the number of running workers accordingly.
max_concurrency = self._map_run_record.max_concurrency
workers_number = (
len(input_items) if max_concurrency == MaxConcurrency.DEFAULT else max_concurrency
len(input_items)
if max_concurrency == DEFAULT_MAX_CONCURRENCY_VALUE
else max_concurrency
)
self._set_active_workers(workers_number=workers_number, env=env)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
JobPool,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.max_concurrency import (
MaxConcurrency,
DEFAULT_MAX_CONCURRENCY_VALUE,
)
from localstack.services.stepfunctions.asl.component.states import States
from localstack.services.stepfunctions.asl.eval.environment import Environment
Expand Down Expand Up @@ -99,7 +99,9 @@ def _eval_body(self, env: Environment) -> None:
)

number_of_workers = (
len(input_items) if max_concurrency == MaxConcurrency.DEFAULT else max_concurrency
len(input_items)
if max_concurrency == DEFAULT_MAX_CONCURRENCY_VALUE
else max_concurrency
)
for _ in range(number_of_workers):
self._launch_worker(env=env)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,80 @@
import abc
from typing import Final

from localstack.services.stepfunctions.asl.component.component import Component
from localstack.aws.api.stepfunctions import ExecutionFailedEventDetails, HistoryEventType
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
FailureEventException,
)
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name import (
StatesErrorName,
)
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.services.stepfunctions.asl.utils.json_path import JSONPathUtils

DEFAULT_MAX_CONCURRENCY_VALUE: Final[int] = 0 # No limit.

class MaxConcurrency(Component):
DEFAULT: Final[int] = 0 # No limit.

def __init__(self, num: int = DEFAULT):
self.num: Final[int] = num
class MaxConcurrencyDecl(EvalComponent, abc.ABC):
@abc.abstractmethod
def _eval_max_concurrency(self, env: Environment) -> int: ...

def _eval_body(self, env: Environment) -> None:
max_concurrency_value = self._eval_max_concurrency(env=env)
env.stack.append(max_concurrency_value)


class MaxConcurrency(MaxConcurrencyDecl):
max_concurrency_value: Final[int]

def __init__(self, num: int = DEFAULT_MAX_CONCURRENCY_VALUE):
super().__init__()
self.max_concurrency_value = num

def _eval_max_concurrency(self, env: Environment) -> int:
return self.max_concurrency_value


class MaxConcurrencyPath(MaxConcurrency):
max_concurrency_path: Final[str]

def __init__(self, max_concurrency_path: str):
super().__init__()
self.max_concurrency_path = max_concurrency_path

def _eval_max_concurrency(self, env: Environment) -> int:
inp = env.stack[-1]
max_concurrency_value = JSONPathUtils.extract_json(self.max_concurrency_path, inp)

error_cause = None
if not isinstance(max_concurrency_value, int):
value_str = (
to_json_str(max_concurrency_value)
if not isinstance(max_concurrency_value, str)
else max_concurrency_value
)
error_cause = f'The MaxConcurrencyPath field refers to value "{value_str}" which is not a valid integer: {self.max_concurrency_path}'
elif max_concurrency_value < 0:
error_cause = f"Expected non-negative integer for MaxConcurrency, got '{max_concurrency_value}' instead."

if error_cause is not None:
raise FailureEventException(
failure_event=FailureEvent(
env=env,
error_name=StatesErrorName(typ=StatesErrorNameType.StatesRuntime),
event_type=HistoryEventType.ExecutionFailed,
event_details=EventDetails(
executionFailedEventDetails=ExecutionFailedEventDetails(
error=StatesErrorNameType.StatesRuntime.to_name(), cause=error_cause
)
),
)
)

return max_concurrency_value
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.max_concurrency import (
MaxConcurrency,
MaxConcurrencyDecl,
)
from localstack.services.stepfunctions.asl.component.state.state_props import StateProps
from localstack.services.stepfunctions.asl.eval.environment import Environment
Expand All @@ -66,7 +67,7 @@ class StateMap(ExecutionState):
item_reader: Optional[ItemReader]
item_selector: Optional[ItemSelector]
parameters: Optional[Parameters]
max_concurrency: MaxConcurrency
max_concurrency_decl: MaxConcurrencyDecl
result_path: Optional[ResultPath]
result_selector: ResultSelector
retry: Optional[RetryDecl]
Expand All @@ -84,7 +85,7 @@ def from_state_props(self, state_props: StateProps) -> None:
self.item_reader = state_props.get(ItemReader)
self.item_selector = state_props.get(ItemSelector)
self.parameters = state_props.get(Parameters)
self.max_concurrency = state_props.get(MaxConcurrency) or MaxConcurrency()
self.max_concurrency_decl = state_props.get(MaxConcurrencyDecl) or MaxConcurrency()
self.result_path = state_props.get(ResultPath) or ResultPath(
result_path_src=ResultPath.DEFAULT_PATH
)
Expand Down Expand Up @@ -112,6 +113,8 @@ def from_state_props(self, state_props: StateProps) -> None:
raise ValueError(f"Unknown value for IteratorDecl '{iteration_decl}'.")

def _eval_execution(self, env: Environment) -> None:
max_concurrency_num = env.stack.pop()

self.items_path.eval(env)
if self.item_reader:
env.event_history.add_event(
Expand All @@ -135,15 +138,15 @@ def _eval_execution(self, env: Environment) -> None:
if isinstance(self.iteration_component, InlineIterator):
eval_input = InlineIteratorEvalInput(
state_name=self.name,
max_concurrency=self.max_concurrency.num,
max_concurrency=max_concurrency_num,
input_items=input_items,
parameters=self.parameters,
item_selector=self.item_selector,
)
elif isinstance(self.iteration_component, DistributedIterator):
eval_input = DistributedIteratorEvalInput(
state_name=self.name,
max_concurrency=self.max_concurrency.num,
max_concurrency=max_concurrency_num,
input_items=input_items,
parameters=self.parameters,
item_selector=self.item_selector,
Expand All @@ -152,15 +155,15 @@ def _eval_execution(self, env: Environment) -> None:
elif isinstance(self.iteration_component, InlineItemProcessor):
eval_input = InlineItemProcessorEvalInput(
state_name=self.name,
max_concurrency=self.max_concurrency.num,
max_concurrency=max_concurrency_num,
input_items=input_items,
item_selector=self.item_selector,
parameters=self.parameters,
)
elif isinstance(self.iteration_component, DistributedItemProcessor):
eval_input = DistributedItemProcessorEvalInput(
state_name=self.name,
max_concurrency=self.max_concurrency.num,
max_concurrency=max_concurrency_num,
input_items=input_items,
item_reader=self.item_reader,
item_selector=self.item_selector,
Expand All @@ -184,6 +187,9 @@ def _eval_state(self, env: Environment) -> None:
# Initialise the retry counter for execution states.
env.context_object_manager.context_object["State"]["RetryCount"] = 0

# Evaluate state level properties.
self.max_concurrency_decl.eval(env=env)

# Attempt to evaluate the state's logic through until it's successful, caught, or retries have run out.
while True:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.item_reader.reader_config.max_items_decl import (
MaxItemsDecl,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.max_concurrency import (
MaxConcurrencyDecl,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
Resource,
)
Expand All @@ -17,17 +20,19 @@
)
from localstack.services.stepfunctions.asl.parse.typed_props import TypedProps

UNIQUE_SUBINSTANCES: Final[set[type]] = {
Resource,
WaitFunction,
Timeout,
Heartbeat,
MaxItemsDecl,
MaxConcurrencyDecl,
ErrorDecl,
CauseDecl,
}


class StateProps(TypedProps):
_UNIQUE_SUBINSTANCES: Final[set[type]] = {
Resource,
WaitFunction,
Timeout,
Heartbeat,
MaxItemsDecl,
ErrorDecl,
CauseDecl,
}
name: str

def add(self, instance: Any) -> None:
Expand All @@ -40,7 +45,7 @@ def add(self, instance: Any) -> None:
raise ValueError(f"Next redefines End, from '{self.get(End)}' to '{instance}'.")

# Subclasses
for typ in self._UNIQUE_SUBINSTANCES:
for typ in UNIQUE_SUBINSTANCES:
if issubclass(inst_type, typ):
super()._add(typ, instance)
return
Expand Down
5 changes: 5 additions & 0 deletions localstack/services/stepfunctions/asl/parse/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.max_concurrency import (
MaxConcurrency,
MaxConcurrencyPath,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.mode import (
Mode,
Expand Down Expand Up @@ -546,6 +547,10 @@ def visitMax_concurrency_decl(
) -> MaxConcurrency:
return MaxConcurrency(num=int(ctx.INT().getText()))

def visitMax_concurrency_path_decl(self, ctx: ASLParser.Max_concurrency_path_declContext):
max_concurrency_path: str = self._inner_string_of(parse_tree=ctx.STRINGPATH())
return MaxConcurrencyPath(max_concurrency_path=max_concurrency_path)

def visitMode_decl(self, ctx: ASLParser.Mode_declContext) -> Mode:
mode_type: int = self.visit(ctx.mode_type())
return Mode(mode_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class ScenariosTemplate(TemplateLoader):
_THIS_FOLDER, "statemachines/catch_states_runtime.json5"
)
PARALLEL_STATE: Final[str] = os.path.join(_THIS_FOLDER, "statemachines/parallel_state.json5")
MAX_CONCURRENCY: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/max_concurrency_path.json5"
)
PARALLEL_STATE_FAIL: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/parallel_state_fail.json5"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"Comment": "MAX_CONCURRENCY_PATH",
"StartAt": "MapState",
"States": {
"MapState": {
"Type": "Map",
"ItemsPath": "$.Values",
"MaxConcurrencyPath": "$.MaxConcurrencyValue",
"ItemProcessor": {
"ProcessorConfig": {
"Mode": "INLINE"
},
"StartAt": "HandleItem",
"States": {
"HandleItem": {
"Type": "Pass",
"End": true
}
}
},
"Next": "Final",
},
"Final": {
"Type": "Pass",
"End": true
}
}
}

0 comments on commit cc5f570

Please sign in to comment.