Skip to content

Commit 450636a

Browse files
author
Viktor Klang
committed
8347274: Gatherers.mapConcurrent exhibits undesired behavior under variable delays, interruption, and finishing
Reviewed-by: alanb
1 parent 82e2a79 commit 450636a

File tree

2 files changed

+132
-55
lines changed

2 files changed

+132
-55
lines changed

src/java.base/share/classes/java/util/stream/Gatherers.java

+71-53
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -30,6 +30,7 @@
3030
import java.util.ArrayDeque;
3131
import java.util.List;
3232
import java.util.Objects;
33+
import java.util.concurrent.Callable;
3334
import java.util.concurrent.ExecutionException;
3435
import java.util.concurrent.Future;
3536
import java.util.concurrent.FutureTask;
@@ -350,86 +351,103 @@ public static <T, R> Gatherer<T,?,R> mapConcurrent(
350351
final int maxConcurrency,
351352
final Function<? super T, ? extends R> mapper) {
352353
if (maxConcurrency < 1)
353-
throw new IllegalArgumentException(
354-
"'maxConcurrency' must be greater than 0");
354+
throw new IllegalArgumentException("'maxConcurrency' must be greater than 0");
355355

356356
Objects.requireNonNull(mapper, "'mapper' must not be null");
357357

358-
class State {
359-
// ArrayDeque default initial size is 16
360-
final ArrayDeque<Future<R>> window =
361-
new ArrayDeque<>(Math.min(maxConcurrency, 16));
362-
final Semaphore windowLock = new Semaphore(maxConcurrency);
363-
364-
final boolean integrate(T element,
365-
Downstream<? super R> downstream) {
366-
if (!downstream.isRejecting())
367-
createTaskFor(element);
368-
return flush(0, downstream);
358+
final class MapConcurrentTask extends FutureTask<R> {
359+
final Thread thread;
360+
private MapConcurrentTask(Callable<R> callable) {
361+
super(callable);
362+
this.thread = Thread.ofVirtual().unstarted(this);
369363
}
364+
}
370365

371-
final void createTaskFor(T element) {
372-
windowLock.acquireUninterruptibly();
366+
final class State {
367+
private final ArrayDeque<MapConcurrentTask> wip =
368+
new ArrayDeque<>(Math.min(maxConcurrency, 16));
373369

374-
var task = new FutureTask<R>(() -> {
375-
try {
376-
return mapper.apply(element);
377-
} finally {
378-
windowLock.release();
379-
}
380-
});
370+
boolean integrate(T element, Downstream<? super R> downstream) {
371+
// Prepare the next task and add it to the work-in-progress
372+
final var task = new MapConcurrentTask(() -> mapper.apply(element));
373+
wip.addLast(task);
374+
375+
assert wip.peekLast() == task;
376+
assert wip.size() <= maxConcurrency;
381377

382-
var wasAddedToWindow = window.add(task);
383-
assert wasAddedToWindow;
378+
// Start the next task
379+
task.thread.start();
384380

385-
Thread.startVirtualThread(task);
381+
// Flush at least 1 element if we're at capacity
382+
return flush(wip.size() < maxConcurrency ? 0 : 1, downstream);
386383
}
387384

388-
final boolean flush(long atLeastN,
389-
Downstream<? super R> downstream) {
390-
boolean proceed = !downstream.isRejecting();
391-
boolean interrupted = false;
385+
boolean flush(long atLeastN, Downstream<? super R> downstream) {
386+
boolean success = false, interrupted = false;
392387
try {
393-
Future<R> current;
394-
while (proceed
395-
&& (current = window.peek()) != null
396-
&& (current.isDone() || atLeastN > 0)) {
397-
proceed &= downstream.push(current.get());
388+
boolean proceed = !downstream.isRejecting();
389+
MapConcurrentTask current;
390+
while (
391+
proceed
392+
&& (current = wip.peekFirst()) != null
393+
&& (current.isDone() || atLeastN > 0)
394+
) {
395+
R result;
396+
397+
// Ensure that the task is done before proceeding
398+
for (;;) {
399+
try {
400+
result = current.get();
401+
break;
402+
} catch (InterruptedException ie) {
403+
interrupted = true; // ignore for now, and restore later
404+
}
405+
}
406+
407+
proceed &= downstream.push(result);
398408
atLeastN -= 1;
399409

400-
var correctRemoval = window.pop() == current;
410+
final var correctRemoval = wip.pollFirst() == current;
401411
assert correctRemoval;
402412
}
403-
} catch(InterruptedException ie) {
404-
proceed = false;
405-
interrupted = true;
413+
return (success = proceed); // Ensure that cleanup occurs if needed
406414
} catch (ExecutionException e) {
407-
proceed = false; // Ensure cleanup
408415
final var cause = e.getCause();
409416
throw (cause instanceof RuntimeException re)
410417
? re
411418
: new RuntimeException(cause == null ? e : cause);
412419
} finally {
413-
// Clean up
414-
if (!proceed) {
415-
Future<R> next;
416-
while ((next = window.pollFirst()) != null) {
417-
next.cancel(true);
420+
// Clean up work-in-progress
421+
if (!success && !wip.isEmpty()) {
422+
// First signal cancellation for all tasks in progress
423+
for (var task : wip)
424+
task.cancel(true);
425+
426+
// Then wait for all in progress task Threads to exit
427+
MapConcurrentTask next;
428+
while ((next = wip.pollFirst()) != null) {
429+
while (next.thread.isAlive()) {
430+
try {
431+
next.thread.join();
432+
} catch (InterruptedException ie) {
433+
interrupted = true; // ignore, for now, and restore later
434+
}
435+
}
418436
}
419437
}
420-
}
421-
422-
if (interrupted)
423-
Thread.currentThread().interrupt();
424438

425-
return proceed;
439+
// integrate(..) could be called from different threads each time
440+
// so we need to restore the interrupt on the calling thread
441+
if (interrupted)
442+
Thread.currentThread().interrupt();
443+
}
426444
}
427445
}
428446

429447
return Gatherer.ofSequential(
430-
State::new,
431-
Integrator.<State, T, R>ofGreedy(State::integrate),
432-
(state, downstream) -> state.flush(Long.MAX_VALUE, downstream)
448+
State::new,
449+
Integrator.<State, T, R>ofGreedy(State::integrate),
450+
(state, downstream) -> state.flush(Long.MAX_VALUE, downstream)
433451
);
434452
}
435453

test/jdk/java/util/stream/GatherersMapConcurrentTest.java

+61-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -24,6 +24,9 @@
2424
import java.util.List;
2525
import java.util.concurrent.CountDownLatch;
2626
import java.util.concurrent.Semaphore;
27+
import java.util.concurrent.atomic.AtomicLong;
28+
import java.util.concurrent.locks.LockSupport;
29+
import java.util.function.Function;
2730
import java.util.stream.Gatherer;
2831
import java.util.stream.Gatherers;
2932
import java.util.stream.Stream;
@@ -298,7 +301,7 @@ public void behavesAsExpected(ConcurrencyConfig cc) {
298301

299302
@ParameterizedTest
300303
@MethodSource("concurrencyConfigurations")
301-
public void behavesAsExpectedWhenShortCircuited(ConcurrencyConfig cc) {
304+
public void shortCircuits(ConcurrencyConfig cc) {
302305
final var limitTo = Math.max(cc.config().streamSize() / 2, 1);
303306

304307
final var expectedResult = cc.config().stream()
@@ -313,4 +316,60 @@ public void behavesAsExpectedWhenShortCircuited(ConcurrencyConfig cc) {
313316

314317
assertEquals(expectedResult, result);
315318
}
319+
320+
@ParameterizedTest
321+
@MethodSource("concurrencyConfigurations")
322+
public void ignoresAndRestoresCallingThreadInterruption(ConcurrencyConfig cc) {
323+
final var limitTo = Math.max(cc.config().streamSize() / 2, 1);
324+
325+
final var expectedResult = cc.config().stream()
326+
.map(x -> x * x)
327+
.limit(limitTo)
328+
.toList();
329+
330+
// Ensure calling thread is interrupted
331+
Thread.currentThread().interrupt();
332+
333+
final var result = cc.config().stream()
334+
.gather(Gatherers.mapConcurrent(cc.concurrencyLevel(), x -> {
335+
LockSupport.parkNanos(10000); // 10 us
336+
return x * x;
337+
}))
338+
.limit(limitTo)
339+
.toList();
340+
341+
// Ensure calling thread remains interrupted
342+
assertEquals(true, Thread.interrupted());
343+
344+
assertEquals(expectedResult, result);
345+
}
346+
347+
@ParameterizedTest
348+
@MethodSource("concurrencyConfigurations")
349+
public void limitsWorkInProgressToMaxConcurrency(ConcurrencyConfig cc) {
350+
final var elementNum = new AtomicLong(0);
351+
final var wipCount = new AtomicLong(0);
352+
final var limitTo = Math.max(cc.config().streamSize() / 2, 1);
353+
354+
final var expectedResult = cc.config().stream()
355+
.map(x -> x * x)
356+
.limit(limitTo)
357+
.toList();
358+
359+
Function<Integer, Integer> fun = x -> {
360+
if (wipCount.incrementAndGet() > cc.concurrencyLevel)
361+
throw new IllegalStateException("Too much wip!");
362+
if (elementNum.getAndIncrement() == 0)
363+
LockSupport.parkNanos(500_000_000); // 500 ms
364+
return x * x;
365+
};
366+
367+
final var result = cc.config().stream()
368+
.gather(Gatherers.mapConcurrent(cc.concurrencyLevel(), fun))
369+
.gather(Gatherer.of((v, e, d) -> wipCount.decrementAndGet() >= 0 && d.push(e)))
370+
.limit(limitTo)
371+
.toList();
372+
373+
assertEquals(expectedResult, result);
374+
}
316375
}

0 commit comments

Comments
 (0)