diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index ab2a36d9492..de96a3306bd 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -29,9 +29,12 @@ import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; import io.grpc.Status; +import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener; import java.net.SocketAddress; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory { private final ClientTransportFactory delegate; @@ -66,6 +69,21 @@ public void close() { private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport { private final ConnectionClientTransport delegate; private final String authority; + // Negative value means transport active, non-negative value indicates shutdown invoked. + private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1); + private volatile Status shutdownStatus; + @GuardedBy("this") + private Status savedShutdownStatus; + @GuardedBy("this") + private Status savedShutdownNowStatus; + private final MetadataApplierListener applierListener = new MetadataApplierListener() { + @Override + public void onComplete() { + if (pendingApplier.decrementAndGet() == 0) { + maybeShutdown(); + } + } + }; CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) { this.delegate = checkNotNull(delegate, "delegate"); @@ -89,7 +107,11 @@ public ClientStream newStream( } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions); + delegate, method, headers, callOptions, applierListener); + if (pendingApplier.incrementAndGet() > 0) { + applierListener.onComplete(); + return new FailingClientStream(shutdownStatus); + } RequestInfo requestInfo = new RequestInfo() { @Override public MethodDescriptor getMethodDescriptor() { @@ -123,8 +145,69 @@ public Attributes getTransportAttrs() { } return applier.returnStream(); } else { + if (pendingApplier.get() >= 0) { + return new FailingClientStream(shutdownStatus); + } return delegate.newStream(method, headers, callOptions); } } + + @Override + public void shutdown(Status status) { + checkNotNull(status, "status"); + synchronized (this) { + if (pendingApplier.get() < 0) { + shutdownStatus = status; + pendingApplier.addAndGet(Integer.MAX_VALUE); + } else { + return; + } + if (pendingApplier.get() != 0) { + savedShutdownStatus = status; + return; + } + } + super.shutdown(status); + } + + // TODO(zivy): cancel pending applier here. + @Override + public void shutdownNow(Status status) { + checkNotNull(status, "status"); + synchronized (this) { + if (pendingApplier.get() < 0) { + shutdownStatus = status; + pendingApplier.addAndGet(Integer.MAX_VALUE); + } else if (savedShutdownNowStatus != null) { + return; + } + if (pendingApplier.get() != 0) { + savedShutdownNowStatus = status; + // TODO(zivy): propagate shutdownNow to the delegate immediately. + return; + } + } + super.shutdownNow(status); + } + + private void maybeShutdown() { + Status maybeShutdown; + Status maybeShutdownNow; + synchronized (this) { + if (pendingApplier.get() != 0) { + return; + } + maybeShutdown = savedShutdownStatus; + maybeShutdownNow = savedShutdownNowStatus; + savedShutdownStatus = null; + savedShutdownNowStatus = null; + } + if (maybeShutdown != null) { + super.shutdown(maybeShutdown); + } + if (maybeShutdownNow != null) { + super.shutdownNow(maybeShutdownNow); + } + } } } diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 4c49a14a06b..76d280b2d00 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier { private final Metadata origHeaders; private final CallOptions callOptions; private final Context ctx; + private final MetadataApplierListener listener; private final Object lock = new Object(); @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier { MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, - CallOptions callOptions) { + CallOptions callOptions, MetadataApplierListener listener) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); + this.listener = listener; } @Override @@ -84,14 +86,19 @@ public void fail(Status status) { private void finalizeWith(ClientStream stream) { checkState(!finalized, "already finalized"); finalized = true; + boolean directStream = false; synchronized (lock) { if (returnedStream == null) { // Fast path: returnStream() hasn't been called, the call will use the // real stream directly. returnedStream = stream; - return; + directStream = true; } } + if (directStream) { + listener.onComplete(); + return; + } // returnStream() has been called before me, thus delayedStream must have been // created. checkState(delayedStream != null, "delayedStream is null"); @@ -100,6 +107,7 @@ private void finalizeWith(ClientStream stream) { // TODO(ejona): run this on a separate thread slow.run(); } + listener.onComplete(); } /** @@ -116,4 +124,11 @@ ClientStream returnStream() { } } } + + public interface MetadataApplierListener { + /** + * Notify that the metadata has been successfully applied, or failed. + * */ + void onComplete(); + } } diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index c26944c16b2..7725c46726b 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; @@ -203,6 +204,10 @@ public void credentialThrows() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -227,6 +232,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable { assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -249,6 +258,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @Test @@ -263,6 +276,9 @@ public void applyMetadata_delayed() { any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); @@ -271,6 +287,9 @@ public void applyMetadata_delayed() { assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -290,6 +309,10 @@ public void fail_delayed() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -301,5 +324,9 @@ public void noCreds() { assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } } diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 6949ab7c310..61a221f73de 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -19,12 +19,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -179,6 +181,11 @@ public void credentialThrows() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); + + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -192,6 +199,10 @@ public void applyMetadata_inline() { assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -214,6 +225,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @Test @@ -228,6 +243,11 @@ public void applyMetadata_delayed() { same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); @@ -236,6 +256,79 @@ public void applyMetadata_delayed() { assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void delayedShutdown_shutdownShutdownNowThenApply() { + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + transport.shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(any(Status.class)); + verify(mockTransport, never()).shutdownNow(any(Status.class)); + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + verify(mockTransport).shutdownNow(Status.ABORTED); + } + + @Test + public void delayedShutdown_shutdownThenApplyThenShutdownNow() { + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(any(Status.class)); + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + verify(mockTransport).shutdownNow(Status.ABORTED); + + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + verify(mockTransport, times(2)).shutdownNow(Status.ABORTED); + } + + @Test + public void delayedShutdown_shutdownMulti() { + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + + transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + applierCaptor.getAllValues().get(1).apply(headers); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + applierCaptor.getAllValues().get(0).apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + applierCaptor.getAllValues().get(2).apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -255,6 +348,10 @@ public void fail_delayed() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -266,6 +363,10 @@ public void noCreds() { assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test