diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorageBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorageBenchmark.java
new file mode 100644
index 00000000000..511aea3d336
--- /dev/null
+++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorageBenchmark.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2022 LINE Corporation
+ *
+ * LINE Corporation licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.linecorp.armeria.internal.common;
+
+import org.openjdk.jmh.annotations.Benchmark;
+
+import com.linecorp.armeria.common.HttpMethod;
+import com.linecorp.armeria.common.HttpRequest;
+import com.linecorp.armeria.common.RequestContext;
+import com.linecorp.armeria.common.RequestContextStorage;
+import com.linecorp.armeria.common.util.Sampler;
+import com.linecorp.armeria.server.ServiceRequestContext;
+
+/**
+ * Microbenchmarks for LeakTracingRequestContextStorage.
+ */
+public class LeakTracingRequestContextStorageBenchmark {
+
+ private static final RequestContextStorage threadLocalReqCtxStorage =
+ RequestContextStorage.threadLocal();
+ private static final RequestContextStorage neverSample =
+ new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.never());
+ private static final RequestContextStorage rateLimited1 =
+ new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("rate-limited=1"));
+ private static final RequestContextStorage rateLimited10 =
+ new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("rate-limited=10"));
+ private static final RequestContextStorage random1 =
+ new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("random=0.01"));
+ private static final RequestContextStorage random10 =
+ new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("random=0.10"));
+ private static final RequestContextStorage alwaysSample =
+ new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.always());
+ private static final RequestContext reqCtx = newCtx("/");
+
+ private static ServiceRequestContext newCtx(String path) {
+ return ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, path))
+ .build();
+ }
+
+ @Benchmark
+ public void baseline_threadLocal() {
+ final RequestContext oldCtx = threadLocalReqCtxStorage.push(reqCtx);
+ threadLocalReqCtxStorage.pop(reqCtx, oldCtx);
+ }
+
+ @Benchmark
+ public void leakTracing_never_sample() {
+ final RequestContext oldCtx = neverSample.push(reqCtx);
+ neverSample.pop(reqCtx, oldCtx);
+ }
+
+ @Benchmark
+ public void leakTracing_rateLimited_1() {
+ final RequestContext oldCtx = rateLimited1.push(reqCtx);
+ rateLimited1.pop(reqCtx, oldCtx);
+ }
+
+ @Benchmark
+ public void leakTracing_rateLimited_10() {
+ final RequestContext oldCtx = rateLimited10.push(reqCtx);
+ rateLimited10.pop(reqCtx, oldCtx);
+ }
+
+ @Benchmark
+ public void leakTracing_random_1() {
+ final RequestContext oldCtx = random1.push(reqCtx);
+ random1.pop(reqCtx, oldCtx);
+ }
+
+ @Benchmark
+ public void leakTracing_random_10() {
+ final RequestContext oldCtx = random10.push(reqCtx);
+ random10.pop(reqCtx, oldCtx);
+ }
+
+ @Benchmark
+ public void leakTracing_always_sample() {
+ final RequestContext oldCtx = alwaysSample.push(reqCtx);
+ alwaysSample.pop(reqCtx, oldCtx);
+ }
+}
diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java
index 9e26f63a0c3..af277a0f9e8 100644
--- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java
+++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java
@@ -216,17 +216,17 @@ static ClientRequestContextBuilder builder(RpcRequest request, URI uri) {
@MustBeClosed
default SafeCloseable push() {
final RequestContext oldCtx = RequestContextUtil.getAndSet(this);
- if (oldCtx == this) {
- // Reentrance
- return noopSafeCloseable();
- }
-
if (oldCtx == null) {
return RequestContextUtil.invokeHookAndPop(this, null);
}
+ if (oldCtx.unwrapAll() == unwrapAll()) {
+ // Reentrance
+ return noopSafeCloseable();
+ }
+
final ServiceRequestContext root = root();
- if (oldCtx.root() == root) {
+ if (RequestContextUtil.equalsIgnoreWrapper(oldCtx.root(), root)) {
return RequestContextUtil.invokeHookAndPop(this, oldCtx);
}
diff --git a/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java
index a190ab84abb..d7a01a7069f 100644
--- a/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java
+++ b/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java
@@ -398,4 +398,9 @@ public Path defaultMultipartUploadsLocation() {
File.separatorChar + "armeria" +
File.separatorChar + "multipart-uploads");
}
+
+ @Override
+ public Sampler super RequestContext> requestContextLeakDetectionSampler() {
+ return Sampler.never();
+ }
}
diff --git a/core/src/main/java/com/linecorp/armeria/common/Flags.java b/core/src/main/java/com/linecorp/armeria/common/Flags.java
index eb470d54c04..c994cf3dbb7 100644
--- a/core/src/main/java/com/linecorp/armeria/common/Flags.java
+++ b/core/src/main/java/com/linecorp/armeria/common/Flags.java
@@ -49,6 +49,7 @@
import com.linecorp.armeria.client.retry.RetryingClient;
import com.linecorp.armeria.client.retry.RetryingRpcClient;
import com.linecorp.armeria.common.annotation.Nullable;
+import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.common.util.Sampler;
import com.linecorp.armeria.common.util.SystemInfo;
@@ -113,10 +114,6 @@ public final class Flags {
static {
final String strSpec = getNormalized("verboseExceptions",
DefaultFlagsProvider.VERBOSE_EXCEPTION_SAMPLER_SPEC, val -> {
- if ("true".equals(val) || "false".equals(val)) {
- return true;
- }
-
try {
Sampler.of(val);
return true;
@@ -377,6 +374,9 @@ public final class Flags {
private static final Path DEFAULT_MULTIPART_UPLOADS_LOCATION =
getValue(FlagsProvider::defaultMultipartUploadsLocation, "defaultMultipartUploadsLocation");
+ private static final Sampler super RequestContext> REQUEST_CONTEXT_LEAK_DETECTION_SAMPLER =
+ getValue(FlagsProvider::requestContextLeakDetectionSampler, "requestContextLeakDetectionSampler");
+
/**
* Returns the specification of the {@link Sampler} that determines whether to retain the stack
* trace of the exceptions that are thrown frequently by Armeria. A sampled exception will have the stack
@@ -1296,6 +1296,21 @@ public static boolean allowDoubleDotsInQueryString() {
return ALLOW_DOUBLE_DOTS_IN_QUERY_STRING;
}
+ /**
+ * Returns the {@link Sampler} that determines whether to trace the stack trace of request contexts leaks
+ * and how frequently to keeps stack trace. A sampled exception will have the stack trace while the others
+ * will have an empty stack trace to eliminate the cost of capturing the stack trace.
+ *
+ *
The default value of this flag is {@link Sampler#never()}.
+ * Specify the {@code -Dcom.linecorp.armeria.requestContextLeakDetectionSampler=} JVM option
+ * to override the default. This feature is disabled if {@link Sampler#never()} is specified.
+ * See {@link Sampler#of(String)} for the specification string format.
+ */
+ @UnstableApi
+ public static Sampler super RequestContext> requestContextLeakDetectionSampler() {
+ return REQUEST_CONTEXT_LEAK_DETECTION_SAMPLER;
+ }
+
@Nullable
private static String nullableCaffeineSpec(Function method, String flagName) {
return caffeineSpec(method, flagName, true);
diff --git a/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java
index ca6f44098e7..eb81e0704d2 100644
--- a/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java
+++ b/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java
@@ -960,4 +960,20 @@ default Boolean allowDoubleDotsInQueryString() {
default Path defaultMultipartUploadsLocation() {
return null;
}
+
+ /**
+ * Returns the {@link Sampler} that determines whether to trace the stack trace of request contexts leaks
+ * and how frequently to keeps stack trace. A sampled exception will have the stack trace while the others
+ * will have an empty stack trace to eliminate the cost of capturing the stack trace.
+ *
+ * The default value of this flag is {@link Sampler#never()}.
+ * Specify the {@code -Dcom.linecorp.armeria.requestContextLeakDetectionSampler=} JVM option
+ * to override the default. This feature is disabled if {@link Sampler#never()} is specified.
+ * See {@link Sampler#of(String)} for the specification string format.
+ */
+ @UnstableApi
+ @Nullable
+ default Sampler super RequestContext> requestContextLeakDetectionSampler() {
+ return null;
+ }
}
diff --git a/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java
index bffd83d138e..a05f58d757e 100644
--- a/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java
+++ b/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java
@@ -71,20 +71,12 @@ public Sampler> verboseExceptionSampler() {
if (spec == null) {
return null;
}
- if ("true".equals(spec) || "always".equals(spec)) {
- return Sampler.always();
- }
- if ("false".equals(spec) || "never".equals(spec)) {
- return Sampler.never();
- }
-
try {
Sampler.of(spec);
} catch (Exception e) {
// Invalid sampler specification
throw new IllegalArgumentException("invalid sampler spec: " + spec, e);
}
-
return new ExceptionSampler(spec);
}
@@ -438,6 +430,20 @@ public Path defaultMultipartUploadsLocation() {
return getAndParse("defaultMultipartUploadsLocation", Paths::get);
}
+ @Override
+ public Sampler super RequestContext> requestContextLeakDetectionSampler() {
+ final String spec = getNormalized("requestContextLeakDetectionSampler");
+ if (spec == null) {
+ return null;
+ }
+ try {
+ return Sampler.of(spec);
+ } catch (Exception e) {
+ // Invalid sampler specification
+ throw new IllegalArgumentException("invalid sampler spec: " + spec, e);
+ }
+ }
+
@Nullable
private static Long getLong(String name) {
return getAndParse(name, Long::parseLong);
diff --git a/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java b/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java
index 7fe5e0a59aa..cc110929c89 100644
--- a/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java
+++ b/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java
@@ -46,7 +46,7 @@ public void pop(RequestContext current, @Nullable RequestContext toRestore) {
requireNonNull(current, "current");
final InternalThreadLocalMap map = InternalThreadLocalMap.get();
final RequestContext contextInThreadLocal = context.get(map);
- if (current != contextInThreadLocal) {
+ if (contextInThreadLocal == null || current.unwrapAll() != contextInThreadLocal.unwrapAll()) {
throw newIllegalContextPoppingException(current, contextInThreadLocal);
}
context.set(map, toRestore);
diff --git a/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java b/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java
index 00436d55f49..1d8463d5426 100644
--- a/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java
+++ b/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java
@@ -70,8 +70,10 @@ static Sampler of(String specification) {
requireNonNull(specification, "specification");
switch (specification.trim()) {
case "always":
+ case "true":
return Sampler.always();
case "never":
+ case "false":
return Sampler.never();
}
diff --git a/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java b/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java
index 7244970ef3d..250ea693e22 100644
--- a/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java
+++ b/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java
@@ -129,4 +129,20 @@ default Object unwrapAll() {
unwrapped = inner;
}
}
+
+ /**
+ * Reference checking this {@link Unwrappable} to another {@link Unwrappable}, ignoring wrappers.
+ * Two {@link Unwrappable} are considered equal ignoring wrappers if they are of the same object reference
+ * after {@link #unwrapAll()}.
+ * @param other The {@link Unwrappable} to compare this {@link Unwrappable} against
+ * @return true if the argument is not {@code null}, and it represents a same object reference after
+ * {@code unwrapAll()}, false otherwise.
+ */
+ @UnstableApi
+ default boolean equalsIgnoreWrapper(@Nullable Unwrappable other) {
+ if (other == null) {
+ return false;
+ }
+ return unwrapAll() == other.unwrapAll();
+ }
}
diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorage.java b/core/src/main/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorage.java
new file mode 100644
index 00000000000..752b8c896da
--- /dev/null
+++ b/core/src/main/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorage.java
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2022 LINE Corporation
+ *
+ * LINE Corporation licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.linecorp.armeria.internal.common;
+
+import static java.lang.Thread.currentThread;
+import static java.util.Objects.requireNonNull;
+
+import com.linecorp.armeria.client.ClientRequestContext;
+import com.linecorp.armeria.client.ClientRequestContextWrapper;
+import com.linecorp.armeria.common.RequestContext;
+import com.linecorp.armeria.common.RequestContextStorage;
+import com.linecorp.armeria.common.RequestContextWrapper;
+import com.linecorp.armeria.common.annotation.Nullable;
+import com.linecorp.armeria.common.util.Sampler;
+import com.linecorp.armeria.server.ServiceRequestContext;
+import com.linecorp.armeria.server.ServiceRequestContextWrapper;
+
+/**
+ * A {@link RequestContextStorage} which keeps track of {@link RequestContext}s, reporting pushed thread
+ * information if a {@link RequestContext} is leaked.
+ */
+final class LeakTracingRequestContextStorage implements RequestContextStorage {
+
+ private final RequestContextStorage delegate;
+ private final Sampler super RequestContext> sampler;
+
+ /**
+ * Creates a new instance.
+ * @param delegate the underlying {@link RequestContextStorage} that stores {@link RequestContext}
+ * @param sampler the {@link Sampler} that determines whether to retain the stacktrace of the context leaks
+ */
+ LeakTracingRequestContextStorage(RequestContextStorage delegate,
+ Sampler super RequestContext> sampler) {
+ this.delegate = requireNonNull(delegate, "delegate");
+ this.sampler = requireNonNull(sampler, "sampler");
+ }
+
+ @Nullable
+ @Override
+ public T push(RequestContext toPush) {
+ requireNonNull(toPush, "toPush");
+ if (sampler.isSampled(toPush)) {
+ return delegate.push(wrapRequestContext(toPush));
+ }
+ return delegate.push(toPush);
+ }
+
+ @Override
+ public void pop(RequestContext current, @Nullable RequestContext toRestore) {
+ requireNonNull(current, "current");
+ delegate.pop(current, toRestore);
+ }
+
+ @Nullable
+ @Override
+ public T currentOrNull() {
+ return delegate.currentOrNull();
+ }
+
+ @Override
+ public RequestContextStorage unwrap() {
+ return delegate;
+ }
+
+ private static RequestContextWrapper> wrapRequestContext(RequestContext ctx) {
+ return ctx instanceof ClientRequestContext ?
+ new TraceableClientRequestContext((ClientRequestContext) ctx)
+ : new TraceableServiceRequestContext((ServiceRequestContext) ctx);
+ }
+
+ private static String stacktraceToString(StackTraceElement[] stackTrace,
+ RequestContext unwrap) {
+ final StringBuilder builder = new StringBuilder(512);
+ builder.append(unwrap).append(System.lineSeparator())
+ .append("The previous RequestContext is pushed at the following stacktrace")
+ .append(System.lineSeparator());
+ for (int i = 1; i < stackTrace.length; i++) {
+ builder.append("\tat ").append(stackTrace[i]).append(System.lineSeparator());
+ }
+ return builder.toString();
+ }
+
+ private static final class TraceableClientRequestContext extends ClientRequestContextWrapper {
+
+ private final StackTraceElement[] stackTrace;
+
+ private TraceableClientRequestContext(ClientRequestContext delegate) {
+ super(delegate);
+ stackTrace = currentThread().getStackTrace();
+ }
+
+ @Override
+ public String toString() {
+ return stacktraceToString(stackTrace, unwrap());
+ }
+ }
+
+ private static final class TraceableServiceRequestContext extends ServiceRequestContextWrapper {
+
+ private final StackTraceElement[] stackTrace;
+
+ private TraceableServiceRequestContext(ServiceRequestContext delegate) {
+ super(delegate);
+ stackTrace = currentThread().getStackTrace();
+ }
+
+ @Override
+ public String toString() {
+ return stacktraceToString(stackTrace, unwrap());
+ }
+ }
+}
diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java
index d74893c4585..bd6dc307722 100644
--- a/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java
+++ b/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java
@@ -37,6 +37,7 @@
import com.linecorp.armeria.common.RequestContextStorageProvider;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.SafeCloseable;
+import com.linecorp.armeria.common.util.Sampler;
import com.linecorp.armeria.internal.client.DefaultClientRequestContext;
import com.linecorp.armeria.internal.server.DefaultServiceRequestContext;
@@ -63,8 +64,13 @@ public final class RequestContextUtil {
static {
final RequestContextStorageProvider provider = Flags.requestContextStorageProvider();
+ final Sampler super RequestContext> sampler = Flags.requestContextLeakDetectionSampler();
try {
- requestContextStorage = provider.newStorage();
+ if (sampler == Sampler.never()) {
+ requestContextStorage = provider.newStorage();
+ } else {
+ requestContextStorage = new LeakTracingRequestContextStorage(provider.newStorage(), sampler);
+ }
} catch (Throwable t) {
throw new IllegalStateException("Failed to create context storage. provider: " + provider, t);
}
@@ -195,6 +201,13 @@ public static SafeCloseable invokeHookAndPop(RequestContext current, @Nullable R
}
}
+ public static boolean equalsIgnoreWrapper(@Nullable RequestContext ctx1, @Nullable RequestContext ctx2) {
+ if (ctx1 == null) {
+ return ctx2 == null;
+ }
+ return ctx1.equalsIgnoreWrapper(ctx2);
+ }
+
@Nullable
private static AutoCloseable invokeHook(RequestContext ctx) {
final Supplier extends AutoCloseable> hook;
diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java
index 7a18cda9e0b..ec1b3e8b43d 100644
--- a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java
+++ b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java
@@ -219,16 +219,16 @@ default ServiceRequestContext root() {
@MustBeClosed
default SafeCloseable push() {
final RequestContext oldCtx = RequestContextUtil.getAndSet(this);
- if (oldCtx == this) {
- // Reentrance
- return noopSafeCloseable();
- }
-
if (oldCtx == null) {
return RequestContextUtil.invokeHookAndPop(this, null);
}
- if (oldCtx.root() == this) {
+ if (oldCtx.unwrapAll() == unwrapAll()) {
+ // Reentrance
+ return noopSafeCloseable();
+ }
+
+ if (RequestContextUtil.equalsIgnoreWrapper(oldCtx.root(), this)) {
return RequestContextUtil.invokeHookAndPop(this, oldCtx);
}
diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java
index b1193e12c68..c89477c4e4e 100644
--- a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java
+++ b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java
@@ -41,9 +41,9 @@ void current() {
final ClientRequestContext ctx = clientRequestContext();
assertThat(ctx.id()).isNotNull();
try (SafeCloseable unused = ctx.push()) {
- assertThat(ClientRequestContext.current()).isSameAs(ctx);
+ assertThat(ClientRequestContext.current().unwrapAll()).isSameAs(ctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
try (SafeCloseable unused = serviceRequestContext().push()) {
assertThatThrownBy(ClientRequestContext::current)
@@ -58,9 +58,9 @@ void currentOrNull() {
final ClientRequestContext ctx = clientRequestContext();
try (SafeCloseable unused = ctx.push()) {
- assertThat(ClientRequestContext.currentOrNull()).isSameAs(ctx);
+ assertThat(ClientRequestContext.currentOrNull().unwrapAll()).isSameAs(ctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
try (SafeCloseable unused = serviceRequestContext().push()) {
assertThat(ClientRequestContext.currentOrNull()).isNull();
@@ -75,9 +75,10 @@ void mapCurrent() {
final ClientRequestContext ctx = clientRequestContext();
try (SafeCloseable unused = ctx.push()) {
assertThat(ClientRequestContext.mapCurrent(c -> "foo", () -> "bar")).isEqualTo("foo");
- assertThat(ClientRequestContext.mapCurrent(Function.identity(), null)).isSameAs(ctx);
+ assertThat(ClientRequestContext.mapCurrent(Function.identity(), null).unwrapAll())
+ .isSameAs(ctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
try (SafeCloseable unused = serviceRequestContext().push()) {
assertThatThrownBy(() -> ClientRequestContext.mapCurrent(c -> "foo", () -> "bar"))
@@ -90,32 +91,32 @@ void mapCurrent() {
void pushReentrance() {
final ClientRequestContext ctx = clientRequestContext();
try (SafeCloseable ignored = ctx.push()) {
- assertCurrentCtx(ctx);
+ assertUnwrapAllCurrentCtx(ctx);
try (SafeCloseable ignored2 = ctx.push()) {
- assertCurrentCtx(ctx);
+ assertUnwrapAllCurrentCtx(ctx);
}
- assertCurrentCtx(ctx);
+ assertUnwrapAllCurrentCtx(ctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
void pushWithOldServiceCtx() {
final ServiceRequestContext sctx = serviceRequestContext();
try (SafeCloseable ignored = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
// The root of ClientRequestContext is sctx.
final ClientRequestContext cctx = clientRequestContext();
try (SafeCloseable ignored1 = cctx.push()) {
- assertCurrentCtx(cctx);
+ assertUnwrapAllCurrentCtx(cctx);
try (SafeCloseable ignored2 = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(cctx);
+ assertUnwrapAllCurrentCtx(cctx);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
@@ -135,65 +136,65 @@ void pushWithOldServiceCtx_exceptionWhenServiceCtxIsDifferFromRoot() {
void pushWithOldClientCtxWhoseRootIsSameServiceCtx_ctx2IsCreatedSameLayer() {
final ServiceRequestContext sctx = serviceRequestContext();
try (SafeCloseable ignored = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
final ClientRequestContext cctx1 = clientRequestContext();
final ClientRequestContext cctx2 = clientRequestContext();
assertThat(cctx1.root()).isSameAs(cctx2.root());
try (SafeCloseable ignored1 = cctx1.push()) {
- assertCurrentCtx(cctx1);
+ assertUnwrapAllCurrentCtx(cctx1);
try (SafeCloseable ignored2 = cctx2.push()) {
- assertCurrentCtx(cctx2);
+ assertUnwrapAllCurrentCtx(cctx2);
}
- assertCurrentCtx(cctx1);
+ assertUnwrapAllCurrentCtx(cctx1);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
- void pushWithOldClientCtxWhoseRootIsSameServiceCtx__ctx2IsCreatedUnderCtx1() {
+ void pushWithOldClientCtxWhoseRootIsSameServiceCtx_ctx2IsCreatedUnderCtx1() {
final ServiceRequestContext sctx = serviceRequestContext();
try (SafeCloseable ignored = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
final ClientRequestContext cctx1 = clientRequestContext();
try (SafeCloseable ignored1 = cctx1.push()) {
- assertCurrentCtx(cctx1);
+ assertUnwrapAllCurrentCtx(cctx1);
final ClientRequestContext cctx2 = clientRequestContext();
assertThat(cctx1.root()).isSameAs(cctx2.root());
try (SafeCloseable ignored2 = cctx2.push()) {
- assertCurrentCtx(cctx2);
+ assertUnwrapAllCurrentCtx(cctx2);
}
- assertCurrentCtx(cctx1);
+ assertUnwrapAllCurrentCtx(cctx1);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
void pushWithOldClientCtxWhoseRootIsSameServiceCtx_derivedCtx() {
final ServiceRequestContext sctx = serviceRequestContext();
try (SafeCloseable ignored = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
final ClientRequestContext cctx1 = clientRequestContext();
final ClientRequestContext derived = cctx1.newDerivedContext(cctx1.id(), cctx1.request(),
cctx1.rpcRequest(), cctx1.endpoint());
try (SafeCloseable ignored1 = derived.push()) {
- assertCurrentCtx(derived);
+ assertUnwrapAllCurrentCtx(derived);
final ClientRequestContext cctx2 = clientRequestContext();
assertThat(derived.root()).isSameAs(cctx2.root());
try (SafeCloseable ignored2 = cctx2.push()) {
- assertCurrentCtx(cctx2);
+ assertUnwrapAllCurrentCtx(cctx2);
}
- assertCurrentCtx(derived);
+ assertUnwrapAllCurrentCtx(derived);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
@@ -217,16 +218,16 @@ void pushWithOldClientCtxWhoseRootIsDifferent() {
void pushWithOldClientCtxWhoseRootIsNull() {
final ClientRequestContext cctx1 = clientRequestContext();
try (SafeCloseable ignored1 = cctx1.push()) {
- assertCurrentCtx(cctx1);
+ assertUnwrapAllCurrentCtx(cctx1);
final ClientRequestContext cctx2 = clientRequestContext();
assertThat(cctx1.root()).isNull();
assertThat(cctx2.root()).isNull();
try (SafeCloseable ignored2 = cctx2.push()) {
- assertCurrentCtx(cctx2);
+ assertUnwrapAllCurrentCtx(cctx2);
}
- assertCurrentCtx(cctx1);
+ assertUnwrapAllCurrentCtx(cctx1);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
@@ -259,9 +260,13 @@ void hasOwnAttr() {
}
}
- private static void assertCurrentCtx(@Nullable RequestContext ctx) {
+ private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) {
final RequestContext current = RequestContext.currentOrNull();
- assertThat(current).isSameAs(ctx);
+ if (current == null) {
+ assertThat(ctx).isNull();
+ } else {
+ assertThat(current.unwrapAll()).isSameAs(ctx);
+ }
}
private static ServiceRequestContext serviceRequestContext() {
diff --git a/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java b/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java
index 7d130324a66..f7106d8267f 100644
--- a/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java
+++ b/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java
@@ -210,7 +210,7 @@ void systemPropertyVerboseExceptionSampler() throws Throwable {
final Method method = flags.getDeclaredMethod("verboseExceptionSampler");
assertThat(method.invoke(flags))
.usingRecursiveComparison()
- .isEqualTo(Sampler.always());
+ .isEqualTo(new ExceptionSampler("true"));
}
@Test
@@ -222,6 +222,32 @@ void invalidSystemPropertyVerboseExceptionSampler() throws Throwable {
.isEqualTo(new ExceptionSampler("rate-limit=10"));
}
+ @Test
+ void defaultRequestContextLeakDetectionSampler() throws Exception {
+ final Method method = flags.getDeclaredMethod("requestContextLeakDetectionSampler");
+ assertThat(method.invoke(flags))
+ .usingRecursiveComparison()
+ .isEqualTo(Sampler.never());
+ }
+
+ @Test
+ @SetSystemProperty(key = "com.linecorp.armeria.requestContextLeakDetectionSampler", value = "always")
+ void systemPropertyRequestContextLeakDetectionSampler() throws Exception {
+ final Method method = flags.getDeclaredMethod("requestContextLeakDetectionSampler");
+ assertThat(method.invoke(flags))
+ .usingRecursiveComparison()
+ .isEqualTo(Sampler.always());
+ }
+
+ @Test
+ @SetSystemProperty(key = "com.linecorp.armeria.requestContextLeakDetectionSampler", value = "invalid-spec")
+ void invalidSystemPropertyRequestContextLeakDetectionSampler() throws Exception {
+ final Method method = flags.getDeclaredMethod("requestContextLeakDetectionSampler");
+ assertThat(method.invoke(flags))
+ .usingRecursiveComparison()
+ .isEqualTo(Sampler.never());
+ }
+
@Test
void testApiConsistencyBetweenFlagsAndFlagsProvider() {
//Check method consistency between Flags and FlagsProvider excluding deprecated methods
diff --git a/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java b/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java
index 92183c934ed..3e3c5302508 100644
--- a/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java
+++ b/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java
@@ -31,10 +31,14 @@ void goodOf() {
// 'always'
assertThat(Sampler.of("always")).isSameAs(Sampler.always());
assertThat(Sampler.of(" always ")).isSameAs(Sampler.always());
+ assertThat(Sampler.of("true")).isSameAs(Sampler.always());
+ assertThat(Sampler.of(" true ")).isSameAs(Sampler.always());
// 'never'
assertThat(Sampler.of("never")).isSameAs(Sampler.never());
assertThat(Sampler.of(" never ")).isSameAs(Sampler.never());
+ assertThat(Sampler.of("false")).isSameAs(Sampler.never());
+ assertThat(Sampler.of(" false ")).isSameAs(Sampler.never());
// 'random='
assertThat(Sampler.of("random=0")).isSameAs(Sampler.never());
diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java
index b4c157cc9ab..927aecee1fd 100644
--- a/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java
+++ b/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java
@@ -20,6 +20,7 @@
import java.util.function.Function;
+import org.assertj.core.api.ObjectAssert;
import org.junit.jupiter.api.Test;
import com.google.common.collect.ImmutableList;
@@ -44,13 +45,13 @@ void current() {
assertThat(ServiceRequestContext.current()).isSameAs(sctx);
final ClientRequestContext cctx = clientRequestContext();
try (SafeCloseable unused1 = cctx.push()) {
- assertThat(ServiceRequestContext.current()).isSameAs(sctx);
- assertThat(ClientRequestContext.current()).isSameAs(cctx);
- assertThat((ClientRequestContext) RequestContext.current()).isSameAs(cctx);
+ assertThatUnwrapAll(ServiceRequestContext.current()).isSameAs(sctx);
+ assertThatUnwrapAll(ClientRequestContext.current()).isSameAs(cctx);
+ assertThatUnwrapAll((ClientRequestContext) RequestContext.current()).isSameAs(cctx);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
try (SafeCloseable unused = clientRequestContext().push()) {
assertThatThrownBy(ServiceRequestContext::current)
@@ -68,13 +69,13 @@ void currentOrNull() {
assertThat(ServiceRequestContext.currentOrNull()).isSameAs(sctx);
final ClientRequestContext cctx = clientRequestContext();
try (SafeCloseable unused1 = cctx.push()) {
- assertThat(ServiceRequestContext.currentOrNull()).isSameAs(sctx);
- assertThat(ClientRequestContext.current()).isSameAs(cctx);
- assertThat((ClientRequestContext) RequestContext.current()).isSameAs(cctx);
+ assertThatUnwrapAll(ServiceRequestContext.currentOrNull()).isSameAs(sctx);
+ assertThatUnwrapAll(ClientRequestContext.current()).isSameAs(cctx);
+ assertThatUnwrapAll((ClientRequestContext) RequestContext.current()).isSameAs(cctx);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
try (SafeCloseable unused = clientRequestContext().push()) {
assertThat(ServiceRequestContext.currentOrNull()).isNull();
@@ -98,16 +99,16 @@ void mapCurrent() {
assertThat(ServiceRequestContext.mapCurrent(c -> c == sctx ? "foo" : "bar",
() -> "defaultValue"))
.isEqualTo("foo");
- assertThat(ClientRequestContext.mapCurrent(c -> c == cctx ? "baz" : "qux",
+ assertThat(ClientRequestContext.mapCurrent(c -> c.unwrapAll() == cctx ? "baz" : "qux",
() -> "defaultValue"))
.isEqualTo("baz");
assertThat(ServiceRequestContext.mapCurrent(Function.identity(), null)).isSameAs(sctx);
- assertThat(ClientRequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx);
- assertThat(RequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx);
+ assertThatUnwrapAll(ClientRequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx);
+ assertThatUnwrapAll(RequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
try (SafeCloseable unused = clientRequestContext().push()) {
assertThatThrownBy(() -> ServiceRequestContext.mapCurrent(c -> "foo", () -> "bar"))
@@ -120,43 +121,43 @@ void mapCurrent() {
void pushReentrance() {
final ServiceRequestContext ctx = serviceRequestContext();
try (SafeCloseable ignored = ctx.push()) {
- assertCurrentCtx(ctx);
+ assertUnwrapAllCurrentCtx(ctx);
try (SafeCloseable ignored2 = ctx.push()) {
- assertCurrentCtx(ctx);
+ assertUnwrapAllCurrentCtx(ctx);
}
- assertCurrentCtx(ctx);
+ assertUnwrapAllCurrentCtx(ctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
void pushWithOldClientCtxWhoseRootIsThisServiceCtx() {
final ServiceRequestContext sctx = serviceRequestContext();
try (SafeCloseable ignored = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
// The root of ClientRequestContext is sctx.
final ClientRequestContext cctx = clientRequestContext();
try (SafeCloseable ignored1 = cctx.push()) {
- assertCurrentCtx(cctx);
+ assertUnwrapAllCurrentCtx(cctx);
try (SafeCloseable ignored2 = sctx.push()) {
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(cctx);
+ assertUnwrapAllCurrentCtx(cctx);
}
- assertCurrentCtx(sctx);
+ assertUnwrapAllCurrentCtx(sctx);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
void pushWithOldIrrelevantClientCtx() {
final ClientRequestContext cctx = clientRequestContext();
try (SafeCloseable ignored = cctx.push()) {
- assertCurrentCtx(cctx);
+ assertUnwrapAllCurrentCtx(cctx);
final ServiceRequestContext sctx = serviceRequestContext();
assertThatThrownBy(sctx::push).isInstanceOf(IllegalStateException.class);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
@@ -164,10 +165,10 @@ void pushWithOldIrrelevantServiceCtx() {
final ServiceRequestContext sctx1 = serviceRequestContext();
final ServiceRequestContext sctx2 = serviceRequestContext();
try (SafeCloseable ignored = sctx1.push()) {
- assertCurrentCtx(sctx1);
+ assertUnwrapAllCurrentCtx(sctx1);
assertThatThrownBy(sctx2::push).isInstanceOf(IllegalStateException.class);
}
- assertCurrentCtx(null);
+ assertUnwrapAllCurrentCtx(null);
}
@Test
@@ -194,9 +195,17 @@ void queryParams() {
assertThat(ctx.queryParams("Not exist")).isEmpty();
}
- private static void assertCurrentCtx(@Nullable RequestContext ctx) {
+ private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) {
final RequestContext current = RequestContext.currentOrNull();
- assertThat(current).isSameAs(ctx);
+ if (current == null) {
+ assertThat(ctx).isNull();
+ } else {
+ assertThatUnwrapAll(current).isEqualTo(ctx);
+ }
+ }
+
+ private static ObjectAssert assertThatUnwrapAll(T actual) {
+ return assertThat(actual.unwrapAll());
}
private static ServiceRequestContext serviceRequestContext() {
diff --git a/it/trace-context-leak/build.gradle b/it/trace-context-leak/build.gradle
new file mode 100644
index 00000000000..d083ea187dc
--- /dev/null
+++ b/it/trace-context-leak/build.gradle
@@ -0,0 +1,8 @@
+task generateSources(type: Copy) {
+ from "${rootProject.projectDir}/core/src/test/java"
+ into "${project.ext.genSrcDir}/test/java"
+ include '**/ServiceRequestContextTest.java'
+ include '**/ClientRequestContextTest.java'
+}
+
+tasks.compileJava.dependsOn(generateSources)
diff --git a/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/EnableLeakDetectionFlagsProvider.java b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/EnableLeakDetectionFlagsProvider.java
new file mode 100644
index 00000000000..720e46ffb20
--- /dev/null
+++ b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/EnableLeakDetectionFlagsProvider.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2022 LINE Corporation
+ *
+ * LINE Corporation licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.linecorp.armeria.internal.common;
+
+import com.linecorp.armeria.common.FlagsProvider;
+import com.linecorp.armeria.common.RequestContext;
+import com.linecorp.armeria.common.util.Sampler;
+
+public final class EnableLeakDetectionFlagsProvider implements FlagsProvider {
+
+ @Override
+ public int priority() {
+ return 10;
+ }
+
+ @Override
+ public Sampler super RequestContext> requestContextLeakDetectionSampler() {
+ return Sampler.always();
+ }
+}
diff --git a/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/TraceRequestContextLeakTest.java b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/TraceRequestContextLeakTest.java
new file mode 100644
index 00000000000..82fff976316
--- /dev/null
+++ b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/TraceRequestContextLeakTest.java
@@ -0,0 +1,294 @@
+/*
+ * Copyright 2022 LINE Corporation
+ *
+ * LINE Corporation licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.linecorp.armeria.internal.common;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.awaitility.Awaitility.await;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import com.linecorp.armeria.client.ClientRequestContext;
+import com.linecorp.armeria.common.HttpMethod;
+import com.linecorp.armeria.common.HttpRequest;
+import com.linecorp.armeria.common.RequestContext;
+import com.linecorp.armeria.common.util.SafeCloseable;
+import com.linecorp.armeria.server.ServiceRequestContext;
+import com.linecorp.armeria.testing.junit5.common.EventLoopExtension;
+import com.linecorp.armeria.testing.junit5.common.EventLoopGroupExtension;
+
+import io.netty.channel.EventLoop;
+import io.netty.channel.EventLoopGroup;
+
+class TraceRequestContextLeakTest {
+
+ @RegisterExtension
+ static final EventLoopExtension eventLoopExtension = new EventLoopExtension();
+
+ @RegisterExtension
+ static final EventLoopGroupExtension eventLoopGroupExtension = new EventLoopGroupExtension(2);
+
+ @Test
+ void singleThreadContextNotLeak() throws InterruptedException {
+ final AtomicBoolean isThrown = new AtomicBoolean(false);
+ final CountDownLatch latch = new CountDownLatch(2);
+
+ final EventLoop executor = eventLoopExtension.get();
+
+ executor.execute(() -> {
+ final ServiceRequestContext ctx = newCtx("/1");
+ try (SafeCloseable ignore = ctx.push()) {
+ // Ignore
+ } catch (Exception ex) {
+ isThrown.set(true);
+ } finally {
+ latch.countDown();
+ }
+ });
+
+ executor.execute(() -> {
+ final ServiceRequestContext anotherCtx = newCtx("/2");
+ try (SafeCloseable ignore = anotherCtx.push()) {
+ final ClientRequestContext clientCtx = newClientCtx("/3");
+ try (SafeCloseable ignore2 = clientCtx.push()) {
+ // Ignore
+ }
+ } catch (Exception ex) {
+ isThrown.set(true);
+ } finally {
+ latch.countDown();
+ }
+ });
+
+ latch.await();
+ assertThat(isThrown).isFalse();
+ }
+
+ @Test
+ @SuppressWarnings("MustBeClosedChecker")
+ void singleThreadContextLeak() throws InterruptedException {
+ final AtomicBoolean isThrown = new AtomicBoolean();
+ final AtomicReference exception = new AtomicReference<>();
+
+ try (DeferredClose deferredClose = new DeferredClose()) {
+ final EventLoop executor = eventLoopExtension.get();
+
+ executor.execute(() -> {
+ final ServiceRequestContext ctx = newCtx("/1");
+ final SafeCloseable leaked = ctx.push(); // <- Leaked, should show in error.
+ deferredClose.add(executor, leaked);
+ });
+
+ executor.execute(() -> {
+ final ServiceRequestContext anotherCtx = newCtx("/2");
+ try (SafeCloseable ignore = anotherCtx.push()) {
+ // Ignore
+ } catch (Exception ex) {
+ exception.set(ex);
+ isThrown.set(true);
+ }
+ });
+
+ await().untilTrue(isThrown);
+ assertThat(exception.get())
+ .hasMessageContaining("singleThreadContextLeak$2(TraceRequestContextLeakTest.java:101)");
+ }
+ }
+
+ @Test
+ @SuppressWarnings("MustBeClosedChecker")
+ void multiThreadContextLeakNotInterfereOthersEventLoop() throws InterruptedException {
+ final AtomicBoolean isThrown = new AtomicBoolean(false);
+ final CountDownLatch latch = new CountDownLatch(2);
+
+ final EventLoopGroup executor = eventLoopGroupExtension.get();
+
+ final Executor ex1 = executor.next();
+ final Executor ex2 = executor.next();
+
+ try (DeferredClose deferredClose = new DeferredClose()) {
+ ex1.execute(() -> {
+ final ServiceRequestContext ctx = newCtx("/1");
+ final SafeCloseable leaked = ctx.push();
+ latch.countDown();
+ deferredClose.add(executor, leaked);
+ });
+
+ ex2.execute(() -> {
+ // Leak happened on the first eventLoop shouldn't affect 2nd eventLoop when trying to push
+ await().until(() -> latch.getCount() == 1);
+ final ServiceRequestContext anotherCtx = newCtx("/2");
+ try (SafeCloseable ignore1 = anotherCtx.push()) {
+ final ClientRequestContext cctx = newClientCtx("/3");
+ try (SafeCloseable ignore2 = cctx.push()) {
+ // Ignore
+ }
+ } catch (Exception ex) {
+ // Not suppose to throw exception on the second event loop
+ isThrown.set(true);
+ } finally {
+ latch.countDown();
+ }
+ });
+
+ latch.await();
+ assertThat(isThrown).isFalse();
+ }
+ }
+
+ @Test
+ @SuppressWarnings("MustBeClosedChecker")
+ void multiThreadContextLeak() throws InterruptedException {
+ final AtomicBoolean isThrown = new AtomicBoolean(false);
+ final AtomicReference exception = new AtomicReference<>();
+ final CountDownLatch waitForExecutor2 = new CountDownLatch(1);
+
+ final EventLoopGroup executor = eventLoopGroupExtension.get();
+
+ final ServiceRequestContext leakingCtx = newCtx("/1-leak");
+ final ServiceRequestContext anotherCtx2 = newCtx("/2-leak");
+ final ServiceRequestContext anotherCtx3 = newCtx("/3-leak");
+
+ final Executor ex1 = executor.next();
+ final Executor ex2 = executor.next();
+
+ try (DeferredClose deferredClose = new DeferredClose()) {
+ ex1.execute(() -> {
+ final SafeCloseable leaked = leakingCtx.push(); // <- Leaked, should show in error.
+ deferredClose.add(ex1, leaked);
+ });
+
+ ex2.execute(() -> {
+ try {
+ final SafeCloseable leaked = anotherCtx2.push();
+ deferredClose.add(ex2, leaked);
+ } catch (Exception ex) {
+ isThrown.set(true);
+ } finally {
+ waitForExecutor2.countDown();
+ }
+ });
+
+ waitForExecutor2.await();
+ assertThat(isThrown).isFalse();
+
+ ex1.execute(() -> {
+ try (SafeCloseable ignore = anotherCtx3.push()) {
+ // Ignore
+ } catch (Exception ex) {
+ exception.set(ex);
+ isThrown.set(true);
+ }
+ });
+
+ await().untilTrue(isThrown);
+ assertThat(exception.get())
+ .hasMessageContaining("multiThreadContextLeak$7(TraceRequestContextLeakTest.java:180)");
+ }
+ }
+
+ @Test
+ void pushIllegalServiceRequestContext() {
+ final ServiceRequestContext sctx1 = newCtx("/1");
+ final ServiceRequestContext sctx2 = newCtx("/2");
+ try (SafeCloseable ignored = sctx1.push()) {
+ assertThatThrownBy(sctx2::push).isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("pushed at the following stacktrace");
+ }
+ }
+
+ @Test
+ void multipleRequestContextPushBeforeLeak() {
+ final ServiceRequestContext sctx1 = newCtx("/1");
+ final ServiceRequestContext sctx2 = newCtx("/2");
+ try (SafeCloseable ignore1 = sctx1.push()) {
+ final ClientRequestContext cctx1 = newClientCtx("/3");
+ try (SafeCloseable ignore3 = cctx1.push()) {
+ assertThatThrownBy(sctx2::push).isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("pushed at the following stacktrace");
+ }
+ }
+ }
+
+ @Test
+ @SuppressWarnings("MustBeClosedChecker")
+ void connerCase() {
+ final AtomicReference exception = new AtomicReference<>();
+
+ try (DeferredClose deferredClose = new DeferredClose()) {
+ final ServiceRequestContext ctx = newCtx("/1");
+ try (SafeCloseable ignored = ctx.push()) {
+ final ClientRequestContext ctx2 = newClientCtx("/2");
+ ctx2.push(); // <- Leaked, should show in error.
+ deferredClose.add(ctx2);
+ final ClientRequestContext ctx3 = newClientCtx("/3");
+ try (SafeCloseable ignored1 = ctx3.push()) {
+ // Ignore
+ }
+ } catch (Exception ex) {
+ exception.set(ex);
+ }
+ }
+ assertThat(exception.get())
+ .hasMessageContaining("connerCase(TraceRequestContextLeakTest.java:245)");
+ }
+
+ private static ServiceRequestContext newCtx(String path) {
+ return ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, path))
+ .build();
+ }
+
+ private static ClientRequestContext newClientCtx(String path) {
+ return ClientRequestContext.builder(HttpRequest.of(HttpMethod.GET, path))
+ .build();
+ }
+
+ // Utility to clean up RequestContext leak after test
+ private static class DeferredClose implements SafeCloseable {
+
+ private final ConcurrentHashMap toClose;
+ private final Set toRemoveFromThreadLocal;
+
+ DeferredClose() {
+ toClose = new ConcurrentHashMap<>();
+ toRemoveFromThreadLocal = new HashSet<>();
+ }
+
+ void add(Executor executor, SafeCloseable closeable) {
+ toClose.put(executor, closeable);
+ }
+
+ void add(RequestContext requestContext) {
+ toRemoveFromThreadLocal.add(requestContext);
+ }
+
+ @Override
+ public void close() {
+ toClose.forEach((executor, closeable) -> executor.execute(closeable::close));
+ toRemoveFromThreadLocal.forEach(ctx -> RequestContextUtil.pop(ctx, null));
+ }
+ }
+}
diff --git a/it/trace-context-leak/src/test/resources/META-INF/services/com.linecorp.armeria.common.FlagsProvider b/it/trace-context-leak/src/test/resources/META-INF/services/com.linecorp.armeria.common.FlagsProvider
new file mode 100644
index 00000000000..0c5a2c315a9
--- /dev/null
+++ b/it/trace-context-leak/src/test/resources/META-INF/services/com.linecorp.armeria.common.FlagsProvider
@@ -0,0 +1 @@
+com.linecorp.armeria.internal.common.EnableLeakDetectionFlagsProvider
diff --git a/settings.gradle b/settings.gradle
index 71d083fba1f..7ac52e789d3 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -100,6 +100,7 @@ includeWithFlags ':it:spring:boot2-tomcat8', 'java', 'relocate'
includeWithFlags ':it:spring:boot2-tomcat9', 'java', 'relocate'
includeWithFlags ':it:spring:webflux-security', 'java', 'relocate'
includeWithFlags ':it:thrift-fullcamel', 'java', 'relocate'
+includeWithFlags ':it:trace-context-leak', 'java', 'relocate'
includeWithFlags ':jetty9.3', 'java', 'relocate'
includeWithFlags ':testing-internal', 'java', 'relocate'
includeWithFlags ':thrift0.12', 'java', 'relocate'