Skip to content

Commit

Permalink
Tell why a request is bad for a 400 Bad Request response (#1575)
Browse files Browse the repository at this point in the history
Motivation:

It is sometimes hard for a client to know why a request is bad.

Modifications:

- Added short error messages to 400 Bad Request responses
- Used `ProtocolViolationException` instead of `IllegalArgumentException` for bad requests.
- Removed unnecessary validation of HTTP method in `HttpServerHandler`.

Result:

- User friendliness.
  • Loading branch information
trustin committed Feb 11, 2019
1 parent 70f781f commit 82bc639
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 111 deletions.
Expand Up @@ -69,6 +69,17 @@ final class Http1RequestDecoder extends ChannelDuplexHandler {
private static final com.linecorp.armeria.common.HttpHeaders CONTINUE_RESPONSE =
com.linecorp.armeria.common.HttpHeaders.of(HttpStatus.CONTINUE);

private static final HttpData DATA_DECODER_FAILURE =
HttpData.ofUtf8(HttpResponseStatus.BAD_REQUEST + "\nDecoder failure");
private static final HttpData DATA_UNSUPPORTED_METHOD =
HttpData.ofUtf8(HttpResponseStatus.METHOD_NOT_ALLOWED + "\nUnsupported method");
private static final HttpData DATA_INVALID_CONTENT_LENGTH =
HttpData.ofUtf8(HttpResponseStatus.BAD_REQUEST + "\nInvalid content length");
private static final HttpData DATA_INVALID_REQUEST_PATH =
HttpData.ofUtf8(HttpResponseStatus.BAD_REQUEST + "\nInvalid request path");
private static final HttpData DATA_INVALID_DECODER_STATE =
HttpData.ofUtf8(HttpResponseStatus.BAD_REQUEST + "\nInvalid decoder state");

private final ServerConfig cfg;
private final AsciiString scheme;
private final InboundTrafficController inboundTrafficController;
Expand Down Expand Up @@ -116,15 +127,15 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
if (msg instanceof HttpRequest) {
final HttpRequest nettyReq = (HttpRequest) msg;
if (!nettyReq.decoderResult().isSuccess()) {
fail(id, HttpResponseStatus.BAD_REQUEST);
fail(id, HttpResponseStatus.BAD_REQUEST, DATA_DECODER_FAILURE);
return;
}

final HttpHeaders nettyHeaders = nettyReq.headers();

// Validate the method.
if (!HttpMethod.isSupported(nettyReq.method().name())) {
fail(id, HttpResponseStatus.METHOD_NOT_ALLOWED);
fail(id, HttpResponseStatus.METHOD_NOT_ALLOWED, DATA_UNSUPPORTED_METHOD);
return;
}

Expand All @@ -136,11 +147,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
try {
contentLength = Long.parseLong(contentLengthStr);
} catch (NumberFormatException ignored) {
fail(id, HttpResponseStatus.BAD_REQUEST);
fail(id, HttpResponseStatus.BAD_REQUEST, DATA_INVALID_CONTENT_LENGTH);
return;
}
if (contentLength < 0) {
fail(id, HttpResponseStatus.BAD_REQUEST);
fail(id, HttpResponseStatus.BAD_REQUEST, DATA_INVALID_CONTENT_LENGTH);
return;
}

Expand All @@ -151,7 +162,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception

if (!handle100Continue(id, nettyReq, nettyHeaders)) {
ctx.pipeline().fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE);
fail(id, HttpResponseStatus.EXPECTATION_FAILED);
fail(id, HttpResponseStatus.EXPECTATION_FAILED, null);
return;
}

Expand All @@ -173,7 +184,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception

ctx.fireChannelRead(req);
} else {
fail(id, HttpResponseStatus.BAD_REQUEST);
fail(id, HttpResponseStatus.BAD_REQUEST, DATA_INVALID_DECODER_STATE);
return;
}
}
Expand All @@ -182,7 +193,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
final HttpContent content = (HttpContent) msg;
final DecoderResult decoderResult = content.decoderResult();
if (!decoderResult.isSuccess()) {
fail(id, HttpResponseStatus.BAD_REQUEST);
fail(id, HttpResponseStatus.BAD_REQUEST, DATA_DECODER_FAILURE);
req.close(new ProtocolViolationException(decoderResult.cause()));
return;
}
Expand All @@ -193,7 +204,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
req.increaseTransferredBytes(dataLength);
final long maxContentLength = req.maxRequestLength();
if (maxContentLength > 0 && req.transferredBytes() > maxContentLength) {
fail(id, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
fail(id, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, null);
req.close(ContentTooLargeException.get());
return;
}
Expand All @@ -214,12 +225,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
}
} catch (URISyntaxException e) {
fail(id, HttpResponseStatus.BAD_REQUEST);
fail(id, HttpResponseStatus.BAD_REQUEST, DATA_INVALID_REQUEST_PATH);
if (req != null) {
req.close(e);
}
} catch (Throwable t) {
fail(id, HttpResponseStatus.INTERNAL_SERVER_ERROR);
fail(id, HttpResponseStatus.INTERNAL_SERVER_ERROR, null);
if (req != null) {
req.close(t);
} else {
Expand Down Expand Up @@ -255,11 +266,11 @@ private boolean handle100Continue(int id, HttpRequest nettyReq, HttpHeaders nett
return true;
}

private void fail(int id, HttpResponseStatus status) {
private void fail(int id, HttpResponseStatus status, @Nullable HttpData content) {
discarding = true;
req = null;

final HttpData data = HttpData.ofUtf8(status.toString());
final HttpData data = content != null ? content : HttpData.ofUtf8(status.toString());
final com.linecorp.armeria.common.HttpHeaders headers =
com.linecorp.armeria.common.HttpHeaders.of(status.code());
headers.set(HttpHeaderNames.CONNECTION, "close");
Expand Down
Expand Up @@ -23,6 +23,8 @@

import java.nio.charset.StandardCharsets;

import javax.annotation.Nullable;

import com.linecorp.armeria.common.ClosedSessionException;
import com.linecorp.armeria.common.ContentTooLargeException;
import com.linecorp.armeria.common.HttpHeaderNames;
Expand Down Expand Up @@ -54,6 +56,16 @@

final class Http2RequestDecoder extends Http2EventAdapter {

private static final ByteBuf DATA_MISSING_METHOD =
Unpooled.copiedBuffer(HttpResponseStatus.BAD_REQUEST + "\nMissing method",
StandardCharsets.UTF_8).asReadOnly();
private static final ByteBuf DATA_UNSUPPORTED_METHOD =
Unpooled.copiedBuffer(HttpResponseStatus.METHOD_NOT_ALLOWED + "\nUnsupported method",
StandardCharsets.UTF_8).asReadOnly();
private static final ByteBuf DATA_INVALID_CONTENT_LENGTH =
Unpooled.copiedBuffer(HttpResponseStatus.BAD_REQUEST + "\nInvalid content length",
StandardCharsets.UTF_8).asReadOnly();

private final ServerConfig cfg;
private final Channel channel;
private final Http2ConnectionEncoder writer;
Expand Down Expand Up @@ -88,11 +100,12 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
// Validate the method.
final CharSequence method = headers.method();
if (method == null) {
writeErrorResponse(ctx, streamId, HttpResponseStatus.BAD_REQUEST);
writeErrorResponse(ctx, streamId, HttpResponseStatus.BAD_REQUEST, DATA_MISSING_METHOD);
return;
}
if (!HttpMethod.isSupported(method.toString())) {
writeErrorResponse(ctx, streamId, HttpResponseStatus.METHOD_NOT_ALLOWED);
writeErrorResponse(ctx, streamId, HttpResponseStatus.METHOD_NOT_ALLOWED,
DATA_UNSUPPORTED_METHOD);
return;
}

Expand All @@ -101,7 +114,8 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
if (headers.contains(HttpHeaderNames.CONTENT_LENGTH)) {
final long contentLength = headers.getLong(HttpHeaderNames.CONTENT_LENGTH, -1L);
if (contentLength < 0) {
writeErrorResponse(ctx, streamId, HttpResponseStatus.BAD_REQUEST);
writeErrorResponse(ctx, streamId, HttpResponseStatus.BAD_REQUEST,
DATA_INVALID_CONTENT_LENGTH);
return;
}
contentEmpty = contentLength == 0;
Expand All @@ -110,7 +124,7 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
}

if (!handle100Continue(ctx, streamId, headers)) {
writeErrorResponse(ctx, streamId, HttpResponseStatus.EXPECTATION_FAILED);
writeErrorResponse(ctx, streamId, HttpResponseStatus.EXPECTATION_FAILED, null);
return;
}

Expand Down Expand Up @@ -209,7 +223,7 @@ public int onDataRead(
if (maxContentLength > 0 && req.transferredBytes() > maxContentLength) {
final Http2Stream stream = writer.connection().stream(streamId);
if (isWritable(stream)) {
writeErrorResponse(ctx, streamId, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
writeErrorResponse(ctx, streamId, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, null);
writer.writeRstStream(ctx, streamId, Http2Error.CANCEL.code(), ctx.voidPromise());
if (req.isOpen()) {
req.close(ContentTooLargeException.get());
Expand Down Expand Up @@ -245,19 +259,21 @@ private static boolean isWritable(Http2Stream stream) {
}
}

private void writeErrorResponse(ChannelHandlerContext ctx, int streamId,
HttpResponseStatus status) throws Http2Exception {
final byte[] content = status.toString().getBytes(StandardCharsets.UTF_8);
private void writeErrorResponse(ChannelHandlerContext ctx, int streamId, HttpResponseStatus status,
@Nullable ByteBuf content) throws Http2Exception {
final ByteBuf data =
content != null ? content
: Unpooled.wrappedBuffer(status.toString().getBytes(StandardCharsets.UTF_8));

writer.writeHeaders(
ctx, streamId,
new DefaultHttp2Headers(false)
.status(status.codeAsText())
.set(HttpHeaderNames.CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString())
.setInt(HttpHeaderNames.CONTENT_LENGTH, content.length),
.setInt(HttpHeaderNames.CONTENT_LENGTH, data.readableBytes()),
0, false, ctx.voidPromise());

writer.writeData(ctx, streamId, Unpooled.wrappedBuffer(content), 0, true, ctx.voidPromise());
writer.writeData(ctx, streamId, data, 0, true, ctx.voidPromise());

final Http2Stream stream = writer.connection().stream(streamId);
if (stream != null && writer.flowController().hasFlowControlled(stream)) {
Expand Down

0 comments on commit 82bc639

Please sign in to comment.