Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 59 additions & 34 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.InternalMetadata;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
Expand All @@ -54,6 +56,7 @@
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController;
import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController;
import io.netty.handler.codec.http2.Http2Connection;
Expand Down Expand Up @@ -375,15 +378,49 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
throws Http2Exception {
if (!teWarningLogged && !TE_TRAILERS.equals(headers.get(TE_HEADER))) {
logger.warning(String.format("Expected header TE: %s, but %s is received. This means "
+ "some intermediate proxy may not support trailers",
TE_TRAILERS, headers.get(TE_HEADER)));
+ "some intermediate proxy may not support trailers",
TE_TRAILERS, headers.get(TE_HEADER)));
teWarningLogged = true;
}

try {

// Remove the leading slash of the path and get the fully qualified method name
CharSequence path = headers.path();

if (path == null) {
respondWithHttpError(ctx, streamId, 404, Status.Code.UNIMPLEMENTED,
"Expected path but is missing");
return;
}

if (path.charAt(0) != '/') {
respondWithHttpError(ctx, streamId, 404, Status.Code.UNIMPLEMENTED,
String.format("Expected path to start with /: %s", path));
return;
}

String method = path.subSequence(1, path.length()).toString();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: path.substring()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path is a CharSequence, so it doesn't have substring.


// Verify that the Content-Type is correct in the request.
verifyContentType(streamId, headers);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. I didn't delete these methods. They're now unused: determineMethod, verifyContentType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and deleted them, since it was trivial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ejona86 great..thanks..I should have done it my self. Did not pay attention. Sorry :-)

String method = determineMethod(streamId, headers);
CharSequence contentType = headers.get(CONTENT_TYPE_HEADER);
if (contentType == null) {
respondWithHttpError(
ctx, streamId, 415, Status.Code.INTERNAL, "Content-Type is missing from the request");
return;
}
String contentTypeString = contentType.toString();
if (!GrpcUtil.isGrpcContentType(contentTypeString)) {
respondWithHttpError(ctx, streamId, 415, Status.Code.INTERNAL,
String.format("Content-Type '%s' is not supported", contentTypeString));
return;
}

if (!HTTP_METHOD.equals(headers.method())) {
respondWithHttpError(ctx, streamId, 405, Status.Code.INTERNAL,
String.format("Method '%s' is not supported", headers.method()));
return;
}

// The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this
// method.
Expand All @@ -400,7 +437,7 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
maxMessageSize,
statsTraceCtx,
transportTracer);
String authority = getOrUpdateAuthority((AsciiString)headers.authority());
String authority = getOrUpdateAuthority((AsciiString) headers.authority());
NettyServerStream stream = new NettyServerStream(
ctx.channel(),
state,
Expand All @@ -411,10 +448,7 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
transportListener.streamCreated(stream, method, metadata);
state.onStreamAllocated();
http2Stream.setProperty(streamKey, state);

} catch (Http2Exception e) {
throw e;
} catch (Throwable e) {
} catch (Exception e) {
logger.log(Level.WARNING, "Exception in onHeadersRead()", e);
// Throw an exception that will get handled by onStreamError.
throw newStreamException(streamId, e);
Expand Down Expand Up @@ -634,17 +668,22 @@ public boolean visit(Http2Stream stream) throws Http2Exception {
});
}

private void verifyContentType(int streamId, Http2Headers headers) throws Http2Exception {
CharSequence contentType = headers.get(CONTENT_TYPE_HEADER);
if (contentType == null) {
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM,
"Content-Type is missing from the request");
}
String contentTypeString = contentType.toString();
if (!GrpcUtil.isGrpcContentType(contentTypeString)) {
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM,
"Content-Type '%s' is not supported", contentTypeString);
}
private void respondWithHttpError(
ChannelHandlerContext ctx, int streamId, int code, Status.Code statusCode, String msg) {
Metadata metadata = new Metadata();
metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus());
metadata.put(InternalStatus.MESSAGE_KEY, msg);
byte[][] serialized = InternalMetadata.serialize(metadata);

Http2Headers headers = new DefaultHttp2Headers(true, serialized.length / 2)
.status("" + code)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");
for (int i = 0; i < serialized.length; i += 2) {
headers.add(new AsciiString(serialized[i], false), new AsciiString(serialized[i + 1], false));
}
encoder().writeHeaders(ctx, streamId, headers, 0, false, ctx.newPromise());
ByteBuf msgBuf = ByteBufUtil.writeUtf8(ctx.alloc(), msg);
encoder().writeData(ctx, streamId, msgBuf, 0, true, ctx.newPromise());
}

private Http2Stream requireHttp2Stream(int streamId) {
Expand All @@ -656,20 +695,6 @@ private Http2Stream requireHttp2Stream(int streamId) {
return stream;
}

private String determineMethod(int streamId, Http2Headers headers) throws Http2Exception {
if (!HTTP_METHOD.equals(headers.method())) {
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM,
"Method '%s' is not supported", headers.method());
}
// Remove the leading slash of the path and get the fully qualified method name
CharSequence path = headers.path();
if (path.charAt(0) != '/') {
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM,
"Malformatted path: %s", path);
}
return path.subSequence(1, path.length()).toString();
}

/**
* Returns the server stream associated to the given HTTP/2 stream object.
*/
Expand Down
79 changes: 73 additions & 6 deletions netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static io.grpc.netty.Utils.TE_HEADER;
import static io.grpc.netty.Utils.TE_TRAILERS;
import static io.netty.buffer.Unpooled.directBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand All @@ -53,6 +54,7 @@
import com.google.common.io.ByteStreams;
import com.google.common.truth.Truth;
import io.grpc.Attributes;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
Expand Down Expand Up @@ -110,6 +112,9 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand

private static final int STREAM_ID = 3;

private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");


@Mock
private ServerStreamListener streamListener;

Expand Down Expand Up @@ -406,14 +411,76 @@ public void cancelShouldSendRstStream() throws Exception {
public void headersWithInvalidContentTypeShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, new AsciiString("application/bad", UTF_8))
.set(TE_HEADER, TE_TRAILERS)
.path(new AsciiString("/foo/bar"));
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, new AsciiString("application/bad", UTF_8))
.set(TE_HEADER, TE_TRAILERS)
.path(new AsciiString("/foo/bar"));
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.REFUSED_STREAM.code()),
any(ChannelPromise.class));
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Content-Type 'application/bad' is not supported")
.status("" + 415)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}

@Test
public void headersWithInvalidMethodShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_FAKE_METHOD)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.path(new AsciiString("/foo/bar"));
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Method 'FAKE' is not supported")
.status("" + 405)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}

@Test
public void headersWithMissingPathShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC);
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.UNIMPLEMENTED.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Expected path but is missing")
.status("" + 404)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}

@Test
public void headersWithInvalidPathShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.path(new AsciiString("foo/bar"));
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.UNIMPLEMENTED.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Expected path to start with /: foo/bar")
.status("" + 404)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}

@Test
Expand Down