Skip to content

Commit

Permalink
fix okhttp padding length
Browse files Browse the repository at this point in the history
  • Loading branch information
YifeiZhuang committed Aug 8, 2023
1 parent 40bff67 commit 3f54bfd
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,8 @@ public void run() {
*/
@SuppressWarnings("GuardedBy")
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
public void data(boolean inFinished, int streamId, BufferedSource in, int length,
int paddedLength)
throws IOException {
logger.logData(OkHttpFrameLogger.Direction.INBOUND,
streamId, in.getBuffer(), length, inFinished);
Expand Down Expand Up @@ -1171,7 +1172,7 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length
}

// connection window update
connectionUnacknowledgedBytesRead += length;
connectionUnacknowledgedBytesRead += paddedLength;
if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) {
synchronized (lock) {
frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,8 @@ private int headerBlockSize(List<Header> headerBlock) {
* Handle an HTTP2 DATA frame.
*/
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
public void data(boolean inFinished, int streamId, BufferedSource in, int length,
int paddedLength)
throws IOException {
frameLogger.logData(
OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished);
Expand Down Expand Up @@ -865,7 +866,7 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length
}

// connection window update
connectionUnacknowledgedBytesRead += length;
connectionUnacknowledgedBytesRead += paddedLength;
if (connectionUnacknowledgedBytesRead
>= config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO) {
synchronized (lock) {
Expand Down
72 changes: 48 additions & 24 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,16 @@ public void close() throws SecurityException {

final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
assertThat(logs).hasSize(1);
log = logs.remove(0);
assertThat(log.getMessage()).startsWith(Direction.INBOUND + " DATA: streamId=" + 3);
assertThat(log.getLevel()).isEqualTo(Level.FINE);

// At most 64 bytes of data frame will be logged.
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])),
1000, 1000);
assertThat(logs).hasSize(1);
log = logs.remove(0);
String data = log.getMessage();
Expand Down Expand Up @@ -377,7 +379,8 @@ public void maxMessageSizeShouldBeEnforced() throws Exception {
// Receive the message.
final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

listener.waitUntilStreamClosed();
assertEquals(Code.RESOURCE_EXHAUSTED, listener.status.getCode());
Expand Down Expand Up @@ -500,7 +503,8 @@ public void readMessages() throws Exception {
assertNotNull(listener.headers);
for (int i = 0; i < numMessages; i++) {
Buffer buffer = createMessageFrame(message + i);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
}
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener.waitUntilStreamClosed();
Expand Down Expand Up @@ -529,7 +533,8 @@ public void receivedHeadersForInvalidStreamShouldKillConnection() throws Excepti
@Test
public void receivedDataForInvalidStreamShouldKillConnection() throws Exception {
initTransport();
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])),
1000, 1000);
verify(frameWriter, timeout(TIME_OUT_MS))
.goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown(isA(Status.class));
Expand All @@ -551,7 +556,8 @@ public void invalidInboundHeadersCancelStream() throws Exception {
HeadersMode.HTTP_20_HEADERS);
// Now wait to receive 1000 bytes of data so we can have a better error message before
// cancelling the streaam.
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3,
createMessageFrame(new String(new char[1000])), 1000, 1000);
verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL));
assertNull(listener.headers);
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
Expand Down Expand Up @@ -622,7 +628,8 @@ public void receiveResetNoError() throws Exception {
assertContainStream(3);
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
Buffer buffer = createMessageFrame("a message");
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
frameHandler().rstStream(3, ErrorCode.NO_ERROR);
stream.request(1);
Expand Down Expand Up @@ -762,30 +769,34 @@ public void windowUpdate() throws Exception {

int messageLength = INITIAL_WINDOW_SIZE / 4;
byte[] fakeMessage = new byte[messageLength];
int paddingLength = 2;

// Stream 1 receives a message
Buffer buffer = createMessageFrame(fakeMessage);
Buffer buffer = createMessageFrame(fakeMessage, paddingLength);
int messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength);
frameHandler().data(false, 3, buffer, messageFrameLength - paddingLength,
messageFrameLength);

// Stream 2 receives a message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 5, buffer, messageFrameLength);
buffer = createMessageFrame(fakeMessage, paddingLength);
frameHandler().data(false, 5, buffer, messageFrameLength - paddingLength,
messageFrameLength);

verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
reset(frameWriter);

// Stream 1 receives another message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 3, buffer, messageFrameLength);
messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength);

verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(3), eq((long) 2 * messageFrameLength));

// Stream 2 receives another message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 5, buffer, messageFrameLength);
frameHandler().data(false, 5, buffer, messageFrameLength, messageFrameLength);

verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(5), eq((long) 2 * messageFrameLength));
Expand Down Expand Up @@ -819,7 +830,8 @@ public void windowUpdateWithInboundFlowControl() throws Exception {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
Buffer buffer = createMessageFrame(fakeMessage);
long messageFrameLength = buffer.size();
frameHandler().data(false, 3, buffer, (int) messageFrameLength);
frameHandler().data(false, 3, buffer, (int) messageFrameLength,
(int) messageFrameLength);
ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(
idCaptor.capture(), eq(messageFrameLength));
Expand Down Expand Up @@ -1123,7 +1135,8 @@ public void receiveGoAway() throws Exception {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
final String receivedMessage = "No, you are fine.";
Buffer buffer = createMessageFrame(receivedMessage);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener1.waitUntilStreamClosed();
assertEquals(1, listener1.messages.size());
Expand Down Expand Up @@ -1154,12 +1167,12 @@ public void streamIdExhausted() throws Exception {
assertNotNull(listener.headers);
String message = "hello";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, startId, buffer, (int) buffer.size());
frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size());

getStream(startId).cancel(Status.CANCELLED);
// Receives the second message after be cancelled.
buffer = createMessageFrame(message);
frameHandler().data(false, startId, buffer, (int) buffer.size());
frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size());

listener.waitUntilStreamClosed();
// Should only have the first message delivered.
Expand Down Expand Up @@ -1329,7 +1342,7 @@ public void receivingWindowExceeded() throws Exception {
byte[] fakeMessage = new byte[messageLength];
Buffer buffer = createMessageFrame(fakeMessage);
int messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength);
frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength);

listener.waitUntilStreamClosed();
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
Expand Down Expand Up @@ -1392,7 +1405,8 @@ public void receiveDataWithoutHeader() throws Exception {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

// Trigger the failure by a trailer.
frameHandler().headers(
Expand All @@ -1414,11 +1428,13 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

// Trigger the failure by a data frame.
buffer = createMessageFrame(new byte[1]);
frameHandler().data(true, 3, buffer, (int) buffer.size());
frameHandler().data(true, 3, buffer, (int) buffer.size(),
(int) buffer.size());

listener.waitUntilStreamClosed();
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
Expand All @@ -1436,7 +1452,8 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1000]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());

// Once we receive enough detail, we cancel the stream. so we should have sent cancel.
verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL));
Expand All @@ -1459,15 +1476,17 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception

Buffer buffer = createMessageFrame(
new byte[INITIAL_WINDOW_SIZE / 2 + 1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
// Should still update the connection window even stream 3 is gone.
verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(0,
HEADER_LENGTH + INITIAL_WINDOW_SIZE / 2 + 1);
buffer = createMessageFrame(
new byte[INITIAL_WINDOW_SIZE / 2 + 1]);

// This should kill the connection, since we never created stream 5.
frameHandler().data(false, 5, buffer, (int) buffer.size());
frameHandler().data(false, 5, buffer, (int) buffer.size(),
(int) buffer.size());
verify(frameWriter, timeout(TIME_OUT_MS))
.goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown(isA(Status.class));
Expand Down Expand Up @@ -2114,10 +2133,15 @@ private static Buffer createMessageFrame(String message) {
}

private static Buffer createMessageFrame(byte[] message) {
return createMessageFrame(message,0);
}

private static Buffer createMessageFrame(byte[] message, int paddingLength) {
Buffer buffer = new Buffer();
buffer.writeByte(0 /* UNCOMPRESSED */);
buffer.writeInt(message.length);
buffer.write(message);
buffer.write(new byte[paddingLength]);
return buffer;
}

Expand Down
18 changes: 12 additions & 6 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ public void setUp() throws Exception {
Buffer buf = new Buffer();
buf.write(in.getBuffer(), length);
clientDataFrames.data(outDone, streamId, buf);
})).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt());
})).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt(),
anyInt());
}

@After
Expand Down Expand Up @@ -379,7 +380,8 @@ public void basicRpc_succeeds() throws Exception {
Buffer responseMessageFrame = createMessageFrame("Howdy client");
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()));
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()),
eq((int) responseMessageFrame.size()));
verify(clientDataFrames).data(false, 1, responseMessageFrame);

List<Header> responseTrailers = Arrays.asList(
Expand Down Expand Up @@ -440,7 +442,8 @@ public void activeRpc_delaysShutdownTermination() throws Exception {
Buffer responseMessageFrame = createMessageFrame("Howdy client");
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()));
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()),
eq((int) responseMessageFrame.size()));
verify(clientDataFrames).data(false, 1, responseMessageFrame);
pingPong();
assertThat(serverTransport.getActiveStreams().length).isEqualTo(1);
Expand Down Expand Up @@ -975,7 +978,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception {
Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(0, 1));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()),
eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(false, 1, responseDataFrame);

clientFrameWriter.windowUpdate(1, 1000);
Expand All @@ -984,7 +988,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception {
responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(1));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()),
eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(true, 1, responseDataFrame);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
Expand Down Expand Up @@ -1279,7 +1284,8 @@ private void verifyHttpError(
Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(true), eq(streamId), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(true), eq(streamId), any(BufferedSource.class),
eq((int) responseDataFrame.size()), eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(true, streamId, responseDataFrame);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface FrameReader extends Closeable {
boolean nextFrame(Handler handler) throws IOException;

interface Handler {
void data(boolean inFinished, int streamId, BufferedSource source, int length)
void data(boolean inFinished, int streamId, BufferedSource source, int length, int paddedLength)
throws IOException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ private List<Header> readHeaderBlock(int length, short padding, byte flags, int
return hpackReader.getAndResetHeaderList();
}

private void readData(Handler handler, int length, byte flags, int streamId)
private void readData(Handler handler, int paddedLength, byte flags, int streamId)
throws IOException {
// TODO: checkState open or half-closed (local) or raise STREAM_CLOSED
boolean inFinished = (flags & FLAG_END_STREAM) != 0;
Expand All @@ -230,10 +230,10 @@ private void readData(Handler handler, int length, byte flags, int streamId)
}

short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0;
length = lengthWithoutPadding(length, flags, padding);
int length = lengthWithoutPadding(paddedLength, flags, padding);

// FIXME: pass padding length to handler because it should be included for flow control
handler.data(inFinished, streamId, source, length);
handler.data(inFinished, streamId, source, length, paddedLength);
source.skip(padding);
}

Expand Down

0 comments on commit 3f54bfd

Please sign in to comment.