Skip to content

Commit

Permalink
fix server
Browse files Browse the repository at this point in the history
  • Loading branch information
YifeiZhuang committed Aug 10, 2023
1 parent a06be73 commit 6933356
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 12 deletions.
6 changes: 4 additions & 2 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,15 @@ public void runOnTransportThread(final Runnable r) {
* Must be called with holding the transport lock.
*/
@Override
public void inboundDataReceived(okio.Buffer frame, int windowConsumed, boolean endOfStream) {
public void inboundDataReceived(okio.Buffer frame, int dataLength, int paddingLength,
boolean endOfStream) {
synchronized (lock) {
PerfMark.event("OkHttpServerTransport$FrameHandler.data", tag);
if (endOfStream) {
this.receivedEndOfStream = true;
}
window -= windowConsumed;
window -= dataLength + paddingLength;
processedWindow -= paddingLength;
super.inboundDataReceived(new OkHttpReadableBuffer(frame), endOfStream);
}
}
Expand Down
28 changes: 18 additions & 10 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
Expand Down Expand Up @@ -139,6 +140,8 @@ final class OkHttpServerTransport implements ServerTransport,
@GuardedBy("lock")
private Long gracefulShutdownPeriod = null;

private FrameHandler handler;

public OkHttpServerTransport(Config config, Socket bareSocket) {
this.config = Preconditions.checkNotNull(config, "config");
this.socket = Preconditions.checkNotNull(bareSocket, "bareSocket");
Expand Down Expand Up @@ -248,8 +251,8 @@ public void data(boolean outFinished, int streamId, Buffer source, int byteCount
TimeUnit.NANOSECONDS);
}

transportExecutor.execute(
new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false)));
handler = new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false));
transportExecutor.execute(handler);
} catch (Error | IOException | RuntimeException ex) {
synchronized (lock) {
if (!handshakeShutdown) {
Expand All @@ -261,6 +264,11 @@ public void data(boolean outFinished, int streamId, Buffer source, int byteCount
}
}

@VisibleForTesting
FrameHandler getHandler() {
return handler;
}

@Override
public void shutdown() {
shutdown(null);
Expand Down Expand Up @@ -708,7 +716,7 @@ public void headers(boolean outFinished,
return;
}
// Ignore the trailers, but still half-close the stream
stream.inboundDataReceived(new Buffer(), 0, true);
stream.inboundDataReceived(new Buffer(), 0, 0, true);
return;
}
} else {
Expand Down Expand Up @@ -799,7 +807,7 @@ public void headers(boolean outFinished,
listener.streamCreated(streamForApp, method, metadata);
stream.onStreamAllocated();
if (inFinished) {
stream.inboundDataReceived(new Buffer(), 0, inFinished);
stream.inboundDataReceived(new Buffer(), 0, 0, inFinished);
}
}
}
Expand Down Expand Up @@ -854,15 +862,15 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length
"Received DATA for half-closed (remote) stream. RFC7540 section 5.1");
return;
}
if (stream.inboundWindowAvailable() < length) {
if (stream.inboundWindowAvailable() < paddedLength) {
in.skip(length);
streamError(streamId, ErrorCode.FLOW_CONTROL_ERROR,
"Received DATA size exceeded window size. RFC7540 section 6.9");
return;
}
Buffer buf = new Buffer();
buf.write(in.getBuffer(), length);
stream.inboundDataReceived(buf, length, inFinished);
stream.inboundDataReceived(buf, length, paddedLength - length, inFinished);
}

// connection window update
Expand Down Expand Up @@ -1065,7 +1073,7 @@ private void respondWithHttpError(
}
streams.put(streamId, stream);
if (inFinished) {
stream.inboundDataReceived(new Buffer(), 0, true);
stream.inboundDataReceived(new Buffer(), 0, 0,true);

Check warning on line 1076 in okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java

View check run for this annotation

Codecov / codecov/patch

okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java#L1076

Added line #L1076 was not covered by tests
}
frameWriter.headers(streamId, headers);
outboundFlow.data(
Expand Down Expand Up @@ -1123,7 +1131,7 @@ public void onPingTimeout() {

interface StreamState {
/** Must be holding 'lock' when calling. */
void inboundDataReceived(Buffer frame, int windowConsumed, boolean endOfStream);
void inboundDataReceived(Buffer frame, int dataLength, int paddingLength, boolean endOfStream);

/** Must be holding 'lock' when calling. */
boolean hasReceivedEndOfStream();
Expand Down Expand Up @@ -1160,12 +1168,12 @@ static class Http2ErrorStreamState implements StreamState, OutboundFlowControlle
@Override public void onSentBytes(int frameBytes) {}

@Override public void inboundDataReceived(
Buffer frame, int windowConsumed, boolean endOfStream) {
Buffer frame, int dataLength, int paddingLength, boolean endOfStream) {
synchronized (lock) {
if (endOfStream) {
receivedEndOfStream = true;
}
window -= windowConsumed;
window -= dataLength + paddingLength;

Check warning on line 1176 in okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java

View check run for this annotation

Codecov / codecov/patch

okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java#L1176

Added line #L1176 was not covered by tests
try {
frame.skip(frame.size()); // Recycle segments
} catch (IOException ex) {
Expand Down
64 changes: 64 additions & 0 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;

Expand Down Expand Up @@ -60,6 +61,7 @@
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
Expand Down Expand Up @@ -998,6 +1000,63 @@ public void httpErrorsAdhereToFlowControl() throws Exception {
shutdownAndTerminate(/*lastStreamId=*/ 1);
}

@Test
public void windowUpdate() throws Exception {
initTransport();
handshake();
OkHttpServerTransport.FrameHandler handler = serverTransport.getHandler();
List<Header> headers = Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything"));

handler.headers(false, false, 1, 0, new ArrayList<>(headers), HeadersMode.HTTP_20_HEADERS);
MockStreamListener streamListener1 = mockTransportListener.newStreams.pop();
handler.headers(false, false, 3, 0, new ArrayList<>(headers), HeadersMode.HTTP_20_HEADERS);
MockStreamListener streamListener2 = mockTransportListener.newStreams.pop();
reset(clientFramesRead);

int messageSize = INITIAL_WINDOW_SIZE / 4 ;
int paddingLength = 10;
Buffer requestMessageFrame = createMessageFrame(new String(new char[messageSize]),
paddingLength);
int frameSize = (int) requestMessageFrame.size();

handler.data(false, 1, requestMessageFrame, frameSize - paddingLength, frameSize);

requestMessageFrame = createMessageFrame(new String(new char[messageSize]), paddingLength);
handler.data(false, 3, requestMessageFrame, frameSize - paddingLength, frameSize);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).windowUpdate(0, frameSize * 2);

requestMessageFrame = createMessageFrame(new String(new char[messageSize]), 0);
handler.data(false, 3, requestMessageFrame, frameSize - paddingLength,
frameSize - paddingLength);

streamListener1.stream.request(1);
streamListener2.stream.request(2);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).windowUpdate(3, frameSize * 2 - paddingLength);


paddingLength = 2 * messageSize + 100;
requestMessageFrame = createMessageFrame(new String(new char[messageSize]), paddingLength);
frameSize = (int) requestMessageFrame.size();
handler.data(false, 1, requestMessageFrame, frameSize - paddingLength, frameSize);

assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).rstStream(eq(1), eq(ErrorCode.FLOW_CONTROL_ERROR));

handler.rstStream(3, ErrorCode.CANCEL);
shutdownAndTerminate(3);
}

@Test
public void dataForStream0_failsWithGoAway() throws Exception {
initTransport();
Expand Down Expand Up @@ -1220,11 +1279,16 @@ private void handshake(Settings settings) throws Exception {
}

private static Buffer createMessageFrame(String stringMessage) {
return createMessageFrame(stringMessage, 0);
}

private static Buffer createMessageFrame(String stringMessage, int paddingLength) {
byte[] message = stringMessage.getBytes(UTF_8);
Buffer buffer = new Buffer();
buffer.writeByte(0 /* UNCOMPRESSED */);
buffer.writeInt(message.length);
buffer.write(message);
buffer.write(new byte[paddingLength]);
return buffer;
}

Expand Down

0 comments on commit 6933356

Please sign in to comment.