diff --git a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java index 3cbb0ced..3e1124ce 100644 --- a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java +++ b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java @@ -180,7 +180,8 @@ public Task completedTask(V value) { @Override public Task> allOf(List> tasks) { Helpers.throwIfArgumentNull(tasks, "tasks"); - + // Mark all retrialbe task that they are inside a Compound task returned by allOf + markRetriableTask(tasks); CompletableFuture[] futures = tasks.stream() .map(t -> t.future) .toArray((IntFunction[]>) CompletableFuture[]::new); @@ -224,6 +225,15 @@ public Task> allOf(List> tasks) { return new CompoundTask<>(tasks, future); } + private void markRetriableTask(List> tasks) { + for (Task task : tasks) { + if (task instanceof RetriableTask) { + RetriableTask retriableTask = (RetriableTask) task; + retriableTask.markInCompoundTask(); + } + } + } + @Override public Task> anyOf(List> tasks) { Helpers.throwIfArgumentNull(tasks, "tasks"); @@ -1032,6 +1042,7 @@ private class RetriableTask extends CompletableTask { private Instant startTime; private int attemptNumber; private Task childTask; + private boolean isInCompoundTask; public RetriableTask(TaskOrchestrationContext context, TaskFactory taskFactory, RetryPolicy policy) { @@ -1111,24 +1122,45 @@ public void tryRetry(TaskFailedException ex) { @Override public V await() { this.init(); - // when awaiting the first child task, we will continue iterating over the history until a result is found - // for that task. If the result is an exception, the child task will invoke "handleChildException" on this - // object, which awaits a timer, *re-sets the current child task to correspond to a retry of this task*, - // and then awaits that child. - // This logic continues until either the operation succeeds, or are our retry quota is met. - // At that point, we break the `await()` on the child task. - // Therefore, once we return from the following `await`, - // we just need to await again on the *current* child task to obtain the result of this task + /** + * when awaiting the first child task, we will continue iterating over the history until a result is found + * for that task. If the result is an exception, the child task will invoke "handleChildException" on this + * object, which awaits a timer, *re-sets the current child task to correspond to a retry of this task*, + * and then awaits that child. + * This logic continues until either the operation succeeds, or are our retry quota is met. + * At that point, we break the `await()` on the child task. + * Therefore, once we return from the following `await`, + * we just need to await again on the *current* child task to obtain the result of this task + */ try{ this.getChildTask().await(); } catch (OrchestratorBlockedException ex) { throw ex; - } catch (Exception ignored) { + } catch (Exception ignore) { // ignore the exception from previous child tasks. - // Only needs to return result from the last child task, which is on next line. + // Only needs to return result from the last child task, which is on next try-catch block. } - // Always return the last child task result. - return this.getChildTask().await(); + /** + * Need to have two try-catch block, the first one we ignore the exception from the previous child task + * Once completed the future, method stack frame will pop out till `this.getChildTask().await();` + * in the first try-catch block. The child task now is previous child task we are awaiting on, for any exception + * throw by previous child task we should ignore, we only care about the last child task which is in next + * try-catch block. + */ + try { + // Always return the last child task result. + return this.getChildTask().await(); + } catch (Exception exception) { + /** + * If this RetriableTask is not configured as part of an allOf method (CompoundTask), + * it throws an exception, marking the orchestration as failed. However, if this RetriableTask + * is configured within an allOf method (CompoundTask), any exceptions are ignored. + * This approach ensures that when awaiting the future of the allOf method, + * it throws the CompositeTaskFailedException defined in its exceptionPath. + */ + if (!this.isInCompoundTask) throw exception; + } + return null; } private boolean shouldRetry() { @@ -1201,6 +1233,10 @@ private Duration getNextDelay() { // is responsible for implementing any delays between retry attempts. return Duration.ZERO; } + + public void markInCompoundTask() { + this.isInCompoundTask = true; + } } private class CompoundTask extends CompletableTask { diff --git a/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java b/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java index 94f73331..3f245f40 100644 --- a/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java +++ b/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java @@ -1374,6 +1374,68 @@ void activityAllOf() throws IOException, TimeoutException { } } + @Test + void activityAllOfException() throws IOException, TimeoutException { + final String orchestratorName = "ActivityAllOf"; + final String activityName = "ToString"; + final String retryActivityName = "RetryToStringException"; + final String result = "test fail"; + final int activityMiddle = 5; + final RetryPolicy retryPolicy = new RetryPolicy(2, Duration.ofSeconds(5)); + final TaskOptions taskOptions = new TaskOptions(retryPolicy); + + DurableTaskGrpcWorker worker = this.createWorkerBuilder() + .addOrchestrator(orchestratorName, ctx -> { + List> parallelTasks = IntStream.range(0, activityMiddle * 2) + .mapToObj(i -> { + if (i < activityMiddle) { + return ctx.callActivity(activityName, i, String.class); + } else { + return ctx.callActivity(retryActivityName, i, taskOptions, String.class); + } + }) + .collect(Collectors.toList()); + + // Wait for all tasks to complete, then sort and reverse the results + try { + List results = null; + results = ctx.allOf(parallelTasks).await(); + Collections.sort(results); + Collections.reverse(results); + ctx.complete(results); + } catch (CompositeTaskFailedException e) { + // only catch this type of exception to ensure the expected type of exception is thrown out. + for (Exception exception : e.getExceptions()) { + if (exception instanceof TaskFailedException) { + TaskFailedException taskFailedException = (TaskFailedException) exception; + System.out.println("Task: " + taskFailedException.getTaskName() + + " Failed for cause: " + taskFailedException.getErrorDetails().getErrorMessage()); + } + } + } + ctx.complete(result); + }) + .addActivity(activityName, ctx -> ctx.getInput(Object.class).toString()) + .addActivity(retryActivityName, ctx -> { + // only throw exception + throw new RuntimeException("test retry"); + }) + .buildAndStart(); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + try (worker; client) { + String instanceId = client.scheduleNewOrchestrationInstance(orchestratorName, 0); + OrchestrationMetadata instance = client.waitForInstanceCompletion(instanceId, defaultTimeout, true); + assertNotNull(instance); + assertEquals(OrchestrationRuntimeStatus.COMPLETED, instance.getRuntimeStatus()); + + String output = instance.readOutputAs(String.class); + assertNotNull(output); + assertEquals(String.class, output.getClass()); + assertEquals(result, output); + } + } + @Test void activityAnyOf() throws IOException, TimeoutException { final String orchestratorName = "ActivityAnyOf"; diff --git a/samples-azure-functions/src/main/java/com/functions/ParallelFunctions.java b/samples-azure-functions/src/main/java/com/functions/ParallelFunctions.java index 24d31db3..f7797705 100644 --- a/samples-azure-functions/src/main/java/com/functions/ParallelFunctions.java +++ b/samples-azure-functions/src/main/java/com/functions/ParallelFunctions.java @@ -89,4 +89,49 @@ public Object parallelAnyOf( tasks.add(ctx.callActivity("AppendHappy", 1, Integer.class)); return ctx.anyOf(tasks).await().await(); } + + @FunctionName("StartParallelCatchException") + public HttpResponseMessage startParallelCatchException( + @HttpTrigger(name = "req", methods = {HttpMethod.GET, HttpMethod.POST}, authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + context.getLogger().info("Java HTTP trigger processed a request."); + + DurableTaskClient client = durableContext.getClient(); + String instanceId = client.scheduleNewOrchestrationInstance("ParallelCatchException"); + context.getLogger().info("Created new Java orchestration with instance ID = " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + @FunctionName("ParallelCatchException") + public List parallelCatchException( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx, + ExecutionContext context) { + try { + List> tasks = new ArrayList<>(); + RetryPolicy policy = new RetryPolicy(2, Duration.ofSeconds(1)); + TaskOptions options = new TaskOptions(policy); + tasks.add(ctx.callActivity("AlwaysException", "Input1", options, String.class)); + tasks.add(ctx.callActivity("AppendHappy", "Input2", options, String.class)); + return ctx.allOf(tasks).await(); + } catch (CompositeTaskFailedException e) { + // only catch this type of exception to ensure the expected type of exception is thrown out. + for (Exception exception : e.getExceptions()) { + if (exception instanceof TaskFailedException) { + TaskFailedException taskFailedException = (TaskFailedException) exception; + context.getLogger().info("Task: " + taskFailedException.getTaskName() + + " Failed for cause: " + taskFailedException.getErrorDetails().getErrorMessage()); + } + } + } + return null; + } + + @FunctionName("AlwaysException") + public String alwaysException( + @DurableActivityTrigger(name = "name") String name, + final ExecutionContext context) { + context.getLogger().info("Throw Test AlwaysException: " + name); + throw new RuntimeException("Test AlwaysException"); + } } \ No newline at end of file diff --git a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java index ace80c3b..411dbbaa 100644 --- a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java +++ b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java @@ -33,7 +33,8 @@ public void setupHost() { @ValueSource(strings = { "StartOrchestration", "StartParallelOrchestration", - "StartParallelAnyOf" + "StartParallelAnyOf", + "StartParallelCatchException" }) public void generalFunctions(String functionName) throws InterruptedException { Set continueStates = new HashSet<>();