Skip to content

Commit

Permalink
Use thread local when creating scopes off of the request context thre…
Browse files Browse the repository at this point in the history
…ad. (#1355)

While the `RequestContext` can be treated as the only thread-local for the request thread, this isn't the case for other threads, i.e., when making callbacks run in a threadpool context-aware. For such callbacks, we need to use a separate thread-local to prevent the multiple threads from clashing the context.
  • Loading branch information
anuraaga authored and trustin committed Sep 6, 2018
1 parent 8dbfcd9 commit d33deef
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.linecorp.armeria.common.tracing;

import java.util.Collections;
import java.util.function.Function;
import java.util.function.Supplier;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -68,6 +69,9 @@ public final class RequestContextCurrentTraceContext extends CurrentTraceContext
private static final AttributeKey<TraceContext> TRACE_CONTEXT_KEY =
AttributeKey.valueOf(RequestContextCurrentTraceContext.class, "TRACE_CONTEXT");

// Thread-local for storing TraceContext when invoking callbacks off the request thread.
private static final ThreadLocal<TraceContext> THREAD_LOCAL_CONTEXT = new ThreadLocal<>();

private static final Scope INCOMPLETE_CONFIGURATION_SCOPE = new Scope() {
@Override
public void close() {
Expand Down Expand Up @@ -148,11 +152,27 @@ public static void copy(RequestContext src, RequestContext dst) {
dst.attr(TRACE_CONTEXT_KEY).set(src.attr(TRACE_CONTEXT_KEY).get());
}

private RequestContextCurrentTraceContext(Builder builder) {
super(builder);
}

@Override
@Nullable
public TraceContext get() {
final Attribute<TraceContext> traceContextAttribute = getTraceContextAttributeOrWarnOnce();
return traceContextAttribute != null ? traceContextAttribute.get() : null;
final RequestContext ctx = getRequestContextOrWarnOnce();
if (ctx == null) {
return null;
}
if (ctx.eventLoop().inEventLoop()) {
return ctx.attr(TRACE_CONTEXT_KEY).get();
} else {
final TraceContext threadLocalContext = THREAD_LOCAL_CONTEXT.get();
if (threadLocalContext != null) {
return threadLocalContext;
}
// First span on a non-request thread will use the request's TraceContext as a parent.
return ctx.attr(TRACE_CONTEXT_KEY).get();
}
}

@Override
Expand All @@ -162,11 +182,24 @@ public Scope newScope(@Nullable TraceContext currentSpan) {
return Scope.NOOP;
}

final Attribute<TraceContext> traceContextAttribute = getTraceContextAttributeOrWarnOnce();
if (traceContextAttribute == null) {
final RequestContext ctx = getRequestContextOrWarnOnce();
if (ctx == null) {
return INCOMPLETE_CONFIGURATION_SCOPE;
}

if (ctx.eventLoop().inEventLoop()) {
return createScopeForRequestThread(ctx, currentSpan);
} else {
// The RequestContext is the canonical thread-local storage for the thread processing the request.
// However, when creating spans on other threads (e.g., a thread-pool), we must use separate
// thread-local storage to prevent threads from replacing the same trace context.
return createScopeForNonRequestThread(currentSpan);
}
}

private static Scope createScopeForRequestThread(RequestContext ctx, @Nullable TraceContext currentSpan) {
final Attribute<TraceContext> traceContextAttribute = ctx.attr(TRACE_CONTEXT_KEY);

final TraceContext previous = traceContextAttribute.getAndSet(currentSpan);

// Don't remove the outer-most context (client or server request)
Expand All @@ -182,27 +215,56 @@ public void close() {
// re-lookup the attribute to avoid holding a reference to the request if this scope is leaked
getTraceContextAttributeOrWarnOnce().set(previous);
}

@Override
public String toString() {
return "RequestContextTraceContextScope";
}
}

return new RequestContextTraceContextScope();
}

private static Scope createScopeForNonRequestThread(@Nullable TraceContext currentSpan) {
final TraceContext previous = THREAD_LOCAL_CONTEXT.get();
THREAD_LOCAL_CONTEXT.set(currentSpan);
class ThreadLocalScope implements Scope {
@Override
public void close() {
THREAD_LOCAL_CONTEXT.set(previous);
}

@Override
public String toString() {
return "ThreadLocalScope";
}
}

return new ThreadLocalScope();
}

/** Armeria code should always have a request context available, and this won't work without it. */
@Nullable private static Attribute<TraceContext> getTraceContextAttributeOrWarnOnce() {
return RequestContext.mapCurrent(r -> r.attr(TRACE_CONTEXT_KEY), LogRequestContextWarningOnce.INSTANCE);
@Nullable
private static RequestContext getRequestContextOrWarnOnce() {
return RequestContext.mapCurrent(Function.identity(), LogRequestContextWarningOnce.INSTANCE);
}

private RequestContextCurrentTraceContext(Builder builder) {
super(builder);
/** Armeria code should always have a request context available, and this won't work without it. */
@Nullable private static Attribute<TraceContext> getTraceContextAttributeOrWarnOnce() {
final RequestContext ctx = getRequestContextOrWarnOnce();
if (ctx == null) {
return null;
}
return ctx.attr(TRACE_CONTEXT_KEY);
}

private enum LogRequestContextWarningOnce implements Supplier<Attribute<TraceContext>> {
private enum LogRequestContextWarningOnce implements Supplier<RequestContext> {

INSTANCE;

@Override
@Nullable
public Attribute<TraceContext> get() {
public RequestContext get() {
ClassLoaderHack.loadMe();
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import brave.propagation.CurrentTraceContext;
import brave.propagation.CurrentTraceContext.Scope;
import brave.propagation.TraceContext;
import io.netty.channel.EventLoop;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;

Expand All @@ -50,6 +51,8 @@ public class RequestContextCurrentTraceContextTest {
RequestContext mockRequestContext;
@Mock(answer = Answers.CALLS_REAL_METHODS)
RequestContext mockRequestContext2;
@Mock
EventLoop eventLoop;

final CurrentTraceContext currentTraceContext = RequestContextCurrentTraceContext.DEFAULT;
final DefaultAttributeMap attrs1 = new DefaultAttributeMap();
Expand All @@ -58,6 +61,10 @@ public class RequestContextCurrentTraceContextTest {

@Before
public void setup() {
when(mockRequestContext.eventLoop()).thenReturn(eventLoop);
when(mockRequestContext2.eventLoop()).thenReturn(eventLoop);
when(eventLoop.inEventLoop()).thenReturn(true);

when(mockRequestContext.attr(isA(AttributeKey.class))).thenAnswer(
(Answer<Attribute>) invocation -> attrs1.attr(invocation.getArgument(0)));
when(mockRequestContext2.attr(isA(AttributeKey.class))).thenAnswer(
Expand Down Expand Up @@ -112,11 +119,35 @@ public void newScope_closeDoesntClearFirstScope() {

try (SafeCloseable requestContextScope = mockRequestContext.push()) {
try (Scope traceContextScope = currentTraceContext.newScope(traceContext)) {
assertThat(traceContextScope).hasToString("InitialRequestScope");
assertThat(currentTraceContext.get()).isEqualTo(traceContext);

try (Scope traceContextScope2 = currentTraceContext.newScope(traceContext2)) {
assertThat(traceContextScope2).hasToString("RequestContextTraceContextScope");
assertThat(currentTraceContext.get()).isEqualTo(traceContext2);
}
assertThat(currentTraceContext.get()).isEqualTo(traceContext);
}
// the first scope is attached to the request context and cleared when that's destroyed
assertThat(currentTraceContext.get()).isEqualTo(traceContext);
}
}

@Test
public void newScope_notOnEventLoop() {
final TraceContext traceContext2 = TraceContext.newBuilder().traceId(1).spanId(2).build();

try (SafeCloseable requestContextScope = mockRequestContext.push()) {
try (Scope traceContextScope = currentTraceContext.newScope(traceContext)) {
assertThat(traceContextScope).hasToString("InitialRequestScope");
assertThat(currentTraceContext.get()).isEqualTo(traceContext);

when(eventLoop.inEventLoop()).thenReturn(false);
try (Scope traceContextScope2 = currentTraceContext.newScope(traceContext2)) {
assertThat(traceContextScope2).hasToString("ThreadLocalScope");
assertThat(currentTraceContext.get()).isEqualTo(traceContext2);
}
when(eventLoop.inEventLoop()).thenReturn(true);
assertThat(currentTraceContext.get()).isEqualTo(traceContext);
}
// the first scope is attached to the request context and cleared when that's destroyed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -66,6 +67,7 @@
import com.linecorp.armeria.testing.server.ServerRule;

import brave.ScopedSpan;
import brave.Tracer.SpanInScope;
import brave.Tracing;
import brave.propagation.CurrentTraceContext;
import brave.propagation.StrictScopeDecorator;
Expand Down Expand Up @@ -125,14 +127,23 @@ protected HttpResponse doGet(ServiceRequestContext ctx, HttpRequest req)
final ListenableFuture<List<Object>> spanAware = allAsList(IntStream.range(1, 3).mapToObj(
i -> executorService.submit(
RequestContext.current().makeContextAware(() -> {
if (i == 2) {
countDownLatch.countDown();
countDownLatch.await();
}
brave.Span span = Tracing.currentTracer().nextSpan().start();
countDownLatch.countDown();
countDownLatch.await();
try {
return null;
try (SpanInScope spanInScope =
Tracing.currentTracer().withSpanInScope(span)) {
if (i == 1) {
countDownLatch.countDown();
countDownLatch.await();
// to wait second task get span.
Thread.sleep(1000L);
}
} finally {
span.finish();
}
return null;
}))).collect(toImmutableList()));

final CompletableFuture<HttpResponse> responseFuture = new CompletableFuture<>();
Expand Down Expand Up @@ -313,6 +324,9 @@ public void testSpanInThreadPoolHasSameTraceId() throws Exception {
poolHttpClient.get("pool").aggregate().get();
final Span[] spans = spanReporter.take(5);
assertThat(Arrays.stream(spans).map(Span::traceId).collect(toImmutableSet())).hasSize(1);
assertThat(Arrays.stream(spans).map(Span::parentId)
.filter(Objects::nonNull)
.collect(toImmutableSet())).hasSize(1);
}

private static Span findSpan(Span[] spans, String serviceName) {
Expand Down

0 comments on commit d33deef

Please sign in to comment.