Skip to content

Commit

Permalink
Add support for alternative protobuf content types (#4364)
Browse files Browse the repository at this point in the history
Motivation:
Armeria `UnframedGrpcService` doesn't support alternative protobuf content types

Modifications:
- add `application/x-protobuf` and `application/x-google-protobuf` to MediaType
- add isProtobuf() function
- identify Protobuf contentType in UnframedGrpcService using isProtobuf()
- respond with request-given protobuf content type for deframed response

Result:

- Closes #4355
- Armeria `UnframedGrpcService` now supports `application/x-protobuf` and `application/x-google-protobuf` media types.
  • Loading branch information
mscheong01 committed Sep 5, 2022
1 parent 70d410d commit e7da834
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 20 deletions.
16 changes: 16 additions & 0 deletions core/src/main/java/com/linecorp/armeria/common/MediaType.java
Expand Up @@ -662,6 +662,8 @@ private static MediaType addKnownType(MediaType mediaType) {
* <a href="https://developers.google.com/protocol-buffers">Protocol buffers</a>.
*/
public static final MediaType PROTOBUF = createConstant(APPLICATION_TYPE, "protobuf");
public static final MediaType X_PROTOBUF = createConstant(APPLICATION_TYPE, "x-protobuf");
public static final MediaType X_GOOGLE_PROTOBUF = createConstant(APPLICATION_TYPE, "x-google-protobuf");

/**
* <a href="https://en.wikipedia.org/wiki/RDF/XML">RDF/XML</a> documents, which are XML
Expand Down Expand Up @@ -1052,6 +1054,20 @@ public boolean isJson() {
return is(JSON) || subtype().endsWith("+json");
}

/**
* Returns {@code true} when the subtype is one of {@link MediaType#PROTOBUF}, {@link MediaType#X_PROTOBUF}
* and {@link MediaType#X_GOOGLE_PROTOBUF}. Otherwise {@code false}.
*
* <pre>{@code
* PROTOBUF.isProtobuf() // true
* X_PROTOBUF.isProtobuf() // true
* X_GOOGLE_PROTOBUF.isProtobuf() // true
* }</pre>
*/
public boolean isProtobuf() {
return is(PROTOBUF) || is(X_PROTOBUF)|| is(X_GOOGLE_PROTOBUF);
}

/**
* Returns {@code true} if this {@link MediaType} belongs to the given {@link MediaType}.
* Similar to what {@link MediaType#is(MediaType)} does except that this one compares the parameters
Expand Down
Expand Up @@ -505,6 +505,14 @@ public final class MediaTypeNames {
* {@value #PROTOBUF}.
*/
public static final String PROTOBUF = "application/protobuf";
/**
* {@value #X_PROTOBUF}.
*/
public static final String X_PROTOBUF = "application/x-protobuf";
/**
* {@value #X_GOOGLE_PROTOBUF}.
*/
public static final String X_GOOGLE_PROTOBUF = "application/x-google-protobuf";
/**
* {@value #RDF_XML_UTF_8}.
*/
Expand Down
Expand Up @@ -138,7 +138,8 @@ protected void frameAndServe(
RequestHeaders grpcHeaders,
HttpData content,
CompletableFuture<HttpResponse> res,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<HttpData, HttpData> responseBodyConverter,
MediaType responseContentType) {
final HttpRequest grpcRequest;
try (ArmeriaMessageFramer framer = new ArmeriaMessageFramer(
ctx.alloc(), ArmeriaMessageFramer.NO_MAX_OUTBOUND_MESSAGE_SIZE, false)) {
Expand Down Expand Up @@ -170,7 +171,7 @@ protected void frameAndServe(
res.completeExceptionally(t);
} else {
deframeAndRespond(ctx, framedResponse, res, unframedGrpcErrorHandler,
responseBodyConverter);
responseBodyConverter, responseContentType);
}
}
return null;
Expand All @@ -182,7 +183,8 @@ static void deframeAndRespond(ServiceRequestContext ctx,
AggregatedHttpResponse grpcResponse,
CompletableFuture<HttpResponse> res,
UnframedGrpcErrorHandler unframedGrpcErrorHandler,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<HttpData, HttpData> responseBodyConverter,
MediaType responseContentType) {
final HttpHeaders trailers = !grpcResponse.trailers().isEmpty() ?
grpcResponse.trailers() : grpcResponse.headers();
final String grpcStatusCode = trailers.get(GrpcHeaderNames.GRPC_STATUS);
Expand Down Expand Up @@ -210,15 +212,15 @@ static void deframeAndRespond(ServiceRequestContext ctx,
}

final MediaType grpcMediaType = grpcResponse.contentType();
if (grpcMediaType == null) {
PooledObjects.close(grpcResponse.content());
res.completeExceptionally(new NullPointerException("MediaType is undefined"));
return;
}

final ResponseHeadersBuilder unframedHeaders = grpcResponse.headers().toBuilder();
unframedHeaders.set(GrpcHeaderNames.GRPC_STATUS, grpcStatusCode); // grpcStatusCode is 0 which is OK.
if (grpcMediaType != null) {
if (grpcMediaType.is(GrpcSerializationFormats.PROTO.mediaType())) {
unframedHeaders.contentType(MediaType.PROTOBUF);
} else if (grpcMediaType.is(GrpcSerializationFormats.JSON.mediaType())) {
unframedHeaders.contentType(MediaType.JSON_UTF_8);
}
}
unframedHeaders.contentType(responseContentType);

final ArmeriaMessageDeframer deframer = new ArmeriaMessageDeframer(
// Max outbound message size is handled by the GrpcService, so we don't need to set it here.
Expand Down
Expand Up @@ -557,8 +557,9 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
"gRPC encoding is not supported for non-framed requests.");
}

final MediaType jsonContentType = GrpcSerializationFormats.JSON.mediaType();
grpcHeaders.method(HttpMethod.POST)
.contentType(GrpcSerializationFormats.JSON.mediaType());
.contentType(jsonContentType);
// All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just
// clear the header if it's present.
grpcHeaders.remove(GrpcHeaderNames.GRPC_ACCEPT_ENCODING);
Expand All @@ -576,7 +577,7 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
ctx.setAttr(FramedGrpcService.RESOLVED_GRPC_METHOD, spec.method);
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
convertToJson(ctx, clientRequest, spec),
responseFuture, generateResponseBodyConverter(spec));
responseFuture, generateResponseBodyConverter(spec), jsonContentType);
} catch (IllegalArgumentException iae) {
responseFuture.completeExceptionally(
HttpStatusException.of(HttpStatus.BAD_REQUEST, iae));
Expand Down
Expand Up @@ -119,14 +119,16 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
final RequestHeadersBuilder grpcHeaders = clientHeaders.toBuilder();

final MediaType framedContentType;
if (contentType.is(MediaType.PROTOBUF)) {
if (contentType.isProtobuf()) {
framedContentType = GrpcSerializationFormats.PROTO.mediaType();
} else if (contentType.is(MediaType.JSON)) {
framedContentType = GrpcSerializationFormats.JSON.mediaType();
} else {
return HttpResponse.of(HttpStatus.UNSUPPORTED_MEDIA_TYPE,
MediaType.PLAIN_TEXT_UTF_8,
"Unsupported media type. Only application/protobuf is supported.");
"Unsupported media type. Only application/protobuf, " +
"application/x-protobuf, application/x-google-protobuf" +
"and application/json are supported.");
}
grpcHeaders.contentType(framedContentType);

Expand All @@ -149,8 +151,8 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
if (t != null) {
responseFuture.completeExceptionally(t);
} else {
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
clientRequest.content(), responseFuture, null);
frameAndServe(unwrap(), ctx, grpcHeaders.build(), clientRequest.content(),
responseFuture, null, contentType);
}
}
return null;
Expand Down
@@ -0,0 +1,109 @@
/*
* Copyright 2022 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.server.grpc;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.grpc.testing.TestServiceGrpc;
import com.linecorp.armeria.protobuf.EmptyProtos;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.common.EventLoopExtension;

import io.grpc.BindableService;
import io.grpc.stub.StreamObserver;

public class UnframedGrpcServiceResponseMediaTypeTest {

@RegisterExtension
static EventLoopExtension eventLoop = new EventLoopExtension();

private static class TestService extends TestServiceGrpc.TestServiceImplBase {

@Override
public void emptyCall(EmptyProtos.Empty request, StreamObserver<EmptyProtos.Empty> responseObserver) {
responseObserver.onNext(EmptyProtos.Empty.newBuilder().build());
responseObserver.onCompleted();
}
}

private static final TestService testService = new TestService();
private static final int MAX_MESSAGE_BYTES = 1024;

@Test
void respondWithCorrespondingJsonMediaType() throws Exception {
final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);

final HttpRequest request = HttpRequest.of(HttpMethod.POST,
"/armeria.grpc.testing.TestService/EmptyCall",
MediaType.JSON_UTF_8, "{}");
final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
.build();

final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
assertThat(res.status()).isEqualTo(HttpStatus.OK);
assertThat(res.contentType()).isEqualTo(MediaType.JSON_UTF_8);
}

@ParameterizedTest
@ArgumentsSource(ProtobufMediaTypeProvider.class)
void respondWithCorrespondingProtobufMediaType(MediaType protobufType) throws Exception {
final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);

final HttpRequest request = HttpRequest.of(HttpMethod.POST,
"/armeria.grpc.testing.TestService/EmptyCall",
protobufType,
EmptyProtos.Empty.getDefaultInstance().toByteArray());
final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
.build();

final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
assertThat(res.status()).isEqualTo(HttpStatus.OK);
assertThat(res.contentType()).isEqualTo(protobufType);
}

private static class ProtobufMediaTypeProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(MediaType.PROTOBUF, MediaType.X_PROTOBUF, MediaType.X_GOOGLE_PROTOBUF)
.map(Arguments::of);
}
}

private static UnframedGrpcService buildUnframedGrpcService(BindableService bindableService) {
return (UnframedGrpcService) GrpcService.builder()
.addService(bindableService)
.maxRequestMessageLength(MAX_MESSAGE_BYTES)
.maxResponseMessageLength(MAX_MESSAGE_BYTES)
.enableUnframedRequests(true)
.build();
}
}
Expand Up @@ -122,24 +122,58 @@ void shouldClosePooledObjectsForNonOK() {
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "1")
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void shouldClosePooledObjectsForMissingMediaType() {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "0")
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
.of(responseHeaders, HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void shouldClosePooledObjectsForMissingGrpcStatus() {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.of(HttpStatus.OK);
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void succeedWithAllRequiredHeaders() throws Exception {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "0")
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
.of(responseHeaders, HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(HttpResponse.from(res).aggregate().get().status()).isEqualTo(HttpStatus.OK);
}

@Test
void unframedGrpcStatusFunction() throws Exception {
final TestService spyTestService = spy(testService);
Expand Down

0 comments on commit e7da834

Please sign in to comment.