Skip to content

Commit

Permalink
[ML] Simplify shortcircuiting scores renormalizer (#85625) (#85645)
Browse files Browse the repository at this point in the history
In #85555 a Phaser was used to track ongoing renormalizations
and wait for them to complete. However, this is a very complex
way to do things.

It turns out there's an easier way, as submitting a task to a
threadpool returns a Future that can be used to track when the
submitted task completes.

There was also a bug in the ShortCircuitingRenormalizer class
where the possibility of RejectedExecutionException was not being
considered.
  • Loading branch information
droberts195 committed Apr 1, 2022
1 parent d47642d commit 0e713f1
Showing 1 changed file with 53 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@

import java.util.Date;
import java.util.Objects;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Phaser;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;

/**
* Renormalizer for one job that discards outdated quantiles if even newer ones are received while waiting for a prior renormalization
Expand All @@ -29,21 +33,13 @@ public class ShortCircuitingRenormalizer implements Renormalizer {
private final ScoresUpdater scoresUpdater;
private final ExecutorService executorService;
/**
* Each job may only have 1 normalization in progress at any time; the phaser enforces this.
* Registrations and arrivals must be synchronized.
*/
private final Phaser phaser = new Phaser() {
/**
* Don't terminate when registrations drops to zero.
*/
protected boolean onAdvance(int phase, int parties) {
return false;
}
};
/**
* Access to this must be synchronized.
* Each job may only have 1 normalization in progress at any time; the semaphore enforces this.
* Modifications must be synchronized.
*/
private final Semaphore semaphore = new Semaphore(1);
// Access to both of these must be synchronized.
private AugmentedQuantiles latestQuantilesHolder;
private Future<?> latestTask;

public ShortCircuitingRenormalizer(String jobId, ScoresUpdater scoresUpdater, ExecutorService executorService) {
this.jobId = jobId;
Expand All @@ -67,18 +63,14 @@ public void renormalize(Quantiles quantiles) {
latestQuantilesHolder = (latestQuantilesHolder == null)
? new AugmentedQuantiles(quantiles, null, new CountDownLatch(1))
: new AugmentedQuantiles(quantiles, latestQuantilesHolder.getEvictedTimestamp(), latestQuantilesHolder.getLatch());
// Don't start a thread if another normalization thread is still working. The existing thread will
// do this normalization when it finishes its current one. This means we serialise normalizations
// without hogging threads or queuing up many large quantiles documents.
if (tryStartWork()) {
executorService.submit(this::doRenormalizations);
}
tryStartWork();
}
}

@Override
public void waitUntilIdle() throws InterruptedException {
CountDownLatch latch;
Future<?> taskToWaitFor;
do {
// The first bit waits for any not-yet-started renormalization to complete.
synchronized (this) {
Expand All @@ -90,12 +82,21 @@ public void waitUntilIdle() throws InterruptedException {
// This next bit waits for any thread that's been started to run doRenormalizations() to exit the loop in that method.
// If no doRenormalizations() thread is running then we'll wait for the previous phase, and a call to do that should
// return immediately.
int phaseToWaitFor;
synchronized (this) {
phaseToWaitFor = phaser.getPhase() - 1 + phaser.getUnarrivedParties();
taskToWaitFor = latestTask;
}
phaser.awaitAdvanceInterruptibly(phaseToWaitFor);
} while (latch != null);
if (taskToWaitFor != null) {
try {
taskToWaitFor.get();
} catch (ExecutionException e) {
// This shouldn't happen, because we catch normalization errors inside the normalization loop
logger.error("[" + jobId + "] Error propagated from normalization", e);
} catch (CancellationException e) {
// Convert cancellations to interruptions to simplify the interface
throw new InterruptedException("Normalization cancelled");
}
}
} while (latch != null || taskToWaitFor != null);
}

@Override
Expand All @@ -106,7 +107,6 @@ public void shutdown() throws InterruptedException {
// scoresUpdater first means it won't do all pending work; it will stop as soon
// as it can without causing further errors.
waitUntilIdle();
phaser.forceTermination();
}

private synchronized AugmentedQuantiles getLatestAugmentedQuantilesAndClear() {
Expand All @@ -116,26 +116,43 @@ private synchronized AugmentedQuantiles getLatestAugmentedQuantilesAndClear() {
}

private synchronized boolean tryStartWork() {
if (phaser.getUnarrivedParties() > 0) {
if (latestQuantilesHolder == null) {
return false;
}
return phaser.register() >= 0;
// Don't start a thread if another normalization thread is still working. The existing thread will
// do this normalization when it finishes its current one. This means we serialise normalizations
// without hogging threads or queuing up many large quantiles documents.
if (semaphore.tryAcquire()) {
try {
latestTask = executorService.submit(this::doRenormalizations);
} catch (RejectedExecutionException e) {
latestQuantilesHolder.getLatch().countDown();
latestQuantilesHolder = null;
latestTask = null;
semaphore.release();
logger.warn("[{}] Normalization discarded as threadpool is shutting down", jobId);
return false;
}
return true;
}
return false;
}

private synchronized boolean tryFinishWork() {
// Synchronized because we cannot tolerate new work being added in between the null check and releasing the semaphore
if (latestQuantilesHolder != null) {
return false;
}
phaser.arriveAndDeregister();
semaphore.release();
latestTask = null;
return true;
}

private void doRenormalizations() {
do {
AugmentedQuantiles latestAugmentedQuantiles = getLatestAugmentedQuantilesAndClear();
assert latestAugmentedQuantiles != null;
if (latestAugmentedQuantiles != null) {
if (latestAugmentedQuantiles != null) { // TODO: remove this if the assert doesn't trip in CI over the next year or so
Quantiles latestQuantiles = latestAugmentedQuantiles.getQuantiles();
CountDownLatch latch = latestAugmentedQuantiles.getLatch();
try {
Expand All @@ -146,16 +163,21 @@ private void doRenormalizations() {
);
} catch (Exception e) {
logger.error("[" + jobId + "] Normalization failed", e);
} finally {
latch.countDown();
}
latch.countDown();
} else {
logger.warn("[{}] request to normalize null quantiles", jobId);
}
// Loop if more work has become available while we were working, because the
// tasks originally submitted to do that work will have exited early.
} while (tryFinishWork() == false);
}

/**
* Simple grouping of a {@linkplain Quantiles} object with its corresponding {@linkplain CountDownLatch} object.
* Grouping of a {@linkplain Quantiles} object with its corresponding {@linkplain CountDownLatch} object.
* Also stores the earliest timestamp that any set of discarded quantiles held, to allow the normalization
* window to be extended if multiple normalization requests are combined.
*/
private class AugmentedQuantiles {
private final Quantiles quantiles;
Expand Down

0 comments on commit 0e713f1

Please sign in to comment.