diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 91edfb631e3..d78402730e3 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -92,7 +92,6 @@ void writeFrame( private final Framer framer; private boolean useGet; private Metadata headers; - private boolean outboundClosed; /** * Whether cancel() has been called. This is not strictly necessary, but removes the delay between * cancel() being called and isReady() beginning to return false, since cancel is commonly @@ -175,8 +174,8 @@ public final void deliverFrame( @Override public final void halfClose() { - if (!outboundClosed) { - outboundClosed = true; + if (!transportState().isOutboundClosed()) { + transportState().setOutboundClosed(); endOfMessages(); } } @@ -209,6 +208,9 @@ protected abstract static class TransportState extends AbstractStream.TransportS private boolean deframerClosed = false; private Runnable deframerClosedTask; + /** Whether the client has half-closed the stream. */ + private volatile boolean outboundClosed; + /** * Whether the stream is closed from the transport's perspective. This can differ from {@link * #listenerClosed} because there may still be messages buffered to deliver to the application. @@ -253,6 +255,14 @@ protected final ClientStreamListener listener() { return listener; } + private final void setOutboundClosed() { + outboundClosed = true; + } + + protected final boolean isOutboundClosed() { + return outboundClosed; + } + /** * Called by transport implementations when they receive headers. * diff --git a/netty/src/main/java/io/grpc/netty/CancelClientStreamCommand.java b/netty/src/main/java/io/grpc/netty/CancelClientStreamCommand.java index 9ea2725cd68..f13863e2579 100644 --- a/netty/src/main/java/io/grpc/netty/CancelClientStreamCommand.java +++ b/netty/src/main/java/io/grpc/netty/CancelClientStreamCommand.java @@ -18,18 +18,19 @@ import com.google.common.base.Preconditions; import io.grpc.Status; +import javax.annotation.Nullable; /** * Command sent from a Netty client stream to the handler to cancel the stream. */ class CancelClientStreamCommand extends WriteQueue.AbstractQueuedCommand { private final NettyClientStream.TransportState stream; - private final Status reason; + @Nullable private final Status reason; CancelClientStreamCommand(NettyClientStream.TransportState stream, Status reason) { this.stream = Preconditions.checkNotNull(stream, "stream"); - Preconditions.checkNotNull(reason, "reason"); - Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status"); + Preconditions.checkArgument( + reason == null || !reason.isOk(), "Should not cancel with OK status"); this.reason = reason; } @@ -37,6 +38,7 @@ NettyClientStream.TransportState stream() { return stream; } + @Nullable Status reason() { return reason; } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 552d174bbc2..fd16746b438 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -524,7 +524,10 @@ public void operationComplete(ChannelFuture future) throws Exception { private void cancelStream(ChannelHandlerContext ctx, CancelClientStreamCommand cmd, ChannelPromise promise) { NettyClientStream.TransportState stream = cmd.stream(); - stream.transportReportStatus(cmd.reason(), true, new Metadata()); + Status reason = cmd.reason(); + if (reason != null) { + stream.transportReportStatus(reason, true, new Metadata()); + } encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/netty/NettyClientStream.java index 0e3ecfc874f..e569ca9d8f9 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientStream.java @@ -293,6 +293,9 @@ public void deframeFailed(Throwable cause) { void transportHeadersReceived(Http2Headers headers, boolean endOfStream) { if (endOfStream) { + if (!isOutboundClosed()) { + handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, null), true); + } transportTrailersReceived(Utils.convertTrailers(headers)); } else { transportHeadersReceived(Utils.convertHeaders(headers)); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index fe67313231b..4ff3b4f6476 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -261,6 +261,28 @@ public void inboundTrailersClosesCall() throws Exception { stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true); } + @Test + public void inboundTrailersBeforeHalfCloseSendsRstStream() { + stream().transportState().setId(STREAM_ID); + stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false); + stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true); + + // Verify a cancel stream with reason=null is sent to the handler. + ArgumentCaptor captor = ArgumentCaptor + .forClass(CancelClientStreamCommand.class); + verify(writeQueue).enqueue(captor.capture(), eq(true)); + assertNull(captor.getValue().reason()); + } + + @Test + public void inboundTrailersAfterHalfCloseDoesNotSendRstStream() { + stream().transportState().setId(STREAM_ID); + stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false); + stream.halfClose(); + stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true); + verify(writeQueue, never()).enqueue(isA(CancelClientStreamCommand.class), eq(true)); + } + @Test public void inboundStatusShouldSetStatus() throws Exception { stream().transportState().setId(STREAM_ID); @@ -293,7 +315,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { stream().transportState().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false); // Now verify that cancel is sent and an error is reported to the listener - verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true)); + verify(writeQueue).enqueue(isA(CancelClientStreamCommand.class), eq(true)); ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); verify(listener).closed(captor.capture(), same(PROCESSED), metadataCaptor.capture()); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 4cb6153a9c4..7a5ee95dab7 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -320,7 +320,7 @@ public void transportDataReceived(okio.Buffer frame, boolean endOfStream) { @GuardedBy("lock") private void onEndOfStream() { - if (!framer().isClosed()) { + if (!isOutboundClosed()) { // If server's end-of-stream is received before client sends end-of-stream, we just send a // reset to server to fully close the server side stream. transport.finishStream(id(),null, PROCESSED, false, ErrorCode.CANCEL, null);