diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorageBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorageBenchmark.java new file mode 100644 index 00000000000..511aea3d336 --- /dev/null +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorageBenchmark.java @@ -0,0 +1,95 @@ +/* + * Copyright 2022 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import org.openjdk.jmh.annotations.Benchmark; + +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.RequestContextStorage; +import com.linecorp.armeria.common.util.Sampler; +import com.linecorp.armeria.server.ServiceRequestContext; + +/** + * Microbenchmarks for LeakTracingRequestContextStorage. + */ +public class LeakTracingRequestContextStorageBenchmark { + + private static final RequestContextStorage threadLocalReqCtxStorage = + RequestContextStorage.threadLocal(); + private static final RequestContextStorage neverSample = + new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.never()); + private static final RequestContextStorage rateLimited1 = + new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("rate-limited=1")); + private static final RequestContextStorage rateLimited10 = + new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("rate-limited=10")); + private static final RequestContextStorage random1 = + new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("random=0.01")); + private static final RequestContextStorage random10 = + new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.of("random=0.10")); + private static final RequestContextStorage alwaysSample = + new LeakTracingRequestContextStorage(threadLocalReqCtxStorage, Sampler.always()); + private static final RequestContext reqCtx = newCtx("/"); + + private static ServiceRequestContext newCtx(String path) { + return ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, path)) + .build(); + } + + @Benchmark + public void baseline_threadLocal() { + final RequestContext oldCtx = threadLocalReqCtxStorage.push(reqCtx); + threadLocalReqCtxStorage.pop(reqCtx, oldCtx); + } + + @Benchmark + public void leakTracing_never_sample() { + final RequestContext oldCtx = neverSample.push(reqCtx); + neverSample.pop(reqCtx, oldCtx); + } + + @Benchmark + public void leakTracing_rateLimited_1() { + final RequestContext oldCtx = rateLimited1.push(reqCtx); + rateLimited1.pop(reqCtx, oldCtx); + } + + @Benchmark + public void leakTracing_rateLimited_10() { + final RequestContext oldCtx = rateLimited10.push(reqCtx); + rateLimited10.pop(reqCtx, oldCtx); + } + + @Benchmark + public void leakTracing_random_1() { + final RequestContext oldCtx = random1.push(reqCtx); + random1.pop(reqCtx, oldCtx); + } + + @Benchmark + public void leakTracing_random_10() { + final RequestContext oldCtx = random10.push(reqCtx); + random10.pop(reqCtx, oldCtx); + } + + @Benchmark + public void leakTracing_always_sample() { + final RequestContext oldCtx = alwaysSample.push(reqCtx); + alwaysSample.pop(reqCtx, oldCtx); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java index 9e26f63a0c3..af277a0f9e8 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java @@ -216,17 +216,17 @@ static ClientRequestContextBuilder builder(RpcRequest request, URI uri) { @MustBeClosed default SafeCloseable push() { final RequestContext oldCtx = RequestContextUtil.getAndSet(this); - if (oldCtx == this) { - // Reentrance - return noopSafeCloseable(); - } - if (oldCtx == null) { return RequestContextUtil.invokeHookAndPop(this, null); } + if (oldCtx.unwrapAll() == unwrapAll()) { + // Reentrance + return noopSafeCloseable(); + } + final ServiceRequestContext root = root(); - if (oldCtx.root() == root) { + if (RequestContextUtil.equalsIgnoreWrapper(oldCtx.root(), root)) { return RequestContextUtil.invokeHookAndPop(this, oldCtx); } diff --git a/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java index a190ab84abb..d7a01a7069f 100644 --- a/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java +++ b/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java @@ -398,4 +398,9 @@ public Path defaultMultipartUploadsLocation() { File.separatorChar + "armeria" + File.separatorChar + "multipart-uploads"); } + + @Override + public Sampler requestContextLeakDetectionSampler() { + return Sampler.never(); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/Flags.java b/core/src/main/java/com/linecorp/armeria/common/Flags.java index eb470d54c04..c994cf3dbb7 100644 --- a/core/src/main/java/com/linecorp/armeria/common/Flags.java +++ b/core/src/main/java/com/linecorp/armeria/common/Flags.java @@ -49,6 +49,7 @@ import com.linecorp.armeria.client.retry.RetryingClient; import com.linecorp.armeria.client.retry.RetryingRpcClient; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.util.Exceptions; import com.linecorp.armeria.common.util.Sampler; import com.linecorp.armeria.common.util.SystemInfo; @@ -113,10 +114,6 @@ public final class Flags { static { final String strSpec = getNormalized("verboseExceptions", DefaultFlagsProvider.VERBOSE_EXCEPTION_SAMPLER_SPEC, val -> { - if ("true".equals(val) || "false".equals(val)) { - return true; - } - try { Sampler.of(val); return true; @@ -377,6 +374,9 @@ public final class Flags { private static final Path DEFAULT_MULTIPART_UPLOADS_LOCATION = getValue(FlagsProvider::defaultMultipartUploadsLocation, "defaultMultipartUploadsLocation"); + private static final Sampler REQUEST_CONTEXT_LEAK_DETECTION_SAMPLER = + getValue(FlagsProvider::requestContextLeakDetectionSampler, "requestContextLeakDetectionSampler"); + /** * Returns the specification of the {@link Sampler} that determines whether to retain the stack * trace of the exceptions that are thrown frequently by Armeria. A sampled exception will have the stack @@ -1296,6 +1296,21 @@ public static boolean allowDoubleDotsInQueryString() { return ALLOW_DOUBLE_DOTS_IN_QUERY_STRING; } + /** + * Returns the {@link Sampler} that determines whether to trace the stack trace of request contexts leaks + * and how frequently to keeps stack trace. A sampled exception will have the stack trace while the others + * will have an empty stack trace to eliminate the cost of capturing the stack trace. + * + *

The default value of this flag is {@link Sampler#never()}. + * Specify the {@code -Dcom.linecorp.armeria.requestContextLeakDetectionSampler=} JVM option + * to override the default. This feature is disabled if {@link Sampler#never()} is specified. + * See {@link Sampler#of(String)} for the specification string format.

+ */ + @UnstableApi + public static Sampler requestContextLeakDetectionSampler() { + return REQUEST_CONTEXT_LEAK_DETECTION_SAMPLER; + } + @Nullable private static String nullableCaffeineSpec(Function method, String flagName) { return caffeineSpec(method, flagName, true); diff --git a/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java index ca6f44098e7..eb81e0704d2 100644 --- a/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java +++ b/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java @@ -960,4 +960,20 @@ default Boolean allowDoubleDotsInQueryString() { default Path defaultMultipartUploadsLocation() { return null; } + + /** + * Returns the {@link Sampler} that determines whether to trace the stack trace of request contexts leaks + * and how frequently to keeps stack trace. A sampled exception will have the stack trace while the others + * will have an empty stack trace to eliminate the cost of capturing the stack trace. + * + *

The default value of this flag is {@link Sampler#never()}. + * Specify the {@code -Dcom.linecorp.armeria.requestContextLeakDetectionSampler=} JVM option + * to override the default. This feature is disabled if {@link Sampler#never()} is specified. + * See {@link Sampler#of(String)} for the specification string format.

+ */ + @UnstableApi + @Nullable + default Sampler requestContextLeakDetectionSampler() { + return null; + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java index bffd83d138e..a05f58d757e 100644 --- a/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java +++ b/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java @@ -71,20 +71,12 @@ public Sampler> verboseExceptionSampler() { if (spec == null) { return null; } - if ("true".equals(spec) || "always".equals(spec)) { - return Sampler.always(); - } - if ("false".equals(spec) || "never".equals(spec)) { - return Sampler.never(); - } - try { Sampler.of(spec); } catch (Exception e) { // Invalid sampler specification throw new IllegalArgumentException("invalid sampler spec: " + spec, e); } - return new ExceptionSampler(spec); } @@ -438,6 +430,20 @@ public Path defaultMultipartUploadsLocation() { return getAndParse("defaultMultipartUploadsLocation", Paths::get); } + @Override + public Sampler requestContextLeakDetectionSampler() { + final String spec = getNormalized("requestContextLeakDetectionSampler"); + if (spec == null) { + return null; + } + try { + return Sampler.of(spec); + } catch (Exception e) { + // Invalid sampler specification + throw new IllegalArgumentException("invalid sampler spec: " + spec, e); + } + } + @Nullable private static Long getLong(String name) { return getAndParse(name, Long::parseLong); diff --git a/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java b/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java index 7fe5e0a59aa..cc110929c89 100644 --- a/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java +++ b/core/src/main/java/com/linecorp/armeria/common/ThreadLocalRequestContextStorage.java @@ -46,7 +46,7 @@ public void pop(RequestContext current, @Nullable RequestContext toRestore) { requireNonNull(current, "current"); final InternalThreadLocalMap map = InternalThreadLocalMap.get(); final RequestContext contextInThreadLocal = context.get(map); - if (current != contextInThreadLocal) { + if (contextInThreadLocal == null || current.unwrapAll() != contextInThreadLocal.unwrapAll()) { throw newIllegalContextPoppingException(current, contextInThreadLocal); } context.set(map, toRestore); diff --git a/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java b/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java index 00436d55f49..1d8463d5426 100644 --- a/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java +++ b/core/src/main/java/com/linecorp/armeria/common/util/Samplers.java @@ -70,8 +70,10 @@ static Sampler of(String specification) { requireNonNull(specification, "specification"); switch (specification.trim()) { case "always": + case "true": return Sampler.always(); case "never": + case "false": return Sampler.never(); } diff --git a/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java b/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java index 7244970ef3d..250ea693e22 100644 --- a/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java +++ b/core/src/main/java/com/linecorp/armeria/common/util/Unwrappable.java @@ -129,4 +129,20 @@ default Object unwrapAll() { unwrapped = inner; } } + + /** + * Reference checking this {@link Unwrappable} to another {@link Unwrappable}, ignoring wrappers. + * Two {@link Unwrappable} are considered equal ignoring wrappers if they are of the same object reference + * after {@link #unwrapAll()}. + * @param other The {@link Unwrappable} to compare this {@link Unwrappable} against + * @return true if the argument is not {@code null}, and it represents a same object reference after + * {@code unwrapAll()}, false otherwise. + */ + @UnstableApi + default boolean equalsIgnoreWrapper(@Nullable Unwrappable other) { + if (other == null) { + return false; + } + return unwrapAll() == other.unwrapAll(); + } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorage.java b/core/src/main/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorage.java new file mode 100644 index 00000000000..752b8c896da --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/common/LeakTracingRequestContextStorage.java @@ -0,0 +1,126 @@ +/* + * Copyright 2022 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import static java.lang.Thread.currentThread; +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.ClientRequestContextWrapper; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.RequestContextStorage; +import com.linecorp.armeria.common.RequestContextWrapper; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.util.Sampler; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.ServiceRequestContextWrapper; + +/** + * A {@link RequestContextStorage} which keeps track of {@link RequestContext}s, reporting pushed thread + * information if a {@link RequestContext} is leaked. + */ +final class LeakTracingRequestContextStorage implements RequestContextStorage { + + private final RequestContextStorage delegate; + private final Sampler sampler; + + /** + * Creates a new instance. + * @param delegate the underlying {@link RequestContextStorage} that stores {@link RequestContext} + * @param sampler the {@link Sampler} that determines whether to retain the stacktrace of the context leaks + */ + LeakTracingRequestContextStorage(RequestContextStorage delegate, + Sampler sampler) { + this.delegate = requireNonNull(delegate, "delegate"); + this.sampler = requireNonNull(sampler, "sampler"); + } + + @Nullable + @Override + public T push(RequestContext toPush) { + requireNonNull(toPush, "toPush"); + if (sampler.isSampled(toPush)) { + return delegate.push(wrapRequestContext(toPush)); + } + return delegate.push(toPush); + } + + @Override + public void pop(RequestContext current, @Nullable RequestContext toRestore) { + requireNonNull(current, "current"); + delegate.pop(current, toRestore); + } + + @Nullable + @Override + public T currentOrNull() { + return delegate.currentOrNull(); + } + + @Override + public RequestContextStorage unwrap() { + return delegate; + } + + private static RequestContextWrapper wrapRequestContext(RequestContext ctx) { + return ctx instanceof ClientRequestContext ? + new TraceableClientRequestContext((ClientRequestContext) ctx) + : new TraceableServiceRequestContext((ServiceRequestContext) ctx); + } + + private static String stacktraceToString(StackTraceElement[] stackTrace, + RequestContext unwrap) { + final StringBuilder builder = new StringBuilder(512); + builder.append(unwrap).append(System.lineSeparator()) + .append("The previous RequestContext is pushed at the following stacktrace") + .append(System.lineSeparator()); + for (int i = 1; i < stackTrace.length; i++) { + builder.append("\tat ").append(stackTrace[i]).append(System.lineSeparator()); + } + return builder.toString(); + } + + private static final class TraceableClientRequestContext extends ClientRequestContextWrapper { + + private final StackTraceElement[] stackTrace; + + private TraceableClientRequestContext(ClientRequestContext delegate) { + super(delegate); + stackTrace = currentThread().getStackTrace(); + } + + @Override + public String toString() { + return stacktraceToString(stackTrace, unwrap()); + } + } + + private static final class TraceableServiceRequestContext extends ServiceRequestContextWrapper { + + private final StackTraceElement[] stackTrace; + + private TraceableServiceRequestContext(ServiceRequestContext delegate) { + super(delegate); + stackTrace = currentThread().getStackTrace(); + } + + @Override + public String toString() { + return stacktraceToString(stackTrace, unwrap()); + } + } +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java index d74893c4585..bd6dc307722 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/RequestContextUtil.java @@ -37,6 +37,7 @@ import com.linecorp.armeria.common.RequestContextStorageProvider; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.SafeCloseable; +import com.linecorp.armeria.common.util.Sampler; import com.linecorp.armeria.internal.client.DefaultClientRequestContext; import com.linecorp.armeria.internal.server.DefaultServiceRequestContext; @@ -63,8 +64,13 @@ public final class RequestContextUtil { static { final RequestContextStorageProvider provider = Flags.requestContextStorageProvider(); + final Sampler sampler = Flags.requestContextLeakDetectionSampler(); try { - requestContextStorage = provider.newStorage(); + if (sampler == Sampler.never()) { + requestContextStorage = provider.newStorage(); + } else { + requestContextStorage = new LeakTracingRequestContextStorage(provider.newStorage(), sampler); + } } catch (Throwable t) { throw new IllegalStateException("Failed to create context storage. provider: " + provider, t); } @@ -195,6 +201,13 @@ public static SafeCloseable invokeHookAndPop(RequestContext current, @Nullable R } } + public static boolean equalsIgnoreWrapper(@Nullable RequestContext ctx1, @Nullable RequestContext ctx2) { + if (ctx1 == null) { + return ctx2 == null; + } + return ctx1.equalsIgnoreWrapper(ctx2); + } + @Nullable private static AutoCloseable invokeHook(RequestContext ctx) { final Supplier hook; diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java index 7a18cda9e0b..ec1b3e8b43d 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContext.java @@ -219,16 +219,16 @@ default ServiceRequestContext root() { @MustBeClosed default SafeCloseable push() { final RequestContext oldCtx = RequestContextUtil.getAndSet(this); - if (oldCtx == this) { - // Reentrance - return noopSafeCloseable(); - } - if (oldCtx == null) { return RequestContextUtil.invokeHookAndPop(this, null); } - if (oldCtx.root() == this) { + if (oldCtx.unwrapAll() == unwrapAll()) { + // Reentrance + return noopSafeCloseable(); + } + + if (RequestContextUtil.equalsIgnoreWrapper(oldCtx.root(), this)) { return RequestContextUtil.invokeHookAndPop(this, oldCtx); } diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java index b1193e12c68..c89477c4e4e 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java @@ -41,9 +41,9 @@ void current() { final ClientRequestContext ctx = clientRequestContext(); assertThat(ctx.id()).isNotNull(); try (SafeCloseable unused = ctx.push()) { - assertThat(ClientRequestContext.current()).isSameAs(ctx); + assertThat(ClientRequestContext.current().unwrapAll()).isSameAs(ctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); try (SafeCloseable unused = serviceRequestContext().push()) { assertThatThrownBy(ClientRequestContext::current) @@ -58,9 +58,9 @@ void currentOrNull() { final ClientRequestContext ctx = clientRequestContext(); try (SafeCloseable unused = ctx.push()) { - assertThat(ClientRequestContext.currentOrNull()).isSameAs(ctx); + assertThat(ClientRequestContext.currentOrNull().unwrapAll()).isSameAs(ctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); try (SafeCloseable unused = serviceRequestContext().push()) { assertThat(ClientRequestContext.currentOrNull()).isNull(); @@ -75,9 +75,10 @@ void mapCurrent() { final ClientRequestContext ctx = clientRequestContext(); try (SafeCloseable unused = ctx.push()) { assertThat(ClientRequestContext.mapCurrent(c -> "foo", () -> "bar")).isEqualTo("foo"); - assertThat(ClientRequestContext.mapCurrent(Function.identity(), null)).isSameAs(ctx); + assertThat(ClientRequestContext.mapCurrent(Function.identity(), null).unwrapAll()) + .isSameAs(ctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); try (SafeCloseable unused = serviceRequestContext().push()) { assertThatThrownBy(() -> ClientRequestContext.mapCurrent(c -> "foo", () -> "bar")) @@ -90,32 +91,32 @@ void mapCurrent() { void pushReentrance() { final ClientRequestContext ctx = clientRequestContext(); try (SafeCloseable ignored = ctx.push()) { - assertCurrentCtx(ctx); + assertUnwrapAllCurrentCtx(ctx); try (SafeCloseable ignored2 = ctx.push()) { - assertCurrentCtx(ctx); + assertUnwrapAllCurrentCtx(ctx); } - assertCurrentCtx(ctx); + assertUnwrapAllCurrentCtx(ctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test void pushWithOldServiceCtx() { final ServiceRequestContext sctx = serviceRequestContext(); try (SafeCloseable ignored = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); // The root of ClientRequestContext is sctx. final ClientRequestContext cctx = clientRequestContext(); try (SafeCloseable ignored1 = cctx.push()) { - assertCurrentCtx(cctx); + assertUnwrapAllCurrentCtx(cctx); try (SafeCloseable ignored2 = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(cctx); + assertUnwrapAllCurrentCtx(cctx); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test @@ -135,65 +136,65 @@ void pushWithOldServiceCtx_exceptionWhenServiceCtxIsDifferFromRoot() { void pushWithOldClientCtxWhoseRootIsSameServiceCtx_ctx2IsCreatedSameLayer() { final ServiceRequestContext sctx = serviceRequestContext(); try (SafeCloseable ignored = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); final ClientRequestContext cctx1 = clientRequestContext(); final ClientRequestContext cctx2 = clientRequestContext(); assertThat(cctx1.root()).isSameAs(cctx2.root()); try (SafeCloseable ignored1 = cctx1.push()) { - assertCurrentCtx(cctx1); + assertUnwrapAllCurrentCtx(cctx1); try (SafeCloseable ignored2 = cctx2.push()) { - assertCurrentCtx(cctx2); + assertUnwrapAllCurrentCtx(cctx2); } - assertCurrentCtx(cctx1); + assertUnwrapAllCurrentCtx(cctx1); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test - void pushWithOldClientCtxWhoseRootIsSameServiceCtx__ctx2IsCreatedUnderCtx1() { + void pushWithOldClientCtxWhoseRootIsSameServiceCtx_ctx2IsCreatedUnderCtx1() { final ServiceRequestContext sctx = serviceRequestContext(); try (SafeCloseable ignored = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); final ClientRequestContext cctx1 = clientRequestContext(); try (SafeCloseable ignored1 = cctx1.push()) { - assertCurrentCtx(cctx1); + assertUnwrapAllCurrentCtx(cctx1); final ClientRequestContext cctx2 = clientRequestContext(); assertThat(cctx1.root()).isSameAs(cctx2.root()); try (SafeCloseable ignored2 = cctx2.push()) { - assertCurrentCtx(cctx2); + assertUnwrapAllCurrentCtx(cctx2); } - assertCurrentCtx(cctx1); + assertUnwrapAllCurrentCtx(cctx1); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test void pushWithOldClientCtxWhoseRootIsSameServiceCtx_derivedCtx() { final ServiceRequestContext sctx = serviceRequestContext(); try (SafeCloseable ignored = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); final ClientRequestContext cctx1 = clientRequestContext(); final ClientRequestContext derived = cctx1.newDerivedContext(cctx1.id(), cctx1.request(), cctx1.rpcRequest(), cctx1.endpoint()); try (SafeCloseable ignored1 = derived.push()) { - assertCurrentCtx(derived); + assertUnwrapAllCurrentCtx(derived); final ClientRequestContext cctx2 = clientRequestContext(); assertThat(derived.root()).isSameAs(cctx2.root()); try (SafeCloseable ignored2 = cctx2.push()) { - assertCurrentCtx(cctx2); + assertUnwrapAllCurrentCtx(cctx2); } - assertCurrentCtx(derived); + assertUnwrapAllCurrentCtx(derived); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test @@ -217,16 +218,16 @@ void pushWithOldClientCtxWhoseRootIsDifferent() { void pushWithOldClientCtxWhoseRootIsNull() { final ClientRequestContext cctx1 = clientRequestContext(); try (SafeCloseable ignored1 = cctx1.push()) { - assertCurrentCtx(cctx1); + assertUnwrapAllCurrentCtx(cctx1); final ClientRequestContext cctx2 = clientRequestContext(); assertThat(cctx1.root()).isNull(); assertThat(cctx2.root()).isNull(); try (SafeCloseable ignored2 = cctx2.push()) { - assertCurrentCtx(cctx2); + assertUnwrapAllCurrentCtx(cctx2); } - assertCurrentCtx(cctx1); + assertUnwrapAllCurrentCtx(cctx1); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test @@ -259,9 +260,13 @@ void hasOwnAttr() { } } - private static void assertCurrentCtx(@Nullable RequestContext ctx) { + private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) { final RequestContext current = RequestContext.currentOrNull(); - assertThat(current).isSameAs(ctx); + if (current == null) { + assertThat(ctx).isNull(); + } else { + assertThat(current.unwrapAll()).isSameAs(ctx); + } } private static ServiceRequestContext serviceRequestContext() { diff --git a/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java b/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java index 7d130324a66..f7106d8267f 100644 --- a/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/FlagsTest.java @@ -210,7 +210,7 @@ void systemPropertyVerboseExceptionSampler() throws Throwable { final Method method = flags.getDeclaredMethod("verboseExceptionSampler"); assertThat(method.invoke(flags)) .usingRecursiveComparison() - .isEqualTo(Sampler.always()); + .isEqualTo(new ExceptionSampler("true")); } @Test @@ -222,6 +222,32 @@ void invalidSystemPropertyVerboseExceptionSampler() throws Throwable { .isEqualTo(new ExceptionSampler("rate-limit=10")); } + @Test + void defaultRequestContextLeakDetectionSampler() throws Exception { + final Method method = flags.getDeclaredMethod("requestContextLeakDetectionSampler"); + assertThat(method.invoke(flags)) + .usingRecursiveComparison() + .isEqualTo(Sampler.never()); + } + + @Test + @SetSystemProperty(key = "com.linecorp.armeria.requestContextLeakDetectionSampler", value = "always") + void systemPropertyRequestContextLeakDetectionSampler() throws Exception { + final Method method = flags.getDeclaredMethod("requestContextLeakDetectionSampler"); + assertThat(method.invoke(flags)) + .usingRecursiveComparison() + .isEqualTo(Sampler.always()); + } + + @Test + @SetSystemProperty(key = "com.linecorp.armeria.requestContextLeakDetectionSampler", value = "invalid-spec") + void invalidSystemPropertyRequestContextLeakDetectionSampler() throws Exception { + final Method method = flags.getDeclaredMethod("requestContextLeakDetectionSampler"); + assertThat(method.invoke(flags)) + .usingRecursiveComparison() + .isEqualTo(Sampler.never()); + } + @Test void testApiConsistencyBetweenFlagsAndFlagsProvider() { //Check method consistency between Flags and FlagsProvider excluding deprecated methods diff --git a/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java b/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java index 92183c934ed..3e3c5302508 100644 --- a/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/util/SamplerTest.java @@ -31,10 +31,14 @@ void goodOf() { // 'always' assertThat(Sampler.of("always")).isSameAs(Sampler.always()); assertThat(Sampler.of(" always ")).isSameAs(Sampler.always()); + assertThat(Sampler.of("true")).isSameAs(Sampler.always()); + assertThat(Sampler.of(" true ")).isSameAs(Sampler.always()); // 'never' assertThat(Sampler.of("never")).isSameAs(Sampler.never()); assertThat(Sampler.of(" never ")).isSameAs(Sampler.never()); + assertThat(Sampler.of("false")).isSameAs(Sampler.never()); + assertThat(Sampler.of(" false ")).isSameAs(Sampler.never()); // 'random=' assertThat(Sampler.of("random=0")).isSameAs(Sampler.never()); diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java index b4c157cc9ab..927aecee1fd 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceRequestContextTest.java @@ -20,6 +20,7 @@ import java.util.function.Function; +import org.assertj.core.api.ObjectAssert; import org.junit.jupiter.api.Test; import com.google.common.collect.ImmutableList; @@ -44,13 +45,13 @@ void current() { assertThat(ServiceRequestContext.current()).isSameAs(sctx); final ClientRequestContext cctx = clientRequestContext(); try (SafeCloseable unused1 = cctx.push()) { - assertThat(ServiceRequestContext.current()).isSameAs(sctx); - assertThat(ClientRequestContext.current()).isSameAs(cctx); - assertThat((ClientRequestContext) RequestContext.current()).isSameAs(cctx); + assertThatUnwrapAll(ServiceRequestContext.current()).isSameAs(sctx); + assertThatUnwrapAll(ClientRequestContext.current()).isSameAs(cctx); + assertThatUnwrapAll((ClientRequestContext) RequestContext.current()).isSameAs(cctx); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); try (SafeCloseable unused = clientRequestContext().push()) { assertThatThrownBy(ServiceRequestContext::current) @@ -68,13 +69,13 @@ void currentOrNull() { assertThat(ServiceRequestContext.currentOrNull()).isSameAs(sctx); final ClientRequestContext cctx = clientRequestContext(); try (SafeCloseable unused1 = cctx.push()) { - assertThat(ServiceRequestContext.currentOrNull()).isSameAs(sctx); - assertThat(ClientRequestContext.current()).isSameAs(cctx); - assertThat((ClientRequestContext) RequestContext.current()).isSameAs(cctx); + assertThatUnwrapAll(ServiceRequestContext.currentOrNull()).isSameAs(sctx); + assertThatUnwrapAll(ClientRequestContext.current()).isSameAs(cctx); + assertThatUnwrapAll((ClientRequestContext) RequestContext.current()).isSameAs(cctx); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); try (SafeCloseable unused = clientRequestContext().push()) { assertThat(ServiceRequestContext.currentOrNull()).isNull(); @@ -98,16 +99,16 @@ void mapCurrent() { assertThat(ServiceRequestContext.mapCurrent(c -> c == sctx ? "foo" : "bar", () -> "defaultValue")) .isEqualTo("foo"); - assertThat(ClientRequestContext.mapCurrent(c -> c == cctx ? "baz" : "qux", + assertThat(ClientRequestContext.mapCurrent(c -> c.unwrapAll() == cctx ? "baz" : "qux", () -> "defaultValue")) .isEqualTo("baz"); assertThat(ServiceRequestContext.mapCurrent(Function.identity(), null)).isSameAs(sctx); - assertThat(ClientRequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx); - assertThat(RequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx); + assertThatUnwrapAll(ClientRequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx); + assertThatUnwrapAll(RequestContext.mapCurrent(Function.identity(), null)).isSameAs(cctx); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); try (SafeCloseable unused = clientRequestContext().push()) { assertThatThrownBy(() -> ServiceRequestContext.mapCurrent(c -> "foo", () -> "bar")) @@ -120,43 +121,43 @@ void mapCurrent() { void pushReentrance() { final ServiceRequestContext ctx = serviceRequestContext(); try (SafeCloseable ignored = ctx.push()) { - assertCurrentCtx(ctx); + assertUnwrapAllCurrentCtx(ctx); try (SafeCloseable ignored2 = ctx.push()) { - assertCurrentCtx(ctx); + assertUnwrapAllCurrentCtx(ctx); } - assertCurrentCtx(ctx); + assertUnwrapAllCurrentCtx(ctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test void pushWithOldClientCtxWhoseRootIsThisServiceCtx() { final ServiceRequestContext sctx = serviceRequestContext(); try (SafeCloseable ignored = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); // The root of ClientRequestContext is sctx. final ClientRequestContext cctx = clientRequestContext(); try (SafeCloseable ignored1 = cctx.push()) { - assertCurrentCtx(cctx); + assertUnwrapAllCurrentCtx(cctx); try (SafeCloseable ignored2 = sctx.push()) { - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(cctx); + assertUnwrapAllCurrentCtx(cctx); } - assertCurrentCtx(sctx); + assertUnwrapAllCurrentCtx(sctx); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test void pushWithOldIrrelevantClientCtx() { final ClientRequestContext cctx = clientRequestContext(); try (SafeCloseable ignored = cctx.push()) { - assertCurrentCtx(cctx); + assertUnwrapAllCurrentCtx(cctx); final ServiceRequestContext sctx = serviceRequestContext(); assertThatThrownBy(sctx::push).isInstanceOf(IllegalStateException.class); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test @@ -164,10 +165,10 @@ void pushWithOldIrrelevantServiceCtx() { final ServiceRequestContext sctx1 = serviceRequestContext(); final ServiceRequestContext sctx2 = serviceRequestContext(); try (SafeCloseable ignored = sctx1.push()) { - assertCurrentCtx(sctx1); + assertUnwrapAllCurrentCtx(sctx1); assertThatThrownBy(sctx2::push).isInstanceOf(IllegalStateException.class); } - assertCurrentCtx(null); + assertUnwrapAllCurrentCtx(null); } @Test @@ -194,9 +195,17 @@ void queryParams() { assertThat(ctx.queryParams("Not exist")).isEmpty(); } - private static void assertCurrentCtx(@Nullable RequestContext ctx) { + private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) { final RequestContext current = RequestContext.currentOrNull(); - assertThat(current).isSameAs(ctx); + if (current == null) { + assertThat(ctx).isNull(); + } else { + assertThatUnwrapAll(current).isEqualTo(ctx); + } + } + + private static ObjectAssert assertThatUnwrapAll(T actual) { + return assertThat(actual.unwrapAll()); } private static ServiceRequestContext serviceRequestContext() { diff --git a/it/trace-context-leak/build.gradle b/it/trace-context-leak/build.gradle new file mode 100644 index 00000000000..d083ea187dc --- /dev/null +++ b/it/trace-context-leak/build.gradle @@ -0,0 +1,8 @@ +task generateSources(type: Copy) { + from "${rootProject.projectDir}/core/src/test/java" + into "${project.ext.genSrcDir}/test/java" + include '**/ServiceRequestContextTest.java' + include '**/ClientRequestContextTest.java' +} + +tasks.compileJava.dependsOn(generateSources) diff --git a/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/EnableLeakDetectionFlagsProvider.java b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/EnableLeakDetectionFlagsProvider.java new file mode 100644 index 00000000000..720e46ffb20 --- /dev/null +++ b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/EnableLeakDetectionFlagsProvider.java @@ -0,0 +1,34 @@ +/* + * Copyright 2022 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import com.linecorp.armeria.common.FlagsProvider; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.util.Sampler; + +public final class EnableLeakDetectionFlagsProvider implements FlagsProvider { + + @Override + public int priority() { + return 10; + } + + @Override + public Sampler requestContextLeakDetectionSampler() { + return Sampler.always(); + } +} diff --git a/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/TraceRequestContextLeakTest.java b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/TraceRequestContextLeakTest.java new file mode 100644 index 00000000000..82fff976316 --- /dev/null +++ b/it/trace-context-leak/src/test/java/com/linecorp/armeria/internal/common/TraceRequestContextLeakTest.java @@ -0,0 +1,294 @@ +/* + * Copyright 2022 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.util.SafeCloseable; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; +import com.linecorp.armeria.testing.junit5.common.EventLoopGroupExtension; + +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; + +class TraceRequestContextLeakTest { + + @RegisterExtension + static final EventLoopExtension eventLoopExtension = new EventLoopExtension(); + + @RegisterExtension + static final EventLoopGroupExtension eventLoopGroupExtension = new EventLoopGroupExtension(2); + + @Test + void singleThreadContextNotLeak() throws InterruptedException { + final AtomicBoolean isThrown = new AtomicBoolean(false); + final CountDownLatch latch = new CountDownLatch(2); + + final EventLoop executor = eventLoopExtension.get(); + + executor.execute(() -> { + final ServiceRequestContext ctx = newCtx("/1"); + try (SafeCloseable ignore = ctx.push()) { + // Ignore + } catch (Exception ex) { + isThrown.set(true); + } finally { + latch.countDown(); + } + }); + + executor.execute(() -> { + final ServiceRequestContext anotherCtx = newCtx("/2"); + try (SafeCloseable ignore = anotherCtx.push()) { + final ClientRequestContext clientCtx = newClientCtx("/3"); + try (SafeCloseable ignore2 = clientCtx.push()) { + // Ignore + } + } catch (Exception ex) { + isThrown.set(true); + } finally { + latch.countDown(); + } + }); + + latch.await(); + assertThat(isThrown).isFalse(); + } + + @Test + @SuppressWarnings("MustBeClosedChecker") + void singleThreadContextLeak() throws InterruptedException { + final AtomicBoolean isThrown = new AtomicBoolean(); + final AtomicReference exception = new AtomicReference<>(); + + try (DeferredClose deferredClose = new DeferredClose()) { + final EventLoop executor = eventLoopExtension.get(); + + executor.execute(() -> { + final ServiceRequestContext ctx = newCtx("/1"); + final SafeCloseable leaked = ctx.push(); // <- Leaked, should show in error. + deferredClose.add(executor, leaked); + }); + + executor.execute(() -> { + final ServiceRequestContext anotherCtx = newCtx("/2"); + try (SafeCloseable ignore = anotherCtx.push()) { + // Ignore + } catch (Exception ex) { + exception.set(ex); + isThrown.set(true); + } + }); + + await().untilTrue(isThrown); + assertThat(exception.get()) + .hasMessageContaining("singleThreadContextLeak$2(TraceRequestContextLeakTest.java:101)"); + } + } + + @Test + @SuppressWarnings("MustBeClosedChecker") + void multiThreadContextLeakNotInterfereOthersEventLoop() throws InterruptedException { + final AtomicBoolean isThrown = new AtomicBoolean(false); + final CountDownLatch latch = new CountDownLatch(2); + + final EventLoopGroup executor = eventLoopGroupExtension.get(); + + final Executor ex1 = executor.next(); + final Executor ex2 = executor.next(); + + try (DeferredClose deferredClose = new DeferredClose()) { + ex1.execute(() -> { + final ServiceRequestContext ctx = newCtx("/1"); + final SafeCloseable leaked = ctx.push(); + latch.countDown(); + deferredClose.add(executor, leaked); + }); + + ex2.execute(() -> { + // Leak happened on the first eventLoop shouldn't affect 2nd eventLoop when trying to push + await().until(() -> latch.getCount() == 1); + final ServiceRequestContext anotherCtx = newCtx("/2"); + try (SafeCloseable ignore1 = anotherCtx.push()) { + final ClientRequestContext cctx = newClientCtx("/3"); + try (SafeCloseable ignore2 = cctx.push()) { + // Ignore + } + } catch (Exception ex) { + // Not suppose to throw exception on the second event loop + isThrown.set(true); + } finally { + latch.countDown(); + } + }); + + latch.await(); + assertThat(isThrown).isFalse(); + } + } + + @Test + @SuppressWarnings("MustBeClosedChecker") + void multiThreadContextLeak() throws InterruptedException { + final AtomicBoolean isThrown = new AtomicBoolean(false); + final AtomicReference exception = new AtomicReference<>(); + final CountDownLatch waitForExecutor2 = new CountDownLatch(1); + + final EventLoopGroup executor = eventLoopGroupExtension.get(); + + final ServiceRequestContext leakingCtx = newCtx("/1-leak"); + final ServiceRequestContext anotherCtx2 = newCtx("/2-leak"); + final ServiceRequestContext anotherCtx3 = newCtx("/3-leak"); + + final Executor ex1 = executor.next(); + final Executor ex2 = executor.next(); + + try (DeferredClose deferredClose = new DeferredClose()) { + ex1.execute(() -> { + final SafeCloseable leaked = leakingCtx.push(); // <- Leaked, should show in error. + deferredClose.add(ex1, leaked); + }); + + ex2.execute(() -> { + try { + final SafeCloseable leaked = anotherCtx2.push(); + deferredClose.add(ex2, leaked); + } catch (Exception ex) { + isThrown.set(true); + } finally { + waitForExecutor2.countDown(); + } + }); + + waitForExecutor2.await(); + assertThat(isThrown).isFalse(); + + ex1.execute(() -> { + try (SafeCloseable ignore = anotherCtx3.push()) { + // Ignore + } catch (Exception ex) { + exception.set(ex); + isThrown.set(true); + } + }); + + await().untilTrue(isThrown); + assertThat(exception.get()) + .hasMessageContaining("multiThreadContextLeak$7(TraceRequestContextLeakTest.java:180)"); + } + } + + @Test + void pushIllegalServiceRequestContext() { + final ServiceRequestContext sctx1 = newCtx("/1"); + final ServiceRequestContext sctx2 = newCtx("/2"); + try (SafeCloseable ignored = sctx1.push()) { + assertThatThrownBy(sctx2::push).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("pushed at the following stacktrace"); + } + } + + @Test + void multipleRequestContextPushBeforeLeak() { + final ServiceRequestContext sctx1 = newCtx("/1"); + final ServiceRequestContext sctx2 = newCtx("/2"); + try (SafeCloseable ignore1 = sctx1.push()) { + final ClientRequestContext cctx1 = newClientCtx("/3"); + try (SafeCloseable ignore3 = cctx1.push()) { + assertThatThrownBy(sctx2::push).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("pushed at the following stacktrace"); + } + } + } + + @Test + @SuppressWarnings("MustBeClosedChecker") + void connerCase() { + final AtomicReference exception = new AtomicReference<>(); + + try (DeferredClose deferredClose = new DeferredClose()) { + final ServiceRequestContext ctx = newCtx("/1"); + try (SafeCloseable ignored = ctx.push()) { + final ClientRequestContext ctx2 = newClientCtx("/2"); + ctx2.push(); // <- Leaked, should show in error. + deferredClose.add(ctx2); + final ClientRequestContext ctx3 = newClientCtx("/3"); + try (SafeCloseable ignored1 = ctx3.push()) { + // Ignore + } + } catch (Exception ex) { + exception.set(ex); + } + } + assertThat(exception.get()) + .hasMessageContaining("connerCase(TraceRequestContextLeakTest.java:245)"); + } + + private static ServiceRequestContext newCtx(String path) { + return ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, path)) + .build(); + } + + private static ClientRequestContext newClientCtx(String path) { + return ClientRequestContext.builder(HttpRequest.of(HttpMethod.GET, path)) + .build(); + } + + // Utility to clean up RequestContext leak after test + private static class DeferredClose implements SafeCloseable { + + private final ConcurrentHashMap toClose; + private final Set toRemoveFromThreadLocal; + + DeferredClose() { + toClose = new ConcurrentHashMap<>(); + toRemoveFromThreadLocal = new HashSet<>(); + } + + void add(Executor executor, SafeCloseable closeable) { + toClose.put(executor, closeable); + } + + void add(RequestContext requestContext) { + toRemoveFromThreadLocal.add(requestContext); + } + + @Override + public void close() { + toClose.forEach((executor, closeable) -> executor.execute(closeable::close)); + toRemoveFromThreadLocal.forEach(ctx -> RequestContextUtil.pop(ctx, null)); + } + } +} diff --git a/it/trace-context-leak/src/test/resources/META-INF/services/com.linecorp.armeria.common.FlagsProvider b/it/trace-context-leak/src/test/resources/META-INF/services/com.linecorp.armeria.common.FlagsProvider new file mode 100644 index 00000000000..0c5a2c315a9 --- /dev/null +++ b/it/trace-context-leak/src/test/resources/META-INF/services/com.linecorp.armeria.common.FlagsProvider @@ -0,0 +1 @@ +com.linecorp.armeria.internal.common.EnableLeakDetectionFlagsProvider diff --git a/settings.gradle b/settings.gradle index 71d083fba1f..7ac52e789d3 100644 --- a/settings.gradle +++ b/settings.gradle @@ -100,6 +100,7 @@ includeWithFlags ':it:spring:boot2-tomcat8', 'java', 'relocate' includeWithFlags ':it:spring:boot2-tomcat9', 'java', 'relocate' includeWithFlags ':it:spring:webflux-security', 'java', 'relocate' includeWithFlags ':it:thrift-fullcamel', 'java', 'relocate' +includeWithFlags ':it:trace-context-leak', 'java', 'relocate' includeWithFlags ':jetty9.3', 'java', 'relocate' includeWithFlags ':testing-internal', 'java', 'relocate' includeWithFlags ':thrift0.12', 'java', 'relocate'