Skip to content

Commit

Permalink
Make the task token key optional (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
cariad committed May 9, 2024
1 parent 05ee485 commit d612e34
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 22 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(event: Any, context: Any) -> Sum | None:
return handle_event(
event,
perform_sum,
"task_token",
task_token_key="task_token",
)

def perform_sum(inputs: Inputs, metadata: Metadata) -> Sum:
Expand All @@ -41,6 +41,8 @@ def perform_sum(inputs: Inputs, metadata: Metadata) -> Sum:

The `lambdaq.handle_event` function reads the invocation event, a reference to a message handler, and the key of the task token injected by the state machine.

If the task token key is omitted then the work will still be performed, but the state won't be reported back to Step Functions. This would be used, for example, for functions invoked by SQS queues that don't need to report back any status.

The message handler--`perform_sum` in this example--reads a strongly-typed message and returns a strongly-typed response.

## How does it work?
Expand Down
37 changes: 27 additions & 10 deletions lambdaq/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,27 @@


class EventHandler(Generic[TMessage, TResponse]):
"""
Event handler.
Arguments:
event: Function event.
handler: Reference to a message-handling function.
session: Optional Boto3 session. A new session will be created by
default.
task_token_key: Key of the Step Functions task token in each message.
Step Functions state will not be published if this is omitted.
"""

def __init__(
self,
event: Any,
handler: MessageHandler[TMessage, TResponse],
task_token_key: str,
session: Session | None = None,
task_token_key: str | None = None,
) -> None:
self.event = event
self.handler = handler
Expand Down Expand Up @@ -78,8 +93,8 @@ def handle_messages(
)

body = loads(record["body"])
token = body[self.task_token_key] if self.task_token_key else None
message = cast(TMessage, body)
token = str(message[self.task_token_key]) # type: ignore

try:
response = self.handler(
Expand All @@ -88,16 +103,18 @@ def handle_messages(
)

except Exception as ex:
self._send_task_state(
token,
exception=ex,
)
if token:
self._send_task_state(
token,
exception=ex,
)

continue

self._send_task_state(
token,
response=response,
)
if token:
self._send_task_state(
token,
response=response,
)

return None
9 changes: 5 additions & 4 deletions lambdaq/handle_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
def handle_event(
event: Any,
handler: MessageHandler[TMessage, TResponse],
task_token_key: str,
session: Session | None = None,
task_token_key: str | None = None,
) -> TResponse | None:
"""
Handles a Lambda function event.
Expand All @@ -20,11 +20,12 @@ def handle_event(
handler: Reference to a message-handling function.
task_token_key: Key of the task token in each message.
session: Optional Boto3 session. A new session will be created by
default.
task_token_key: Key of the Step Functions task token in each message.
Step Functions state will not be submitted if this is omitted.
Returns:
Message handling response if the function was invoked directly, or
`None` if the function was invoked by an SQS queue.
Expand All @@ -33,8 +34,8 @@ def handle_event(
event_handler = EventHandler(
event,
handler,
task_token_key,
session=session,
task_token_key=task_token_key,
)

return event_handler.handle_messages()
72 changes: 65 additions & 7 deletions tests/test_handle_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Message(TypedDict):
magic_word: str


class QueuedMessage(Message):
class QueuedMessage(Message, total=False):
task_token: str


Expand Down Expand Up @@ -43,8 +43,8 @@ def test_direct_event(
response = handle_event(
event,
handle_message,
"task_token",
session=session,
task_token_key="task_token",
)

send_task_failure.assert_not_called()
Expand All @@ -68,8 +68,8 @@ def test_direct_event_with_failure(
_ = handle_event(
event,
handle_message,
"task_token",
session=session,
task_token_key="task_token",
)

assert str(ex.value) == "Clowns are too cool for nerds"
Expand Down Expand Up @@ -99,8 +99,8 @@ def test_single_enqueued_event(
response = handle_event(
event,
handle_message,
"task_token",
session=session,
task_token_key="task_token",
)

send_task_failure.assert_not_called()
Expand All @@ -116,6 +116,64 @@ def test_single_enqueued_event(
assert response is None


def test_single_enqueued_event_without_token(
send_task_failure: Mock,
send_task_success: Mock,
session: Mock,
) -> None:
event = {
"Records": [
{
"body": dumps(
QueuedMessage(
magic_word="octopus",
),
)
}
],
}

response = handle_event(
event,
handle_message,
session=session,
)

send_task_failure.assert_not_called()
send_task_success.assert_not_called()

assert response is None


def test_single_enqueued_event_without_token_with_failure(
send_task_failure: Mock,
send_task_success: Mock,
session: Mock,
) -> None:
event = {
"Records": [
{
"body": dumps(
QueuedMessage(
magic_word="clown",
),
)
}
],
}

response = handle_event(
event,
handle_message,
session=session,
)

send_task_failure.assert_not_called()
send_task_success.assert_not_called()

assert response is None


def test_multiple_enqueued_events(
send_task_failure: Mock,
send_task_success: Mock,
Expand Down Expand Up @@ -145,8 +203,8 @@ def test_multiple_enqueued_events(
response = handle_event(
event,
handle_message,
"task_token",
session=session,
task_token_key="task_token",
)

send_task_failure.assert_not_called()
Expand Down Expand Up @@ -222,8 +280,8 @@ def test_multiple_enqueued_events_with_failures(
response = handle_event(
event,
handle_message,
"task_token",
session=session,
task_token_key="task_token",
)

assert send_task_failure.call_count == 2
Expand Down Expand Up @@ -295,8 +353,8 @@ def test_state_machine_time_out(
response = handle_event(
event,
handle_message,
"task_token",
session=session,
task_token_key="task_token",
)

send_task_failure.assert_not_called()
Expand Down

0 comments on commit d612e34

Please sign in to comment.