|
1 | 1 | /*
|
2 |
| - * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. |
| 2 | + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. |
3 | 3 | * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
4 | 4 | *
|
5 | 5 | * This code is free software; you can redistribute it and/or modify it
|
|
30 | 30 | import java.util.ArrayDeque;
|
31 | 31 | import java.util.List;
|
32 | 32 | import java.util.Objects;
|
| 33 | +import java.util.concurrent.Callable; |
33 | 34 | import java.util.concurrent.ExecutionException;
|
34 | 35 | import java.util.concurrent.Future;
|
35 | 36 | import java.util.concurrent.FutureTask;
|
@@ -350,86 +351,103 @@ public static <T, R> Gatherer<T,?,R> mapConcurrent(
|
350 | 351 | final int maxConcurrency,
|
351 | 352 | final Function<? super T, ? extends R> mapper) {
|
352 | 353 | if (maxConcurrency < 1)
|
353 |
| - throw new IllegalArgumentException( |
354 |
| - "'maxConcurrency' must be greater than 0"); |
| 354 | + throw new IllegalArgumentException("'maxConcurrency' must be greater than 0"); |
355 | 355 |
|
356 | 356 | Objects.requireNonNull(mapper, "'mapper' must not be null");
|
357 | 357 |
|
358 |
| - class State { |
359 |
| - // ArrayDeque default initial size is 16 |
360 |
| - final ArrayDeque<Future<R>> window = |
361 |
| - new ArrayDeque<>(Math.min(maxConcurrency, 16)); |
362 |
| - final Semaphore windowLock = new Semaphore(maxConcurrency); |
363 |
| - |
364 |
| - final boolean integrate(T element, |
365 |
| - Downstream<? super R> downstream) { |
366 |
| - if (!downstream.isRejecting()) |
367 |
| - createTaskFor(element); |
368 |
| - return flush(0, downstream); |
| 358 | + final class MapConcurrentTask extends FutureTask<R> { |
| 359 | + final Thread thread; |
| 360 | + private MapConcurrentTask(Callable<R> callable) { |
| 361 | + super(callable); |
| 362 | + this.thread = Thread.ofVirtual().unstarted(this); |
369 | 363 | }
|
| 364 | + } |
370 | 365 |
|
371 |
| - final void createTaskFor(T element) { |
372 |
| - windowLock.acquireUninterruptibly(); |
| 366 | + final class State { |
| 367 | + private final ArrayDeque<MapConcurrentTask> wip = |
| 368 | + new ArrayDeque<>(Math.min(maxConcurrency, 16)); |
373 | 369 |
|
374 |
| - var task = new FutureTask<R>(() -> { |
375 |
| - try { |
376 |
| - return mapper.apply(element); |
377 |
| - } finally { |
378 |
| - windowLock.release(); |
379 |
| - } |
380 |
| - }); |
| 370 | + boolean integrate(T element, Downstream<? super R> downstream) { |
| 371 | + // Prepare the next task and add it to the work-in-progress |
| 372 | + final var task = new MapConcurrentTask(() -> mapper.apply(element)); |
| 373 | + wip.addLast(task); |
| 374 | + |
| 375 | + assert wip.peekLast() == task; |
| 376 | + assert wip.size() <= maxConcurrency; |
381 | 377 |
|
382 |
| - var wasAddedToWindow = window.add(task); |
383 |
| - assert wasAddedToWindow; |
| 378 | + // Start the next task |
| 379 | + task.thread.start(); |
384 | 380 |
|
385 |
| - Thread.startVirtualThread(task); |
| 381 | + // Flush at least 1 element if we're at capacity |
| 382 | + return flush(wip.size() < maxConcurrency ? 0 : 1, downstream); |
386 | 383 | }
|
387 | 384 |
|
388 |
| - final boolean flush(long atLeastN, |
389 |
| - Downstream<? super R> downstream) { |
390 |
| - boolean proceed = !downstream.isRejecting(); |
391 |
| - boolean interrupted = false; |
| 385 | + boolean flush(long atLeastN, Downstream<? super R> downstream) { |
| 386 | + boolean success = false, interrupted = false; |
392 | 387 | try {
|
393 |
| - Future<R> current; |
394 |
| - while (proceed |
395 |
| - && (current = window.peek()) != null |
396 |
| - && (current.isDone() || atLeastN > 0)) { |
397 |
| - proceed &= downstream.push(current.get()); |
| 388 | + boolean proceed = !downstream.isRejecting(); |
| 389 | + MapConcurrentTask current; |
| 390 | + while ( |
| 391 | + proceed |
| 392 | + && (current = wip.peekFirst()) != null |
| 393 | + && (current.isDone() || atLeastN > 0) |
| 394 | + ) { |
| 395 | + R result; |
| 396 | + |
| 397 | + // Ensure that the task is done before proceeding |
| 398 | + for (;;) { |
| 399 | + try { |
| 400 | + result = current.get(); |
| 401 | + break; |
| 402 | + } catch (InterruptedException ie) { |
| 403 | + interrupted = true; // ignore for now, and restore later |
| 404 | + } |
| 405 | + } |
| 406 | + |
| 407 | + proceed &= downstream.push(result); |
398 | 408 | atLeastN -= 1;
|
399 | 409 |
|
400 |
| - var correctRemoval = window.pop() == current; |
| 410 | + final var correctRemoval = wip.pollFirst() == current; |
401 | 411 | assert correctRemoval;
|
402 | 412 | }
|
403 |
| - } catch(InterruptedException ie) { |
404 |
| - proceed = false; |
405 |
| - interrupted = true; |
| 413 | + return (success = proceed); // Ensure that cleanup occurs if needed |
406 | 414 | } catch (ExecutionException e) {
|
407 |
| - proceed = false; // Ensure cleanup |
408 | 415 | final var cause = e.getCause();
|
409 | 416 | throw (cause instanceof RuntimeException re)
|
410 | 417 | ? re
|
411 | 418 | : new RuntimeException(cause == null ? e : cause);
|
412 | 419 | } finally {
|
413 |
| - // Clean up |
414 |
| - if (!proceed) { |
415 |
| - Future<R> next; |
416 |
| - while ((next = window.pollFirst()) != null) { |
417 |
| - next.cancel(true); |
| 420 | + // Clean up work-in-progress |
| 421 | + if (!success && !wip.isEmpty()) { |
| 422 | + // First signal cancellation for all tasks in progress |
| 423 | + for (var task : wip) |
| 424 | + task.cancel(true); |
| 425 | + |
| 426 | + // Then wait for all in progress task Threads to exit |
| 427 | + MapConcurrentTask next; |
| 428 | + while ((next = wip.pollFirst()) != null) { |
| 429 | + while (next.thread.isAlive()) { |
| 430 | + try { |
| 431 | + next.thread.join(); |
| 432 | + } catch (InterruptedException ie) { |
| 433 | + interrupted = true; // ignore, for now, and restore later |
| 434 | + } |
| 435 | + } |
418 | 436 | }
|
419 | 437 | }
|
420 |
| - } |
421 |
| - |
422 |
| - if (interrupted) |
423 |
| - Thread.currentThread().interrupt(); |
424 | 438 |
|
425 |
| - return proceed; |
| 439 | + // integrate(..) could be called from different threads each time |
| 440 | + // so we need to restore the interrupt on the calling thread |
| 441 | + if (interrupted) |
| 442 | + Thread.currentThread().interrupt(); |
| 443 | + } |
426 | 444 | }
|
427 | 445 | }
|
428 | 446 |
|
429 | 447 | return Gatherer.ofSequential(
|
430 |
| - State::new, |
431 |
| - Integrator.<State, T, R>ofGreedy(State::integrate), |
432 |
| - (state, downstream) -> state.flush(Long.MAX_VALUE, downstream) |
| 448 | + State::new, |
| 449 | + Integrator.<State, T, R>ofGreedy(State::integrate), |
| 450 | + (state, downstream) -> state.flush(Long.MAX_VALUE, downstream) |
433 | 451 | );
|
434 | 452 | }
|
435 | 453 |
|
|
0 commit comments