Skip to content

Commit

Permalink
CAS and voltile approach to fix delta concurrency bug (#5976)
Browse files Browse the repository at this point in the history
  • Loading branch information
jack-berg committed Nov 13, 2023
1 parent aca4157 commit 72a5bb1
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand Down Expand Up @@ -85,14 +83,13 @@ Queue<AggregatorHandle<T, U>> getAggregatorHandlePool() {

@Override
public void recordLong(long value, Attributes attributes, Context context) {
Lock readLock = aggregatorHolder.lock.readLock();
readLock.lock();
AggregatorHolder<T, U> aggregatorHolder = getHolderForRecord();
try {
AggregatorHandle<T, U> handle =
getAggregatorHandle(aggregatorHolder.aggregatorHandles, attributes, context);
handle.recordLong(value, attributes, context);
} finally {
readLock.unlock();
releaseHolderForRecord(aggregatorHolder);
}
}

Expand All @@ -108,17 +105,46 @@ public void recordDouble(double value, Attributes attributes, Context context) {
+ ". Dropping measurement.");
return;
}
Lock readLock = aggregatorHolder.lock.readLock();
readLock.lock();
AggregatorHolder<T, U> aggregatorHolder = getHolderForRecord();
try {
AggregatorHandle<T, U> handle =
getAggregatorHandle(aggregatorHolder.aggregatorHandles, attributes, context);
handle.recordDouble(value, attributes, context);
} finally {
readLock.unlock();
releaseHolderForRecord(aggregatorHolder);
}
}

/**
* Obtain the AggregatorHolder for recording measurements, re-reading the volatile
* this.aggregatorHolder until we access one where recordsInProgress is even. Collect sets
* recordsInProgress to odd as a signal that AggregatorHolder is stale and is being replaced.
* Record operations increment recordInProgress by 2. Callers MUST call {@link
* #releaseHolderForRecord(AggregatorHolder)} when record operation completes to signal to that
* its safe to proceed with Collect operations.
*/
private AggregatorHolder<T, U> getHolderForRecord() {
do {
AggregatorHolder<T, U> aggregatorHolder = this.aggregatorHolder;
int recordsInProgress = aggregatorHolder.activeRecordingThreads.addAndGet(2);
if (recordsInProgress % 2 == 0) {
return aggregatorHolder;
} else {
// Collect is in progress, decrement recordsInProgress to allow collect to proceed and
// re-read aggregatorHolder
aggregatorHolder.activeRecordingThreads.addAndGet(-2);
}
} while (true);
}

/**
* Called on the {@link AggregatorHolder} obtained from {@link #getHolderForRecord()} to indicate
* that recording is complete and it is safe to collect.
*/
private void releaseHolderForRecord(AggregatorHolder<T, U> aggregatorHolder) {
aggregatorHolder.activeRecordingThreads.addAndGet(-2);
}

private AggregatorHandle<T, U> getAggregatorHandle(
ConcurrentHashMap<Attributes, AggregatorHandle<T, U>> aggregatorHandles,
Attributes attributes,
Expand Down Expand Up @@ -169,13 +195,15 @@ public MetricData collect(
if (reset) {
AggregatorHolder<T, U> holder = this.aggregatorHolder;
this.aggregatorHolder = new AggregatorHolder<>();
Lock writeLock = holder.lock.writeLock();
writeLock.lock();
try {
aggregatorHandles = holder.aggregatorHandles;
} finally {
writeLock.unlock();
// Increment recordsInProgress by 1, which produces an odd number acting as a signal that
// record operations should re-read the volatile this.aggregatorHolder.
// Repeatedly grab recordsInProgress until it is <= 1, which signals all active record
// operations are complete.
int recordsInProgress = holder.activeRecordingThreads.addAndGet(1);
while (recordsInProgress > 1) {
recordsInProgress = holder.activeRecordingThreads.get();
}
aggregatorHandles = holder.aggregatorHandles;
} else {
aggregatorHandles = this.aggregatorHolder.aggregatorHandles;
}
Expand Down Expand Up @@ -217,6 +245,20 @@ public MetricDescriptor getMetricDescriptor() {
private static class AggregatorHolder<T extends PointData, U extends ExemplarData> {
private final ConcurrentHashMap<Attributes, AggregatorHandle<T, U>> aggregatorHandles =
new ConcurrentHashMap<>();
private final ReadWriteLock lock = new ReentrantReadWriteLock();
// Recording threads grab the current interval (AggregatorHolder) and atomically increment
// this by 2 before recording against it (and then decrement by two when done).
//
// The collection thread grabs the current interval (AggregatorHolder) and atomically
// increments this by 1 to "lock" this interval (and then waits for any active recording
// threads to complete before collecting it).
//
// Recording threads check the return value of their atomic increment, and if it's odd
// that means the collector thread has "locked" this interval for collection.
//
// But before the collector "locks" the interval it sets up a new current interval
// (AggregatorHolder), and so if a recording thread encounters an odd value,
// all it needs to do is release the "read lock" it just obtained (decrementing by 2),
// and then grab and record against the new current interval (AggregatorHolder).
private final AtomicInteger activeRecordingThreads = new AtomicInteger(0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ void recordAndCollect_concurrentStressTest(
Thread collectThread =
new Thread(
() -> {
do {
int extraCollects = 0;
// If we terminate when latch.count() == 0, the last collect may have occurred before
// the last recorded measurement. To ensure we collect all measurements, we collect
// one extra time after latch.count() == 0.
while (latch.getCount() != 0 && extraCollects <= 1) {
Uninterruptibles.sleepUninterruptibly(Duration.ofMillis(1));
MetricData metricData =
storage.collect(Resource.empty(), InstrumentationScopeInfo.empty(), 0, 1);
Expand All @@ -419,7 +423,10 @@ void recordAndCollect_concurrentStressTest(
metricData.getDoubleSumData().getPoints().stream()
.findFirst()
.ifPresent(pointData -> collect.accept(pointData.getValue(), cumulativeSum));
} while (latch.getCount() != 0);
if (latch.getCount() == 0) {
extraCollects++;
}
}
});

// Start all the threads
Expand Down

0 comments on commit 72a5bb1

Please sign in to comment.