Skip to content

External events and other improvements #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
"type": "python",
"request": "launch",
"program": "${file}",
"cwd": "${fileDirname}",
"purpose": [
"debug-test"
],
"env": {
// pytest-cov breaks debugging, so we have to disable it during debug sessions
"PYTEST_ADDOPTS": "--no-cov"
"PYTEST_ADDOPTS": "--no-cov",
"PYTHONPATH": "${workspaceFolder}"
},
"console": "integratedTerminal",
"justMyCode": false
Expand Down
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,37 @@ def orchestrator(ctx: task.OrchestrationContext, _):

You can find the full sample [here](./examples/fanout_fanin.py).

### Human interaction and durable timers

An orchestration can wait for a user-defined event, such as a human approval event, before proceding to the next step. In addition, the orchestration can create a timer with an arbitrary duration that triggers some alternate action if the external event hasn't been received:

```python
def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order):
"""Orchestrator function that represents a purchase order workflow"""
# Orders under $1000 are auto-approved
if order.Cost < 1000:
return "Auto-approved"

# Orders of $1000 or more require manager approval
yield ctx.call_activity(send_approval_request, input=order)

# Approvals must be received within 24 hours or they will be canceled.
approval_event = ctx.wait_for_external_event("approval_received")
timeout_event = ctx.create_timer(timedelta(hours=24))
winner = yield task.when_any([approval_event, timeout_event])
if winner == timeout_event:
return "Canceled"

# The order was approved
ctx.call_activity(place_order, input=order)
approval_details = approval_event.get_result()
return f"Approved by '{approval_details.approver}'"
```

As an aside, you'll also notice that the example orchestration above works with custom business objects. Support for custom business objects includes support for custom classes, custom data classes, and named tuples. Serialization and deserialization of these objects is handled automatically by the SDK.

You can find the full sample [here](./examples/human_interaction.py).

## Getting Started

### Prerequisites
Expand Down
46 changes: 38 additions & 8 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import TypeVar
from typing import Any, TypeVar

import grpc
import simplejson as json
from google.protobuf import wrappers_pb2

import durabletask.internal.helpers as helpers
Expand Down Expand Up @@ -46,14 +45,38 @@ class OrchestrationState:
serialized_input: str | None
serialized_output: str | None
serialized_custom_status: str | None
failure_details: pb.TaskFailureDetails | None
failure_details: task.FailureDetails | None

def raise_if_failed(self):
if self.failure_details is not None:
raise OrchestrationFailedError(
f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}",
self.failure_details)


class OrchestrationFailedError(Exception):
def __init__(self, message: str, failure_details: task.FailureDetails):
super().__init__(message)
self._failure_details = failure_details

@property
def failure_details(self):
return self._failure_details


def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> OrchestrationState | None:
if not res.exists:
return None

state = res.orchestrationState

failure_details = None
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
failure_details = task.FailureDetails(
state.failureDetails.errorMessage,
state.failureDetails.errorType,
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)

return OrchestrationState(
instance_id,
state.name,
Expand All @@ -63,7 +86,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Or
state.input.value if not helpers.is_empty(state.input) else None,
state.output.value if not helpers.is_empty(state.output) else None,
state.customStatus.value if not helpers.is_empty(state.customStatus) else None,
state.failureDetails if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '' else None)
failure_details)


class TaskHubGrpcClient:
Expand All @@ -86,7 +109,7 @@ def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOu
req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=json.dumps(input)) if input else None,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None)

self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
Expand Down Expand Up @@ -128,6 +151,16 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
else:
raise

def raise_orchestration_event(self, instance_id: str, event_name: str, *,
data: Any | None = None):
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
self._stub.RaiseEvent(req)

def terminate_orchestration(self):
pass

Expand All @@ -136,6 +169,3 @@ def suspend_orchestration(self):

def resume_orchestration(self):
pass

def raise_orchestration_event(self):
pass
16 changes: 10 additions & 6 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

import traceback
from datetime import datetime
from typing import Any

import simplejson as json
from google.protobuf import timestamp_pb2, wrappers_pb2

import durabletask.internal.orchestrator_service_pb2 as pb
Expand Down Expand Up @@ -117,6 +115,14 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails:
)


def new_event_raised_event(name: str, encoded_input: str | None = None) -> pb.HistoryEvent:
return pb.HistoryEvent(
eventId=-1,
timestamp=timestamp_pb2.Timestamp(),
eventRaised=pb.EventRaisedEvent(name=name, input=get_string_value(encoded_input))
)


def get_string_value(val: str | None) -> wrappers_pb2.StringValue | None:
if val is None:
return None
Expand Down Expand Up @@ -146,8 +152,7 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction
return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp))


def new_schedule_task_action(id: int, name: str, input: Any) -> pb.OrchestratorAction:
encoded_input = json.dumps(input) if input is not None else None
def new_schedule_task_action(id: int, name: str, encoded_input: str | None) -> pb.OrchestratorAction:
return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction(
name=name,
input=get_string_value(encoded_input)
Expand All @@ -164,8 +169,7 @@ def new_create_sub_orchestration_action(
id: int,
name: str,
instance_id: str | None,
input: Any) -> pb.OrchestratorAction:
encoded_input = json.dumps(input) if input is not None else None
encoded_input: str | None) -> pb.OrchestratorAction:
return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction(
name=name,
instanceId=instance_id,
Expand Down
54 changes: 54 additions & 0 deletions durabletask/internal/shared.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import dataclasses
import json
import logging
from types import SimpleNamespace
from typing import Any, Dict

import grpc

# Field name used to indicate that an object was automatically serialized
# and should be deserialized as a SimpleNamespace
AUTO_SERIALIZED = "__durabletask_autoobject__"


def get_default_host_address() -> str:
return "localhost:4001"
Expand Down Expand Up @@ -35,3 +43,49 @@ def get_logger(
datefmt='%Y-%m-%d %H:%M:%S')
log_handler.setFormatter(log_formatter)
return logger


def to_json(obj):
return json.dumps(obj, cls=InternalJSONEncoder)


def from_json(json_str):
return json.loads(json_str, cls=InternalJSONDecoder)


class InternalJSONEncoder(json.JSONEncoder):
"""JSON encoder that supports serializing specific Python types."""

def encode(self, obj: Any) -> str:
# if the object is a namedtuple, convert it to a dict with the AUTO_SERIALIZED key added
if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_asdict"):
d = obj._asdict() # type: ignore
d[AUTO_SERIALIZED] = True
obj = d
return super().encode(obj)

def default(self, obj):
if dataclasses.is_dataclass(obj):
# Dataclasses are not serializable by default, so we convert them to a dict and mark them for
# automatic deserialization by the receiver
d = dataclasses.asdict(obj)
d[AUTO_SERIALIZED] = True
return d
elif isinstance(obj, SimpleNamespace):
# Most commonly used for serializing custom objects that were previously serialized using our encoder
d = vars(obj)
d[AUTO_SERIALIZED] = True
return d
# This will typically raise a TypeError
return json.JSONEncoder.default(self, obj)


class InternalJSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.dict_to_object, *args, **kwargs)

def dict_to_object(self, d: Dict[str, Any]):
# If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace
if d.pop(AUTO_SERIALIZED, False):
return SimpleNamespace(**d)
return d
54 changes: 42 additions & 12 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Callable, Generator, Generic, List, TypeVar

import durabletask.internal.helpers as pbh
Expand Down Expand Up @@ -70,13 +70,13 @@ def is_replaying(self) -> bool:
pass

@abstractmethod
def create_timer(self, fire_at: datetime) -> Task:
def create_timer(self, fire_at: datetime | timedelta) -> Task:
"""Create a Timer Task to fire after at the specified deadline.

Parameters
----------
fire_at: datetime.datetime
The time for the timer to trigger
fire_at: datetime.datetime | datetime.timedelta
The time for the timer to trigger or a time delta from now.

Returns
-------
Expand Down Expand Up @@ -129,12 +129,27 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
"""
pass

# TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
# not received within the specified timeout. This requires support for task cancellation.
@abstractmethod
def wait_for_external_event(self, name: str) -> Task:
"""Wait asynchronously for an event to be raised with the name `name`.

class TaskFailedError(Exception):
"""Exception type for all orchestration task failures."""
Parameters
----------
name : str
The event name of the event that the task is waiting for.

Returns
-------
Task[TOutput]
A Durable Task that completes when the event is received.
"""
pass


class FailureDetails:
def __init__(self, message: str, error_type: str, stack_trace: str | None):
super().__init__(message)
self._message = message
self._error_type = error_type
self._stack_trace = stack_trace
Expand All @@ -152,6 +167,21 @@ def stack_trace(self) -> str | None:
return self._stack_trace


class TaskFailedError(Exception):
"""Exception type for all orchestration task failures."""

def __init__(self, message: str, details: pb.TaskFailureDetails):
super().__init__(message)
self._details = FailureDetails(
details.errorMessage,
details.errorType,
details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)

@property
def details(self) -> FailureDetails:
return self._details


class NonDeterminismError(Exception):
pass

Expand Down Expand Up @@ -208,6 +238,8 @@ def __init__(self, tasks: List[Task]):
self._failed_tasks = 0
for task in tasks:
task._parent = self
if task.is_complete:
self.on_child_completed(task)

def get_tasks(self) -> List[Task]:
return self._tasks
Expand All @@ -230,13 +262,10 @@ def complete(self, result: T):
if self._parent is not None:
self._parent.on_child_completed(self)

def fail(self, details: pb.TaskFailureDetails):
def fail(self, message: str, details: pb.TaskFailureDetails):
if self._is_complete:
raise ValueError('The task has already completed.')
self._exception = TaskFailedError(
details.errorMessage,
details.errorType,
details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)
self._exception = TaskFailedError(message, details)
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
Expand Down Expand Up @@ -278,6 +307,7 @@ def __init__(self, tasks: List[Task]):
super().__init__(tasks)

def on_child_completed(self, task: Task):
# The first task to complete is the result of the WhenAnyTask.
if not self.is_complete:
self._is_complete = True
self._result = task
Expand Down
Loading