diff --git a/core/src/main/java/com/linecorp/armeria/common/MediaType.java b/core/src/main/java/com/linecorp/armeria/common/MediaType.java index adc7c3c7816..3ee9e94cd4d 100644 --- a/core/src/main/java/com/linecorp/armeria/common/MediaType.java +++ b/core/src/main/java/com/linecorp/armeria/common/MediaType.java @@ -662,6 +662,8 @@ private static MediaType addKnownType(MediaType mediaType) { * Protocol buffers. */ public static final MediaType PROTOBUF = createConstant(APPLICATION_TYPE, "protobuf"); + public static final MediaType X_PROTOBUF = createConstant(APPLICATION_TYPE, "x-protobuf"); + public static final MediaType X_GOOGLE_PROTOBUF = createConstant(APPLICATION_TYPE, "x-google-protobuf"); /** * RDF/XML documents, which are XML @@ -1052,6 +1054,20 @@ public boolean isJson() { return is(JSON) || subtype().endsWith("+json"); } + /** + * Returns {@code true} when the subtype is one of {@link MediaType#PROTOBUF}, {@link MediaType#X_PROTOBUF} + * and {@link MediaType#X_GOOGLE_PROTOBUF}. Otherwise {@code false}. + * + *
{@code
+     * PROTOBUF.isProtobuf() // true
+     * X_PROTOBUF.isProtobuf() // true
+     * X_GOOGLE_PROTOBUF.isProtobuf() // true
+     * }
+ */ + public boolean isProtobuf() { + return is(PROTOBUF) || is(X_PROTOBUF)|| is(X_GOOGLE_PROTOBUF); + } + /** * Returns {@code true} if this {@link MediaType} belongs to the given {@link MediaType}. * Similar to what {@link MediaType#is(MediaType)} does except that this one compares the parameters diff --git a/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java b/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java index ee8b0983943..3efd086c724 100644 --- a/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java +++ b/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java @@ -505,6 +505,14 @@ public final class MediaTypeNames { * {@value #PROTOBUF}. */ public static final String PROTOBUF = "application/protobuf"; + /** + * {@value #X_PROTOBUF}. + */ + public static final String X_PROTOBUF = "application/x-protobuf"; + /** + * {@value #X_GOOGLE_PROTOBUF}. + */ + public static final String X_GOOGLE_PROTOBUF = "application/x-google-protobuf"; /** * {@value #RDF_XML_UTF_8}. */ diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java index e3c79095bdc..f4655117e33 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java @@ -138,7 +138,8 @@ protected void frameAndServe( RequestHeaders grpcHeaders, HttpData content, CompletableFuture res, - @Nullable Function responseBodyConverter) { + @Nullable Function responseBodyConverter, + MediaType responseContentType) { final HttpRequest grpcRequest; try (ArmeriaMessageFramer framer = new ArmeriaMessageFramer( ctx.alloc(), ArmeriaMessageFramer.NO_MAX_OUTBOUND_MESSAGE_SIZE, false)) { @@ -170,7 +171,7 @@ protected void frameAndServe( res.completeExceptionally(t); } else { deframeAndRespond(ctx, framedResponse, res, unframedGrpcErrorHandler, - responseBodyConverter); + responseBodyConverter, responseContentType); } } return null; @@ -182,7 +183,8 @@ static void deframeAndRespond(ServiceRequestContext ctx, AggregatedHttpResponse grpcResponse, CompletableFuture res, UnframedGrpcErrorHandler unframedGrpcErrorHandler, - @Nullable Function responseBodyConverter) { + @Nullable Function responseBodyConverter, + MediaType responseContentType) { final HttpHeaders trailers = !grpcResponse.trailers().isEmpty() ? grpcResponse.trailers() : grpcResponse.headers(); final String grpcStatusCode = trailers.get(GrpcHeaderNames.GRPC_STATUS); @@ -210,15 +212,15 @@ static void deframeAndRespond(ServiceRequestContext ctx, } final MediaType grpcMediaType = grpcResponse.contentType(); + if (grpcMediaType == null) { + PooledObjects.close(grpcResponse.content()); + res.completeExceptionally(new NullPointerException("MediaType is undefined")); + return; + } + final ResponseHeadersBuilder unframedHeaders = grpcResponse.headers().toBuilder(); unframedHeaders.set(GrpcHeaderNames.GRPC_STATUS, grpcStatusCode); // grpcStatusCode is 0 which is OK. - if (grpcMediaType != null) { - if (grpcMediaType.is(GrpcSerializationFormats.PROTO.mediaType())) { - unframedHeaders.contentType(MediaType.PROTOBUF); - } else if (grpcMediaType.is(GrpcSerializationFormats.JSON.mediaType())) { - unframedHeaders.contentType(MediaType.JSON_UTF_8); - } - } + unframedHeaders.contentType(responseContentType); final ArmeriaMessageDeframer deframer = new ArmeriaMessageDeframer( // Max outbound message size is handled by the GrpcService, so we don't need to set it here. diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java index 2eb9496ad2c..a87d2313868 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java @@ -557,8 +557,9 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req, "gRPC encoding is not supported for non-framed requests."); } + final MediaType jsonContentType = GrpcSerializationFormats.JSON.mediaType(); grpcHeaders.method(HttpMethod.POST) - .contentType(GrpcSerializationFormats.JSON.mediaType()); + .contentType(jsonContentType); // All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just // clear the header if it's present. grpcHeaders.remove(GrpcHeaderNames.GRPC_ACCEPT_ENCODING); @@ -576,7 +577,7 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req, ctx.setAttr(FramedGrpcService.RESOLVED_GRPC_METHOD, spec.method); frameAndServe(unwrap(), ctx, grpcHeaders.build(), convertToJson(ctx, clientRequest, spec), - responseFuture, generateResponseBodyConverter(spec)); + responseFuture, generateResponseBodyConverter(spec), jsonContentType); } catch (IllegalArgumentException iae) { responseFuture.completeExceptionally( HttpStatusException.of(HttpStatus.BAD_REQUEST, iae)); diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java index ae9da44a64b..9318ff68876 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java @@ -119,14 +119,16 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc final RequestHeadersBuilder grpcHeaders = clientHeaders.toBuilder(); final MediaType framedContentType; - if (contentType.is(MediaType.PROTOBUF)) { + if (contentType.isProtobuf()) { framedContentType = GrpcSerializationFormats.PROTO.mediaType(); } else if (contentType.is(MediaType.JSON)) { framedContentType = GrpcSerializationFormats.JSON.mediaType(); } else { return HttpResponse.of(HttpStatus.UNSUPPORTED_MEDIA_TYPE, MediaType.PLAIN_TEXT_UTF_8, - "Unsupported media type. Only application/protobuf is supported."); + "Unsupported media type. Only application/protobuf, " + + "application/x-protobuf, application/x-google-protobuf" + + "and application/json are supported."); } grpcHeaders.contentType(framedContentType); @@ -149,8 +151,8 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc if (t != null) { responseFuture.completeExceptionally(t); } else { - frameAndServe(unwrap(), ctx, grpcHeaders.build(), - clientRequest.content(), responseFuture, null); + frameAndServe(unwrap(), ctx, grpcHeaders.build(), clientRequest.content(), + responseFuture, null, contentType); } } return null; diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceResponseMediaTypeTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceResponseMediaTypeTest.java new file mode 100644 index 00000000000..d7fea89ace5 --- /dev/null +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceResponseMediaTypeTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2022 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.grpc; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.grpc.testing.TestServiceGrpc; +import com.linecorp.armeria.protobuf.EmptyProtos; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; + +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +public class UnframedGrpcServiceResponseMediaTypeTest { + + @RegisterExtension + static EventLoopExtension eventLoop = new EventLoopExtension(); + + private static class TestService extends TestServiceGrpc.TestServiceImplBase { + + @Override + public void emptyCall(EmptyProtos.Empty request, StreamObserver responseObserver) { + responseObserver.onNext(EmptyProtos.Empty.newBuilder().build()); + responseObserver.onCompleted(); + } + } + + private static final TestService testService = new TestService(); + private static final int MAX_MESSAGE_BYTES = 1024; + + @Test + void respondWithCorrespondingJsonMediaType() throws Exception { + final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService); + + final HttpRequest request = HttpRequest.of(HttpMethod.POST, + "/armeria.grpc.testing.TestService/EmptyCall", + MediaType.JSON_UTF_8, "{}"); + final ServiceRequestContext ctx = ServiceRequestContext.builder(request) + .build(); + + final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join(); + assertThat(res.status()).isEqualTo(HttpStatus.OK); + assertThat(res.contentType()).isEqualTo(MediaType.JSON_UTF_8); + } + + @ParameterizedTest + @ArgumentsSource(ProtobufMediaTypeProvider.class) + void respondWithCorrespondingProtobufMediaType(MediaType protobufType) throws Exception { + final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService); + + final HttpRequest request = HttpRequest.of(HttpMethod.POST, + "/armeria.grpc.testing.TestService/EmptyCall", + protobufType, + EmptyProtos.Empty.getDefaultInstance().toByteArray()); + final ServiceRequestContext ctx = ServiceRequestContext.builder(request) + .build(); + + final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join(); + assertThat(res.status()).isEqualTo(HttpStatus.OK); + assertThat(res.contentType()).isEqualTo(protobufType); + } + + private static class ProtobufMediaTypeProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + return Stream.of(MediaType.PROTOBUF, MediaType.X_PROTOBUF, MediaType.X_GOOGLE_PROTOBUF) + .map(Arguments::of); + } + } + + private static UnframedGrpcService buildUnframedGrpcService(BindableService bindableService) { + return (UnframedGrpcService) GrpcService.builder() + .addService(bindableService) + .maxRequestMessageLength(MAX_MESSAGE_BYTES) + .maxResponseMessageLength(MAX_MESSAGE_BYTES) + .enableUnframedRequests(true) + .build(); + } +} diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java index 26fb91af80b..906e6a8304a 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java @@ -122,10 +122,26 @@ void shouldClosePooledObjectsForNonOK() { final ByteBuf byteBuf = Unpooled.buffer(); final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK) .add(GrpcHeaderNames.GRPC_STATUS, "1") + .contentType(MediaType.PROTOBUF) .build(); final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders, HttpData.wrap(byteBuf)); - UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null); + UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), + null, MediaType.PROTOBUF); + assertThat(byteBuf.refCnt()).isZero(); + } + + @Test + void shouldClosePooledObjectsForMissingMediaType() { + final CompletableFuture res = new CompletableFuture<>(); + final ByteBuf byteBuf = Unpooled.buffer(); + final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK) + .add(GrpcHeaderNames.GRPC_STATUS, "0") + .build(); + final AggregatedHttpResponse framedResponse = AggregatedHttpResponse + .of(responseHeaders, HttpData.wrap(byteBuf)); + AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), + null, MediaType.PROTOBUF); assertThat(byteBuf.refCnt()).isZero(); } @@ -133,13 +149,31 @@ void shouldClosePooledObjectsForNonOK() { void shouldClosePooledObjectsForMissingGrpcStatus() { final CompletableFuture res = new CompletableFuture<>(); final ByteBuf byteBuf = Unpooled.buffer(); - final ResponseHeaders responseHeaders = ResponseHeaders.of(HttpStatus.OK); + final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK) + .contentType(MediaType.PROTOBUF) + .build(); final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders, - HttpData.wrap(byteBuf)); - UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null); + HttpData.wrap(byteBuf)); + AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), + null, MediaType.PROTOBUF); assertThat(byteBuf.refCnt()).isZero(); } + @Test + void succeedWithAllRequiredHeaders() throws Exception { + final CompletableFuture res = new CompletableFuture<>(); + final ByteBuf byteBuf = Unpooled.buffer(); + final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK) + .add(GrpcHeaderNames.GRPC_STATUS, "0") + .contentType(MediaType.PROTOBUF) + .build(); + final AggregatedHttpResponse framedResponse = AggregatedHttpResponse + .of(responseHeaders, HttpData.wrap(byteBuf)); + AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), + null, MediaType.PROTOBUF); + assertThat(HttpResponse.from(res).aggregate().get().status()).isEqualTo(HttpStatus.OK); + } + @Test void unframedGrpcStatusFunction() throws Exception { final TestService spyTestService = spy(testService);