Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for alternative protobuf content types #4364

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions core/src/main/java/com/linecorp/armeria/common/MediaType.java
Expand Up @@ -662,6 +662,8 @@ private static MediaType addKnownType(MediaType mediaType) {
* <a href="https://developers.google.com/protocol-buffers">Protocol buffers</a>.
*/
public static final MediaType PROTOBUF = createConstant(APPLICATION_TYPE, "protobuf");
public static final MediaType X_PROTOBUF = createConstant(APPLICATION_TYPE, "x-protobuf");
public static final MediaType X_GOOGLE_PROTOBUF = createConstant(APPLICATION_TYPE, "x-google-protobuf");

/**
* <a href="https://en.wikipedia.org/wiki/RDF/XML">RDF/XML</a> documents, which are XML
Expand Down Expand Up @@ -1052,6 +1054,20 @@ public boolean isJson() {
return is(JSON) || subtype().endsWith("+json");
}

/**
* Returns {@code true} when the subtype is one of {@link MediaType#PROTOBUF}, {@link MediaType#X_PROTOBUF}
* and {@link MediaType#X_GOOGLE_PROTOBUF}. Otherwise {@code false}.
*
* <pre>{@code
* PROTOBUF.isProtobuf() // true
* X_PROTOBUF.isProtobuf() // true
* X_GOOGLE_PROTOBUF.isProtobuf() // true
* }</pre>
*/
public boolean isProtobuf() {
return is(PROTOBUF) || is(X_PROTOBUF)|| is(X_GOOGLE_PROTOBUF);
}

/**
* Returns {@code true} if this {@link MediaType} belongs to the given {@link MediaType}.
* Similar to what {@link MediaType#is(MediaType)} does except that this one compares the parameters
Expand Down
Expand Up @@ -505,6 +505,14 @@ public final class MediaTypeNames {
* {@value #PROTOBUF}.
*/
public static final String PROTOBUF = "application/protobuf";
/**
* {@value #X_PROTOBUF}.
*/
public static final String X_PROTOBUF = "application/x-protobuf";
/**
* {@value #X_GOOGLE_PROTOBUF}.
*/
public static final String X_GOOGLE_PROTOBUF = "application/x-google-protobuf";
/**
* {@value #RDF_XML_UTF_8}.
*/
Expand Down
Expand Up @@ -138,7 +138,8 @@ protected void frameAndServe(
RequestHeaders grpcHeaders,
HttpData content,
CompletableFuture<HttpResponse> res,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<HttpData, HttpData> responseBodyConverter,
@Nullable MediaType responseContentType) {
minwoox marked this conversation as resolved.
Show resolved Hide resolved
final HttpRequest grpcRequest;
try (ArmeriaMessageFramer framer = new ArmeriaMessageFramer(
ctx.alloc(), ArmeriaMessageFramer.NO_MAX_OUTBOUND_MESSAGE_SIZE, false)) {
Expand Down Expand Up @@ -170,7 +171,7 @@ protected void frameAndServe(
res.completeExceptionally(t);
} else {
deframeAndRespond(ctx, framedResponse, res, unframedGrpcErrorHandler,
responseBodyConverter);
responseBodyConverter, responseContentType);
}
}
return null;
Expand All @@ -182,7 +183,8 @@ static void deframeAndRespond(ServiceRequestContext ctx,
AggregatedHttpResponse grpcResponse,
CompletableFuture<HttpResponse> res,
UnframedGrpcErrorHandler unframedGrpcErrorHandler,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<HttpData, HttpData> responseBodyConverter,
@Nullable MediaType responseContentType) {
minwoox marked this conversation as resolved.
Show resolved Hide resolved
final HttpHeaders trailers = !grpcResponse.trailers().isEmpty() ?
grpcResponse.trailers() : grpcResponse.headers();
final String grpcStatusCode = trailers.get(GrpcHeaderNames.GRPC_STATUS);
Expand Down Expand Up @@ -210,14 +212,20 @@ static void deframeAndRespond(ServiceRequestContext ctx,
}

final MediaType grpcMediaType = grpcResponse.contentType();
if (grpcMediaType == null) {
PooledObjects.close(grpcResponse.content());
res.completeExceptionally(new NullPointerException("MediaType is undefined"));
return;
}

final ResponseHeadersBuilder unframedHeaders = grpcResponse.headers().toBuilder();
unframedHeaders.set(GrpcHeaderNames.GRPC_STATUS, grpcStatusCode); // grpcStatusCode is 0 which is OK.
if (grpcMediaType != null) {
if (grpcMediaType.is(GrpcSerializationFormats.PROTO.mediaType())) {
unframedHeaders.contentType(MediaType.PROTOBUF);
} else if (grpcMediaType.is(GrpcSerializationFormats.JSON.mediaType())) {
unframedHeaders.contentType(MediaType.JSON_UTF_8);
}
if (responseContentType != null) {
minwoox marked this conversation as resolved.
Show resolved Hide resolved
unframedHeaders.contentType(responseContentType);
} else if (grpcMediaType.is(GrpcSerializationFormats.PROTO.mediaType())) {
unframedHeaders.contentType(MediaType.PROTOBUF);
} else if (grpcMediaType.is(GrpcSerializationFormats.JSON.mediaType())) {
unframedHeaders.contentType(MediaType.JSON_UTF_8);
}

final ArmeriaMessageDeframer deframer = new ArmeriaMessageDeframer(
Expand Down
Expand Up @@ -557,8 +557,9 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
"gRPC encoding is not supported for non-framed requests.");
}

final MediaType jsonContentType = GrpcSerializationFormats.JSON.mediaType();
grpcHeaders.method(HttpMethod.POST)
.contentType(GrpcSerializationFormats.JSON.mediaType());
.contentType(jsonContentType);
// All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just
// clear the header if it's present.
grpcHeaders.remove(GrpcHeaderNames.GRPC_ACCEPT_ENCODING);
Expand All @@ -576,7 +577,7 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
ctx.setAttr(FramedGrpcService.RESOLVED_GRPC_METHOD, spec.method);
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
convertToJson(ctx, clientRequest, spec),
responseFuture, generateResponseBodyConverter(spec));
responseFuture, generateResponseBodyConverter(spec), jsonContentType);
} catch (IllegalArgumentException iae) {
responseFuture.completeExceptionally(
HttpStatusException.of(HttpStatus.BAD_REQUEST, iae));
Expand Down
Expand Up @@ -119,14 +119,16 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
final RequestHeadersBuilder grpcHeaders = clientHeaders.toBuilder();

final MediaType framedContentType;
if (contentType.is(MediaType.PROTOBUF)) {
if (contentType.isProtobuf()) {
framedContentType = GrpcSerializationFormats.PROTO.mediaType();
} else if (contentType.is(MediaType.JSON)) {
framedContentType = GrpcSerializationFormats.JSON.mediaType();
} else {
return HttpResponse.of(HttpStatus.UNSUPPORTED_MEDIA_TYPE,
MediaType.PLAIN_TEXT_UTF_8,
"Unsupported media type. Only application/protobuf is supported.");
"Unsupported media type. Only application/protobuf, " +
"application/x-protobuf, application/x-google-protobuf" +
"and application/json are supported.");
}
grpcHeaders.contentType(framedContentType);

Expand All @@ -149,8 +151,8 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
if (t != null) {
responseFuture.completeExceptionally(t);
} else {
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
clientRequest.content(), responseFuture, null);
frameAndServe(unwrap(), ctx, grpcHeaders.build(), clientRequest.content(),
responseFuture, null, contentType);
}
}
return null;
Expand Down
@@ -0,0 +1,109 @@
/*
* Copyright 2022 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.server.grpc;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.grpc.testing.TestServiceGrpc;
import com.linecorp.armeria.protobuf.EmptyProtos;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.common.EventLoopExtension;

import io.grpc.BindableService;
import io.grpc.stub.StreamObserver;

public class UnframedGrpcServiceResponseMediaTypeTest {

@RegisterExtension
static EventLoopExtension eventLoop = new EventLoopExtension();

private static class TestService extends TestServiceGrpc.TestServiceImplBase {

@Override
public void emptyCall(EmptyProtos.Empty request, StreamObserver<EmptyProtos.Empty> responseObserver) {
responseObserver.onNext(EmptyProtos.Empty.newBuilder().build());
responseObserver.onCompleted();
}
}

private static final TestService testService = new TestService();
private static final int MAX_MESSAGE_BYTES = 1024;

@Test
void respondWithCorrespondingJsonMediaType() throws Exception {
final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);

final HttpRequest request = HttpRequest.of(HttpMethod.POST,
"/armeria.grpc.testing.TestService/EmptyCall",
MediaType.JSON_UTF_8, "{}");
final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
.build();

final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
assertThat(res.status()).isEqualTo(HttpStatus.OK);
assertThat(res.contentType()).isEqualTo(MediaType.JSON_UTF_8);
}

@ParameterizedTest
@ArgumentsSource(ProtobufMediaTypeProvider.class)
void respondWithCorrespondingProtobufMediaType(MediaType protobufType) throws Exception {
final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);

final HttpRequest request = HttpRequest.of(HttpMethod.POST,
"/armeria.grpc.testing.TestService/EmptyCall",
protobufType,
EmptyProtos.Empty.getDefaultInstance().toByteArray());
final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
.build();

final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
assertThat(res.status()).isEqualTo(HttpStatus.OK);
assertThat(res.contentType()).isEqualTo(protobufType);
}

private static class ProtobufMediaTypeProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(MediaType.PROTOBUF, MediaType.X_PROTOBUF, MediaType.X_GOOGLE_PROTOBUF)
.map(Arguments::of);
}
}

private static UnframedGrpcService buildUnframedGrpcService(BindableService bindableService) {
return (UnframedGrpcService) GrpcService.builder()
.addService(bindableService)
.maxRequestMessageLength(MAX_MESSAGE_BYTES)
.maxResponseMessageLength(MAX_MESSAGE_BYTES)
.enableUnframedRequests(true)
.build();
}
}
Expand Up @@ -122,24 +122,58 @@ void shouldClosePooledObjectsForNonOK() {
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "1")
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void shouldClosePooledObjectsForMissingMediaType() {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "0")
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
.of(responseHeaders, HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void shouldClosePooledObjectsForMissingGrpcStatus() {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.of(HttpStatus.OK);
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void succeedWithAllRequiredHeaders() throws Exception {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "0")
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
.of(responseHeaders, HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(HttpResponse.from(res).aggregate().get().status()).isEqualTo(HttpStatus.OK);
}

@Test
void unframedGrpcStatusFunction() throws Exception {
final TestService spyTestService = spy(testService);
Expand Down