diff --git a/opentracing-elasticsearch-client-common/src/main/java/io/opentracing/contrib/elasticsearch/common/TracingHttpClientConfigCallback.java b/opentracing-elasticsearch-client-common/src/main/java/io/opentracing/contrib/elasticsearch/common/TracingHttpClientConfigCallback.java index 0fa76f1..f7259e1 100644 --- a/opentracing-elasticsearch-client-common/src/main/java/io/opentracing/contrib/elasticsearch/common/TracingHttpClientConfigCallback.java +++ b/opentracing-elasticsearch-client-common/src/main/java/io/opentracing/contrib/elasticsearch/common/TracingHttpClientConfigCallback.java @@ -21,6 +21,7 @@ import io.opentracing.propagation.Format.Builtin; import io.opentracing.tag.Tags; import io.opentracing.util.GlobalTracer; +import java.lang.reflect.Field; import java.util.function.Function; import org.apache.http.HttpRequest; import org.apache.http.HttpRequestInterceptor; @@ -35,6 +36,8 @@ public class TracingHttpClientConfigCallback implements RestClientBuilder.HttpCl private final Tracer tracer; private final Function spanNameProvider; private final HttpClientConfigCallback callback; + private static final String OT_IS_AUTH_CACHING_DISABLED = "ot-is-auth-caching-disabled"; + private static final String OT_SPAN = "ot-span"; public TracingHttpClientConfigCallback(Tracer tracer, Function spanNameProvider, @@ -79,18 +82,36 @@ public TracingHttpClientConfigCallback(HttpClientConfigCallback callback) { this(GlobalTracer.get(), ClientSpanNameProvider.REQUEST_METHOD_NAME, callback); } + private boolean isAuthCachingDisabled(HttpAsyncClientBuilder httpAsyncClientBuilder) { + try { + final Field authCachingDisabledField = httpAsyncClientBuilder.getClass() + .getDeclaredField("authCachingDisabled"); + final boolean accessible = authCachingDisabledField.isAccessible(); + authCachingDisabledField.setAccessible(true); + final boolean isAuthCachingDisabled = (boolean) authCachingDisabledField + .get(httpAsyncClientBuilder); + authCachingDisabledField.setAccessible(accessible); + return isAuthCachingDisabled; + } catch (Exception ignore) { + } + return false; + } + @Override public HttpAsyncClientBuilder customizeHttpClient( final HttpAsyncClientBuilder httpAsyncClientBuilder) { HttpAsyncClientBuilder httpClientBuilder; + final boolean isAuthCachingDisabled; if (callback != null) { httpClientBuilder = callback.customizeHttpClient(httpAsyncClientBuilder); + isAuthCachingDisabled = isAuthCachingDisabled(httpClientBuilder); } else { httpClientBuilder = httpAsyncClientBuilder; + isAuthCachingDisabled = false; } - httpClientBuilder.addInterceptorFirst((HttpRequestInterceptor) (request, context) -> { + httpClientBuilder.addInterceptorLast((HttpRequestInterceptor) (request, context) -> { SpanBuilder spanBuilder = tracer.buildSpan(spanNameProvider.apply(request)) .ignoreActiveSpan() .withTag(Tags.SPAN_KIND.getKey(), Tags.SPAN_KIND_CLIENT); @@ -107,12 +128,24 @@ public HttpAsyncClientBuilder customizeHttpClient( tracer.inject(span.context(), Builtin.HTTP_HEADERS, new HttpTextMapInjectAdapter(request)); - context.setAttribute("span", span); + context.setAttribute(OT_SPAN, span); + if (isAuthCachingDisabled) { + context.setAttribute(OT_IS_AUTH_CACHING_DISABLED, "true"); + } }); httpClientBuilder.addInterceptorFirst((HttpResponseInterceptor) (response, context) -> { - Object spanObject = context.getAttribute("span"); + if (context.getAttribute(OT_IS_AUTH_CACHING_DISABLED) != null) { + context.removeAttribute(OT_IS_AUTH_CACHING_DISABLED); + if (response.getStatusLine().getStatusCode() == 401) { + // response interceptor is called twice if auth caching is disabled + // and server requires authentication + return; + } + } + Object spanObject = context.getAttribute(OT_SPAN); if (spanObject instanceof Span) { + context.removeAttribute(OT_SPAN); Span span = (Span) spanObject; SpanDecorator.onResponse(response, span); span.finish(); diff --git a/opentracing-elasticsearch6-client/src/test/java/io/opentracing/contrib/elasticsearch6/TracingTest.java b/opentracing-elasticsearch6-client/src/test/java/io/opentracing/contrib/elasticsearch6/TracingTest.java index 26aa164..1ac1db6 100644 --- a/opentracing-elasticsearch6-client/src/test/java/io/opentracing/contrib/elasticsearch6/TracingTest.java +++ b/opentracing-elasticsearch6-client/src/test/java/io/opentracing/contrib/elasticsearch6/TracingTest.java @@ -189,6 +189,57 @@ public void onFailure(Exception exception) { assertNull(mockTracer.activeSpan()); } + @Test + public void restClientWithCallbackDisabledAuthCaching() throws Exception { + RestClient restClient = RestClient.builder( + new HttpHost("localhost", HTTP_PORT, "http")) + .setHttpClientConfigCallback(new TracingHttpClientConfigCallback(mockTracer, + (HttpClientConfigCallback) httpClientBuilder -> { + httpClientBuilder.disableAuthCaching(); + return httpClientBuilder; + })) + .build(); + + HttpEntity entity = new NStringEntity( + "{\n" + + " \"user\" : \"kimchy\",\n" + + " \"post_date\" : \"2009-11-15T14:12:12\",\n" + + " \"message\" : \"trying out Elasticsearch\"\n" + + "}", ContentType.APPLICATION_JSON); + + Request request = new Request("PUT", "/twitter/tweet/1"); + request.setEntity(entity); + + Response indexResponse = restClient.performRequest(request); + + assertNotNull(indexResponse); + + Request request2 = new Request("PUT", "/twitter/tweet/2"); + request2.setEntity(entity); + + final CountDownLatch latch = new CountDownLatch(1); + restClient + .performRequestAsync(request2, new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + latch.countDown(); + } + }); + + latch.await(30, TimeUnit.SECONDS); + restClient.close(); + + List finishedSpans = mockTracer.finishedSpans(); + assertEquals(2, finishedSpans.size()); + checkSpans(finishedSpans, "PUT"); + assertNull(mockTracer.activeSpan()); + } + @Test public void transportClient() throws Exception { diff --git a/opentracing-elasticsearch7-client/src/test/java/io/opentracing/contrib/elasticsearch7/TracingTest.java b/opentracing-elasticsearch7-client/src/test/java/io/opentracing/contrib/elasticsearch7/TracingTest.java index 65099b3..ce2202d 100644 --- a/opentracing-elasticsearch7-client/src/test/java/io/opentracing/contrib/elasticsearch7/TracingTest.java +++ b/opentracing-elasticsearch7-client/src/test/java/io/opentracing/contrib/elasticsearch7/TracingTest.java @@ -43,6 +43,7 @@ import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseListener; import org.elasticsearch.client.RestClient; +import org.elasticsearch.client.RestClientBuilder.HttpClientConfigCallback; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; @@ -141,6 +142,105 @@ public void onFailure(Exception exception) { assertNull(mockTracer.activeSpan()); } + @Test + public void restClientWithCallback() throws Exception { + RestClient restClient = RestClient.builder( + new HttpHost("localhost", HTTP_PORT, "http")) + .setHttpClientConfigCallback(new TracingHttpClientConfigCallback(mockTracer, + (HttpClientConfigCallback) httpClientBuilder -> httpClientBuilder)) + .build(); + + HttpEntity entity = new NStringEntity( + "{\n" + + " \"user\" : \"kimchy\",\n" + + " \"post_date\" : \"2009-11-15T14:12:12\",\n" + + " \"message\" : \"trying out Elasticsearch\"\n" + + "}", ContentType.APPLICATION_JSON); + + Request request = new Request("PUT", "/twitter/tweet/1"); + request.setEntity(entity); + + Response indexResponse = restClient.performRequest(request); + + assertNotNull(indexResponse); + + Request request2 = new Request("PUT", "/twitter/tweet/2"); + request2.setEntity(entity); + + final CountDownLatch latch = new CountDownLatch(1); + restClient + .performRequestAsync(request2, new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + latch.countDown(); + } + }); + + latch.await(30, TimeUnit.SECONDS); + restClient.close(); + + List finishedSpans = mockTracer.finishedSpans(); + assertEquals(2, finishedSpans.size()); + checkSpans(finishedSpans, "PUT"); + assertNull(mockTracer.activeSpan()); + } + + @Test + public void restClientWithCallbackDisabledAuthCaching() throws Exception { + RestClient restClient = RestClient.builder( + new HttpHost("localhost", HTTP_PORT, "http")) + .setHttpClientConfigCallback(new TracingHttpClientConfigCallback(mockTracer, + (HttpClientConfigCallback) httpClientBuilder -> { + httpClientBuilder.disableAuthCaching(); + return httpClientBuilder; + })) + .build(); + + HttpEntity entity = new NStringEntity( + "{\n" + + " \"user\" : \"kimchy\",\n" + + " \"post_date\" : \"2009-11-15T14:12:12\",\n" + + " \"message\" : \"trying out Elasticsearch\"\n" + + "}", ContentType.APPLICATION_JSON); + + Request request = new Request("PUT", "/twitter/tweet/1"); + request.setEntity(entity); + + Response indexResponse = restClient.performRequest(request); + + assertNotNull(indexResponse); + + Request request2 = new Request("PUT", "/twitter/tweet/2"); + request2.setEntity(entity); + + final CountDownLatch latch = new CountDownLatch(1); + restClient + .performRequestAsync(request2, new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + latch.countDown(); + } + }); + + latch.await(30, TimeUnit.SECONDS); + restClient.close(); + + List finishedSpans = mockTracer.finishedSpans(); + assertEquals(2, finishedSpans.size()); + checkSpans(finishedSpans, "PUT"); + assertNull(mockTracer.activeSpan()); + } + @Test public void transportClient() throws Exception {