From 7849d3be311e8037b0eea589041cbe7289028ce8 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Sat, 31 Jul 2021 18:33:02 -0700 Subject: [PATCH] all: API refactoring in preparation to support retry stats (#8355) Rebased PR #8343 into the first commit of this PR, then (the 2nd commit) reverted the part for metric recording of retry attempts. The PR as a whole is mechanical refactoring. No behavior change (except that some of the old code path when tracer is created is moved into the new method `streamCreated()`). The API change is documented in go/grpc-stats-api-change-for-retry-java --- .../main/java/io/grpc/ClientStreamTracer.java | 82 ++++-- .../test/java/io/grpc/CallOptionsTest.java | 6 +- .../grpc/binder/internal/BinderTransport.java | 21 +- .../io/grpc/census/CensusStatsModule.java | 57 +++-- .../io/grpc/census/CensusTracingModule.java | 21 +- .../io/grpc/census/CensusModulesTest.java | 71 ++++-- .../internal/StatsTraceContextBenchmark.java | 5 +- .../io/grpc/inprocess/InProcessTransport.java | 24 +- ...llCredentialsApplyingTransportFactory.java | 12 +- .../java/io/grpc/internal/ClientCallImpl.java | 5 +- .../io/grpc/internal/ClientTransport.java | 8 +- .../grpc/internal/DelayedClientTransport.java | 35 ++- .../java/io/grpc/internal/DelayedStream.java | 4 + .../io/grpc/internal/FailingClientStream.java | 13 +- .../grpc/internal/FailingClientTransport.java | 6 +- .../ForwardingClientStreamTracer.java | 101 ++++++++ .../ForwardingConnectionClientTransport.java | 6 +- .../main/java/io/grpc/internal/GrpcUtil.java | 77 +++++- .../io/grpc/internal/InternalSubchannel.java | 6 +- .../io/grpc/internal/ManagedChannelImpl.java | 13 +- .../io/grpc/internal/MetadataApplierImpl.java | 10 +- .../java/io/grpc/internal/OobChannel.java | 5 +- .../io/grpc/internal/RetriableStream.java | 20 +- .../io/grpc/internal/StatsTraceContext.java | 20 +- .../io/grpc/internal/SubchannelChannel.java | 5 +- .../util/ForwardingClientStreamTracer.java | 6 + .../java/io/grpc/ClientStreamTracerTest.java | 4 + .../grpc/internal/AbstractTransportTest.java | 181 +++++++------ .../CallCredentials2ApplyingTest.java | 67 +++-- .../internal/CallCredentialsApplyingTest.java | 102 +++++--- .../io/grpc/internal/ClientCallImplTest.java | 7 +- .../internal/DelayedClientTransportTest.java | 116 ++++++--- .../internal/FailingClientStreamTest.java | 8 +- .../internal/FailingClientTransportTest.java | 6 +- .../ForwardingClientStreamTracerTest.java | 49 ++++ .../java/io/grpc/internal/GrpcUtilTest.java | 51 +++- .../grpc/internal/ManagedChannelImplTest.java | 239 ++++++++++++------ .../io/grpc/internal/RetriableStreamTest.java | 5 +- .../test/java/io/grpc/internal/TestUtils.java | 5 +- .../ForwardingClientStreamTracerTest.java | 1 + .../io/grpc/cronet/CronetClientTransport.java | 5 +- .../grpc/cronet/CronetChannelBuilderTest.java | 9 +- .../cronet/CronetClientTransportTest.java | 9 +- .../grpc/grpclb/GrpclbClientLoadRecorder.java | 2 +- .../grpclb/TokenAttachingTracerFactory.java | 36 ++- .../grpc/grpclb/GrpclbLoadBalancerTest.java | 3 + .../TokenAttachingTracerFactoryTest.java | 47 ++-- .../integration/AbstractInteropTest.java | 14 +- .../io/grpc/netty/NettyClientTransport.java | 8 +- .../grpc/netty/NettyClientTransportTest.java | 5 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 12 +- .../okhttp/OkHttpClientTransportTest.java | 123 ++++----- .../io/grpc/xds/ClusterImplLoadBalancer.java | 5 +- .../java/io/grpc/xds/OrcaPerRequestUtil.java | 7 +- .../grpc/xds/ClusterImplLoadBalancerTest.java | 8 +- 55 files changed, 1210 insertions(+), 563 deletions(-) create mode 100644 core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java create mode 100644 core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java index 6259522487a..6a5d3cc3397 100644 --- a/api/src/main/java/io/grpc/ClientStreamTracer.java +++ b/api/src/main/java/io/grpc/ClientStreamTracer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; -import io.grpc.Grpc; import javax.annotation.concurrent.ThreadSafe; /** @@ -28,6 +27,18 @@ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") @ThreadSafe public abstract class ClientStreamTracer extends StreamTracer { + + /** + * The stream is being created on a ready transport. + * + * @param headers the mutable initial metadata. Modifications to it will be sent to the socket but + * not be seen by client interceptors and the application. + * + * @since 1.40.0 + */ + public void streamCreated(@Grpc.TransportAttr Attributes transportAttrs, Metadata headers) { + } + /** * Headers has been sent to the socket. */ @@ -54,22 +65,6 @@ public void inboundTrailers(Metadata trailers) { * Factory class for {@link ClientStreamTracer}. */ public abstract static class Factory { - /** - * Creates a {@link ClientStreamTracer} for a new client stream. - * - * @param callOptions the effective CallOptions of the call - * @param headers the mutable headers of the stream. It can be safely mutated within this - * method. It should not be saved because it is not safe for read or write after the - * method returns. - * - * @deprecated use {@link - * #newClientStreamTracer(io.grpc.ClientStreamTracer.StreamInfo, io.grpc.Metadata)} instead. - */ - @Deprecated - public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) { - throw new UnsupportedOperationException("Not implemented"); - } - /** * Creates a {@link ClientStreamTracer} for a new client stream. This is called inside the * transport when it's creating the stream. @@ -81,12 +76,15 @@ public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadat * * @since 1.20.0 */ - @SuppressWarnings("deprecation") public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { - return newClientStreamTracer(info.getCallOptions(), headers); + throw new UnsupportedOperationException("Not implemented"); } } + /** An abstract class for internal use only. */ + @Internal + public abstract static class InternalLimitedInfoFactory extends Factory {} + /** * Information about a stream. * @@ -99,15 +97,21 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header public static final class StreamInfo { private final Attributes transportAttrs; private final CallOptions callOptions; + private final boolean isTransparentRetry; - StreamInfo(Attributes transportAttrs, CallOptions callOptions) { + StreamInfo(Attributes transportAttrs, CallOptions callOptions, boolean isTransparentRetry) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs"); this.callOptions = checkNotNull(callOptions, "callOptions"); + this.isTransparentRetry = isTransparentRetry; } /** * Returns the attributes of the transport that this stream was created on. + * + * @deprecated Use {@link ClientStreamTracer#streamCreated(Attributes, Metadata)} to handle + * the transport Attributes instead. */ + @Deprecated @Grpc.TransportAttr public Attributes getTransportAttrs() { return transportAttrs; @@ -120,16 +124,25 @@ public CallOptions getCallOptions() { return callOptions; } + /** + * Whether the stream is a transparent retry. + * + * @since 1.40.0 + */ + public boolean isTransparentRetry() { + return isTransparentRetry; + } + /** * Converts this StreamInfo into a new Builder. * * @since 1.21.0 */ public Builder toBuilder() { - Builder builder = new Builder(); - builder.setTransportAttrs(transportAttrs); - builder.setCallOptions(callOptions); - return builder; + return new Builder() + .setCallOptions(callOptions) + .setTransportAttrs(transportAttrs) + .setIsTransparentRetry(isTransparentRetry); } /** @@ -146,6 +159,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("transportAttrs", transportAttrs) .add("callOptions", callOptions) + .add("isTransparentRetry", isTransparentRetry) .toString(); } @@ -157,6 +171,7 @@ public String toString() { public static final class Builder { private Attributes transportAttrs = Attributes.EMPTY; private CallOptions callOptions = CallOptions.DEFAULT; + private boolean isTransparentRetry; Builder() { } @@ -164,9 +179,12 @@ public static final class Builder { /** * Sets the attributes of the transport that this stream was created on. This field is * optional. + * + * @deprecated Use {@link ClientStreamTracer#streamCreated(Attributes, Metadata)} to handle + * the transport Attributes instead. */ - @Grpc.TransportAttr - public Builder setTransportAttrs(Attributes transportAttrs) { + @Deprecated + public Builder setTransportAttrs(@Grpc.TransportAttr Attributes transportAttrs) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs cannot be null"); return this; } @@ -179,11 +197,21 @@ public Builder setCallOptions(CallOptions callOptions) { return this; } + /** + * Sets whether the stream is a transparent retry. + * + * @since 1.40.0 + */ + public Builder setIsTransparentRetry(boolean isTransparentRetry) { + this.isTransparentRetry = isTransparentRetry; + return this; + } + /** * Builds a new StreamInfo. */ public StreamInfo build() { - return new StreamInfo(transportAttrs, callOptions); + return new StreamInfo(transportAttrs, callOptions, isTransparentRetry); } } } diff --git a/api/src/test/java/io/grpc/CallOptionsTest.java b/api/src/test/java/io/grpc/CallOptionsTest.java index 31861306891..0bc0d357358 100644 --- a/api/src/test/java/io/grpc/CallOptionsTest.java +++ b/api/src/test/java/io/grpc/CallOptionsTest.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.mock; import com.google.common.base.Objects; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.internal.SerializingExecutor; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -271,7 +272,7 @@ public void increment(long period, TimeUnit unit) { } } - private static class FakeTracerFactory extends ClientStreamTracer.Factory { + private static class FakeTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { final String name; FakeTracerFactory(String name) { @@ -279,8 +280,7 @@ private static class FakeTracerFactory extends ClientStreamTracer.Factory { } @Override - public ClientStreamTracer newClientStreamTracer( - ClientStreamTracer.StreamInfo info, Metadata headers) { + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { return new ClientStreamTracer() {}; } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 04070ddfcef..b132844069c 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -32,6 +32,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.Internal; import io.grpc.InternalChannelz.SocketStats; @@ -632,28 +633,28 @@ public synchronized Runnable start(ManagedClientTransport.Listener clientTranspo public synchronized ClientStream newStream( final MethodDescriptor method, final Metadata headers, - final CallOptions callOptions) { + final CallOptions callOptions, + ClientStreamTracer[] tracers) { if (isShutdown()) { - return newFailingClientStream(shutdownStatus, callOptions, attributes, headers); + return newFailingClientStream(shutdownStatus, attributes, headers, tracers); } else { int callId = latestCallId++; if (latestCallId == LAST_CALL_ID) { latestCallId = FIRST_CALL_ID; } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); Inbound.ClientInbound inbound = new Inbound.ClientInbound( this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); if (ongoingCalls.putIfAbsent(callId, inbound) != null) { Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); shutdownInternal(failure, true); - return newFailingClientStream(failure, callOptions, attributes, headers); + return newFailingClientStream(failure, attributes, headers, tracers); } else { if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { clientTransportListener.transportInUse(true); } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(callOptions, attributes, headers); - Outbound.ClientOutbound outbound = new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); if (method.getType().clientSendsOneMessage()) { @@ -763,12 +764,12 @@ protected void handlePingResponse(Parcel parcel) { } private static ClientStream newFailingClientStream( - Status failure, CallOptions callOptions, Attributes attributes, Metadata headers) { + Status failure, Attributes attributes, Metadata headers, + ClientStreamTracer[] tracers) { StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(callOptions, attributes, headers); + StatsTraceContext.newClientContext(tracers, attributes, headers); statsTraceContext.clientOutboundHeaders(); - statsTraceContext.streamClosed(failure); - return new FailingClientStream(failure); + return new FailingClientStream(failure, tracers); } private static InternalLogId buildLogId( diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index d625a6f5c6f..ac5f4e705e3 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -138,15 +139,6 @@ public TagContext parseBytes(byte[] serialized) { }); } - /** - * Creates a {@link ClientCallTracer} for a new call. - */ - @VisibleForTesting - ClientCallTracer newClientCallTracer( - TagContext parentCtx, String fullMethodName) { - return new ClientCallTracer(this, parentCtx, fullMethodName); - } - /** * Returns the server tracer factory. */ @@ -231,6 +223,7 @@ private static final class ClientTracer extends ClientStreamTracer { } private final CensusStatsModule module; + final TagContext parentCtx; private final TagContext startCtx; volatile long outboundMessageCount; @@ -240,11 +233,22 @@ private static final class ClientTracer extends ClientStreamTracer { volatile long outboundUncompressedSize; volatile long inboundUncompressedSize; - ClientTracer(CensusStatsModule module, TagContext startCtx) { + ClientTracer(CensusStatsModule module, TagContext parentCtx, TagContext startCtx) { this.module = checkNotNull(module, "module"); + this.parentCtx = parentCtx; this.startCtx = checkNotNull(startCtx, "startCtx"); } + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + if (module.propagateTags) { + headers.discardAll(module.statsHeader); + if (!module.tagger.empty().equals(parentCtx)) { + headers.put(module.statsHeader, parentCtx); + } + } + } + @Override @SuppressWarnings("NonAtomicVolatileUpdate") public void outboundWireSize(long bytes) { @@ -315,12 +319,14 @@ public void outboundMessage(int seqNo) { } @VisibleForTesting - static final class ClientCallTracer extends ClientStreamTracer.Factory { + static final class CallAttemptsTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { @Nullable - private static final AtomicReferenceFieldUpdater + private static final AtomicReferenceFieldUpdater streamTracerUpdater; - @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Nullable + private static final AtomicIntegerFieldUpdater callEndedUpdater; /** * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their @@ -328,14 +334,14 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { * (potentially racy) direct updates of the volatile variables. */ static { - AtomicReferenceFieldUpdater tmpStreamTracerUpdater; - AtomicIntegerFieldUpdater tmpCallEndedUpdater; + AtomicReferenceFieldUpdater tmpStreamTracerUpdater; + AtomicIntegerFieldUpdater tmpCallEndedUpdater; try { tmpStreamTracerUpdater = AtomicReferenceFieldUpdater.newUpdater( - ClientCallTracer.class, ClientTracer.class, "streamTracer"); + CallAttemptsTracerFactory.class, ClientTracer.class, "streamTracer"); tmpCallEndedUpdater = - AtomicIntegerFieldUpdater.newUpdater(ClientCallTracer.class, "callEnded"); + AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); } catch (Throwable t) { logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); tmpStreamTracerUpdater = null; @@ -352,7 +358,8 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { private final TagContext parentCtx; private final TagContext startCtx; - ClientCallTracer(CensusStatsModule module, TagContext parentCtx, String fullMethodName) { + CallAttemptsTracerFactory( + CensusStatsModule module, TagContext parentCtx, String fullMethodName) { this.module = checkNotNull(module); this.parentCtx = checkNotNull(parentCtx); TagValue methodTag = TagValue.create(fullMethodName); @@ -370,7 +377,7 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - ClientTracer tracer = new ClientTracer(module, startCtx); + ClientTracer tracer = new ClientTracer(module, parentCtx, startCtx); // TODO(zhangkun83): Once retry or hedging is implemented, a ClientCall may start more than // one streams. We will need to update this file to support them. if (streamTracerUpdater != null) { @@ -383,12 +390,6 @@ public ClientStreamTracer newClientStreamTracer( "Are you creating multiple streams per call? This class doesn't yet support this case"); streamTracer = tracer; } - if (module.propagateTags) { - headers.discardAll(module.statsHeader); - if (!module.tagger.empty().equals(parentCtx)) { - headers.put(module.statsHeader, parentCtx); - } - } return tracer; } @@ -416,7 +417,7 @@ void callEnded(Status status) { long roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); ClientTracer tracer = streamTracer; if (tracer == null) { - tracer = new ClientTracer(module, startCtx); + tracer = new ClientTracer(module, parentCtx, startCtx); } MeasureMap measureMap = module.statsRecorder.newMeasureMap() // TODO(songya): remove the deprecated measure constants once they are completed removed. @@ -686,8 +687,8 @@ public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { // New RPCs on client-side inherit the tag context from the current Context. TagContext parentCtx = tagger.getCurrentTagContext(); - final ClientCallTracer tracerFactory = - newClientCallTracer(parentCtx, method.getFullMethodName()); + final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( + CensusStatsModule.this, parentCtx, method.getFullMethodName()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { diff --git a/census/src/main/java/io/grpc/census/CensusTracingModule.java b/census/src/main/java/io/grpc/census/CensusTracingModule.java index fc35d89db55..dac62206fd2 100644 --- a/census/src/main/java/io/grpc/census/CensusTracingModule.java +++ b/census/src/main/java/io/grpc/census/CensusTracingModule.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -222,7 +223,7 @@ private static void recordMessageEvent( } @VisibleForTesting - final class ClientCallTracer extends ClientStreamTracer.Factory { + final class ClientCallTracer extends ClientStreamTracer.InternalLimitedInfoFactory { volatile int callEnded; private final boolean isSampledToLocalTracing; @@ -243,11 +244,7 @@ final class ClientCallTracer extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - if (span != BlankSpan.INSTANCE) { - headers.discardAll(tracingHeader); - headers.put(tracingHeader, span.getContext()); - } - return new ClientTracer(span); + return new ClientTracer(span, tracingHeader); } /** @@ -273,9 +270,19 @@ void callEnded(io.grpc.Status status) { private static final class ClientTracer extends ClientStreamTracer { private final Span span; + final Metadata.Key tracingHeader; - ClientTracer(Span span) { + ClientTracer(Span span, Metadata.Key tracingHeader) { this.span = checkNotNull(span, "span"); + this.tracingHeader = tracingHeader; + } + + @Override + public void streamCreated(Attributes transportAtts, Metadata headers) { + if (span != BlankSpan.INSTANCE) { + headers.discardAll(tracingHeader); + headers.put(tracingHeader, span.getContext()); + } } @Override diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index fbbcd44150c..fd3a049f7a4 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -295,7 +295,7 @@ public ClientCall interceptCall( instanceof CensusTracingModule.ClientCallTracer); assertTrue( capturedCallOptions.get().getStreamTracerFactories().get(1) - instanceof CensusStatsModule.ClientCallTracer); + instanceof CensusStatsModule.CallAttemptsTracerFactory); // Make the call Metadata headers = new Metadata(); @@ -388,11 +388,12 @@ private void subtestClientBasicStatsDefaultContext( new CensusStatsModule( tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true, recordStarts, recordFinishes, recordRealTime); - CensusStatsModule.ClientCallTracer callTracer = - localCensusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); Metadata headers = new Metadata(); - ClientStreamTracer tracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); if (recordStarts) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -455,7 +456,7 @@ private void subtestClientBasicStatsDefaultContext( tracer.inboundUncompressedSize(552); tracer.streamClosed(Status.OK); - callTracer.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); if (recordFinishes) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -522,6 +523,7 @@ public void clientBasicTracingDefaultSpan() { censusTracing.newClientCallTracer(null, method); Metadata headers = new Metadata(); ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), ArgumentMatchers.isNull()); verify(spyClientSpan, never()).end(any(EndSpanOptions.class)); @@ -575,11 +577,15 @@ public void clientTracingSampledToLocalSpanStore() { @Test public void clientStreamNeverCreatedStillRecordStats() { - CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); - + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + censusStats, tagger.empty(), method.getFullMethodName()); + ClientStreamTracer streamTracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); fakeClock.forwardTime(3000, MILLISECONDS); - callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); + Status status = Status.DEADLINE_EXCEEDED.withDescription("3 seconds"); + streamTracer.streamClosed(status); + callAttemptsTracerFactory.callEnded(status); // Upstart record StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -680,10 +686,13 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS fakeClock.getStopwatchSupplier(), propagate, recordStats, recordStats, recordStats); Metadata headers = new Metadata(); - CensusStatsModule.ClientCallTracer callTracer = - census.newClientCallTracer(clientCtx, method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + census, clientCtx, method.getFullMethodName()); // This propagates clientCtx to headers if propagates==true - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); if (recordStats) { // Client upstart record StatsTestUtils.MetricsRecord clientRecord = statsRecorder.pollRecord(); @@ -746,7 +755,8 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS // Verifies that the client tracer factory uses clientCtx, which includes the custom tags, to // record stats. - callTracer.callEnded(Status.OK); + streamTracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); if (recordStats) { // Client completion record @@ -769,10 +779,12 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS @Test public void statsHeadersNotPropagateDefaultContext() { - CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + censusStats, tagger.empty(), method.getFullMethodName()); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers) + .streamCreated(Attributes.EMPTY, headers); assertFalse(headers.containsKey(censusStats.statsHeader)); // Clear recorded stats to satisfy the assertions in wrapUp() statsRecorder.rolloverRecords(); @@ -803,7 +815,8 @@ public void traceHeadersPropagateSpanContext() throws Exception { CensusTracingModule.ClientCallTracer callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); verify(mockTracingPropagationHandler).toByteArray(same(fakeClientSpanContext)); verifyNoMoreInteractions(mockTracingPropagationHandler); @@ -831,7 +844,8 @@ public void traceHeaders_propagateSpanContext() throws Exception { censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).isNotEmpty(); } @@ -845,7 +859,7 @@ public void traceHeaders_missingCensusImpl_notPropagateSpanContext() CensusTracingModule.ClientCallTracer callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).isEmpty(); } @@ -862,7 +876,7 @@ public void traceHeaders_clientMissingCensusImpl_preservingHeaders() throws Exce CensusTracingModule.ClientCallTracer callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).containsExactlyElementsIn(originalHeaderKeys); } @@ -1186,13 +1200,18 @@ public void newTagsPopulateOldViews() throws InterruptedException { tagger, tagCtxSerializer, localStats.getStatsRecorder(), fakeClock.getStopwatchSupplier(), false, false, true, false /* real-time */); - CensusStatsModule.ClientCallTracer callTracer = - localCensusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); - callTracer.newClientStreamTracer(STREAM_INFO, new Metadata()); + Metadata headers = new Metadata(); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); + tracer.streamCreated(Attributes.EMPTY, headers); fakeClock.forwardTime(30, MILLISECONDS); - callTracer.callEnded(Status.PERMISSION_DENIED.withDescription("No you don't")); + Status status = Status.PERMISSION_DENIED.withDescription("No you don't"); + tracer.streamClosed(status); + callAttemptsTracerFactory.callEnded(status); // Give OpenCensus a chance to update the views asynchronously. Thread.sleep(100); diff --git a/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java b/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java index aec2659f024..4d4349eef1b 100644 --- a/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java +++ b/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java @@ -17,7 +17,7 @@ package io.grpc.internal; import io.grpc.Attributes; -import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; @@ -50,7 +50,8 @@ public class StatsTraceContextBenchmark { @BenchmarkMode(Mode.SampleTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) public StatsTraceContext newClientContext() { - return StatsTraceContext.newClientContext(CallOptions.DEFAULT, Attributes.EMPTY, emptyMetadata); + return StatsTraceContext.newClientContext( + new ClientStreamTracer[] { new ClientStreamTracer() {} }, Attributes.EMPTY, emptyMetadata); } /** diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 58df4371e72..895b709559b 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -26,6 +26,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Compressor; import io.grpc.Deadline; import io.grpc.Decompressor; @@ -205,10 +206,12 @@ public void run() { @Override public synchronized ClientStream newStream( - final MethodDescriptor method, final Metadata headers, final CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); if (shutdownStatus != null) { - return failedClientStream( - StatsTraceContext.newClientContext(callOptions, attributes, headers), shutdownStatus); + return failedClientStream(statsTraceContext, shutdownStatus); } headers.put(GrpcUtil.USER_AGENT_KEY, userAgent); @@ -226,12 +229,12 @@ public synchronized ClientStream newStream( "Request metadata larger than %d: %d", serverMaxInboundMetadataSize, metadataSize)); - return failedClientStream( - StatsTraceContext.newClientContext(callOptions, attributes, headers), status); + return failedClientStream(statsTraceContext, status); } } - return new InProcessStream(method, headers, callOptions, authority).clientStream; + return new InProcessStream(method, headers, callOptions, authority, statsTraceContext) + .clientStream; } private ClientStream failedClientStream( @@ -377,12 +380,12 @@ private class InProcessStream { private InProcessStream( MethodDescriptor method, Metadata headers, CallOptions callOptions, - String authority) { + String authority , StatsTraceContext statsTraceContext) { this.method = checkNotNull(method, "method"); this.headers = checkNotNull(headers, "headers"); this.callOptions = checkNotNull(callOptions, "callOptions"); this.authority = authority; - this.clientStream = new InProcessClientStream(callOptions, headers); + this.clientStream = new InProcessClientStream(callOptions, statsTraceContext); this.serverStream = new InProcessServerStream(method, headers); } @@ -673,9 +676,10 @@ private class InProcessClientStream implements ClientStream { @GuardedBy("this") private int outboundSeqNo; - InProcessClientStream(CallOptions callOptions, Metadata headers) { + InProcessClientStream( + CallOptions callOptions, StatsTraceContext statsTraceContext) { this.callOptions = callOptions; - statsTraceCtx = StatsTraceContext.newClientContext(callOptions, attributes, headers); + statsTraceCtx = statsTraceContext; } private synchronized void setListener(ServerStreamListener listener) { diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 0b1ce3514a2..6b6472825d2 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -25,6 +25,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.CompositeCallCredentials; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -104,7 +105,8 @@ protected ConnectionClientTransport delegate() { @Override @SuppressWarnings("deprecation") public ClientStream newStream( - final MethodDescriptor method, Metadata headers, final CallOptions callOptions) { + final MethodDescriptor method, Metadata headers, final CallOptions callOptions, + ClientStreamTracer[] tracers) { CallCredentials creds = callOptions.getCredentials(); if (creds == null) { creds = channelCallCredentials; @@ -113,10 +115,10 @@ public ClientStream newStream( } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions, applierListener); + delegate, method, headers, callOptions, applierListener, tracers); if (pendingApplier.incrementAndGet() > 0) { applierListener.onComplete(); - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } RequestInfo requestInfo = new RequestInfo() { @Override @@ -152,9 +154,9 @@ public Attributes getTransportAttrs() { return applier.returnStream(); } else { if (pendingApplier.get() >= 0) { - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } - return delegate.newStream(method, headers, callOptions); + return delegate.newStream(method, headers, callOptions, tracers); } } diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index c2e1bd2b1f2..28cd3351203 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -33,6 +33,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.CompressorRegistry; @@ -254,9 +255,11 @@ public void runInContext() { effectiveDeadline, context.getDeadline(), callOptions.getDeadline()); stream = clientStreamProvider.newStream(method, callOptions, headers, context); } else { + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers(callOptions, headers, false); stream = new FailingClientStream( DEADLINE_EXCEEDED.withDescription( - "ClientCall started after deadline exceeded: " + effectiveDeadline)); + "ClientCall started after deadline exceeded: " + effectiveDeadline), + tracers); } if (callExecutorIsDirect) { diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index cc8471ab6a3..a569a7922df 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -17,6 +17,7 @@ package io.grpc.internal; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.Metadata; @@ -46,10 +47,15 @@ public interface ClientTransport extends InternalInstrumented { * @param method the descriptor of the remote method to be called for this stream. * @param headers to send at the beginning of the call * @param callOptions runtime options of the call + * @param tracers a non-empty array of tracers. The last element in it is reserved to be set by + * the load balancer's pick result and otherwise is a no-op tracer. * @return the newly created stream. */ // TODO(nmittler): Consider also throwing for stopping. - ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions callOptions); + ClientStream newStream( + MethodDescriptor method, Metadata headers, CallOptions callOptions, + // Using array for tracers instead of a list or composition for better performance. + ClientStreamTracer[] tracers); /** * Pings a remote endpoint. When an acknowledgement is received, the given callback will be diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 6a72eb7c21e..2b1145d1c4b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; @@ -133,7 +134,8 @@ public void run() { */ @Override public final ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { try { PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions); SubchannelPicker picker = null; @@ -141,14 +143,14 @@ public final ClientStream newStream( while (true) { synchronized (lock) { if (shutdownStatus != null) { - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } if (lastPicker == null) { - return createPendingStream(args); + return createPendingStream(args, tracers); } // Check for second time through the loop, and whether anything changed if (picker != null && pickerVersion == lastPickerVersion) { - return createPendingStream(args); + return createPendingStream(args, tracers); } picker = lastPicker; pickerVersion = lastPickerVersion; @@ -158,7 +160,8 @@ public final ClientStream newStream( callOptions.isWaitForReady()); if (transport != null) { return transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions()); + args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + tracers); } // This picker's conclusion is "buffer". If there hasn't been a newer picker set (possible // race with reprocess()), we will buffer it. Otherwise, will try with the new picker. @@ -173,8 +176,9 @@ public final ClientStream newStream( * schedule tasks on syncContext. */ @GuardedBy("lock") - private PendingStream createPendingStream(PickSubchannelArgs args) { - PendingStream pendingStream = new PendingStream(args); + private PendingStream createPendingStream( + PickSubchannelArgs args, ClientStreamTracer[] tracers) { + PendingStream pendingStream = new PendingStream(args, tracers); pendingStreams.add(pendingStream); if (getPendingStreamsCount() == 1) { syncContext.executeLater(reportTransportInUse); @@ -239,7 +243,8 @@ public final void shutdownNow(Status status) { } if (savedReportTransportTerminated != null) { for (PendingStream stream : savedPendingStreams) { - Runnable runnable = stream.setStream(new FailingClientStream(status, RpcProgress.REFUSED)); + Runnable runnable = stream.setStream( + new FailingClientStream(status, RpcProgress.REFUSED, stream.tracers)); if (runnable != null) { // Drain in-line instead of using an executor as failing stream just throws everything // away. This is essentially the same behavior as DelayedStream.cancel() but can be done @@ -346,9 +351,11 @@ public InternalLogId getLogId() { private class PendingStream extends DelayedStream { private final PickSubchannelArgs args; private final Context context = Context.current(); + private final ClientStreamTracer[] tracers; - private PendingStream(PickSubchannelArgs args) { + private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { this.args = args; + this.tracers = tracers; } /** Runnable may be null. */ @@ -357,7 +364,8 @@ private Runnable createRealStream(ClientTransport transport) { Context origContext = context.attach(); try { realStream = transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions()); + args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + tracers); } finally { context.detach(origContext); } @@ -382,6 +390,13 @@ public void cancel(Status reason) { syncContext.drain(); } + @Override + protected void onEarlyCancellation(Status reason) { + for (ClientStreamTracer tracer : tracers) { + tracer.streamClosed(reason); + } + } + @Override public void appendTimeoutInsight(InsightBuilder insight) { if (args.getCallOptions().isWaitForReady()) { diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index f0a378e8124..28ce2764c75 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -324,11 +324,15 @@ public void run() { }); } else { drainPendingCalls(); + onEarlyCancellation(reason); // Note that listener is a DelayedStreamListener listener.closed(reason, RpcProgress.PROCESSED, new Metadata()); } } + protected void onEarlyCancellation(Status reason) { + } + @GuardedBy("this") private void setRealStream(ClientStream realStream) { checkState(this.realStream == null, "realStream already set to %s", this.realStream); diff --git a/core/src/main/java/io/grpc/internal/FailingClientStream.java b/core/src/main/java/io/grpc/internal/FailingClientStream.java index 6d368b6975f..6388ef8b6ee 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientStream.java +++ b/core/src/main/java/io/grpc/internal/FailingClientStream.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -30,27 +31,33 @@ public final class FailingClientStream extends NoopClientStream { private boolean started; private final Status error; private final RpcProgress rpcProgress; + private final ClientStreamTracer[] tracers; /** * Creates a {@code FailingClientStream} that would fail with the given error. */ - public FailingClientStream(Status error) { - this(error, RpcProgress.PROCESSED); + public FailingClientStream(Status error, ClientStreamTracer[] tracers) { + this(error, RpcProgress.PROCESSED, tracers); } /** * Creates a {@code FailingClientStream} that would fail with the given error. */ - public FailingClientStream(Status error, RpcProgress rpcProgress) { + public FailingClientStream( + Status error, RpcProgress rpcProgress, ClientStreamTracer[] tracers) { Preconditions.checkArgument(!error.isOk(), "error must not be OK"); this.error = error; this.rpcProgress = rpcProgress; + this.tracers = tracers; } @Override public void start(ClientStreamListener listener) { Preconditions.checkState(!started, "already started"); started = true; + for (ClientStreamTracer tracer : tracers) { + tracer.streamClosed(error); + } listener.closed(error, rpcProgress, new Metadata()); } diff --git a/core/src/main/java/io/grpc/internal/FailingClientTransport.java b/core/src/main/java/io/grpc/internal/FailingClientTransport.java index 25d20017c92..5b31e6e5073 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientTransport.java +++ b/core/src/main/java/io/grpc/internal/FailingClientTransport.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -45,8 +46,9 @@ class FailingClientTransport implements ClientTransport { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return new FailingClientStream(error, rpcProgress); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + return new FailingClientStream(error, rpcProgress, tracers); } @Override diff --git a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java new file mode 100644 index 00000000000..fd03564d396 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java @@ -0,0 +1,101 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import com.google.common.base.MoreObjects; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.Status; + +public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { + + /** + * Returns the underlying {@code ClientStreamTracer}. + */ + protected abstract ClientStreamTracer delegate(); + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + delegate().streamCreated(transportAttrs, headers); + } + + @Override + public void outboundHeaders() { + delegate().outboundHeaders(); + } + + @Override + public void inboundHeaders() { + delegate().inboundHeaders(); + } + + @Override + public void inboundTrailers(Metadata trailers) { + delegate().inboundTrailers(trailers); + } + + @Override + public void streamClosed(Status status) { + delegate().streamClosed(status); + } + + @Override + public void outboundMessage(int seqNo) { + delegate().outboundMessage(seqNo); + } + + @Override + public void inboundMessage(int seqNo) { + delegate().inboundMessage(seqNo); + } + + @Override + public void outboundMessageSent(int seqNo, long optionalWireSize, long optionalUncompressedSize) { + delegate().outboundMessageSent(seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead(int seqNo, long optionalWireSize, long optionalUncompressedSize) { + delegate().inboundMessageRead(seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void outboundWireSize(long bytes) { + delegate().outboundWireSize(bytes); + } + + @Override + public void outboundUncompressedSize(long bytes) { + delegate().outboundUncompressedSize(bytes); + } + + @Override + public void inboundWireSize(long bytes) { + delegate().inboundWireSize(bytes); + } + + @Override + public void inboundUncompressedSize(long bytes) { + delegate().inboundUncompressedSize(bytes); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); + } +} diff --git a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java index e54f8b169d6..bfdccbe5d6a 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -45,8 +46,9 @@ public void shutdownNow(Status status) { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return delegate().newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + return delegate().newStream(method, headers, callOptions, tracers); } @Override diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 45c0fce7122..782ae21d3b3 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; @@ -26,8 +27,11 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.InternalLimitedInfoFactory; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; @@ -54,12 +58,14 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Collection; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -253,6 +259,8 @@ public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) { public static final CallOptions.Key CALL_OPTIONS_RPC_OWNED_BY_BALANCER = CallOptions.Key.create("io.grpc.internal.CALL_OPTIONS_RPC_OWNED_BY_BALANCER"); + private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; + /** * Returns true if an RPC with the given properties should be counted when calculating the * in-use state of a transport. @@ -711,9 +719,14 @@ static ClientTransport getTransportFromPickResult(PickResult result, boolean isW return new ClientTransport() { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return transport.newStream( - method, headers, callOptions.withStreamTracerFactory(streamTracerFactory)); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + StreamInfo info = StreamInfo.newBuilder().setCallOptions(callOptions).build(); + ClientStreamTracer streamTracer = + newClientStreamTracer(streamTracerFactory, info, headers); + checkState(tracers[tracers.length - 1] == NOOP_TRACER, "lb tracer already assigned"); + tracers[tracers.length - 1] = streamTracer; + return transport.newStream(method, headers, callOptions, tracers); } @Override @@ -743,6 +756,64 @@ public ListenableFuture getStats() { return null; } + /** Gets stream tracers based on CallOptions. */ + public static ClientStreamTracer[] getClientStreamTracers( + CallOptions callOptions, Metadata headers, boolean isTransparentRetry) { + List factories = callOptions.getStreamTracerFactories(); + ClientStreamTracer[] tracers = new ClientStreamTracer[factories.size() + 1]; + StreamInfo streamInfo = StreamInfo.newBuilder() + .setCallOptions(callOptions) + .setIsTransparentRetry(isTransparentRetry) + .build(); + for (int i = 0; i < factories.size(); i++) { + tracers[i] = newClientStreamTracer(factories.get(i), streamInfo, headers); + } + // Reserved to be set later by the lb as per the API contract of ClientTransport.newStream(). + // See also GrpcUtil.getTransportFromPickResult() + tracers[tracers.length - 1] = NOOP_TRACER; + return tracers; + } + + // A util function for backward compatibility to support deprecated StreamInfo.getAttributes(). + @VisibleForTesting + static ClientStreamTracer newClientStreamTracer( + final ClientStreamTracer.Factory streamTracerFactory, final StreamInfo info, + final Metadata headers) { + ClientStreamTracer streamTracer; + if (streamTracerFactory instanceof InternalLimitedInfoFactory) { + streamTracer = streamTracerFactory.newClientStreamTracer(info, headers); + } else { + streamTracer = new ForwardingClientStreamTracer() { + final ClientStreamTracer noop = new ClientStreamTracer() {}; + AtomicReference delegate = new AtomicReference<>(noop); + + void maybeInit(StreamInfo info, Metadata headers) { + delegate.compareAndSet(noop, streamTracerFactory.newClientStreamTracer(info, headers)); + } + + @Override + protected ClientStreamTracer delegate() { + return delegate.get(); + } + + @SuppressWarnings("deprecation") + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + StreamInfo streamInfo = info.toBuilder().setTransportAttrs(transportAttrs).build(); + maybeInit(streamInfo, headers); + delegate().streamCreated(transportAttrs, headers); + } + + @Override + public void streamClosed(Status status) { + maybeInit(info, headers); + delegate().streamClosed(status); + } + }; + } + return streamTracer; + } + /** Quietly closes all messages in MessageProducer. */ static void closeQuietly(MessageProducer producer) { InputStream message; diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index 331add6c8c4..fa2bf2e46bc 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -34,6 +34,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; @@ -667,8 +668,9 @@ protected ConnectionClientTransport delegate() { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - final ClientStream streamDelegate = super.newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + final ClientStream streamDelegate = super.newStream(method, headers, callOptions, tracers); return new ForwardingClientStream() { @Override protected ClientStream delegate() { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index a9d24cd247a..87162d9aba2 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -532,8 +532,10 @@ public ClientStream newStream( ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, headers, callOptions)); Context origContext = context.attach(); + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, /* isTransparentRetry= */ false); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } @@ -569,13 +571,16 @@ void postCommit() { } @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata newHeaders) { - CallOptions newOptions = callOptions.withStreamTracerFactory(tracerFactory); + ClientStream newSubstream( + Metadata newHeaders, ClientStreamTracer.Factory factory, boolean isTransparentRetry) { + CallOptions newOptions = callOptions.withStreamTracerFactory(factory); + ClientStreamTracer[] tracers = + GrpcUtil.getClientStreamTracers(newOptions, newHeaders, isTransparentRetry); ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, newHeaders, newOptions)); Context origContext = context.attach(); try { - return transport.newStream(method, newHeaders, newOptions); + return transport.newStream(method, newHeaders, newOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 76d280b2d00..6893713c1d2 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -22,6 +22,7 @@ import io.grpc.CallCredentials.MetadataApplier; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -36,7 +37,7 @@ final class MetadataApplierImpl extends MetadataApplier { private final CallOptions callOptions; private final Context ctx; private final MetadataApplierListener listener; - + private final ClientStreamTracer[] tracers; private final Object lock = new Object(); // null if neither apply() or returnStream() are called. @@ -52,13 +53,14 @@ final class MetadataApplierImpl extends MetadataApplier { MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, - CallOptions callOptions, MetadataApplierListener listener) { + CallOptions callOptions, MetadataApplierListener listener, ClientStreamTracer[] tracers) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); this.listener = listener; + this.tracers = tracers; } @Override @@ -69,7 +71,7 @@ public void apply(Metadata headers) { ClientStream realStream; Context origCtx = ctx.attach(); try { - realStream = transport.newStream(method, origHeaders, callOptions); + realStream = transport.newStream(method, origHeaders, callOptions, tracers); } finally { ctx.detach(origCtx); } @@ -80,7 +82,7 @@ public void apply(Metadata headers) { public void fail(Status status) { checkArgument(!status.isOk(), "Cannot fail with OK status"); checkState(!finalized, "apply() or fail() already called"); - finalizeWith(new FailingClientStream(status)); + finalizeWith(new FailingClientStream(status, tracers)); } private void finalizeWith(ClientStream stream) { diff --git a/core/src/main/java/io/grpc/internal/OobChannel.java b/core/src/main/java/io/grpc/internal/OobChannel.java index f69fd17e5c4..b628842efe4 100644 --- a/core/src/main/java/io/grpc/internal/OobChannel.java +++ b/core/src/main/java/io/grpc/internal/OobChannel.java @@ -26,6 +26,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Context; @@ -86,12 +87,14 @@ final class OobChannel extends ManagedChannel implements InternalInstrumented method, CallOptions callOptions, Metadata headers, Context context) { + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, /* isTransparentRetry= */ false); Context origContext = context.attach(); // delayed transport's newStream() always acquires a lock, but concurrent performance doesn't // matter here because OOB communication should be sparse, and it's not on application RPC's // critical path. try { - return delayedTransport.newStream(method, headers, callOptions); + return delayedTransport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 9d752b86576..23725788466 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -203,11 +203,11 @@ private void commitAndRun(Substream winningSubstream) { } } - private Substream createSubstream(int previousAttemptCount) { + private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { Substream sub = new Substream(previousAttemptCount); // one tracer per substream final ClientStreamTracer bufferSizeTracer = new BufferSizeTracer(sub); - ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -217,7 +217,7 @@ public ClientStreamTracer newClientStreamTracer( Metadata newHeaders = updateHeaders(headers, previousAttemptCount); // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newSubstream(tracerFactory, newHeaders); + sub.stream = newSubstream(newHeaders, tracerFactory, isTransparentRetry); return sub; } @@ -226,7 +226,7 @@ public ClientStreamTracer newClientStreamTracer( * Client stream is not yet started. */ abstract ClientStream newSubstream( - ClientStreamTracer.Factory tracerFactory, Metadata headers); + Metadata headers, ClientStreamTracer.Factory tracerFactory, boolean isTransparentRetry); /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting @@ -322,7 +322,7 @@ public void runWith(Substream substream) { state.buffer.add(new StartEntry()); } - Substream substream = createSubstream(0); + Substream substream = createSubstream(0, false); if (isHedging) { FutureCanceller scheduledHedgingRef = null; @@ -399,7 +399,7 @@ public void run() { // If this run is not cancelled, the value of state.hedgingAttemptCount won't change // until state.addActiveHedge() is called subsequently, even the state could possibly // change. - Substream newSubstream = createSubstream(state.hedgingAttemptCount); + Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); boolean cancelled = false; FutureCanceller future = null; @@ -784,8 +784,7 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (rpcProgress == RpcProgress.REFUSED && noMoreTransparentRetry.compareAndSet(false, true)) { // transparent retry - final Substream newSubstream = createSubstream( - substream.previousAttemptCount); + final Substream newSubstream = createSubstream(substream.previousAttemptCount, true); if (isHedging) { boolean commit = false; synchronized (lock) { @@ -863,8 +862,9 @@ public void run() { @Override public void run() { // retry - Substream newSubstream = - createSubstream(substream.previousAttemptCount + 1); + Substream newSubstream = createSubstream( + substream.previousAttemptCount + 1, + false); drain(newSubstream); } }); diff --git a/core/src/main/java/io/grpc/internal/StatsTraceContext.java b/core/src/main/java/io/grpc/internal/StatsTraceContext.java index adb0b63ec8a..33e84e5a0b8 100644 --- a/core/src/main/java/io/grpc/internal/StatsTraceContext.java +++ b/core/src/main/java/io/grpc/internal/StatsTraceContext.java @@ -20,7 +20,6 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; -import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; @@ -48,21 +47,12 @@ public final class StatsTraceContext { * Factory method for the client-side. */ public static StatsTraceContext newClientContext( - final CallOptions callOptions, final Attributes transportAttrs, Metadata headers) { - List factories = callOptions.getStreamTracerFactories(); - if (factories.isEmpty()) { - return NOOP; + ClientStreamTracer[] tracers, Attributes transportAtts, Metadata headers) { + StatsTraceContext ctx = new StatsTraceContext(tracers); + for (ClientStreamTracer tracer : tracers) { + tracer.streamCreated(transportAtts, headers); } - ClientStreamTracer.StreamInfo info = - ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs(transportAttrs).setCallOptions(callOptions).build(); - // This array will be iterated multiple times per RPC. Use primitive array instead of Collection - // so that for-each doesn't create an Iterator every time. - StreamTracer[] tracers = new StreamTracer[factories.size()]; - for (int i = 0; i < tracers.length; i++) { - tracers[i] = factories.get(i).newClientStreamTracer(info, headers); - } - return new StatsTraceContext(tracers); + return ctx; } /** diff --git a/core/src/main/java/io/grpc/internal/SubchannelChannel.java b/core/src/main/java/io/grpc/internal/SubchannelChannel.java index 6c316e4f185..1380a6bc716 100644 --- a/core/src/main/java/io/grpc/internal/SubchannelChannel.java +++ b/core/src/main/java/io/grpc/internal/SubchannelChannel.java @@ -22,6 +22,7 @@ import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.InternalConfigSelector; import io.grpc.Metadata; @@ -57,9 +58,11 @@ public ClientStream newStream(MethodDescriptor method, if (transport == null) { transport = notReadyTransport; } + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, /* isTransparentRetry= */ false); Context origContext = context.attach(); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java index de7d12e397c..7bb9d8cf71a 100644 --- a/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java +++ b/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java @@ -17,6 +17,7 @@ package io.grpc.util; import com.google.common.base.MoreObjects; +import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.ExperimentalApi; import io.grpc.Metadata; @@ -27,6 +28,11 @@ public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { /** Returns the underlying {@code ClientStreamTracer}. */ protected abstract ClientStreamTracer delegate(); + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + delegate().streamCreated(transportAttrs, headers); + } + @Override public void outboundHeaders() { delegate().outboundHeaders(); diff --git a/core/src/test/java/io/grpc/ClientStreamTracerTest.java b/core/src/test/java/io/grpc/ClientStreamTracerTest.java index 2008a3de5c7..df450adc630 100644 --- a/core/src/test/java/io/grpc/ClientStreamTracerTest.java +++ b/core/src/test/java/io/grpc/ClientStreamTracerTest.java @@ -34,6 +34,7 @@ public class ClientStreamTracerTest { Attributes.newBuilder().set(TRANSPORT_ATTR_KEY, "value").build(); @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_empty() { StreamInfo info = StreamInfo.newBuilder().build(); assertThat(info.getCallOptions()).isSameInstanceAs(CallOptions.DEFAULT); @@ -41,6 +42,7 @@ public void streamInfo_empty() { } @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_withInfo() { StreamInfo info = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); @@ -49,6 +51,7 @@ public void streamInfo_withInfo() { } @Test + @SuppressWarnings("deprecation") // info.setTransportAttrs() public void streamInfo_noEquality() { StreamInfo info1 = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); @@ -60,6 +63,7 @@ public void streamInfo_noEquality() { } @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_toBuilder() { StreamInfo info1 = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index 091415efadc..cd522181311 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -48,7 +48,6 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; -import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Grpc; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -172,7 +171,7 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} .setRequestMarshaller(StringMarshaller.INSTANCE) .setResponseMarshaller(StringMarshaller.INSTANCE) .build(); - private CallOptions callOptions; + private final CallOptions callOptions = CallOptions.DEFAULT; private Metadata.Key asciiKey = Metadata.Key.of( "ascii-key", Metadata.ASCII_STRING_MARSHALLER); @@ -186,24 +185,14 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} = mock(ManagedClientTransport.Listener.class); private MockServerListener serverListener = new MockServerListener(); private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private final TestClientStreamTracer clientStreamTracer1 = new TestClientStreamTracer(); - private final TestClientStreamTracer clientStreamTracer2 = new TestClientStreamTracer(); - private final ClientStreamTracer.Factory clientStreamTracerFactory = mock( - ClientStreamTracer.Factory.class, - delegatesTo(new ClientStreamTracer.Factory() { - final ArrayDeque tracers = - new ArrayDeque<>(Arrays.asList(clientStreamTracer1, clientStreamTracer2)); - - @Override - public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { - metadata.put(tracerHeaderKey, tracerKeyValue); - TestClientStreamTracer tracer = tracers.poll(); - if (tracer != null) { - return tracer; - } - return new TestClientStreamTracer(); - } - })); + private final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); + private final TestClientStreamTracer clientStreamTracer2 = new TestHeaderClientStreamTracer(); + private final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + clientStreamTracer1, clientStreamTracer2 + }; + private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer(); @@ -230,7 +219,6 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata @Before public void setUp() { server = newServer(Arrays.asList(serverStreamTracerFactory)); - callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); } @After @@ -291,7 +279,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { // after having sent a RST_STREAM to the server. Previously, this would have broken the // Netty channel. - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -314,7 +303,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { // Test that the channel is still usable i.e. we can receive headers from the server on a // new stream. - stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); stream.start(mockClientStreamListener2); serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); @@ -449,7 +439,8 @@ public void openStreamPreventsTermination() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -501,7 +492,8 @@ public void shutdownNowKillsClientStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -539,7 +531,8 @@ public void shutdownNowKillsServerStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -594,7 +587,8 @@ public void ping_duringShutdown() throws Exception { client = newClientTransport(server); startTransport(client, mockClientTransportListener); // Stream prevents termination - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); @@ -633,22 +627,19 @@ public void ping_afterTermination() throws Exception { @Test public void newStream_duringShutdown() throws Exception { - InOrder inOrder = inOrder(clientStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); // Stream prevents termination - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); - inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); - inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); Status clientStreamStatus2 = @@ -683,15 +674,14 @@ public void newStream_afterTermination() throws Exception { client.shutdown(shutdownReason); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); Thread.sleep(100); - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); assertEquals( shutdownReason, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); - verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(shutdownReason, clientStreamTracer1.getStatus()); // Assert no interactions @@ -708,7 +698,8 @@ public void transportInUse_balancerRpcsNotCounted() throws Exception { // CALL_OPTIONS_RPC_OWNED_BY_BALANCER in CallOptions. It won't be counted for in-use signal. ClientStream stream1 = client.newStream( methodDescriptor, new Metadata(), - callOptions.withOption(GrpcUtil.CALL_OPTIONS_RPC_OWNED_BY_BALANCER, Boolean.TRUE)); + callOptions.withOption(GrpcUtil.CALL_OPTIONS_RPC_OWNED_BY_BALANCER, Boolean.TRUE), + noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); MockServerTransportListener serverTransportListener @@ -717,7 +708,8 @@ methodDescriptor, new Metadata(), = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); // stream2 is the normal RPC, and will be counted for in-use - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -743,7 +735,8 @@ public void transportInUse_normalClose() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream1 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -751,7 +744,8 @@ public void transportInUse_normalClose() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); StreamCreation serverStreamCreation1 = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); StreamCreation serverStreamCreation2 @@ -773,11 +767,13 @@ public void transportInUse_clientCancel() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream1 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); @@ -792,7 +788,6 @@ public void transportInUse_clientCancel() throws Exception { @Test public void basicStream() throws Exception { - InOrder clientInOrder = inOrder(clientStreamTracerFactory); InOrder serverInOrder = inOrder(serverStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); @@ -816,14 +811,10 @@ public void basicStream() throws Exception { Metadata clientHeadersCopy = new Metadata(); clientHeadersCopy.merge(clientHeaders); - ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); - ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); - clientInOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - streamInfoCaptor.capture(), same(clientHeaders)); - ClientStreamTracer.StreamInfo streamInfo = streamInfoCaptor.getValue(); - assertThat(streamInfo.getTransportAttrs()).isSameInstanceAs( - ((ConnectionClientTransport) client).getAttributes()); - assertThat(streamInfo.getCallOptions()).isSameInstanceAs(callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, clientHeaders, callOptions, tracers); + assertThat(((TestHeaderClientStreamTracer) clientStreamTracer1).transportAttrs) + .isSameInstanceAs(((ConnectionClientTransport) client).getAttributes()); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -974,7 +965,8 @@ public void authorityPropagation() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientHeaders = new Metadata(); - ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, clientHeaders, callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1005,7 +997,8 @@ public void zeroMessageStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1044,7 +1037,8 @@ public void earlyServerClose_withServerHeaders() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1080,7 +1074,8 @@ public void earlyServerClose_noServerHeaders() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1122,7 +1117,8 @@ public void earlyServerClose_serverFailure() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1155,7 +1151,8 @@ public void earlyServerClose_serverFailure_withClientCancelOnListenerClosed() th serverTransport = serverTransportListener.transport; final ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase() { @Override @@ -1196,7 +1193,8 @@ public void clientCancel() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1230,7 +1228,8 @@ public void clientCancelFromWithinMessageRead() throws Exception { final SettableFuture closedCalled = SettableFuture.create(); final ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); final Status status = Status.CANCELLED.withDescription("nevermind"); clientStream.start(new ClientStreamListener() { private boolean messageReceived = false; @@ -1311,7 +1310,8 @@ public void serverCancel() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1331,8 +1331,6 @@ public void serverCancel() throws Exception { // Cause should not be transmitted between server and client assertNull(clientStreamStatus.getCause()); - verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); assertTrue(clientStreamTracer1.getOutboundHeaders()); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); @@ -1353,7 +1351,8 @@ public void flowControlPushBack() throws Exception { serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = @@ -1515,7 +1514,8 @@ public void interactionsAfterServerStreamCloseAreNoops() throws Exception { // boilerplate ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation server @@ -1547,7 +1547,8 @@ public void interactionsAfterClientStreamCancelAreNoops() throws Exception { // boilerplate ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListener clientListener = mock(ClientStreamListener.class); clientStream.start(clientListener); StreamCreation server @@ -1594,7 +1595,8 @@ public void transportTracer_streamStarted() throws Exception { assertEquals(0, clientBefore.streamsStarted); assertEquals(0, clientBefore.lastRemoteStreamCreatedTimeNanos); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener @@ -1624,7 +1626,8 @@ public void transportTracer_streamStarted() throws Exception { TransportStats clientBefore = getTransportStats(client); assertEquals(1, clientBefore.streamsStarted); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener @@ -1654,7 +1657,8 @@ public void transportTracer_server_streamEnded_ok() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1693,7 +1697,8 @@ public void transportTracer_server_streamEnded_nonOk() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1733,7 +1738,8 @@ public void transportTracer_client_streamEnded_nonOk() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener = @@ -1768,7 +1774,8 @@ public void transportTracer_server_receive_msg() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1809,7 +1816,8 @@ public void transportTracer_server_send_msg() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1849,7 +1857,8 @@ public void socketStats() throws Exception { server.start(serverListener); ManagedClientTransport client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1896,8 +1905,8 @@ public void serverChecksInboundMetadataSize() throws Exception { Metadata.Key.of("foo-bin", Metadata.BINARY_BYTE_MARSHALLER), new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); - ClientStream clientStream = - client.newStream(methodDescriptor, tooLargeMetadata, callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, tooLargeMetadata, callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1931,7 +1940,8 @@ public void clientChecksInboundMetadataSize_header() throws Exception { new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1975,7 +1985,8 @@ public void clientChecksInboundMetadataSize_trailer() throws Exception { new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -2011,7 +2022,9 @@ private void doPingPong(MockServerListener serverListener) throws Exception { ManagedClientTransport client = newClientTransport(server); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); startTransport(client, listener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, + new ClientStreamTracer[] { new ClientStreamTracer() {} }); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -2092,6 +2105,16 @@ private static void startTransport( verify(listener, timeout(TIMEOUT_MS)).transportReady(); } + private final class TestHeaderClientStreamTracer extends TestClientStreamTracer { + Attributes transportAttrs; + + @Override + public void streamCreated(Attributes transportAttrs, Metadata metadata) { + this.transportAttrs = transportAttrs; + metadata.put(tracerHeaderKey, tracerKeyValue); + } + } + private static class MockServerListener implements ServerListener { public final BlockingQueue listeners = new LinkedBlockingQueue<>(); diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index 7725c46726b..963a586319b 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -34,6 +34,7 @@ import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -48,6 +49,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; @@ -103,6 +105,9 @@ public class CallCredentials2ApplyingTest { private static final Metadata.Key CREDS_KEY = Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER); private static final String CREDS_VALUE = "some credentials"; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final Metadata origHeaders = new Metadata(); private ForwardingConnectionClientTransport transport; @@ -118,7 +123,9 @@ public void setUp() { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, null, mockExecutor); @@ -134,7 +141,7 @@ public void parameterPropagation_base() { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -155,7 +162,7 @@ public void parameterPropagation_transportSetSecurityLevel() { .build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -176,8 +183,10 @@ public void parameterPropagation_callOptionsSetAuthority() { when(mockTransport.getAttributes()).thenReturn(transportAttrs); Executor anotherExecutor = mock(Executor.class); - transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -199,13 +208,15 @@ public void credentialThrows() { any(io.grpc.CallCredentials2.MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -226,14 +237,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { any(RequestInfo.class), same(mockExecutor), any(io.grpc.CallCredentials2.MetadataApplier.class)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -254,12 +265,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { any(io.grpc.CallCredentials2.MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertSame(error, stream.getError()); transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @@ -269,12 +282,15 @@ public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); transport.shutdown(Status.UNAVAILABLE); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); @@ -283,11 +299,11 @@ public void applyMetadata_delayed() { headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -297,7 +313,8 @@ public void fail_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -306,11 +323,13 @@ public void fail_delayed() { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -318,14 +337,14 @@ public void fail_delayed() { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 61a221f73de..ef49e66bf2d 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -35,6 +35,7 @@ import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -49,6 +50,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; @@ -86,6 +88,9 @@ public class CallCredentialsApplyingTest { @Mock private ChannelLogger channelLogger; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private static final String AUTHORITY = "testauthority"; private static final String USER_AGENT = "testuseragent"; private static final Attributes.Key ATTR_KEY = Attributes.Key.create("somekey"); @@ -117,7 +122,9 @@ public void setUp() { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, null, mockExecutor); @@ -133,7 +140,7 @@ public void parameterPropagation_base() { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), same(mockExecutor), @@ -154,8 +161,10 @@ public void parameterPropagation_overrideByCallOptions() { when(mockTransport.getAttributes()).thenReturn(transportAttrs); Executor anotherExecutor = mock(Executor.class); - transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), @@ -175,15 +184,17 @@ public void credentialThrows() { any(RequestInfo.class), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -193,14 +204,15 @@ public void applyMetadata_inline() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -220,13 +232,15 @@ public Void answer(InvocationOnMock invocation) throws Throwable { }).when(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertSame(error, stream.getError()); transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @@ -236,23 +250,26 @@ public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); transport.shutdown(Status.UNAVAILABLE); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); @@ -261,20 +278,20 @@ public void applyMetadata_delayed() { @Test public void delayedShutdown_shutdownShutdownNowThenApply() { - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); transport.shutdownNow(Status.ABORTED); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(any(Status.class)); verify(mockTransport, never()).shutdownNow(any(Status.class)); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); verify(mockTransport).shutdownNow(Status.ABORTED); @@ -282,12 +299,12 @@ public void delayedShutdown_shutdownShutdownNowThenApply() { @Test public void delayedShutdown_shutdownThenApplyThenShutdownNow() { - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(any(Status.class)); Metadata headers = new Metadata(); @@ -308,25 +325,25 @@ public void delayedShutdown_shutdownMulti() { Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); - transport.newStream(method, origHeaders, callOptions); - transport.newStream(method, origHeaders, callOptions); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); + transport.newStream(method, origHeaders, callOptions, tracers); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); applierCaptor.getAllValues().get(1).apply(headers); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); applierCaptor.getAllValues().get(0).apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); applierCaptor.getAllValues().get(2).apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -336,7 +353,8 @@ public void fail_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), @@ -345,11 +363,13 @@ public void fail_delayed() { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -357,14 +377,15 @@ public void fail_delayed() { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -373,7 +394,8 @@ public void noCreds() { public void justCallOptionCreds() { callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); @@ -388,7 +410,8 @@ public void justChannelCreds() { transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); @@ -406,7 +429,8 @@ public void callOptionAndChanelCreds() { String creds2Value = "some more credentials"; callOptions = callOptions.withCallCredentials(new FakeCallCredentials(creds2Key, creds2Value)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 1808a4bd478..0e5e5f50599 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -37,6 +37,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; @@ -47,6 +48,7 @@ import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Codec; import io.grpc.Context; import io.grpc.Deadline; @@ -143,6 +145,8 @@ public void setUp() { any(Metadata.class), any(Context.class))) .thenReturn(stream); + when(streamTracerFactory.newClientStreamTracer(any(StreamInfo.class), any(Metadata.class))) + .thenReturn(new ClientStreamTracer() {}); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock in) { @@ -156,7 +160,7 @@ public Void answer(InvocationOnMock in) { @After public void tearDown() { - verifyNoInteractions(streamTracerFactory); + verifyNoMoreInteractions(streamTracerFactory); } @Test @@ -763,6 +767,7 @@ public void deadlineExceededBeforeCallStarted() { channelCallTracer, configSelector) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); + verify(streamTracerFactory).newClientStreamTracer(any(StreamInfo.class), any(Metadata.class)); verify(clientStreamProvider, never()) .newStream( (MethodDescriptor) any(MethodDescriptor.class), diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 9f48b8987d1..4cae565a19e 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -36,6 +36,7 @@ import static org.mockito.Mockito.when; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -57,6 +58,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -89,6 +91,9 @@ public class DelayedClientTransportTest { = CallOptions.Key.createWithDefault("shard-id", -1); private static final Status SHUTDOWN_STATUS = Status.UNAVAILABLE.withDescription("shutdown called"); + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final MethodDescriptor method = MethodDescriptor.newBuilder() @@ -122,9 +127,13 @@ public void uncaughtException(Thread t, Throwable e) { .thenReturn(PickResult.withSubchannel(mockSubchannel)); when(mockSubchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); when(mockInternalSubchannel.obtainActiveTransport()).thenReturn(mockRealTransport); - when(mockRealTransport.newStream(same(method), same(headers), same(callOptions))) + when(mockRealTransport.newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any())) .thenReturn(mockRealStream); - when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2))) + when(mockRealTransport2.newStream( + same(method2), same(headers2), same(callOptions2), + ArgumentMatchers.any())) .thenReturn(mockRealStream2); delayedTransport.start(transportListener); } @@ -135,7 +144,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void streamStartThenAssignTransport() { assertFalse(delayedTransport.hasPendingStreams()); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(delayedTransport.hasPendingStreams()); @@ -145,7 +155,9 @@ public void uncaughtException(Thread t, Throwable e) { assertEquals(0, delayedTransport.getPendingStreamsCount()); assertFalse(delayedTransport.hasPendingStreams()); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); verify(mockRealStream).start(listenerCaptor.capture()); verifyNoMoreInteractions(streamListener); listenerCaptor.getValue().onReady(); @@ -154,7 +166,7 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenAssignTransportThenShutdown() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.reprocess(mockPicker); @@ -163,7 +175,9 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); assertEquals(0, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); } @@ -181,11 +195,13 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); assertEquals(0, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof FailingClientStream); verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test public void assignTransportThenShutdownNowThenNewStream() { @@ -193,15 +209,18 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); assertEquals(0, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof FailingClientStream); verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test public void startThenCancelStreamWithoutSetTransport() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.cancel(Status.CANCELLED); @@ -213,7 +232,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenShutdownTransportThenAssignTransport() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); delayedTransport.shutdown(SHUTDOWN_STATUS); @@ -225,7 +245,8 @@ public void uncaughtException(Thread t, Throwable e) { // ... and will proceed if a real transport is available delayedTransport.reprocess(mockPicker); fakeExecutor.runDueTasks(); - verify(mockRealTransport).newStream(method, headers, callOptions); + verify(mockRealTransport).newStream( + method, headers, callOptions, tracers); verify(mockRealStream).start(any(ClientStreamListener.class)); // Since no more streams are pending, delayed transport is now terminated @@ -233,7 +254,8 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportTerminated(); // Further newStream() will return a failing stream - stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); verify(streamListener, never()).closed( any(Status.class), any(RpcProgress.class), any(Metadata.class)); stream.start(streamListener); @@ -247,7 +269,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenShutdownTransportThenCancelStream() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); @@ -264,7 +287,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); @@ -272,7 +296,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void startStreamThenShutdownNow() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); @@ -286,7 +311,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); @@ -301,55 +327,59 @@ public void uncaughtException(Thread t, Throwable e) { AbstractSubchannel subchannel1 = mock(AbstractSubchannel.class); AbstractSubchannel subchannel2 = mock(AbstractSubchannel.class); AbstractSubchannel subchannel3 = mock(AbstractSubchannel.class); - when(mockRealTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream); - when(mockRealTransport2.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream2); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + when(mockRealTransport2.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream2); when(subchannel1.getInternalSubchannel()).thenReturn(newTransportProvider(mockRealTransport)); when(subchannel2.getInternalSubchannel()).thenReturn(newTransportProvider(mockRealTransport2)); when(subchannel3.getInternalSubchannel()).thenReturn(newTransportProvider(null)); // Fail-fast streams DelayedStream ff1 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions); + method, headers, failFastCallOptions, tracers); ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); PickSubchannelArgsImpl ff1args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions); + method2, headers2, failFastCallOptions, tracers); PickSubchannelArgsImpl ff2args = new PickSubchannelArgsImpl(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions); + method, headers, failFastCallOptions, tracers); PickSubchannelArgsImpl ff3args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions); + method2, headers2, failFastCallOptions, tracers); PickSubchannelArgsImpl ff4args = new PickSubchannelArgsImpl(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions); + method, headers, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr1args = new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions); + method2, headers2, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr2args = new PickSubchannelArgsImpl(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( - method, headers, wfr3callOptions); + method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); wfr3.halfClose(); PickSubchannelArgsImpl wfr3args = new PickSubchannelArgsImpl(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions); + method2, headers2, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr4args = new PickSubchannelArgsImpl(method2, headers2, waitForReadyCallOptions); @@ -386,8 +416,10 @@ public void uncaughtException(Thread t, Throwable e) { // streams are now owned by a real transport (which should prevent the Channel from // terminating). // ff1 and wfr1 went through - verify(mockRealTransport).newStream(method, headers, failFastCallOptions); - verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions); + verify(mockRealTransport).newStream( + method, headers, failFastCallOptions, tracers); + verify(mockRealTransport2).newStream( + method, headers, waitForReadyCallOptions, tracers); assertSame(mockRealStream, ff1.getRealStream()); assertSame(mockRealStream2, wfr1.getRealStream()); verify(mockRealStream).start(any(ClientStreamListener.class)); @@ -443,7 +475,7 @@ public void uncaughtException(Thread t, Throwable e) { // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions); + method, headers, waitForReadyCallOptions, tracers); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions)); @@ -474,14 +506,17 @@ public void reprocess_NoPendingStream() { when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel)); - when(mockRealTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); delayedTransport.reprocess(picker); verifyNoMoreInteractions(picker); verifyNoMoreInteractions(transportListener); // Though picker was not originally used, it will be saved and serve future streams. - ClientStream stream = delayedTransport.newStream(method, headers, CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, headers, CallOptions.DEFAULT, tracers); verify(picker).pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT)); verify(mockInternalSubchannel).obtainActiveTransport(); assertSame(mockRealStream, stream); @@ -519,7 +554,7 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable { @Override public void run() { // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers, callOptions); + delayedTransport.newStream(method, headers, callOptions, tracers); } }; sideThread.start(); @@ -552,7 +587,7 @@ public void run() { @Override public void run() { // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers2, callOptions); + delayedTransport.newStream(method, headers2, callOptions, tracers); } }; sideThread2.start(); @@ -600,7 +635,8 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { // Because there is no pending stream yet, it will do nothing but save the picker. delayedTransport.reprocess(picker); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); assertTrue(delayedTransport.hasPendingStreams()); verify(transportListener).transportInUse(true); @@ -609,7 +645,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { @Test public void pendingStream_appendTimeoutInsight_waitForReady() { ClientStream stream = delayedTransport.newStream( - method, headers, callOptions.withWaitForReady()); + method, headers, callOptions.withWaitForReady(), tracers); stream.start(streamListener); InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); diff --git a/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java b/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java index dad82902395..c07812577d5 100644 --- a/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -33,13 +34,16 @@ */ @RunWith(JUnit4.class) public class FailingClientStreamTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; @Test public void processedRpcProgressPopulatedToListener() { ClientStreamListener listener = mock(ClientStreamListener.class); Status status = Status.UNAVAILABLE; - ClientStream stream = new FailingClientStream(status); + ClientStream stream = new FailingClientStream(status, RpcProgress.PROCESSED, tracers); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.PROCESSED), any(Metadata.class)); } @@ -49,7 +53,7 @@ public void droppedRpcProgressPopulatedToListener() { ClientStreamListener listener = mock(ClientStreamListener.class); Status status = Status.UNAVAILABLE; - ClientStream stream = new FailingClientStream(status, RpcProgress.DROPPED); + ClientStream stream = new FailingClientStream(status, RpcProgress.DROPPED, tracers); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.DROPPED), any(Metadata.class)); } diff --git a/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java b/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java index ff15ef7ff02..98749d74910 100644 --- a/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.verify; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -41,8 +42,9 @@ public void newStreamStart() { Status error = Status.UNAVAILABLE; RpcProgress rpcProgress = RpcProgress.DROPPED; FailingClientTransport transport = new FailingClientTransport(error, rpcProgress); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + new ClientStreamTracer[] { new ClientStreamTracer() {} }); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); diff --git a/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java b/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java new file mode 100644 index 00000000000..5eb5b49fa19 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.mockito.Mockito.mock; + +import io.grpc.ClientStreamTracer; +import io.grpc.ForwardingTestUtil; +import java.lang.reflect.Method; +import java.util.Collections; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ForwardingClientStreamTracer}. */ +@RunWith(JUnit4.class) +public class ForwardingClientStreamTracerTest { + private final ClientStreamTracer mockDelegate = mock(ClientStreamTracer.class); + + @Test + public void allMethodsForwarded() throws Exception { + ForwardingTestUtil.testMethodsForwarded( + ClientStreamTracer.class, + mockDelegate, + new ForwardingClientStreamTracerTest.TestClientStreamTracer(), + Collections.emptyList()); + } + + private final class TestClientStreamTracer extends ForwardingClientStreamTracer { + @Override + protected ClientStreamTracer delegate() { + return mockDelegate; + } + } +} diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 7a3808de6e3..95d1c448f4f 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -27,13 +28,17 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.LoadBalancer.PickResult; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil.Http2Error; import io.grpc.testing.TestMethodDescriptors; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -44,6 +49,10 @@ @RunWith(JUnit4.class) public class GrpcUtilTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; + @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); @@ -244,8 +253,9 @@ public void getTransportFromPickResult_errorPickResult_failFast() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); @@ -260,8 +270,9 @@ public void getTransportFromPickResult_dropPickResult_waitForReady() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); @@ -276,11 +287,39 @@ public void getTransportFromPickResult_dropPickResult_failFast() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.DROPPED), any(Metadata.class)); } + + @Test + public void clientStreamTracerFactoryBackwardCompatibility() { + final AtomicReference transportAttrsRef = new AtomicReference<>(); + final ClientStreamTracer mockTracer = mock(ClientStreamTracer.class); + final Metadata.Key key = Metadata.Key.of("fake-key", Metadata.ASCII_STRING_MARSHALLER); + ClientStreamTracer.Factory oldFactoryImpl = new ClientStreamTracer.Factory() { + @SuppressWarnings("deprecation") + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + transportAttrsRef.set(info.getTransportAttrs()); + headers.put(key, "fake-value"); + return mockTracer; + } + }; + + StreamInfo info = + StreamInfo.newBuilder().setCallOptions(CallOptions.DEFAULT.withWaitForReady()).build(); + Metadata metadata = new Metadata(); + Attributes transAttrs = + Attributes.newBuilder().set(Attributes.Key.create("foo"), "bar").build(); + ClientStreamTracer tracer = GrpcUtil.newClientStreamTracer(oldFactoryImpl, info, metadata); + tracer.streamCreated(transAttrs, metadata); + + assertThat(transportAttrsRef.get()).isEqualTo(transAttrs); + assertThat(metadata.get(key)).isEqualTo("fake-value"); + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index e5e017d756a..ccfb5f074c5 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -71,6 +71,7 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.CompositeChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -151,6 +152,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -225,6 +227,8 @@ public boolean shouldAccept(Runnable command) { private ArgumentCaptor statusCaptor; @Captor private ArgumentCaptor callOptionsCaptor; + @Captor + private ArgumentCaptor tracersCaptor; @Mock private LoadBalancer mockLoadBalancer; @Mock @@ -525,7 +529,9 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -534,7 +540,9 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), executor.runDueTasks(); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); assertThat(callOptionsCaptor.getValue().isWaitForReady()).isTrue(); verify(mockStream).start(streamListenerCaptor.capture()); @@ -600,7 +608,9 @@ public ClientCall interceptCall( MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -609,7 +619,9 @@ public ClientCall interceptCall( executor.runDueTasks(); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); assertThat(callOptionsCaptor.getValue().getOption(callOptionsKey)).isEqualTo("fooValue"); verify(mockStream).start(streamListenerCaptor.capture()); @@ -800,9 +812,13 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft ConnectionClientTransport mockTransport = transportInfo.transport; verify(mockTransport).start(any(ManagedClientTransport.Listener.class)); ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) + when(mockTransport.newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any())) .thenReturn(mockStream); - when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT))) + when(mockTransport.newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any())) .thenReturn(mockStream2); transportListener.transportReady(); when(mockPicker.pickSubchannel( @@ -820,14 +836,19 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); call.start(mockCallListener, headers); - verify(mockTransport, never()) - .newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport, never()).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // Second RPC, will be assigned to the real transport ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener2, headers2); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); + verify(mockTransport).newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); verify(mockStream2).start(any(ClientStreamListener.class)); // Shutdown @@ -872,7 +893,9 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, picker2); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -1021,7 +1044,9 @@ public void callOptionsExecutor() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -1031,7 +1056,8 @@ public void callOptionsExecutor() { // Real streams are started in the call executor if they were previously buffered. assertEquals(1, callExecutor.runDueTasks()); - verify(mockTransport).newStream(same(method), same(headers), same(options)); + verify(mockTransport).newStream( + same(method), same(headers), same(options), ArgumentMatchers.any()); verify(mockStream).start(streamListenerCaptor.capture()); // Call listener callbacks are also run in the call executor @@ -1298,7 +1324,8 @@ public void firstResolvedServerFailedToConnect() throws Exception { same(goodAddress), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo goodTransportInfo = transports.poll(); when(goodTransportInfo.transport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mock(ClientStream.class)); goodTransportInfo.listener.transportReady(); @@ -1310,11 +1337,13 @@ public void firstResolvedServerFailedToConnect() throws Exception { // Delayed transport uses the app executor to create real streams. executor.runDueTasks(); - verify(goodTransportInfo.transport).newStream(same(method), same(headers), - same(CallOptions.DEFAULT)); + verify(goodTransportInfo.transport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // The bad transport was never used. - verify(badTransportInfo.transport, times(0)).newStream(any(MethodDescriptor.class), - any(Metadata.class), any(CallOptions.class)); + verify(badTransportInfo.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test @@ -1464,10 +1493,12 @@ public void allServersFailedToConnect() throws Exception { // ... while the wait-for-ready call stays verifyNoMoreInteractions(mockCallListener); // No real stream was ever created - verify(transportInfo1.transport, times(0)) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - verify(transportInfo2.transport, times(0)) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + verify(transportInfo1.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + verify(transportInfo2.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test @@ -1763,8 +1794,9 @@ public void oobchannels() { assertEquals(0, balancerRpcExecutor.numPendingTasks()); transportInfo.listener.transportReady(); assertEquals(1, balancerRpcExecutor.runDueTasks()); - verify(transportInfo.transport).newStream(same(method), same(headers), - same(CallOptions.DEFAULT)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // The transport goes away transportInfo.listener.transportShutdown(Status.UNAVAILABLE); @@ -1870,7 +1902,9 @@ public void oobChannelHasNoChannelCallCredentials() { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(channelCredValue, callCredValue).inOrder(); @@ -1887,7 +1921,9 @@ public void oobChannelHasNoChannelCallCredentials() { transportInfo.listener.transportReady(); balancerRpcExecutor.runDueTasks(); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); @@ -1919,7 +1955,9 @@ public void oobChannelHasNoChannelCallCredentials() { call.start(mockCallListener2, headers); // CallOptions may contain StreamTracerFactory for census that is added by default. - verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + verify(transportInfo.transport).newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); } @@ -1962,7 +2000,9 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(channelCredValue, callCredValue).inOrder(); @@ -1998,7 +2038,9 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { call.start(mockCallListener2, headers); // CallOptions may contain StreamTracerFactory for census that is added by default. - verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + verify(transportInfo.transport).newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(oobChannelCredValue, callCredValue).inOrder(); oob.shutdownNow(); @@ -2097,7 +2139,9 @@ public void subchannelChannel_normalUsage() { ClientCall call = sChannel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); CallOptions capturedCallOption = callOptionsCaptor.getValue(); assertThat(capturedCallOption.getDeadline()).isSameInstanceAs(callOptions.getDeadline()); @@ -2125,7 +2169,8 @@ public void subchannelChannel_failWhenNotReady() { ClientCall call = sChannel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, headers); verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verifyNoInteractions(mockCallListener); assertEquals(1, balancerRpcExecutor.runDueTasks()); @@ -2157,7 +2202,8 @@ public void subchannelChannel_failWaitForReady() { sChannel.newCall(method, CallOptions.DEFAULT.withWaitForReady()); call.start(mockCallListener, headers); verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verifyNoInteractions(mockCallListener); assertEquals(1, balancerRpcExecutor.runDueTasks()); @@ -2332,7 +2378,8 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { return mock(ClientStream.class); } }).when(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(creds, never()).applyRequestMetadata( any(RequestInfo.class), any(Executor.class), any(CallCredentials.MetadataApplier.class)); @@ -2351,11 +2398,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { assertEquals(AUTHORITY, infoCaptor.getValue().getAuthority()); assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); verify(transport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); // newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport).newStream( + same(method), any(Metadata.class), same(callOptions), + ArgumentMatchers.any()); assertEquals("testValue", testKey.get(newStreamContexts.poll())); // The context should not live beyond the scope of newStream() and applyRequestMetadata() assertNull(testKey.get()); @@ -2374,11 +2424,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); // This is from the first call verify(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); // Still, newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport, times(2)).newStream( + same(method), any(Metadata.class), same(callOptions), + ArgumentMatchers.any()); assertEquals("testValue", testKey.get(newStreamContexts.poll())); assertNull(testKey.get()); @@ -2387,8 +2440,20 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { @Test public void pickerReturnsStreamTracer_noDelay() { ClientStream mockStream = mock(ClientStream.class); - ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); - ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); + final ClientStreamTracer tracer1 = new ClientStreamTracer() {}; + final ClientStreamTracer tracer2 = new ClientStreamTracer() {}; + ClientStreamTracer.Factory factory1 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer1; + } + }; + ClientStreamTracer.Factory factory2 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer2; + } + }; createChannel(); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); @@ -2397,7 +2462,8 @@ public void pickerReturnsStreamTracer_noDelay() { transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( @@ -2409,20 +2475,29 @@ public void pickerReturnsStreamTracer_noDelay() { call.start(mockCallListener, new Metadata()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); - verify(mockTransport).newStream(same(method), any(Metadata.class), callOptionsCaptor.capture()); - assertEquals( - Arrays.asList(factory1, factory2), - callOptionsCaptor.getValue().getStreamTracerFactories()); - // The factories are safely not stubbed because we do not expect any usage of them. - verifyNoInteractions(factory1); - verifyNoInteractions(factory2); + verify(mockTransport).newStream( + same(method), any(Metadata.class), callOptionsCaptor.capture(), + tracersCaptor.capture()); + assertThat(tracersCaptor.getValue()).isEqualTo(new ClientStreamTracer[] {tracer1, tracer2}); } @Test public void pickerReturnsStreamTracer_delayed() { ClientStream mockStream = mock(ClientStream.class); - ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); - ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); + final ClientStreamTracer tracer1 = new ClientStreamTracer() {}; + final ClientStreamTracer tracer2 = new ClientStreamTracer() {}; + ClientStreamTracer.Factory factory1 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer1; + } + }; + ClientStreamTracer.Factory factory2 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer2; + } + }; createChannel(); CallOptions callOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1); @@ -2436,7 +2511,8 @@ public void pickerReturnsStreamTracer_delayed() { transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory2)); @@ -2445,13 +2521,10 @@ public void pickerReturnsStreamTracer_delayed() { assertEquals(1, executor.runDueTasks()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); - verify(mockTransport).newStream(same(method), any(Metadata.class), callOptionsCaptor.capture()); - assertEquals( - Arrays.asList(factory1, factory2), - callOptionsCaptor.getValue().getStreamTracerFactories()); - // The factories are safely not stubbed because we do not expect any usage of them. - verifyNoInteractions(factory1); - verifyNoInteractions(factory2); + verify(mockTransport).newStream( + same(method), any(Metadata.class), callOptionsCaptor.capture(), + tracersCaptor.capture()); + assertThat(tracersCaptor.getValue()).isEqualTo(new ClientStreamTracer[] {tracer1, tracer2}); } @Test @@ -2818,7 +2891,9 @@ public void idleMode_resetsDelayedTransportPicker() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2829,7 +2904,9 @@ public void idleMode_resetsDelayedTransportPicker() { executor.runDueTasks(); // Verify the buffered call was drained - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2888,7 +2965,9 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -2898,7 +2977,9 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { // Verify the original call was drained executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2920,7 +3001,9 @@ public void updateBalancingStateDoesUpdatePicker() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2929,8 +3012,9 @@ public void updateBalancingStateDoesUpdatePicker() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport, never()) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream, never()).start(any(ClientStreamListener.class)); @@ -2939,7 +3023,9 @@ public void updateBalancingStateDoesUpdatePicker() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2958,7 +3044,9 @@ public void updateBalancingState_withWrappedSubchannel() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2973,7 +3061,9 @@ protected Subchannel delegate() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -3405,7 +3495,8 @@ private void channelsAndSubchannels_instrumented0(boolean success) throws Except transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory)); @@ -3478,7 +3569,9 @@ private void channelsAndSubchannels_oob_instrumented0(boolean success) throws Ex MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); // subchannel stat bumped when call gets assigned to it @@ -3650,7 +3743,9 @@ public double nextDouble() { ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); updateBalancingStateSafely(helper, READY, mockPicker); @@ -3754,7 +3849,9 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); updateBalancingStateSafely(helper, READY, mockPicker); diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index a83964b5e91..26c6fcf9b4e 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -163,9 +163,10 @@ void postCommit() { } @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata metadata) { + ClientStream newSubstream( + Metadata metadata, ClientStreamTracer.Factory tracerFactory, boolean isTransparentRetry) { bufferSizeTracer = - tracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracerFactory.newClientStreamTracer(STREAM_INFO, metadata); int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null ? 0 : Integer.valueOf(metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); return retriableStreamRecorder.newSubstream(actualPreviousRpcAttemptsInHeader); diff --git a/core/src/test/java/io/grpc/internal/TestUtils.java b/core/src/test/java/io/grpc/internal/TestUtils.java index d5b4ce4949e..974f36e595c 100644 --- a/core/src/test/java/io/grpc/internal/TestUtils.java +++ b/core/src/test/java/io/grpc/internal/TestUtils.java @@ -23,6 +23,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.InternalLogId; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -35,6 +36,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import javax.annotation.Nullable; +import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -118,7 +120,8 @@ public ConnectionClientTransport answer(InvocationOnMock invocation) throws Thro when(mockTransport.getLogId()) .thenReturn(InternalLogId.allocate("mocktransport", /*details=*/ null)); when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mock(ClientStream.class)); // Save the listener doAnswer(new Answer() { diff --git a/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java b/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java index fcb19b69eb8..dbd7e99b29a 100644 --- a/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java +++ b/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java @@ -40,6 +40,7 @@ public void allMethodsForwarded() throws Exception { Collections.emptyList()); } + @SuppressWarnings("deprecation") private final class TestClientStreamTracer extends ForwardingClientStreamTracer { @Override protected ClientStreamTracer delegate() { diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java index dc4fc45ae4e..d41ec372d4c 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -118,7 +119,7 @@ public ListenableFuture getStats() { @Override public CronetClientStream newStream(final MethodDescriptor method, final Metadata headers, - final CallOptions callOptions) { + final CallOptions callOptions, ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); @@ -126,7 +127,7 @@ public CronetClientStream newStream(final MethodDescriptor method, final M final String url = "https://" + authority + defaultPath; final StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, attrs, headers); + StatsTraceContext.newClientContext(tracers, attrs, headers); class StartCallback implements Runnable { final CronetClientStream clientStream = new CronetClientStream( url, userAgent, executor, headers, CronetClientTransport.this, this, lock, maxMessageSize, diff --git a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java index 39fe03991e4..c27963c6d56 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java @@ -25,6 +25,7 @@ import android.os.Build; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.cronet.CronetChannelBuilder.CronetTransportFactory; @@ -50,6 +51,8 @@ public final class CronetChannelBuilderTest { @Mock private ExperimentalCronetEngine mockEngine; @Mock private ChannelLogger channelLogger; + private final ClientStreamTracer[] tracers = + new ClientStreamTracer[]{ new ClientStreamTracer() {} }; private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @Before @@ -69,7 +72,8 @@ public void alwaysUsePutTrue_cronetStreamIsIdempotent() throws Exception { new InetSocketAddress("localhost", 443), new ClientTransportOptions(), channelLogger); - CronetClientStream stream = transport.newStream(method, new Metadata(), CallOptions.DEFAULT); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); assertTrue(stream.idempotent); } @@ -85,7 +89,8 @@ public void alwaysUsePut_defaultsToFalse() throws Exception { new InetSocketAddress("localhost", 443), new ClientTransportOptions(), channelLogger); - CronetClientStream stream = transport.newStream(method, new Metadata(), CallOptions.DEFAULT); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); assertFalse(stream.idempotent); } diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java index 50017cb43f8..9503481e747 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java @@ -27,6 +27,7 @@ import android.os.Build; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -60,6 +61,8 @@ public final class CronetClientTransportTest { private static final Attributes EAG_ATTRS = Attributes.newBuilder().set(EAG_ATTR_KEY, "value").build(); + private final ClientStreamTracer[] tracers = + new ClientStreamTracer[]{ new ClientStreamTracer() {} }; private CronetClientTransport transport; @Mock private StreamBuilderFactory streamFactory; @Mock private Executor executor; @@ -101,9 +104,9 @@ public void transportAttributes() { @Test public void shutdownTransport() throws Exception { CronetClientStream stream1 = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); CronetClientStream stream2 = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); // Create a transport and start two streams on it. ArgumentCaptor callbackCaptor = @@ -137,7 +140,7 @@ public void shutdownTransport() throws Exception { @Test public void startStreamAfterShutdown() throws Exception { CronetClientStream stream = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdown(); BaseClientStreamListener listener = new BaseClientStreamListener(); stream.start(listener); diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java index d27c485dc13..75f2481254d 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java @@ -37,7 +37,7 @@ * span of an LB stream with the remote load-balancer. */ @ThreadSafe -final class GrpclbClientLoadRecorder extends ClientStreamTracer.Factory { +final class GrpclbClientLoadRecorder extends ClientStreamTracer.InternalLimitedInfoFactory { private static final AtomicLongFieldUpdater callsStartedUpdater = AtomicLongFieldUpdater.newUpdater(GrpclbClientLoadRecorder.class, "callsStarted"); diff --git a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java index 03b9bdf7f1b..03e1447bb2c 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java +++ b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java @@ -22,6 +22,7 @@ import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.internal.GrpcAttributes; import javax.annotation.Nullable; @@ -29,7 +30,7 @@ * Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and * attaches them to headers. This is only used in the PICK_FIRST mode. */ -final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { +final class TokenAttachingTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; @Nullable @@ -42,19 +43,30 @@ final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs"); - Attributes eagAttrs = - checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); - String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); - headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); - if (token != null) { - headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); - } - if (delegate != null) { - return delegate.newClientStreamTracer(info, headers); - } else { + if (delegate == null) { return NOOP_TRACER; } + final ClientStreamTracer clientStreamTracer = delegate.newClientStreamTracer(info, headers); + class TokenPropagationTracer extends ForwardingClientStreamTracer { + @Override + protected ClientStreamTracer delegate() { + return clientStreamTracer; + } + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + Attributes eagAttrs = + checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); + String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); + headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); + if (token != null) { + headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); + } + delegate().streamCreated(transportAttrs, headers); + } + } + + return new TokenPropagationTracer(); } @Override diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 39f736dbcf4..a68962ad7d9 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -481,6 +481,7 @@ public void loadReporting() { ClientStreamTracer tracer1 = pick1.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer1.streamCreated(Attributes.EMPTY, new Metadata()); PickResult pick2 = picker.pickSubchannel(args); assertNull(pick2.getSubchannel()); @@ -504,6 +505,7 @@ public void loadReporting() { assertSame(getLoadRecorder(), pick3.getStreamTracerFactory()); ClientStreamTracer tracer3 = pick3.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer3.streamCreated(Attributes.EMPTY, new Metadata()); // pick3 has sent out headers tracer3.outboundHeaders(); @@ -541,6 +543,7 @@ public void loadReporting() { assertSame(getLoadRecorder(), pick5.getStreamTracerFactory()); ClientStreamTracer tracer5 = pick5.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer5.streamCreated(Attributes.EMPTY, new Metadata()); // pick3 ended without receiving response headers tracer3.streamClosed(Status.DEADLINE_EXCEEDED); diff --git a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java index 34b0a8ea1aa..29ded18d913 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java @@ -33,12 +33,23 @@ /** Unit tests for {@link TokenAttachingTracerFactory}. */ @RunWith(JUnit4.class) public class TokenAttachingTracerFactoryTest { - private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {}; + private static final class FakeClientStreamTracer extends ClientStreamTracer { + Attributes transportAttrs; + Metadata headers; + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + this.transportAttrs = transportAttrs; + this.headers = headers; + } + } + + private static final FakeClientStreamTracer fakeTracer = new FakeClientStreamTracer(); private final ClientStreamTracer.Factory delegate = mock( ClientStreamTracer.Factory.class, delegatesTo( - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -51,28 +62,25 @@ public void hasToken() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); Attributes eagAttrs = Attributes.newBuilder() .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build(); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); // Preexisting token should be replaced headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token"); ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); verify(delegate).newClientStreamTracer(same(info), same(headers)); - assertThat(tracer).isSameInstanceAs(fakeTracer); + Attributes transportAttrs = + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); + tracer.streamCreated(transportAttrs, headers); + assertThat(fakeTracer.transportAttrs).isSameInstanceAs(transportAttrs); + assertThat(fakeTracer.headers).isSameInstanceAs(headers); assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001"); } @Test public void noToken() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder() - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); // Preexisting token should be removed @@ -80,22 +88,25 @@ public void noToken() { ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); verify(delegate).newClientStreamTracer(same(info), same(headers)); - assertThat(tracer).isSameInstanceAs(fakeTracer); + Attributes transportAttrs = + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(); + tracer.streamCreated(transportAttrs, headers); + assertThat(fakeTracer.transportAttrs).isSameInstanceAs(transportAttrs); + assertThat(fakeTracer.headers).isSameInstanceAs(headers); assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); } @Test public void nullDelegate() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder() - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); + tracer.streamCreated( + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(), + headers); assertThat(tracer).isNotNull(); assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 758f99d5353..1b447a63c32 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -289,7 +289,7 @@ final SocketAddress getListenAddress() { new LinkedBlockingQueue<>(); private final ClientStreamTracer.Factory clientStreamTracerFactory = - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -375,7 +375,8 @@ protected final ClientInterceptor createCensusStatsClientInterceptor() { .getClientInterceptor( tagger, tagContextBinarySerializer, clientStatsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, - true, true, true, false /* real-time metrics */); + true, true, true, + /* recordRealTimeMetrics= */ false); } protected final ServerStreamTracer.Factory createCustomCensusTracerFactory() { @@ -1179,6 +1180,7 @@ public void deadlineExceeded() throws Exception { public void deadlineExceededServerStreaming() throws Exception { // warm up the channel and JVM blockingStub.emptyCall(Empty.getDefaultInstance()); + assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); ResponseParameters.Builder responseParameters = ResponseParameters.newBuilder() .setSize(1) .setIntervalUs(10000); @@ -1195,7 +1197,6 @@ public void deadlineExceededServerStreaming() throws Exception { recorder.awaitCompletion(); assertEquals(Status.DEADLINE_EXCEEDED.getCode(), Status.fromThrowable(recorder.getError()).getCode()); - assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); if (metricsExpected()) { // Stream may not have been created when deadline is exceeded, thus we don't check tracer // stats. @@ -1239,6 +1240,12 @@ public void deadlineInPast() throws Exception { // warm up the channel blockingStub.emptyCall(Empty.getDefaultInstance()); + if (metricsExpected()) { + // clientStartRecord + clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + // clientEndRecord + clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + } try { blockingStub .withDeadlineAfter(-10, TimeUnit.SECONDS) @@ -1249,7 +1256,6 @@ public void deadlineInPast() throws Exception { assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after deadline exceeded"); } - assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); if (metricsExpected()) { MetricsRecord clientStartRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); checkStartTags(clientStartRecord, "grpc.testing.TestService/EmptyCall", true); diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index c3807986c9f..a7a1044059c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -28,6 +28,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -167,14 +168,15 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); if (channel == null) { - return new FailingClientStream(statusExplainingWhyTheChannelIsNull); + return new FailingClientStream(statusExplainingWhyTheChannelIsNull, tracers); } StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, getAttributes(), headers); + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); return new NettyClientStream( new NettyClientStream.TransportState( handler, diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index e4165e89243..018ca9b6594 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -41,6 +41,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.InternalChannelz; import io.grpc.Metadata; @@ -828,7 +829,9 @@ private static class Rpc { } Rpc(NettyClientTransport transport, Metadata headers) { - stream = transport.newStream(METHOD, headers, CallOptions.DEFAULT); + stream = transport.newStream( + METHOD, headers, CallOptions.DEFAULT, + new ClientStreamTracer[]{ new ClientStreamTracer() {} }); stream.start(listener); stream.request(1); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8))); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index a001ddb73e7..121093716db 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -34,6 +34,7 @@ import com.squareup.okhttp.internal.http.StatusLine; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz; @@ -387,12 +388,13 @@ public void ping(final PingCallback callback, Executor executor) { } @Override - public OkHttpClientStream newStream(final MethodDescriptor method, - final Metadata headers, CallOptions callOptions) { + public OkHttpClientStream newStream( + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); - StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, attributes, headers); + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -406,7 +408,7 @@ public OkHttpClientStream newStream(final MethodDescriptor method, initialWindowSize, defaultAuthority, userAgent, - statsTraceCtx, + statsTraceContext, transportTracer, callOptions, useGetForSafeMethods); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index b03a2dedc00..b70b832a797 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -56,6 +56,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -146,6 +147,9 @@ public class OkHttpClientTransportTest { private static final int DEFAULT_MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE; private static final Attributes EAG_ATTRS = Attributes.EMPTY; private static final Logger logger = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; @Rule public final Timeout globalTimeout = Timeout.seconds(10); @@ -299,7 +303,7 @@ public void close() throws SecurityException { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -387,7 +391,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -443,11 +447,11 @@ public void nextFrameThrowIoException() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); assertEquals(2, activeStreamCount()); @@ -477,7 +481,7 @@ public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertEquals(1, activeStreamCount()); @@ -498,7 +502,7 @@ public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); frameReader.nextFrameAtEndOfStream(); @@ -516,7 +520,7 @@ public void readMessages() throws Exception { final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); assertContainStream(3); @@ -566,7 +570,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -590,7 +594,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -610,7 +614,7 @@ public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); @@ -624,7 +628,7 @@ public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().rstStream(3, ErrorCode.PROTOCOL_ERROR); @@ -641,7 +645,7 @@ public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); @@ -661,7 +665,7 @@ public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -676,7 +680,7 @@ public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), GrpcUtil.getGrpcUserAgent("okhttp", null)); @@ -695,7 +699,7 @@ public void overrideDefaultUserAgent() throws Exception { startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "fakeUserAgent"); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List
expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), @@ -714,7 +718,7 @@ public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -728,7 +732,7 @@ public void writeMessage() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); assertEquals(12, input.available()); @@ -772,12 +776,12 @@ public void windowUpdate() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); assertEquals(2, activeStreamCount()); @@ -838,7 +842,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; byte[] fakeMessage = new byte[messageLength]; @@ -874,7 +878,7 @@ public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); // Outbound window always starts at 65535 until changed by Settings.INITIAL_WINDOW_SIZE @@ -920,7 +924,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 75; @@ -963,7 +967,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 100000; @@ -999,7 +1003,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; setInitialWindowSize(HEADER_LENGTH + 10); @@ -1045,7 +1049,7 @@ public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() thr initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; setInitialWindowSize(HEADER_LENGTH + 10); @@ -1080,10 +1084,10 @@ public void stopNormally() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); clientTransport.shutdown(SHUTDOWN_REASON); @@ -1110,11 +1114,11 @@ public void receiveGoAway() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); assertEquals(2, activeStreamCount()); @@ -1168,7 +1172,7 @@ public void streamIdExhausted() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1204,11 +1208,11 @@ public void pendingStreamSucceed() throws Exception { final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); @@ -1241,7 +1245,7 @@ public void pendingStreamCancelled() throws Exception { setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); stream.cancel(Status.CANCELLED); @@ -1260,11 +1264,11 @@ public void pendingStreamFailedByGoAway() throws Exception { final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); waitForStreamPending(1); @@ -1290,7 +1294,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1314,15 +1318,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(2); @@ -1346,7 +1350,7 @@ public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1398,7 +1402,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( eq(false), eq(false), eq(3), eq(0), ArgumentMatchers.
anyList()); @@ -1415,7 +1419,7 @@ public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); @@ -1437,7 +1441,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); @@ -1459,7 +1463,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1000]); @@ -1480,7 +1484,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1507,7 +1511,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); // This should be ignored. @@ -1527,7 +1531,7 @@ public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); assertTrue(listener.isOnReadyCalled()); @@ -1545,7 +1549,7 @@ public void notifyOnReady() throws Exception { setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); // Be notified at the beginning. @@ -1695,7 +1699,7 @@ public void writeBeforeConnected() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); stream.writeMessage(input); @@ -1720,7 +1724,7 @@ public void cancelBeforeConnected() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); stream.writeMessage(input); @@ -1738,7 +1742,7 @@ public void shutdownDuringConnecting() throws Exception { initTransportAndDelayConnected(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); clientTransport.shutdown(SHUTDOWN_REASON); allowTransportConnected(); @@ -1810,7 +1814,8 @@ public void unreachableServer() throws Exception { assertTrue(status.getCause().toString(), status.getCause() instanceof IOException); MockStreamListener streamListener = new MockStreamListener(); - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT).start(streamListener); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers) + .start(streamListener); streamListener.waitUntilStreamClosed(); assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode()); } @@ -2054,13 +2059,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2094,13 +2099,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); assertEquals(3, activeStreamCount()); @@ -2158,7 +2163,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); assertFalse(listener.status.isOk()); diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 5beefc3384c..dffbe3dade7 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -31,8 +31,8 @@ import io.grpc.LoadBalancer; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.internal.ObjectPool; -import io.grpc.util.ForwardingClientStreamTracer; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; @@ -329,7 +329,8 @@ public String toString() { } } - private static final class CountingStreamTracerFactory extends ClientStreamTracer.Factory { + private static final class CountingStreamTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { private ClusterLocalityStats stats; private final AtomicLong inFlights; @Nullable diff --git a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java index c193f5e35e5..156d53f638e 100644 --- a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java +++ b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java @@ -25,8 +25,8 @@ import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.LoadBalancer; import io.grpc.Metadata; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.protobuf.ProtoUtils; -import io.grpc.util.ForwardingClientStreamTracer; import java.util.ArrayList; import java.util.List; @@ -37,7 +37,7 @@ abstract class OrcaPerRequestUtil { private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER = new ClientStreamTracer() {}; private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY = - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { return NOOP_CLIENT_STREAM_TRACER; @@ -189,7 +189,8 @@ public interface OrcaPerRequestReportListener { * per-request ORCA reports and push to registered listeners for calls they trace. */ @VisibleForTesting - static final class OrcaReportingTracerFactory extends ClientStreamTracer.Factory { + static final class OrcaReportingTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { @VisibleForTesting static final Metadata.Key ORCA_ENDPOINT_LOAD_METRICS_KEY = diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 74aa85501a9..3b2a54c2c25 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -341,8 +341,8 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); - streamTracerFactory.newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), - new Metadata()); + streamTracerFactory.newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); } ClusterStats clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); @@ -429,8 +429,8 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); - streamTracerFactory.newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), - new Metadata()); + streamTracerFactory.newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); } ClusterStats clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER));