diff --git a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java index cfb340758061..2d5ffaa76d98 100644 --- a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java +++ b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java @@ -20,8 +20,10 @@ import static org.junit.platform.engine.support.hierarchical.ExclusiveResource.GLOBAL_READ_WRITE; import static org.junit.platform.engine.support.hierarchical.Node.ExecutionMode.SAME_THREAD; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; +import java.util.Deque; import java.util.EnumMap; import java.util.Iterator; import java.util.List; @@ -105,7 +107,9 @@ public void close() { return completedFuture(null); } - return new WorkStealingFuture(enqueue(testTask)); + var entry = enqueue(testTask); + workerThread.trackSubmittedChild(entry); + return new WorkStealingFuture(entry); } @Override @@ -197,6 +201,8 @@ private class WorkerThread extends Thread { @Nullable WorkerLease workerLease; + private final Deque stateStack = new ArrayDeque<>(); + WorkerThread(Runnable runnable, String name) { super(runnable, name); } @@ -345,7 +351,7 @@ private WorkStealResult tryToStealWork(WorkQueue.Entry entry, BlockingMode block } var claimed = workQueue.remove(entry); if (claimed) { - LOGGER.trace(() -> "stole work: " + entry); + LOGGER.trace(() -> "stole work: " + entry.task); var executed = executeStolenWork(entry, blockingMode); if (executed) { return WorkStealResult.EXECUTED_BY_THIS_WORKER; @@ -474,10 +480,12 @@ private boolean tryExecuteTask(TestTask testTask) { private void doExecute(TestTask testTask) { LOGGER.trace(() -> "executing: " + testTask); + stateStack.push(new State()); try { testTask.execute(); } finally { + stateStack.pop(); LOGGER.trace(() -> "finished executing: " + testTask); } } @@ -490,8 +498,52 @@ private static CompletableFuture toCombinedFuture(List entri return CompletableFuture.allOf(futures); } + private void trackSubmittedChild(WorkQueue.Entry entry) { + stateStack.element().trackSubmittedChild(entry); + } + + private void tryToStealWorkFromSubmittedChildren() { + var currentState = stateStack.element(); + var currentSubmittedChildren = currentState.submittedChildren; + if (currentSubmittedChildren == null || currentSubmittedChildren.isEmpty()) { + return; + } + var iterator = currentSubmittedChildren.listIterator(currentSubmittedChildren.size()); + while (iterator.hasPrevious()) { + WorkQueue.Entry entry = iterator.previous(); + var result = tryToStealWork(entry, BlockingMode.NON_BLOCKING); + if (result.isExecuted()) { + iterator.remove(); + } + } + currentState.clearIfEmpty(); + } + + private static class State { + + @Nullable + private List submittedChildren; + + private void trackSubmittedChild(WorkQueue.Entry entry) { + if (submittedChildren == null) { + submittedChildren = new ArrayList<>(); + } + submittedChildren.add(entry); + } + + private void clearIfEmpty() { + if (submittedChildren != null && submittedChildren.isEmpty()) { + submittedChildren = null; + } + } + } + private enum WorkStealResult { - EXECUTED_BY_DIFFERENT_WORKER, RESOURCE_LOCK_UNAVAILABLE, EXECUTED_BY_THIS_WORKER + EXECUTED_BY_DIFFERENT_WORKER, RESOURCE_LOCK_UNAVAILABLE, EXECUTED_BY_THIS_WORKER; + + private boolean isExecuted() { + return this != RESOURCE_LOCK_UNAVAILABLE; + } } private interface BlockingAction { @@ -519,8 +571,11 @@ private static class WorkStealingFuture extends BlockingAwareFuture<@Nullable Vo if (entry.future.isDone()) { return callable.call(); } - // TODO steal other dynamic children until future is done and check again before blocking - LOGGER.trace(() -> "blocking for child task"); + workerThread.tryToStealWorkFromSubmittedChildren(); + if (entry.future.isDone()) { + return callable.call(); + } + LOGGER.trace(() -> "blocking for child task: " + entry.task); return workerThread.runBlocking(entry.future::isDone, () -> { try { return callable.call(); diff --git a/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorServiceTests.java b/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorServiceTests.java index 17db098e1e45..1491fbb2f4b1 100644 --- a/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorServiceTests.java +++ b/platform-tests/src/test/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorServiceTests.java @@ -525,6 +525,144 @@ void workIsStolenInReverseOrder() throws Exception { .isSorted(); } + @RepeatedTest(value = 100, failureThreshold = 1) + void stealsDynamicChildren() throws Exception { + service = new ConcurrentHierarchicalTestExecutorService(configuration(2, 2)); + + var child1Started = new CountDownLatch(1); + var child2Finished = new CountDownLatch(1); + var child1 = new TestTaskStub(ExecutionMode.CONCURRENT, () -> { + child1Started.countDown(); + child2Finished.await(); + }) // + .withName("child1").withLevel(2); + var child2 = new TestTaskStub(ExecutionMode.CONCURRENT, child2Finished::countDown) // + .withName("child2").withLevel(2); + + var root = new TestTaskStub(ExecutionMode.SAME_THREAD, () -> { + var future1 = requiredService().submit(child1); + child1Started.await(); + var future2 = requiredService().submit(child2); + future1.get(); + future2.get(); + }) // + .withName("root").withLevel(1); + + service.submit(root).get(); + + assertThat(Stream.of(root, child1, child2)) // + .allSatisfy(TestTaskStub::assertExecutedSuccessfully); + assertThat(child2.executionThread).isEqualTo(root.executionThread).isNotEqualTo(child1.executionThread); + } + + @RepeatedTest(value = 100, failureThreshold = 1) + void stealsNestedDynamicChildren() throws Exception { + service = new ConcurrentHierarchicalTestExecutorService(configuration(2, 2)); + + var barrier = new CyclicBarrier(2); + + var leaf1a = new TestTaskStub(ExecutionMode.CONCURRENT) // + .withName("leaf1a").withLevel(3); + var leaf1b = new TestTaskStub(ExecutionMode.CONCURRENT) // + .withName("leaf1b").withLevel(3); + + var child1 = new TestTaskStub(ExecutionMode.CONCURRENT, () -> { + barrier.await(); + var futureA = requiredService().submit(leaf1a); + barrier.await(); + var futureB = requiredService().submit(leaf1b); + futureA.get(); + futureB.get(); + barrier.await(); + }) // + .withName("child1").withLevel(2); + + var leaf2a = new TestTaskStub(ExecutionMode.CONCURRENT) // + .withName("leaf2a").withLevel(3); + var leaf2b = new TestTaskStub(ExecutionMode.CONCURRENT) // + .withName("leaf2b").withLevel(3); + + var child2 = new TestTaskStub(ExecutionMode.CONCURRENT, () -> { + barrier.await(); + var futureA = requiredService().submit(leaf2a); + barrier.await(); + var futureB = requiredService().submit(leaf2b); + futureB.get(); + futureA.get(); + barrier.await(); + }) // + .withName("child2").withLevel(2); + + var root = new TestTaskStub(ExecutionMode.SAME_THREAD, () -> { + var future1 = requiredService().submit(child1); + var future2 = requiredService().submit(child2); + future1.get(); + future2.get(); + }) // + .withName("root").withLevel(1); + + service.submit(root).get(); + + assertThat(Stream.of(root, child1, child2, leaf1a, leaf1b, leaf2a, leaf2b)) // + .allSatisfy(TestTaskStub::assertExecutedSuccessfully); + assertThat(child2.executionThread).isNotEqualTo(child1.executionThread); + assertThat(child1.executionThread).isEqualTo(leaf1a.executionThread).isEqualTo(leaf1b.executionThread); + assertThat(child2.executionThread).isEqualTo(leaf2a.executionThread).isEqualTo(leaf2b.executionThread); + } + + @RepeatedTest(value = 100, failureThreshold = 1) + void stealsSiblingDynamicChildrenOnly() throws Exception { + service = new ConcurrentHierarchicalTestExecutorService(configuration(2, 3)); + + var child1Started = new CountDownLatch(1); + var child2Started = new CountDownLatch(1); + var leaf1ASubmitted = new CountDownLatch(1); + var leaf1AStarted = new CountDownLatch(1); + + var leaf1a = new TestTaskStub(ExecutionMode.CONCURRENT, () -> { + leaf1AStarted.countDown(); + child2Started.await(); + }) // + .withName("leaf1a").withLevel(3); + + var child1 = new TestTaskStub(ExecutionMode.CONCURRENT, () -> { + child1Started.countDown(); + leaf1ASubmitted.await(); + }) // + .withName("child1").withLevel(2); + + var child2 = new TestTaskStub(ExecutionMode.CONCURRENT, child2Started::countDown) // + .withName("child2").withLevel(2); + + var child3 = new TestTaskStub(ExecutionMode.CONCURRENT, () -> { + var futureA = requiredService().submit(leaf1a); + leaf1ASubmitted.countDown(); + leaf1AStarted.await(); + futureA.get(); + }) // + .withName("child3").withLevel(2); + + var root = new TestTaskStub(ExecutionMode.SAME_THREAD, () -> { + var future1 = requiredService().submit(child1); + child1Started.await(); + var future2 = requiredService().submit(child2); + var future3 = requiredService().submit(child3); + future1.get(); + future2.get(); + future3.get(); + }) // + .withName("root").withLevel(1); + + service.submit(root).get(); + + assertThat(Stream.of(root, child1, child2, child3, leaf1a)) // + .allSatisfy(TestTaskStub::assertExecutedSuccessfully); + + assertThat(child3.executionThread).isNotEqualTo(child1.executionThread).isNotEqualTo(child2.executionThread); + assertThat(child1.executionThread).isNotEqualTo(child2.executionThread); + assertThat(child1.executionThread).isEqualTo(leaf1a.executionThread); + } + private static ExclusiveResource exclusiveResource(LockMode lockMode) { return new ExclusiveResource("key", lockMode); }