Skip to content

Commit

Permalink
netty: Handle write queue promise failures (#11016)
Browse files Browse the repository at this point in the history
Handles Netty write frame failures caused by issues in the Netty
itself.

Normally we don't need to do anything on frame write failures because
the cause of a failed future would be an IO error that resulted in
the stream closure.  Prior to this PR we treated these issues as a
noop, except the initial headers write on the client side.

However, a case like netty/netty#13805 (a bug in generating next
stream id) resulted in an unclosed stream on our side. This PR adds
write frame future failure handlers that ensures the stream is
cancelled, and the cause is propagated via Status.

Fixes #10849
  • Loading branch information
sergiitk committed Apr 16, 2024
1 parent 497e155 commit e490273
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) {
*/
public final void transportReportStatus(final Status status) {
Preconditions.checkArgument(!status.isOk(), "status must not be OK");
onStreamDeallocated();
if (deframerClosed) {
deframerClosedTask = null;
closeListener(status);
Expand All @@ -315,6 +316,7 @@ public void run() {
* #transportReportStatus}.
*/
public void complete() {
onStreamDeallocated();
if (deframerClosed) {
deframerClosedTask = null;
closeListener(Status.OK);
Expand Down Expand Up @@ -350,7 +352,6 @@ private void closeListener(Status newStatus) {
getTransportTracer().reportStreamClosed(closedStatus.isOk());
}
listenerClosed = true;
onStreamDeallocated();
listener().closed(newStatus);
}
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/io/grpc/internal/AbstractStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ protected final void onStreamDeallocated() {
}
}

protected boolean isStreamDeallocated() {
synchronized (onReadyLock) {
return deallocated;
}
}

/**
* Event handler to be called by the subclass when a number of bytes are being queued for
* sending to the remote endpoint.
Expand Down
39 changes: 26 additions & 13 deletions netty/src/main/java/io/grpc/netty/NettyClientStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,10 @@ private void writeFrameInternal(
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
ChannelFutureListener failureListener =
future -> transportState().onWriteFrameData(future, numMessages, numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// If the future succeeds when http2stream is null, the stream has been cancelled
// before it began and Netty is purging pending writes from the flow-controller.
if (future.isSuccess() && transportState().http2Stream() != null) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages);
}
}
});
.addListener(failureListener);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(
Expand Down Expand Up @@ -307,6 +297,29 @@ protected void http2ProcessingFailed(Status status, boolean stopDelivery, Metada
handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true);
}

private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
// If the future succeeds when http2stream is null, the stream has been cancelled
// before it began and Netty is purging pending writes from the flow-controller.
if (future.isSuccess() && http2Stream() == null) {
return;
}

if (future.isSuccess()) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
onSentBytes(numBytes);
getTransportTracer().reportMessageSent(numMessages);
} else if (!isStreamDeallocated()) {
// Future failed, fail RPC.
// Normally we don't need to do anything here because the cause of a failed future
// while writing DATA frames would be an IO error and the stream is already closed.
// However, we still need handle any unexpected failures raised in Netty.
// Note: isStreamDeallocated() protects from spamming stream resets by scheduling multiple
// CancelClientStreamCommand commands.
http2ProcessingFailed(statusFromFailedFuture(future), true, new Metadata());
}
}

@Override
public void runOnTransportThread(final Runnable r) {
if (eventLoop.inEventLoop()) {
Expand Down
3 changes: 1 addition & 2 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,7 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
state,
attributes,
authority,
statsTraceCtx,
transportTracer);
statsTraceCtx);
transportListener.streamCreated(stream, method, metadata);
state.onStreamAllocated();
http2Stream.setProperty(streamKey, state);
Expand Down
85 changes: 52 additions & 33 deletions netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,19 @@ class NettyServerStream extends AbstractServerStream {
private final WriteQueue writeQueue;
private final Attributes attributes;
private final String authority;
private final TransportTracer transportTracer;
private final int streamId;

public NettyServerStream(
Channel channel,
TransportState state,
Attributes transportAttrs,
String authority,
StatsTraceContext statsTraceCtx,
TransportTracer transportTracer) {
StatsTraceContext statsTraceCtx) {
super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx);
this.state = checkNotNull(state, "transportState");
this.writeQueue = state.handler.getWriteQueue();
this.attributes = checkNotNull(transportAttrs);
this.authority = authority;
this.transportTracer = checkNotNull(transportTracer, "transportTracer");
// Read the id early to avoid reading transportState later.
this.streamId = transportState().id();
}
Expand Down Expand Up @@ -96,48 +93,37 @@ private class Sink implements AbstractServerStream.Sink {
@Override
public void writeHeaders(Metadata headers, boolean flush) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeHeaders")) {
writeQueue.enqueue(
SendResponseHeadersCommand.createHeaders(
transportState(),
Utils.convertServerHeaders(headers)),
flush);
Http2Headers http2headers = Utils.convertServerHeaders(headers);
SendResponseHeadersCommand headersCommand =
SendResponseHeadersCommand.createHeaders(transportState(), http2headers);
writeQueue.enqueue(headersCommand, flush)
.addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures);
}
}

private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) {
Preconditions.checkArgument(numMessages >= 0);
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
if (future.isSuccess()) {
transportTracer.reportMessageSent(numMessages);
}
}
});
}

@Override
public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) {
writeFrameInternal(frame, flush, numMessages);
Preconditions.checkArgument(numMessages >= 0);
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
ChannelFutureListener failureListener =
future -> transportState().onWriteFrameData(future, numMessages, numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener(failureListener);
}
}

@Override
public void writeTrailers(Metadata trailers, boolean headersSent, Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeTrailers")) {
Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent);
writeQueue.enqueue(
SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status),
true);
SendResponseHeadersCommand trailersCommand =
SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status);
writeQueue.enqueue(trailersCommand, true)
.addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures);
}
}

Expand Down Expand Up @@ -206,6 +192,39 @@ public void deframeFailed(Throwable cause) {
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}

private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
if (future.isSuccess()) {
onSentBytes(numBytes);
getTransportTracer().reportMessageSent(numMessages);
} else {
handleWriteFutureFailures(future);
}
}

private void handleWriteFutureFailures(ChannelFuture future) {
// isStreamDeallocated() check protects from spamming stream resets by scheduling multiple
// CancelServerStreamCommand commands.
if (future.isSuccess() || isStreamDeallocated()) {
return;
}

// Future failed, fail RPC.
// Normally we don't need to do anything on frame write failures because the cause of
// the failed future would be an IO error that closed the stream.
// However, we still need handle any unexpected failures raised in Netty.
http2ProcessingFailed(Utils.statusFromThrowable(future.cause()));
}

/**
* Called to process a failure in HTTP/2 processing.
*/
protected void http2ProcessingFailed(Status status) {
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}

void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
Expand Down
49 changes: 49 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.STATUS_OK;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand All @@ -34,6 +36,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -62,6 +65,7 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.util.AsciiString;
import java.io.BufferedInputStream;
Expand All @@ -75,6 +79,7 @@
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
Expand Down Expand Up @@ -205,6 +210,50 @@ public void writeMessageShouldSendRequestUnknownLength() throws Exception {
eq(true));
}

@Test
public void writeFrameFutureFailedShouldCancelRpc() {
Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID);
// Fail all SendGrpcFrameCommands command sent to the queue.
when(writeQueue.enqueue(any(SendGrpcFrameCommand.class), anyBoolean())).thenReturn(
new DefaultChannelPromise(channel).setFailure(h2Error));

// Write multiple messages to ensure multiple SendGrpcFrameCommand are enqueued. We set up all
// of them to fail, which allows us to assert that only a single cancel is sent, and the stream
// isn't spammed with multiple RST_STREAM.
stream().transportState().setId(STREAM_ID);
stream.writeMessage(new ByteArrayInputStream(smallMessage()));
stream.writeMessage(new ByteArrayInputStream(largeMessage()));
stream.flush();

InOrder inOrder = Mockito.inOrder(writeQueue);
// Normal stream create and write frame.
inOrder.verify(writeQueue).enqueue(any(CreateStreamCommand.class), eq(false));
inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(false));
// Verify that failed SendGrpcFrameCommand results in immediate CancelClientStreamCommand.
inOrder.verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true));
// Verify that any other failures do not produce another CancelClientStreamCommand in the queue.
inOrder.verify(writeQueue, atLeast(1)).enqueue(any(SendGrpcFrameCommand.class), eq(false));
inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(true));
inOrder.verifyNoMoreInteractions();

// Get the CancelClientStreamCommand written to the queue. Above we verified that there is
// only one CancelClientStreamCommand enqueued, and is the third enqueued command (create,
// frame write failure, cancel).
CancelClientStreamCommand cancelCommand = Mockito.mockingDetails(writeQueue).getInvocations()
// Get enqueue() innovations only
.stream().filter(invocation -> invocation.getMethod().getName().equals("enqueue"))
// Get the third invocation of enqueue()
.skip(2).findFirst().get()
// Get the first argument (QueuedCommand command)
.getArgument(0);

Status cancelReason = cancelCommand.reason();
assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode());
assertThat(cancelReason.getCause()).isEqualTo(h2Error);
// Verify listener closed.
verify(listener).closed(same(cancelReason), eq(PROCESSED), any(Metadata.class));
}

@Test
public void setStatusWithOkShouldCloseStream() {
stream().transportState().setId(STREAM_ID);
Expand Down
Loading

0 comments on commit e490273

Please sign in to comment.