Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

okhttp: okhttp client and server transport should use padded length for flow control #10422

Merged
merged 7 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,12 @@ public void transportHeadersReceived(List<Header> headers, boolean endOfStream)
* Must be called with holding the transport lock.
*/
@GuardedBy("lock")
public void transportDataReceived(okio.Buffer frame, boolean endOfStream) {
public void transportDataReceived(okio.Buffer frame, boolean endOfStream, int paddingLen) {
// We only support 16 KiB frames, and the max permitted in HTTP/2 is 16 MiB. This is verified
// in OkHttp's Http2 deframer. In addition, this code is after the data has been read.
int length = (int) frame.size();
window -= length;
window -= (length + paddingLen);
processedWindow -= paddingLen;
if (window < 0) {
frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR);
transport.finishStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,8 @@ public void run() {
*/
@SuppressWarnings("GuardedBy")
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
public void data(boolean inFinished, int streamId, BufferedSource in, int length,
int paddedLength)
throws IOException {
logger.logData(OkHttpFrameLogger.Direction.INBOUND,
streamId, in.getBuffer(), length, inFinished);
Expand All @@ -1166,12 +1167,12 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length
synchronized (lock) {
// TODO(b/145386688): This access should be guarded by 'stream.transportState().lock';
// instead found: 'OkHttpClientTransport.this.lock'
stream.transportState().transportDataReceived(buf, inFinished);
stream.transportState().transportDataReceived(buf, inFinished, paddedLength - length);
}
}

// connection window update
connectionUnacknowledgedBytesRead += length;
connectionUnacknowledgedBytesRead += paddedLength;
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) {
synchronized (lock) {
frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,8 @@ private int headerBlockSize(List<Header> headerBlock) {
* Handle an HTTP2 DATA frame.
*/
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
public void data(boolean inFinished, int streamId, BufferedSource in, int length,
int paddedLength)
throws IOException {
frameLogger.logData(
OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished);
Expand Down Expand Up @@ -865,7 +866,7 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length
}

// connection window update
connectionUnacknowledgedBytesRead += length;
connectionUnacknowledgedBytesRead += paddedLength;
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
if (connectionUnacknowledgedBytesRead
>= config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO) {
synchronized (lock) {
Expand Down
76 changes: 50 additions & 26 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,16 @@ public void close() throws SecurityException {

final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
assertThat(logs).hasSize(1);
log = logs.remove(0);
assertThat(log.getMessage()).startsWith(Direction.INBOUND + " DATA: streamId=" + 3);
assertThat(log.getLevel()).isEqualTo(Level.FINE);

// At most 64 bytes of data frame will be logged.
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])),
1000, 1000);
assertThat(logs).hasSize(1);
log = logs.remove(0);
String data = log.getMessage();
Expand Down Expand Up @@ -377,7 +379,8 @@ public void maxMessageSizeShouldBeEnforced() throws Exception {
// Receive the message.
final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

listener.waitUntilStreamClosed();
assertEquals(Code.RESOURCE_EXHAUSTED, listener.status.getCode());
Expand Down Expand Up @@ -500,7 +503,8 @@ public void readMessages() throws Exception {
assertNotNull(listener.headers);
for (int i = 0; i < numMessages; i++) {
Buffer buffer = createMessageFrame(message + i);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
}
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener.waitUntilStreamClosed();
Expand Down Expand Up @@ -529,7 +533,8 @@ public void receivedHeadersForInvalidStreamShouldKillConnection() throws Excepti
@Test
public void receivedDataForInvalidStreamShouldKillConnection() throws Exception {
initTransport();
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])),
1000, 1000);
verify(frameWriter, timeout(TIME_OUT_MS))
.goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown(isA(Status.class));
Expand All @@ -551,7 +556,8 @@ public void invalidInboundHeadersCancelStream() throws Exception {
HeadersMode.HTTP_20_HEADERS);
// Now wait to receive 1000 bytes of data so we can have a better error message before
// cancelling the streaam.
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3,
createMessageFrame(new String(new char[1000])), 1000, 1000);
verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL));
assertNull(listener.headers);
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
Expand Down Expand Up @@ -622,7 +628,8 @@ public void receiveResetNoError() throws Exception {
assertContainStream(3);
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
Buffer buffer = createMessageFrame("a message");
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
frameHandler().rstStream(3, ErrorCode.NO_ERROR);
stream.request(1);
Expand Down Expand Up @@ -762,33 +769,37 @@ public void windowUpdate() throws Exception {

int messageLength = INITIAL_WINDOW_SIZE / 4;
byte[] fakeMessage = new byte[messageLength];
int paddingLength = 2;

// Stream 1 receives a message
Buffer buffer = createMessageFrame(fakeMessage);
Buffer buffer = createMessageFrame(fakeMessage, paddingLength);
int messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength);
frameHandler().data(false, 3, buffer, messageFrameLength - paddingLength,
messageFrameLength);

// Stream 2 receives a message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 5, buffer, messageFrameLength);
buffer = createMessageFrame(fakeMessage, paddingLength);
frameHandler().data(false, 5, buffer, messageFrameLength - paddingLength,
messageFrameLength);

verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
reset(frameWriter);

// Stream 1 receives another message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 3, buffer, messageFrameLength);
messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength);

verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(3), eq((long) 2 * messageFrameLength));
.windowUpdate(eq(3), eq((long) 2 * messageFrameLength + paddingLength));

// Stream 2 receives another message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 5, buffer, messageFrameLength);
frameHandler().data(false, 5, buffer, messageFrameLength, messageFrameLength);

verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(5), eq((long) 2 * messageFrameLength));
.windowUpdate(eq(5), eq((long) 2 * messageFrameLength + paddingLength));
verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(0), eq((long) 2 * messageFrameLength));

Expand Down Expand Up @@ -819,7 +830,8 @@ public void windowUpdateWithInboundFlowControl() throws Exception {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
Buffer buffer = createMessageFrame(fakeMessage);
long messageFrameLength = buffer.size();
frameHandler().data(false, 3, buffer, (int) messageFrameLength);
frameHandler().data(false, 3, buffer, (int) messageFrameLength,
(int) messageFrameLength);
ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(
idCaptor.capture(), eq(messageFrameLength));
Expand Down Expand Up @@ -1123,7 +1135,8 @@ public void receiveGoAway() throws Exception {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
final String receivedMessage = "No, you are fine.";
Buffer buffer = createMessageFrame(receivedMessage);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener1.waitUntilStreamClosed();
assertEquals(1, listener1.messages.size());
Expand Down Expand Up @@ -1154,12 +1167,12 @@ public void streamIdExhausted() throws Exception {
assertNotNull(listener.headers);
String message = "hello";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, startId, buffer, (int) buffer.size());
frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size());

getStream(startId).cancel(Status.CANCELLED);
// Receives the second message after be cancelled.
buffer = createMessageFrame(message);
frameHandler().data(false, startId, buffer, (int) buffer.size());
frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size());

listener.waitUntilStreamClosed();
// Should only have the first message delivered.
Expand Down Expand Up @@ -1329,7 +1342,7 @@ public void receivingWindowExceeded() throws Exception {
byte[] fakeMessage = new byte[messageLength];
Buffer buffer = createMessageFrame(fakeMessage);
int messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength);
frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength);

listener.waitUntilStreamClosed();
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
Expand Down Expand Up @@ -1392,7 +1405,8 @@ public void receiveDataWithoutHeader() throws Exception {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

// Trigger the failure by a trailer.
frameHandler().headers(
Expand All @@ -1414,11 +1428,13 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

// Trigger the failure by a data frame.
buffer = createMessageFrame(new byte[1]);
frameHandler().data(true, 3, buffer, (int) buffer.size());
frameHandler().data(true, 3, buffer, (int) buffer.size(),
(int) buffer.size());

listener.waitUntilStreamClosed();
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
Expand All @@ -1436,7 +1452,8 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1000]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

// Once we receive enough detail, we cancel the stream. so we should have sent cancel.
verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL));
Expand All @@ -1459,15 +1476,17 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception

Buffer buffer = createMessageFrame(
new byte[INITIAL_WINDOW_SIZE / 2 + 1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
// Should still update the connection window even stream 3 is gone.
verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(0,
HEADER_LENGTH + INITIAL_WINDOW_SIZE / 2 + 1);
buffer = createMessageFrame(
new byte[INITIAL_WINDOW_SIZE / 2 + 1]);

// This should kill the connection, since we never created stream 5.
frameHandler().data(false, 5, buffer, (int) buffer.size());
frameHandler().data(false, 5, buffer, (int) buffer.size(),
(int) buffer.size());
verify(frameWriter, timeout(TIME_OUT_MS))
.goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown(isA(Status.class));
Expand Down Expand Up @@ -2114,10 +2133,15 @@ private static Buffer createMessageFrame(String message) {
}

private static Buffer createMessageFrame(byte[] message) {
return createMessageFrame(message,0);
}

private static Buffer createMessageFrame(byte[] message, int paddingLength) {
Buffer buffer = new Buffer();
buffer.writeByte(0 /* UNCOMPRESSED */);
buffer.writeInt(message.length);
buffer.write(message);
buffer.write(new byte[paddingLength]);
return buffer;
}

Expand Down
18 changes: 12 additions & 6 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ public void setUp() throws Exception {
Buffer buf = new Buffer();
buf.write(in.getBuffer(), length);
clientDataFrames.data(outDone, streamId, buf);
})).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt());
})).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt(),
anyInt());
}

@After
Expand Down Expand Up @@ -379,7 +380,8 @@ public void basicRpc_succeeds() throws Exception {
Buffer responseMessageFrame = createMessageFrame("Howdy client");
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()));
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()),
eq((int) responseMessageFrame.size()));
verify(clientDataFrames).data(false, 1, responseMessageFrame);

List<Header> responseTrailers = Arrays.asList(
Expand Down Expand Up @@ -440,7 +442,8 @@ public void activeRpc_delaysShutdownTermination() throws Exception {
Buffer responseMessageFrame = createMessageFrame("Howdy client");
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()));
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()),
eq((int) responseMessageFrame.size()));
verify(clientDataFrames).data(false, 1, responseMessageFrame);
pingPong();
assertThat(serverTransport.getActiveStreams().length).isEqualTo(1);
Expand Down Expand Up @@ -975,7 +978,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception {
Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(0, 1));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()),
eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(false, 1, responseDataFrame);

clientFrameWriter.windowUpdate(1, 1000);
Expand All @@ -984,7 +988,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception {
responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(1));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()),
eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(true, 1, responseDataFrame);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
Expand Down Expand Up @@ -1279,7 +1284,8 @@ private void verifyHttpError(
Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(true), eq(streamId), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(true), eq(streamId), any(BufferedSource.class),
eq((int) responseDataFrame.size()), eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(true, streamId, responseDataFrame);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface FrameReader extends Closeable {
boolean nextFrame(Handler handler) throws IOException;

interface Handler {
void data(boolean inFinished, int streamId, BufferedSource source, int length)
void data(boolean inFinished, int streamId, BufferedSource source, int length, int paddedLength)
throws IOException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ private List<Header> readHeaderBlock(int length, short padding, byte flags, int
return hpackReader.getAndResetHeaderList();
}

private void readData(Handler handler, int length, byte flags, int streamId)
private void readData(Handler handler, int paddedLength, byte flags, int streamId)
throws IOException {
// TODO: checkState open or half-closed (local) or raise STREAM_CLOSED
boolean inFinished = (flags & FLAG_END_STREAM) != 0;
Expand All @@ -230,10 +230,10 @@ private void readData(Handler handler, int length, byte flags, int streamId)
}

short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0;
length = lengthWithoutPadding(length, flags, padding);
int length = lengthWithoutPadding(paddedLength, flags, padding);

// FIXME: pass padding length to handler because it should be included for flow control
handler.data(inFinished, streamId, source, length);
handler.data(inFinished, streamId, source, length, paddedLength);
source.skip(padding);
}

Expand Down