Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;

final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
Expand Down Expand Up @@ -66,6 +69,21 @@ public void close() {
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
// Negative value means transport active, non-negative value indicates shutdown invoked.
private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1);
private volatile Status shutdownStatus;
@GuardedBy("this")
private Status savedShutdownStatus;
@GuardedBy("this")
private Status savedShutdownNowStatus;
private final MetadataApplierListener applierListener = new MetadataApplierListener() {
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};

CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
Expand All @@ -89,7 +107,11 @@ public ClientStream newStream(
}
if (creds != null) {
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);
delegate, method, headers, callOptions, applierListener);
if (pendingApplier.incrementAndGet() > 0) {
applierListener.onComplete();
return new FailingClientStream(shutdownStatus);
}
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
Expand Down Expand Up @@ -123,8 +145,69 @@ public Attributes getTransportAttrs() {
}
return applier.returnStream();
} else {
if (pendingApplier.get() >= 0) {
return new FailingClientStream(shutdownStatus);
}
return delegate.newStream(method, headers, callOptions);
}
}

@Override
public void shutdown(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownStatus = status;
return;
}
}
super.shutdown(status);
}

// TODO(zivy): cancel pending applier here.
@Override
public void shutdownNow(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else if (savedShutdownNowStatus != null) {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownNowStatus = status;
// TODO(zivy): propagate shutdownNow to the delegate immediately.
return;
}
}
super.shutdownNow(status);
}

private void maybeShutdown() {
Status maybeShutdown;
Status maybeShutdownNow;
synchronized (this) {
if (pendingApplier.get() != 0) {
return;
}
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
savedShutdownStatus = null;
savedShutdownNowStatus = null;
}
if (maybeShutdown != null) {
super.shutdown(maybeShutdown);
}
if (maybeShutdownNow != null) {
super.shutdownNow(maybeShutdownNow);
}
}
}
}
19 changes: 17 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier {
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final MetadataApplierListener listener;

private final Object lock = new Object();

Expand All @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier {

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
CallOptions callOptions) {
CallOptions callOptions, MetadataApplierListener listener) {
this.transport = transport;
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
this.listener = listener;
}

@Override
Expand Down Expand Up @@ -84,14 +86,19 @@ public void fail(Status status) {
private void finalizeWith(ClientStream stream) {
checkState(!finalized, "already finalized");
finalized = true;
boolean directStream = false;
synchronized (lock) {
if (returnedStream == null) {
// Fast path: returnStream() hasn't been called, the call will use the
// real stream directly.
returnedStream = stream;
return;
directStream = true;
}
}
if (directStream) {
listener.onComplete();
return;
}
// returnStream() has been called before me, thus delayedStream must have been
// created.
checkState(delayedStream != null, "delayedStream is null");
Expand All @@ -100,6 +107,7 @@ private void finalizeWith(ClientStream stream) {
// TODO(ejona): run this on a separate thread
slow.run();
}
listener.onComplete();
}

/**
Expand All @@ -116,4 +124,11 @@ ClientStream returnStream() {
}
}
}

public interface MetadataApplierListener {
/**
* Notify that the metadata has been successfully applied, or failed.
* */
void onComplete();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -203,6 +204,10 @@ public void credentialThrows() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -227,6 +232,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -249,6 +258,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}

@Test
Expand All @@ -263,6 +276,9 @@ public void applyMetadata_delayed() {
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
Expand All @@ -271,6 +287,9 @@ public void applyMetadata_delayed() {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -290,6 +309,10 @@ public void fail_delayed() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -301,5 +324,9 @@ public void noCreds() {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
}
Loading