Skip to content

Commit 15f02ba

Browse files
author
nmittler
committed
Adding maxMessageSize config option
Fixes #832
1 parent 40c66a1 commit 15f02ba

26 files changed

+295
-105
lines changed

core/src/main/java/io/grpc/internal/AbstractClientStream.java

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,15 @@
3232
package io.grpc.internal;
3333

3434
import static com.google.common.base.Preconditions.checkArgument;
35-
import static io.grpc.Status.Code.CANCELLED;
36-
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
35+
import static io.grpc.internal.GrpcUtil.CANCEL_REASONS;
3736

3837
import com.google.common.base.Objects;
3938
import com.google.common.base.Preconditions;
4039

4140
import io.grpc.Metadata;
4241
import io.grpc.Status;
43-
import io.grpc.Status.Code;
4442

4543
import java.io.InputStream;
46-
import java.util.EnumSet;
47-
import java.util.Set;
4844
import java.util.logging.Level;
4945
import java.util.logging.Logger;
5046

@@ -55,8 +51,6 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
5551
implements ClientStream {
5652

5753
private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName());
58-
private static final Set<Code> CANCEL_REASONS =
59-
EnumSet.of(CANCELLED, DEADLINE_EXCEEDED, Code.INTERNAL, Code.UNKNOWN);
6054

6155
private final ClientStreamListener listener;
6256
private boolean listenerClosed;
@@ -67,15 +61,10 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
6761
private Metadata trailers;
6862
private Runnable closeListenerTask;
6963

70-
71-
/**
72-
* Constructor used by subclasses.
73-
*
74-
* @param listener the listener to receive notifications
75-
*/
7664
protected AbstractClientStream(WritableBufferAllocator bufferAllocator,
77-
ClientStreamListener listener) {
78-
super(bufferAllocator);
65+
ClientStreamListener listener,
66+
int maxMessageSize) {
67+
super(bufferAllocator, maxMessageSize);
7968
this.listener = Preconditions.checkNotNull(listener);
8069
}
8170

core/src/main/java/io/grpc/internal/AbstractServerStream.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
6363
/** Saved trailers from close() that need to be sent once the framer has sent all messages. */
6464
private Metadata stashedTrailers;
6565

66-
protected AbstractServerStream(WritableBufferAllocator bufferAllocator) {
67-
super(bufferAllocator);
66+
protected AbstractServerStream(WritableBufferAllocator bufferAllocator,
67+
int maxMessageSize) {
68+
super(bufferAllocator, maxMessageSize);
6869
}
6970

7071
/**

core/src/main/java/io/grpc/internal/AbstractStream.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ protected enum Phase {
100100

101101
private final Object onReadyLock = new Object();
102102

103-
AbstractStream(WritableBufferAllocator bufferAllocator) {
103+
AbstractStream(WritableBufferAllocator bufferAllocator, int maxMessageSize) {
104104
MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() {
105105
@Override
106106
public void bytesRead(int numBytes) {
@@ -130,7 +130,7 @@ public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flus
130130
};
131131

132132
framer = new MessageFramer(outboundFrameHandler, bufferAllocator);
133-
deframer = new MessageDeframer(inboundMessageHandler);
133+
deframer = new MessageDeframer(inboundMessageHandler, MessageEncoding.NONE, maxMessageSize);
134134
}
135135

136136
@Override

core/src/main/java/io/grpc/internal/GrpcUtil.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,18 @@
3131

3232
package io.grpc.internal;
3333

34+
import static io.grpc.Status.Code.CANCELLED;
35+
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
36+
3437
import com.google.common.annotations.VisibleForTesting;
3538
import com.google.common.base.Preconditions;
3639

3740
import io.grpc.Metadata;
3841
import io.grpc.Status;
3942

4043
import java.net.HttpURLConnection;
44+
import java.util.EnumSet;
45+
import java.util.Set;
4146
import java.util.concurrent.Executors;
4247
import java.util.concurrent.ScheduledExecutorService;
4348
import java.util.concurrent.ThreadFactory;
@@ -109,6 +114,17 @@ public final class GrpcUtil {
109114
*/
110115
public static final String MESSAGE_ENCODING = "grpc-encoding";
111116

117+
/**
118+
* The default maximum uncompressed size (in bytes) for inbound messages. Defaults to 100 MiB.
119+
*/
120+
public static final int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
121+
122+
/**
123+
* The set of valid status codes for client cancellation.
124+
*/
125+
public static final Set<Status.Code> CANCEL_REASONS =
126+
EnumSet.of(CANCELLED, DEADLINE_EXCEEDED, Status.Code.INTERNAL, Status.Code.UNKNOWN);
127+
112128
/**
113129
* Maps HTTP error response status codes to transport codes.
114130
*/

core/src/main/java/io/grpc/internal/Http2ClientStream.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ public Integer parseAsciiString(String serialized) {
7171
private boolean contentTypeChecked;
7272

7373
protected Http2ClientStream(WritableBufferAllocator bufferAllocator,
74-
ClientStreamListener listener) {
75-
super(bufferAllocator, listener);
74+
ClientStreamListener listener,
75+
int maxMessageSize) {
76+
super(bufferAllocator, listener, maxMessageSize);
7677
}
7778

7879
/**

core/src/main/java/io/grpc/internal/MessageDeframer.java

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import io.grpc.Status;
4040

4141
import java.io.Closeable;
42+
import java.io.FilterInputStream;
4243
import java.io.IOException;
4344
import java.io.InputStream;
4445

@@ -94,6 +95,7 @@ private enum State {
9495
}
9596

9697
private final Listener listener;
98+
private final int maxMessageSize;
9799
private MessageEncoding.Decompressor decompressor;
98100
private State state = State.HEADER;
99101
private int requiredLength = HEADER_LENGTH;
@@ -105,25 +107,19 @@ private enum State {
105107
private boolean deliveryStalled = true;
106108
private boolean inDelivery = false;
107109

108-
/**
109-
* Creates a deframer. Compression will not be supported.
110-
*
111-
* @param listener listener for deframer events.
112-
*/
113-
public MessageDeframer(Listener listener) {
114-
this(listener, MessageEncoding.NONE);
115-
}
116-
117110
/**
118111
* Create a deframer.
119112
*
120113
* @param listener listener for deframer events.
121114
* @param decompressor the compression used if a compressed frame is encountered, with
122115
* {@code NONE} meaning unsupported
116+
* @param maxMessageSize the maximum allowed size for received messages.
123117
*/
124-
public MessageDeframer(Listener listener, MessageEncoding.Decompressor decompressor) {
125-
this.listener = Preconditions.checkNotNull(listener, "listener");
118+
public MessageDeframer(Listener listener, MessageEncoding.Decompressor decompressor,
119+
int maxMessageSize) {
120+
this.listener = Preconditions.checkNotNull(listener, "sink");
126121
this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor");
122+
this.maxMessageSize = maxMessageSize;
127123
}
128124

129125
/**
@@ -162,8 +158,7 @@ public void request(int numMessages) {
162158
* the remote endpoint. End of stream should not be used in the event of a transport
163159
* error, such as a stream reset.
164160
* @throws IllegalStateException if {@link #close()} has been called previously or if
165-
* {@link #deframe(ReadableBuffer, boolean)} has previously been called with
166-
* {@code endOfStream=true}.
161+
* this method has previously been called with {@code endOfStream=true}.
167162
*/
168163
public void deframe(ReadableBuffer data, boolean endOfStream) {
169164
Preconditions.checkNotNull(data, "data");
@@ -291,10 +286,6 @@ private void deliver() {
291286
}
292287
}
293288

294-
private boolean isDataAvailable() {
295-
return unprocessed.readableBytes() > 0 || (nextFrame != null && nextFrame.readableBytes() > 0);
296-
}
297-
298289
/**
299290
* Attempts to read the required bytes into nextFrame.
300291
*
@@ -340,6 +331,10 @@ private void processHeader() {
340331

341332
// Update the required length to include the length of the frame.
342333
requiredLength = nextFrame.readInt();
334+
if (requiredLength < 0 || requiredLength > maxMessageSize) {
335+
throw Status.INTERNAL.withDescription(String.format("Frame size %d exceeds maximum: %d, ",
336+
requiredLength, maxMessageSize)).asRuntimeException();
337+
}
343338

344339
// Continue reading the frame body.
345340
state = State.BODY;
@@ -370,9 +365,79 @@ private InputStream getCompressedBody() {
370365
}
371366

372367
try {
373-
return decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
368+
// Enforce the maxMessageSize limit on the returned stream.
369+
return new SizeEnforcingInputStream(decompressor.decompress(
370+
ReadableBuffers.openStream(nextFrame, true)));
374371
} catch (IOException e) {
375372
throw new RuntimeException(e);
376373
}
377374
}
375+
376+
/**
377+
* An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
378+
*/
379+
private final class SizeEnforcingInputStream extends FilterInputStream {
380+
private long count;
381+
private long mark = -1;
382+
383+
public SizeEnforcingInputStream(InputStream in) {
384+
super(in);
385+
}
386+
387+
@Override
388+
public int read() throws IOException {
389+
int result = in.read();
390+
if (result != -1) {
391+
count++;
392+
}
393+
verifySize();
394+
return result;
395+
}
396+
397+
@Override
398+
public int read(byte[] b, int off, int len) throws IOException {
399+
int result = in.read(b, off, len);
400+
if (result != -1) {
401+
count += result;
402+
}
403+
verifySize();
404+
return result;
405+
}
406+
407+
@Override
408+
public long skip(long n) throws IOException {
409+
long result = in.skip(n);
410+
count += result;
411+
verifySize();
412+
return result;
413+
}
414+
415+
@Override
416+
public synchronized void mark(int readlimit) {
417+
in.mark(readlimit);
418+
mark = count;
419+
// it's okay to mark even if mark isn't supported, as reset won't work
420+
}
421+
422+
@Override
423+
public synchronized void reset() throws IOException {
424+
if (!in.markSupported()) {
425+
throw new IOException("Mark not supported");
426+
}
427+
if (mark == -1) {
428+
throw new IOException("Mark not set");
429+
}
430+
431+
in.reset();
432+
count = mark;
433+
}
434+
435+
private void verifySize() {
436+
if (count > maxMessageSize) {
437+
throw Status.INTERNAL.withDescription(String.format(
438+
"Compressed frame exceeds maximum frame size: %d. Bytes read: %d",
439+
maxMessageSize, count)).asRuntimeException();
440+
}
441+
}
442+
}
378443
}

core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
package io.grpc.internal;
3333

34+
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
3435
import static org.junit.Assert.assertEquals;
3536
import static org.junit.Assert.fail;
3637
import static org.mockito.Matchers.isA;
@@ -242,7 +243,7 @@ public void inboundHeadersReceived_notifiesListenerOnBadEncoding() {
242243
private static class BaseAbstractClientStream<T> extends AbstractClientStream<T> {
243244
protected BaseAbstractClientStream(
244245
WritableBufferAllocator allocator, ClientStreamListener listener) {
245-
super(allocator, listener);
246+
super(allocator, listener, DEFAULT_MAX_MESSAGE_SIZE);
246247
}
247248

248249
@Override

core/src/test/java/io/grpc/internal/AbstractStreamTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
package io.grpc.internal;
3434

35+
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
3536
import static org.junit.Assert.fail;
3637
import static org.mockito.Mockito.verify;
3738

@@ -102,7 +103,7 @@ public void validPhaseTransitions() {
102103
*/
103104
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
104105
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
105-
super(bufferAllocator);
106+
super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE);
106107
}
107108

108109
@Override

core/src/test/java/io/grpc/internal/MessageDeframerTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
package io.grpc.internal;
3333

34+
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
3435
import static org.junit.Assert.assertEquals;
3536
import static org.junit.Assert.assertTrue;
3637
import static org.mockito.Matchers.anyInt;
@@ -67,7 +68,8 @@
6768
@RunWith(JUnit4.class)
6869
public class MessageDeframerTest {
6970
private Listener listener = mock(Listener.class);
70-
private MessageDeframer deframer = new MessageDeframer(listener);
71+
private MessageDeframer deframer = new MessageDeframer(listener, MessageEncoding.NONE,
72+
DEFAULT_MAX_MESSAGE_SIZE);
7173
private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class);
7274

7375
@Test
@@ -176,7 +178,7 @@ public void endOfStreamCallbackShouldWaitForMessageDelivery() {
176178

177179
@Test
178180
public void compressed() {
179-
deframer = new MessageDeframer(listener, new MessageEncoding.Gzip());
181+
deframer = new MessageDeframer(listener, new MessageEncoding.Gzip(), DEFAULT_MAX_MESSAGE_SIZE);
180182
deframer.request(1);
181183

182184
byte[] payload = compress(new byte[1000]);

netty/src/main/java/io/grpc/netty/CancelClientStreamCommand.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,12 @@
3131

3232
package io.grpc.netty;
3333

34-
import static io.grpc.Status.Code.CANCELLED;
35-
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
34+
import static io.grpc.internal.GrpcUtil.CANCEL_REASONS;
3635

3736
import com.google.common.base.Preconditions;
3837

3938
import io.grpc.Status;
4039

41-
import java.util.EnumSet;
42-
4340
/**
4441
* Command sent from a Netty client stream to the handler to cancel the stream.
4542
*/
@@ -50,8 +47,8 @@ class CancelClientStreamCommand {
5047
CancelClientStreamCommand(NettyClientStream stream, Status reason) {
5148
this.stream = Preconditions.checkNotNull(stream, "stream");
5249
Preconditions.checkNotNull(reason);
53-
Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()),
54-
"Invalid cancellation reason");
50+
Preconditions.checkArgument(CANCEL_REASONS.contains(reason.getCode()),
51+
"Invalid cancellation reason");
5552
this.reason = reason;
5653
}
5754

0 commit comments

Comments
 (0)