Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]:
pass

@abstractmethod
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None,
Expand Down
7 changes: 5 additions & 2 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,15 +1029,18 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLoc

def call_sub_orchestrator(
self,
orchestrator: task.Orchestrator[TInput, TOutput],
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
*,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[task.RetryPolicy] = None,
version: Optional[str] = None,
) -> task.Task[TOutput]:
id = self.next_sequence_number()
orchestrator_name = task.get_name(orchestrator)
if isinstance(orchestrator, str):
orchestrator_name = orchestrator
else:
orchestrator_name = task.get_name(orchestrator)
default_version = self._registry.versioning.default_version if self._registry.versioning else None
orchestrator_version = version if version else default_version
self.call_activity_function_helper(
Expand Down
28 changes: 28 additions & 0 deletions tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,34 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
assert activity_counter == 30


def test_sub_orchestrator_by_name():
sub_orchestrator_counter = 0

def orchestrator_child(ctx: task.OrchestrationContext, _):
nonlocal sub_orchestrator_counter
sub_orchestrator_counter += 1

def parent_orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator("orchestrator_child")

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_orchestrator(orchestrator_child)
w.add_orchestrator(parent_orchestrator)
w.start()

task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.failure_details is None
assert sub_orchestrator_counter == 1


def test_wait_for_multiple_external_events():
def orchestrator(ctx: task.OrchestrationContext, _):
a = yield ctx.wait_for_external_event('A')
Expand Down
26 changes: 26 additions & 0 deletions tests/durabletask/test_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,32 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
assert activity_counter == 30


def test_sub_orchestrator_by_name():
sub_orchestrator_counter = 0

def orchestrator_child(ctx: task.OrchestrationContext, _):
nonlocal sub_orchestrator_counter
sub_orchestrator_counter += 1

def parent_orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator("orchestrator_child")

# Start a worker, which will connect to the sidecar in a background thread
with worker.TaskHubGrpcWorker() as w:
w.add_orchestrator(orchestrator_child)
w.add_orchestrator(parent_orchestrator)
w.start()

task_hub_client = client.TaskHubGrpcClient()
id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.failure_details is None
assert sub_orchestrator_counter == 1


def test_wait_for_multiple_external_events():
def orchestrator(ctx: task.OrchestrationContext, _):
a = yield ctx.wait_for_external_event('A')
Expand Down
Loading