diff --git a/durabletask/task.py b/durabletask/task.py index 2f49371..3570838 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -201,7 +201,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]: pass @abstractmethod - def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, + def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, retry_policy: Optional[RetryPolicy] = None, diff --git a/durabletask/worker.py b/durabletask/worker.py index 09f6559..a244927 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1029,7 +1029,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLoc def call_sub_orchestrator( self, - orchestrator: task.Orchestrator[TInput, TOutput], + orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -1037,7 +1037,10 @@ def call_sub_orchestrator( version: Optional[str] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() - orchestrator_name = task.get_name(orchestrator) + if isinstance(orchestrator, str): + orchestrator_name = orchestrator + else: + orchestrator_name = task.get_name(orchestrator) default_version = self._registry.versioning.default_version if self._registry.versioning else None orchestrator_version = version if version else default_version self.call_activity_function_helper( diff --git a/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py b/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py index 6155733..1f47810 100644 --- a/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py +++ b/tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py @@ -175,6 +175,34 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): assert activity_counter == 30 +def test_sub_orchestrator_by_name(): + sub_orchestrator_counter = 0 + + def orchestrator_child(ctx: task.OrchestrationContext, _): + nonlocal sub_orchestrator_counter + sub_orchestrator_counter += 1 + + def parent_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator("orchestrator_child") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator_child) + w.add_orchestrator(parent_orchestrator) + w.start() + + task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert sub_orchestrator_counter == 1 + + def test_wait_for_multiple_external_events(): def orchestrator(ctx: task.OrchestrationContext, _): a = yield ctx.wait_for_external_event('A') diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 63f14c5..6dbc8e1 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -162,6 +162,32 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): assert activity_counter == 30 +def test_sub_orchestrator_by_name(): + sub_orchestrator_counter = 0 + + def orchestrator_child(ctx: task.OrchestrationContext, _): + nonlocal sub_orchestrator_counter + sub_orchestrator_counter += 1 + + def parent_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator("orchestrator_child") + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator_child) + w.add_orchestrator(parent_orchestrator) + w.start() + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert sub_orchestrator_counter == 1 + + def test_wait_for_multiple_external_events(): def orchestrator(ctx: task.OrchestrationContext, _): a = yield ctx.wait_for_external_event('A')