Skip to content

Commit

Permalink
Fix delta metric storage concurrency bug (#5932)
Browse files Browse the repository at this point in the history
  • Loading branch information
jack-berg committed Nov 10, 2023
1 parent 83993e0 commit 04f6d9c
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Meter build() {
.get("io.opentelemetry.sdk.metrics");
}
}),
SDK(
SDK_CUMULATIVE(
new SdkBuilder() {
@Override
Meter build() {
Expand All @@ -50,6 +50,19 @@ Meter build() {
.build()
.get("io.opentelemetry.sdk.metrics");
}
}),
SDK_DELTA(
new SdkBuilder() {
@Override
Meter build() {
return SdkMeterProvider.builder()
.setClock(Clock.getDefault())
.setResource(Resource.empty())
// Must register reader for real SDK.
.registerMetricReader(InMemoryMetricReader.createDelta())
.build()
.get("io.opentelemetry.sdk.metrics");
}
});

private final SdkBuilder sdkBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -46,8 +49,7 @@ public final class DefaultSynchronousMetricStorage<T extends PointData, U extend
private final MetricDescriptor metricDescriptor;
private final AggregationTemporality aggregationTemporality;
private final Aggregator<T, U> aggregator;
private final ConcurrentHashMap<Attributes, AggregatorHandle<T, U>> aggregatorHandles =
new ConcurrentHashMap<>();
private volatile AggregatorHolder<T, U> aggregatorHolder = new AggregatorHolder<>();
private final AttributesProcessor attributesProcessor;

/**
Expand Down Expand Up @@ -83,8 +85,15 @@ Queue<AggregatorHandle<T, U>> getAggregatorHandlePool() {

@Override
public void recordLong(long value, Attributes attributes, Context context) {
AggregatorHandle<T, U> handle = getAggregatorHandle(attributes, context);
handle.recordLong(value, attributes, context);
Lock readLock = aggregatorHolder.lock.readLock();
readLock.lock();
try {
AggregatorHandle<T, U> handle =
getAggregatorHandle(aggregatorHolder.aggregatorHandles, attributes, context);
handle.recordLong(value, attributes, context);
} finally {
readLock.unlock();
}
}

@Override
Expand All @@ -99,11 +108,21 @@ public void recordDouble(double value, Attributes attributes, Context context) {
+ ". Dropping measurement.");
return;
}
AggregatorHandle<T, U> handle = getAggregatorHandle(attributes, context);
handle.recordDouble(value, attributes, context);
Lock readLock = aggregatorHolder.lock.readLock();
readLock.lock();
try {
AggregatorHandle<T, U> handle =
getAggregatorHandle(aggregatorHolder.aggregatorHandles, attributes, context);
handle.recordDouble(value, attributes, context);
} finally {
readLock.unlock();
}
}

private AggregatorHandle<T, U> getAggregatorHandle(Attributes attributes, Context context) {
private AggregatorHandle<T, U> getAggregatorHandle(
ConcurrentHashMap<Attributes, AggregatorHandle<T, U>> aggregatorHandles,
Attributes attributes,
Context context) {
Objects.requireNonNull(attributes, "attributes");
attributes = attributesProcessor.process(attributes, context);
AggregatorHandle<T, U> handle = aggregatorHandles.get(attributes);
Expand Down Expand Up @@ -146,13 +165,27 @@ public MetricData collect(
? registeredReader.getLastCollectEpochNanos()
: startEpochNanos;

ConcurrentHashMap<Attributes, AggregatorHandle<T, U>> aggregatorHandles;
if (reset) {
AggregatorHolder<T, U> holder = this.aggregatorHolder;
this.aggregatorHolder = new AggregatorHolder<>();
Lock writeLock = holder.lock.writeLock();
writeLock.lock();
try {
aggregatorHandles = holder.aggregatorHandles;
} finally {
writeLock.unlock();
}
} else {
aggregatorHandles = this.aggregatorHolder.aggregatorHandles;
}

// Grab aggregated points.
List<T> points = new ArrayList<>(aggregatorHandles.size());
aggregatorHandles.forEach(
(attributes, handle) -> {
T point = handle.aggregateThenMaybeReset(start, epochNanos, attributes, reset);
if (reset) {
aggregatorHandles.remove(attributes, handle);
// Return the aggregator to the pool.
aggregatorHandlePool.offer(handle);
}
Expand Down Expand Up @@ -180,4 +213,10 @@ public MetricData collect(
public MetricDescriptor getMetricDescriptor() {
return metricDescriptor;
}

private static class AggregatorHolder<T extends PointData, U extends ExemplarData> {
private final ConcurrentHashMap<Attributes, AggregatorHandle<T, U>> aggregatorHandles =
new ConcurrentHashMap<>();
private final ReadWriteLock lock = new ReentrantReadWriteLock();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.common.util.concurrent.AtomicDouble;
import com.google.common.util.concurrent.Uninterruptibles;
import io.github.netmikey.logunit.api.LogCapturer;
import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.common.Attributes;
Expand All @@ -21,9 +23,11 @@
import io.opentelemetry.sdk.metrics.Aggregation;
import io.opentelemetry.sdk.metrics.InstrumentType;
import io.opentelemetry.sdk.metrics.InstrumentValueType;
import io.opentelemetry.sdk.metrics.data.ExemplarData;
import io.opentelemetry.sdk.metrics.data.LongExemplarData;
import io.opentelemetry.sdk.metrics.data.LongPointData;
import io.opentelemetry.sdk.metrics.data.MetricData;
import io.opentelemetry.sdk.metrics.data.PointData;
import io.opentelemetry.sdk.metrics.internal.aggregator.Aggregator;
import io.opentelemetry.sdk.metrics.internal.aggregator.AggregatorFactory;
import io.opentelemetry.sdk.metrics.internal.aggregator.EmptyMetricData;
Expand All @@ -37,8 +41,17 @@
import io.opentelemetry.sdk.resources.Resource;
import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader;
import io.opentelemetry.sdk.testing.time.TestClock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.event.Level;

@SuppressLogger(DefaultSynchronousMetricStorage.class)
Expand Down Expand Up @@ -370,4 +383,79 @@ void recordAndCollect_DeltaAtLimit() {
assertThat(storage.getAggregatorHandlePool()).hasSize(CARDINALITY_LIMIT);
logs.assertContains("Instrument name has exceeded the maximum allowed cardinality");
}

@ParameterizedTest
@MethodSource("concurrentStressTestArguments")
void recordAndCollect_concurrentStressTest(
DefaultSynchronousMetricStorage<?, ?> storage, BiConsumer<Double, AtomicDouble> collect) {
// Define record threads. Each records a value of 1.0, 2000 times
List<Thread> threads = new ArrayList<>();
CountDownLatch latch = new CountDownLatch(4);
for (int i = 0; i < 4; i++) {
Thread thread =
new Thread(
() -> {
for (int j = 0; j < 2000; j++) {
storage.recordDouble(1.0, Attributes.empty(), Context.current());
Uninterruptibles.sleepUninterruptibly(Duration.ofMillis(1));
}
latch.countDown();
});
threads.add(thread);
}

// Define collect thread. Collect thread collects and aggregates the
AtomicDouble cumulativeSum = new AtomicDouble();
Thread collectThread =
new Thread(
() -> {
do {
Uninterruptibles.sleepUninterruptibly(Duration.ofMillis(1));
MetricData metricData =
storage.collect(Resource.empty(), InstrumentationScopeInfo.empty(), 0, 1);
if (metricData.isEmpty()) {
continue;
}
metricData.getDoubleSumData().getPoints().stream()
.findFirst()
.ifPresent(pointData -> collect.accept(pointData.getValue(), cumulativeSum));
} while (latch.getCount() != 0);
});

// Start all the threads
collectThread.start();
threads.forEach(Thread::start);

// Wait for the collect thread to end, which collects until the record threads are done
Uninterruptibles.joinUninterruptibly(collectThread);

assertThat(cumulativeSum.get()).isEqualTo(8000.0);
}

private static Stream<Arguments> concurrentStressTestArguments() {
Aggregator<PointData, ExemplarData> aggregator =
((AggregatorFactory) Aggregation.sum())
.createAggregator(DESCRIPTOR, ExemplarFilter.alwaysOff());
return Stream.of(
Arguments.of(
// Delta
new DefaultSynchronousMetricStorage<>(
RegisteredReader.create(InMemoryMetricReader.createDelta(), ViewRegistry.create()),
METRIC_DESCRIPTOR,
aggregator,
AttributesProcessor.noop(),
CARDINALITY_LIMIT),
(BiConsumer<Double, AtomicDouble>)
(value, cumulativeCount) -> cumulativeCount.addAndGet(value)),
Arguments.of(
// Cumulative
new DefaultSynchronousMetricStorage<>(
RegisteredReader.create(InMemoryMetricReader.create(), ViewRegistry.create()),
METRIC_DESCRIPTOR,
aggregator,
AttributesProcessor.noop(),
CARDINALITY_LIMIT),
(BiConsumer<Double, AtomicDouble>)
(value, cumulativeCount) -> cumulativeCount.set(value)));
}
}

0 comments on commit 04f6d9c

Please sign in to comment.