From 1d268aa278524997ad5c1cc8f66661cfabbb6dd1 Mon Sep 17 00:00:00 2001 From: Koki Sato <38124381+sato9818@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:22:32 +0900 Subject: [PATCH] Support custom json marshaller for unframed gRPC error (#5555) Motivation: - A user can't use custom JSON marshaller for unframed grpc error. Modifications: - Implement `UnframedGrpcErrorHandlerBuilder` which takes a custom json marshaller Result: - Closes #4723. - A user can use their custom JSON marshaller for unframed grpc error. --- .../server/grpc/UnframedGrpcErrorHandler.java | 8 + .../grpc/UnframedGrpcErrorHandlerBuilder.java | 250 ++++++++++++++++++ .../grpc/UnframedGrpcErrorHandlers.java | 24 +- .../grpc/UnframedGrpcErrorResponseType.java | 38 +++ .../grpc/ErrorDetailsMarshallerTest.java | 39 ++- .../UnframedGrpcErrorHandlerBuilderTest.java | 196 ++++++++++++++ .../grpc/UnframedGrpcErrorHandlerTest.java | 188 ++++++++++++- 7 files changed, 727 insertions(+), 16 deletions(-) create mode 100644 grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilder.java create mode 100644 grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorResponseType.java create mode 100644 grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilderTest.java diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandler.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandler.java index 6d7db6da278d..0674585b7d01 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandler.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandler.java @@ -31,6 +31,14 @@ @UnstableApi public interface UnframedGrpcErrorHandler { + /** + * Returns a new {@link UnframedGrpcErrorHandlerBuilder}. + */ + @UnstableApi + static UnframedGrpcErrorHandlerBuilder builder() { + return new UnframedGrpcErrorHandlerBuilder(); + } + /** * Returns a plain text or json response based on the content type. */ diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilder.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilder.java new file mode 100644 index 000000000000..b6d48bd15c1c --- /dev/null +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilder.java @@ -0,0 +1,250 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.grpc; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +import org.curioswitch.common.protobuf.json.MessageMarshaller; + +import com.google.protobuf.Message; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Constructs a {@link UnframedGrpcErrorHandler} to handle unframed gRPC errors. + */ +@UnstableApi +public final class UnframedGrpcErrorHandlerBuilder { + private UnframedGrpcStatusMappingFunction statusMappingFunction = UnframedGrpcStatusMappingFunction.of(); + + @Nullable + private MessageMarshaller jsonMarshaller; + + @Nullable + private List marshalledMessages; + + @Nullable + private List> marshalledMessageTypes; + + @Nullable + private Set responseTypes; + + UnframedGrpcErrorHandlerBuilder() {} + + /** + * Sets a custom JSON marshaller to be used by the error handler. + * + *

This method allows the caller to specify a custom JSON marshaller + * for encoding the error responses. If messages or message types have + * already been registered, calling this method will result in an + * {@link IllegalStateException}. If nothing is specified, + * {@link UnframedGrpcErrorHandlers#ERROR_DETAILS_MARSHALLER} is used as + * default json marshaller. + * + * @param jsonMarshaller The custom JSON marshaller to use + */ + public UnframedGrpcErrorHandlerBuilder jsonMarshaller(MessageMarshaller jsonMarshaller) { + requireNonNull(jsonMarshaller, "jsonMarshaller"); + checkState(marshalledMessages == null && marshalledMessageTypes == null, + "Cannot set a custom JSON marshaller because one or more Message instances or " + + "Message types have already been registered. To set a custom marshaller, " + + "ensure that no Message or Message type registrations have been made before " + + "calling this method." + ); + this.jsonMarshaller = jsonMarshaller; + return this; + } + + /** + * Specifies the status mapping function to be used by the error handler. + * + *

This function determines how gRPC statuses are mapped to HTTP statuses + * in the error response. + * + * @param statusMappingFunction The status mapping function + */ + public UnframedGrpcErrorHandlerBuilder statusMappingFunction( + UnframedGrpcStatusMappingFunction statusMappingFunction) { + this.statusMappingFunction = requireNonNull(statusMappingFunction, "statusMappingFunction"); + return this; + } + + /** + * Specifies the response types that the error handler will support. + * + *

This method allows specifying one or more response types (e.g., JSON, PLAINTEXT) + * that the error handler can produce. If nothing is specified or multiple types are specified, the actual + * response type is determined by the response's content type. + * + * @param responseTypes The response types to support + */ + public UnframedGrpcErrorHandlerBuilder responseTypes(UnframedGrpcErrorResponseType... responseTypes) { + requireNonNull(responseTypes, "responseTypes"); + + if (this.responseTypes == null) { + this.responseTypes = EnumSet.noneOf(UnframedGrpcErrorResponseType.class); + } + Collections.addAll(this.responseTypes, responseTypes); + return this; + } + + /** + * Registers custom messages to be marshalled by the error handler. + * + *

This method registers specific message instances for custom error responses. + * If a custom JSON marshaller has already been set, calling this method will + * result in an {@link IllegalStateException}. + * + * @param messages The message instances to register + */ + public UnframedGrpcErrorHandlerBuilder registerMarshalledMessages(Message... messages) { + requireNonNull(messages, "messages"); + checkState(jsonMarshaller == null, + "Cannot register custom messages because a custom JSON marshaller has already been set. " + + "Use the custom marshaller to register custom messages."); + + if (marshalledMessages == null) { + marshalledMessages = new ArrayList<>(); + } + Collections.addAll(marshalledMessages, messages); + return this; + } + + /** + * Registers custom messages to be marshalled by the error handler. + * + *

This method allows registering message instances for custom error responses. + * If a custom JSON marshaller has already been set, calling this method will + * result in an {@link IllegalStateException}. + * + * @param messages The collection of messages to register + */ + public UnframedGrpcErrorHandlerBuilder registerMarshalledMessages( + Iterable messages) { + requireNonNull(messages, "messages"); + checkState(jsonMarshaller == null, + "Cannot register the collection of messages because a custom JSON marshaller has " + + "already been set. Use the custom marshaller to register custom messages."); + + if (marshalledMessages == null) { + marshalledMessages = new ArrayList<>(); + } + messages.forEach(marshalledMessages::add); + return this; + } + + /** + * Registers custom message types to be marshalled by the error handler. + * + *

This method registers specific message types for custom error responses. + * If a custom JSON marshaller has already been set, calling this method will + * result in an {@link IllegalStateException}. + * + * @param messageTypes The message types to register + */ + @SafeVarargs + public final UnframedGrpcErrorHandlerBuilder registerMarshalledMessageTypes( + Class... messageTypes) { + requireNonNull(messageTypes, "messageTypes"); + checkState(jsonMarshaller == null, + "Cannot register custom messageTypes because a custom JSON marshaller has already been " + + "set. Use the custom marshaller to register custom message types."); + + if (marshalledMessageTypes == null) { + marshalledMessageTypes = new ArrayList<>(); + } + Collections.addAll(marshalledMessageTypes, messageTypes); + return this; + } + + /** + * Registers custom message types to be marshalled by the error handler. + * + *

This method allows registering message instances for custom error responses. + * If a custom JSON marshaller has already been set, calling this method will + * result in an {@link IllegalStateException}. + * + * @param messageTypes The collection of message types to register + */ + public UnframedGrpcErrorHandlerBuilder registerMarshalledMessageTypes( + Iterable> messageTypes) { + requireNonNull(messageTypes, "messageTypes"); + checkState(jsonMarshaller == null, + "Cannot register the collection of messageTypes because a custom JSON marshaller has " + + "already been set. Use the custom marshaller to register custom message types."); + + if (marshalledMessageTypes == null) { + marshalledMessageTypes = new ArrayList<>(); + } + messageTypes.forEach(marshalledMessageTypes::add); + return this; + } + + /** + * Returns a newly created {@link UnframedGrpcErrorHandler}. + * + *

This method constructs a new {@code UnframedGrpcErrorHandler} with the + * current configuration of this builder. + */ + public UnframedGrpcErrorHandler build() { + if (jsonMarshaller == null) { + jsonMarshaller = UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER; + final MessageMarshaller.Builder builder = jsonMarshaller.toBuilder(); + + if (marshalledMessages != null) { + for (final Message message : marshalledMessages) { + builder.register(message); + } + } + + if (marshalledMessageTypes != null) { + for (final Class messageType : marshalledMessageTypes) { + builder.register(messageType); + } + } + + jsonMarshaller = builder.build(); + } + + if (responseTypes == null) { + return UnframedGrpcErrorHandlers.of(statusMappingFunction, jsonMarshaller); + } + + if (responseTypes.contains(UnframedGrpcErrorResponseType.JSON) && + responseTypes.contains(UnframedGrpcErrorResponseType.PLAINTEXT)) { + return UnframedGrpcErrorHandlers.of(statusMappingFunction, jsonMarshaller); + } + + if (responseTypes.contains(UnframedGrpcErrorResponseType.JSON)) { + return UnframedGrpcErrorHandlers.ofJson(statusMappingFunction, jsonMarshaller); + } + + if (responseTypes.contains(UnframedGrpcErrorResponseType.PLAINTEXT)) { + return UnframedGrpcErrorHandlers.ofPlaintext(statusMappingFunction); + } + + return UnframedGrpcErrorHandlers.of(statusMappingFunction, jsonMarshaller); + } +} diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlers.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlers.java index 32b6f3fe2191..d5872d820bc1 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlers.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlers.java @@ -68,8 +68,7 @@ final class UnframedGrpcErrorHandlers { private static final Logger logger = LoggerFactory.getLogger(UnframedGrpcErrorHandlers.class); - // XXX(ikhoon): Support custom JSON marshaller? - private static final MessageMarshaller ERROR_DETAILS_MARSHALLER = + static final MessageMarshaller ERROR_DETAILS_MARSHALLER = MessageMarshaller.builder() .omittingInsignificantWhitespace(true) .register(RetryInfo.getDefaultInstance()) @@ -93,11 +92,16 @@ final class UnframedGrpcErrorHandlers { * to an {@link HttpStatus} code. */ static UnframedGrpcErrorHandler of(UnframedGrpcStatusMappingFunction statusMappingFunction) { + return of(statusMappingFunction, ERROR_DETAILS_MARSHALLER); + } + + static UnframedGrpcErrorHandler of( + UnframedGrpcStatusMappingFunction statusMappingFunction, MessageMarshaller jsonMarshaller) { final UnframedGrpcStatusMappingFunction mappingFunction = withDefault(statusMappingFunction); return (ctx, status, response) -> { final MediaType grpcMediaType = response.contentType(); if (grpcMediaType != null && grpcMediaType.isJson()) { - return ofJson(mappingFunction).handle(ctx, status, response); + return ofJson(mappingFunction, jsonMarshaller).handle(ctx, status, response); } else { return ofPlaintext(mappingFunction).handle(ctx, status, response); } @@ -110,7 +114,8 @@ static UnframedGrpcErrorHandler of(UnframedGrpcStatusMappingFunction statusMappi * @param statusMappingFunction The function which maps the {@link Throwable} or gRPC {@link Status} code * to an {@link HttpStatus} code. */ - static UnframedGrpcErrorHandler ofJson(UnframedGrpcStatusMappingFunction statusMappingFunction) { + static UnframedGrpcErrorHandler ofJson( + UnframedGrpcStatusMappingFunction statusMappingFunction, MessageMarshaller jsonMarshaller) { final UnframedGrpcStatusMappingFunction mappingFunction = withDefault(statusMappingFunction); return (ctx, status, response) -> { final ByteBuf buffer = ctx.alloc().buffer(); @@ -148,7 +153,7 @@ static UnframedGrpcErrorHandler ofJson(UnframedGrpcStatusMappingFunction statusM } if (rpcStatus != null) { jsonGenerator.writeFieldName("details"); - writeErrorDetails(rpcStatus.getDetailsList(), jsonGenerator); + writeErrorDetails(rpcStatus.getDetailsList(), jsonGenerator, jsonMarshaller); } } jsonGenerator.writeEndObject(); @@ -169,6 +174,10 @@ static UnframedGrpcErrorHandler ofJson(UnframedGrpcStatusMappingFunction statusM }; } + static UnframedGrpcErrorHandler ofJson(UnframedGrpcStatusMappingFunction statusMappingFunction) { + return ofJson(statusMappingFunction, ERROR_DETAILS_MARSHALLER); + } + /** * Returns a plaintext response. * @@ -217,11 +226,12 @@ private static UnframedGrpcStatusMappingFunction withDefault( } @VisibleForTesting - static void writeErrorDetails(List details, JsonGenerator jsonGenerator) throws IOException { + static void writeErrorDetails(List details, JsonGenerator jsonGenerator, + MessageMarshaller jsonMarshaller) throws IOException { jsonGenerator.writeStartArray(); for (Any detail : details) { try { - ERROR_DETAILS_MARSHALLER.writeValue(detail, jsonGenerator); + jsonMarshaller.writeValue(detail, jsonGenerator); } catch (IOException e) { logger.warn("Unexpected exception while writing an error detail to JSON. detail: {}", detail, e); diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorResponseType.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorResponseType.java new file mode 100644 index 000000000000..f35c741420f2 --- /dev/null +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorResponseType.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.grpc; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * The types of responses that can be sent when handling errors in an unframed gRPC service. + * + *

When multiple {@code UnframedGrpcErrorResponseType} values are selected, the actual response type + * is determined by the response's {@code contentType}. + */ +@UnstableApi +public enum UnframedGrpcErrorResponseType { + /** + * The error response will be formatted as a JSON object. + */ + JSON, + + /** + * The error response will be sent as plain text. + */ + PLAINTEXT, +} diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/ErrorDetailsMarshallerTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/ErrorDetailsMarshallerTest.java index 10617c71f541..ae6946191d43 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/ErrorDetailsMarshallerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/ErrorDetailsMarshallerTest.java @@ -16,12 +16,14 @@ package com.linecorp.armeria.server.grpc; +import static com.linecorp.armeria.server.grpc.UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER; import static net.javacrumbs.jsonunit.fluent.JsonFluentAssert.assertThatJson; import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.io.IOException; import java.io.StringWriter; +import org.curioswitch.common.protobuf.json.MessageMarshaller; import org.junit.jupiter.api.Test; import com.fasterxml.jackson.core.JsonGenerator; @@ -45,6 +47,7 @@ import com.google.rpc.RetryInfo; import com.google.rpc.Status; +import com.linecorp.armeria.grpc.testing.Error.AuthError; import com.linecorp.armeria.internal.common.JacksonUtil; class ErrorDetailsMarshallerTest { @@ -119,7 +122,8 @@ void convertErrorDetailToJsonNodeTest() throws IOException { final StringWriter jsonObjectWriter = new StringWriter(); final JsonGenerator jsonGenerator = mapper.createGenerator(jsonObjectWriter); - UnframedGrpcErrorHandlers.writeErrorDetails(status.getDetailsList(), jsonGenerator); + UnframedGrpcErrorHandlers.writeErrorDetails( + status.getDetailsList(), jsonGenerator, ERROR_DETAILS_MARSHALLER); jsonGenerator.flush(); final String expectedJsonString = "[\n" + @@ -191,6 +195,36 @@ void convertErrorDetailToJsonNodeTest() throws IOException { assertThatJson(mapper.readTree(jsonObjectWriter.toString())).isEqualTo(expectedJsonString); } + @Test + void convertCustomErrorDetailToJsonNodeTest() throws IOException { + final AuthError authError = AuthError.newBuilder() + .setCode(401) + .setMessage("Auth error.") + .build(); + final Status status = Status.newBuilder() + .setCode(Code.UNKNOWN.getNumber()) + .setMessage("Unknown Exceptions Test") + .addDetails(Any.pack(authError)) + .build(); + final StringWriter jsonObjectWriter = new StringWriter(); + final JsonGenerator jsonGenerator = mapper.createGenerator(jsonObjectWriter); + final MessageMarshaller jsonMarshaller = ERROR_DETAILS_MARSHALLER.toBuilder() + .register(authError) + .build(); + UnframedGrpcErrorHandlers.writeErrorDetails( + status.getDetailsList(), jsonGenerator, jsonMarshaller); + jsonGenerator.flush(); + final String expectedJsonString = + "[\n" + + " {\n" + + " \"@type\":\"type.googleapis.com/armeria.grpc.testing.AuthError\",\n" + + " \"code\": 401," + + " \"message\": \"Auth error.\"" + + " }\n" + + ']'; + assertThatJson(mapper.readTree(jsonObjectWriter.toString())).isEqualTo(expectedJsonString); + } + @Test void shouldThrowIOException() throws IOException { final Empty empty = Empty.getDefaultInstance(); @@ -199,6 +233,7 @@ void shouldThrowIOException() throws IOException { final JsonGenerator jsonGenerator = mapper.createGenerator(jsonObjectWriter); assertThatThrownBy(() -> UnframedGrpcErrorHandlers.writeErrorDetails( - status.getDetailsList(), jsonGenerator)).isInstanceOf(IOException.class); + status.getDetailsList(), jsonGenerator, ERROR_DETAILS_MARSHALLER)).isInstanceOf( + IOException.class); } } diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilderTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilderTest.java new file mode 100644 index 000000000000..51d10e427c63 --- /dev/null +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerBuilderTest.java @@ -0,0 +1,196 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.grpc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.curioswitch.common.protobuf.json.MessageMarshaller; +import org.junit.jupiter.api.Test; + +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.grpc.testing.Error.AuthError; +import com.linecorp.armeria.grpc.testing.Error.InternalError; +import com.linecorp.armeria.server.ServiceRequestContext; + +import io.grpc.Status; + +public class UnframedGrpcErrorHandlerBuilderTest { + @Test + void cannotCallRegisterMarshalledMessagesAndJsonMarshallerSimultaneously() { + assertThatThrownBy( + () -> UnframedGrpcErrorHandler.builder() + .jsonMarshaller( + UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER) + .registerMarshalledMessageTypes(InternalError.class)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining( + "Cannot register custom messageTypes because a custom JSON marshaller has " + + "already been set. Use the custom marshaller to register custom message types."); + + assertThatThrownBy( + () -> UnframedGrpcErrorHandler.builder() + .jsonMarshaller( + UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER) + .registerMarshalledMessages( + InternalError.newBuilder().build())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining( + "Cannot register custom messages because a custom JSON marshaller has " + + "already been set. Use the custom marshaller to register custom messages."); + + assertThatThrownBy( + () -> UnframedGrpcErrorHandler.builder() + .jsonMarshaller( + UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER) + .registerMarshalledMessages( + ImmutableList.of(InternalError.newBuilder().build()))) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining( + "Cannot register the collection of messages because a custom JSON marshaller has " + + "already been set. Use the custom marshaller to register custom messages."); + + assertThatThrownBy( + () -> UnframedGrpcErrorHandler.builder() + .jsonMarshaller( + UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER) + .registerMarshalledMessageTypes( + ImmutableList.of(InternalError.class))) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining( + "Cannot register the collection of messageTypes because a custom JSON marshaller has " + + "already been set. Use the custom marshaller to register custom message types."); + + assertThatThrownBy( + () -> UnframedGrpcErrorHandler.builder() + .registerMarshalledMessages(InternalError.newBuilder().build()) + .jsonMarshaller( + UnframedGrpcErrorHandlers.ERROR_DETAILS_MARSHALLER)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining( + "Cannot set a custom JSON marshaller because one or more Message instances or " + + "Message types have already been registered. To set a custom marshaller, " + + "ensure that no Message or Message type registrations have been made before " + + "calling this method."); + } + + @Test + void buildWithoutOptions() { + final UnframedGrpcErrorHandler unframedGrpcErrorHandler = UnframedGrpcErrorHandler.builder().build(); + final ServiceRequestContext ctx = + ServiceRequestContext.of(HttpRequest.of(RequestHeaders.of(HttpMethod.GET, "/test"))); + final AggregatedHttpResponse jsonResponse = + AggregatedHttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR, + MediaType.JSON_UTF_8, + "{\"message\":\"Internal Server Error\"}"); + final HttpResponse httpResponseJson = + unframedGrpcErrorHandler.handle(ctx, Status.INTERNAL, jsonResponse); + final AggregatedHttpResponse aggregatedHttpResponse = httpResponseJson.aggregate().join(); + assertThat(aggregatedHttpResponse.headers().contentType().isJson()).isTrue(); + + final AggregatedHttpResponse plaintextResponse = + AggregatedHttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR, + MediaType.PLAIN_TEXT_UTF_8, + "Internal Server Error"); + final HttpResponse httpResponsePlaintext = + unframedGrpcErrorHandler.handle(ctx, Status.INTERNAL, plaintextResponse); + assertThat(httpResponsePlaintext.aggregate().join().headers() + .contentType().is(MediaType.PLAIN_TEXT)).isTrue(); + } + + @Test + void buildWithResponseType() { + final UnframedGrpcErrorHandler unframedGrpcErrorHandlerJson = + UnframedGrpcErrorHandler.builder() + .responseTypes(UnframedGrpcErrorResponseType.JSON) + .build(); + final ServiceRequestContext ctx = ServiceRequestContext.of( + HttpRequest.of(RequestHeaders.of(HttpMethod.GET, "/test"))); + final AggregatedHttpResponse response = AggregatedHttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR); + final HttpResponse httpResponseJson = + unframedGrpcErrorHandlerJson.handle(ctx, Status.INTERNAL, response); + assertThat(httpResponseJson.aggregate().join().headers().contentType() + .isJson()).isTrue(); + + final UnframedGrpcErrorHandler unframedGrpcErrorHandlerPlaintext = + UnframedGrpcErrorHandler.builder() + .responseTypes(UnframedGrpcErrorResponseType.PLAINTEXT) + .build(); + final HttpResponse httpResponsePlaintext = + unframedGrpcErrorHandlerPlaintext.handle(ctx, Status.INTERNAL, response); + assertThat(httpResponsePlaintext.aggregate().join().headers().contentType() + .is(MediaType.PLAIN_TEXT)).isTrue(); + + final UnframedGrpcErrorHandler unframedGrpcErrorHandler = + UnframedGrpcErrorHandler.builder() + .responseTypes( + UnframedGrpcErrorResponseType.JSON, + UnframedGrpcErrorResponseType.PLAINTEXT) + .build(); + final AggregatedHttpResponse jsonResponse = + AggregatedHttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR, + MediaType.JSON_UTF_8, + "{\"message\":\"Internal Server Error\"}"); + + final HttpResponse httpResponse = unframedGrpcErrorHandler.handle(ctx, Status.INTERNAL, jsonResponse); + assertThat(httpResponse.aggregate().join().headers().contentType() + .isJson()).isTrue(); + } + + @Test + void buildWithCustomJsonMarshaller() { + final MessageMarshaller messageMarshaller = MessageMarshaller.builder().build(); + assertDoesNotThrow(() -> UnframedGrpcErrorHandler.builder() + .jsonMarshaller(messageMarshaller) + .build()); + } + + @Test + void buildWithCustomMessage() { + assertDoesNotThrow(() -> UnframedGrpcErrorHandler.builder() + .registerMarshalledMessageTypes( + InternalError.class, + AuthError.class) + .build()); + assertDoesNotThrow(() -> UnframedGrpcErrorHandler.builder() + .registerMarshalledMessages( + InternalError.newBuilder().build(), + AuthError.newBuilder().build()) + .build()); + assertDoesNotThrow(() -> UnframedGrpcErrorHandler.builder() + .registerMarshalledMessageTypes( + ImmutableList.of(InternalError.class, + AuthError.class)) + .build()); + assertDoesNotThrow(() -> UnframedGrpcErrorHandler.builder() + .registerMarshalledMessages( + ImmutableList.of(InternalError.newBuilder() + .build(), + AuthError.newBuilder() + .build())) + .build()); + } +} diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerTest.java index 33c8bda10c00..815511988c0e 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcErrorHandlerTest.java @@ -19,6 +19,7 @@ import static net.javacrumbs.jsonunit.fluent.JsonFluentAssert.assertThatJson; import static org.assertj.core.api.Assertions.assertThat; +import org.curioswitch.common.protobuf.json.MessageMarshaller; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -32,6 +33,8 @@ import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.grpc.testing.Error.AuthError; +import com.linecorp.armeria.grpc.testing.Error.InternalError; import com.linecorp.armeria.internal.common.JacksonUtil; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.testing.junit5.server.ServerExtension; @@ -53,7 +56,7 @@ protected void configure(ServerBuilder sb) { }; @RegisterExtension - static ServerExtension verbosePlainTextResServer = new ServerExtension() { + static ServerExtension verbosePlaintextResServer = new ServerExtension() { @Override protected void configure(ServerBuilder sb) { configureServer(sb, true, UnframedGrpcErrorHandler.ofPlainText(), testService); @@ -76,6 +79,61 @@ protected void configure(ServerBuilder sb) { } }; + @RegisterExtension + static ServerExtension plaintextResServerWithBuilder = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + configureServer(sb, false, + UnframedGrpcErrorHandler.builder() + .build(), + testService); + } + }; + + @RegisterExtension + static ServerExtension jsonResServerWithMarshalledMessage = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + configureServer(sb, false, + UnframedGrpcErrorHandler.builder() + .registerMarshalledMessageTypes(InternalError.class) + .registerMarshalledMessages(AuthError.newBuilder().build()) + .responseTypes(UnframedGrpcErrorResponseType.JSON) + .build(), + testServiceWithCustomMessage); + } + }; + + @RegisterExtension + static ServerExtension jsonResServerWithCustomJsonMarshaller = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final MessageMarshaller jsonMarshaller = MessageMarshaller.builder() + .register(InternalError.class) + .register(AuthError.class) + .build(); + configureServer(sb, false, + UnframedGrpcErrorHandler.builder() + .jsonMarshaller(jsonMarshaller) + .responseTypes(UnframedGrpcErrorResponseType.JSON) + .build(), + testServiceWithCustomMessage); + } + }; + + @RegisterExtension + static ServerExtension plaintextResServerWithCustomStatusMapping = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final UnframedGrpcStatusMappingFunction mappingFunction = (ctx, status, response) -> HttpStatus.OK; + configureServer(sb, false, + UnframedGrpcErrorHandler.builder() + .statusMappingFunction(mappingFunction) + .build(), + testService); + } + }; + private static class TestService extends TestServiceImplBase { @Override @@ -94,11 +152,35 @@ public void emptyCall(Empty request, StreamObserver responseObserver) { final com.google.rpc.Status status = com.google.rpc.Status.newBuilder() - .setCode( - Code.UNKNOWN.getNumber()) + .setCode(Code.UNKNOWN.getNumber()) .setMessage("Unknown Exceptions Test") - .addDetails( - Any.pack(errorInfo)) + .addDetails(Any.pack(errorInfo)) + .build(); + + responseObserver.onError(StatusProto.toStatusRuntimeException(status)); + } + } + + private static class TestServiceWithCustomMessage extends TestServiceImplBase { + + @Override + public void emptyCall(Empty request, StreamObserver responseObserver) { + final InternalError internalError = InternalError.newBuilder() + .setCode(500) + .setMessage("Internal server error.") + .build(); + + final AuthError authError = AuthError.newBuilder() + .setCode(500) + .setMessage("Auth server error.") + .build(); + + final com.google.rpc.Status + status = com.google.rpc.Status.newBuilder() + .setCode(Code.INTERNAL.getNumber()) + .setMessage("Custom error message test.") + .addDetails(Any.pack(internalError)) + .addDetails(Any.pack(authError)) .build(); responseObserver.onError(StatusProto.toStatusRuntimeException(status)); @@ -111,6 +193,9 @@ public void emptyCall(Empty request, StreamObserver responseObserver) { private static final TestServiceGrpcStatus testServiceGrpcStatus = new TestServiceGrpcStatus(); + private static final TestServiceWithCustomMessage testServiceWithCustomMessage = + new TestServiceWithCustomMessage(); + private static void configureServer(ServerBuilder sb, boolean verboseResponses, UnframedGrpcErrorHandler errorHandler, TestServiceImplBase testServiceImplBase) { @@ -137,8 +222,8 @@ void withoutStackTrace() { } @Test - void plainTextWithStackTrace() { - final BlockingWebClient client = verbosePlainTextResServer.webClient().blocking(); + void plaintextWithStackTrace() { + final BlockingWebClient client = verbosePlaintextResServer.webClient().blocking(); final AggregatedHttpResponse response = client.prepare() .post(TestServiceGrpc.getEmptyCallMethod().getFullMethodName()) @@ -192,4 +277,93 @@ void richJson() throws JsonProcessingException { '}'); assertThat(response.trailers()).isEmpty(); } + + @Test + void plainTestUsingBuilder() { + final BlockingWebClient client = plaintextResServerWithBuilder.webClient().blocking(); + final AggregatedHttpResponse response = + client.prepare() + .post(TestServiceGrpc.getEmptyCallMethod().getFullMethodName()) + .content(MediaType.PROTOBUF, Empty.getDefaultInstance().toByteArray()) + .execute(); + assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + final String content = response.contentUtf8(); + assertThat(content).isEqualTo("grpc-code: UNKNOWN, grpc error message"); + assertThat(response.trailers()).isEmpty(); + } + + @Test + void jsonWithCustomMessage() throws JsonProcessingException { + final BlockingWebClient client = jsonResServerWithMarshalledMessage.webClient().blocking(); + final AggregatedHttpResponse response = + client.prepare() + .post(TestServiceGrpc.getEmptyCallMethod().getFullMethodName()) + .content(MediaType.PROTOBUF, Empty.getDefaultInstance().toByteArray()) + .execute(); + assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + assertThatJson(mapper.readTree(response.contentUtf8())) + .isEqualTo( + '{' + + " \"code\": 13," + + " \"grpc-code\": \"INTERNAL\"," + + " \"message\": \"Custom error message test.\"," + + " \"details\": [" + + " {" + + " \"@type\": \"type.googleapis.com/armeria.grpc.testing.InternalError\"," + + " \"code\": 500," + + " \"message\": \"Internal server error.\"" + + " }," + + " {" + + " \"@type\": \"type.googleapis.com/armeria.grpc.testing.AuthError\"," + + " \"code\": 500," + + " \"message\": \"Auth server error.\"" + + " }" + + " ]" + + '}'); + assertThat(response.trailers()).isEmpty(); + } + + @Test + void jsonWithCustomJsonMarshaller() throws JsonProcessingException { + final BlockingWebClient client = jsonResServerWithCustomJsonMarshaller.webClient().blocking(); + final AggregatedHttpResponse response = + client.prepare() + .post(TestServiceGrpc.getEmptyCallMethod().getFullMethodName()) + .content(MediaType.PROTOBUF, Empty.getDefaultInstance().toByteArray()) + .execute(); + assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + assertThatJson(mapper.readTree(response.contentUtf8())) + .isEqualTo( + '{' + + " \"code\": 13," + + " \"grpc-code\": \"INTERNAL\"," + + " \"message\": \"Custom error message test.\"," + + " \"details\": [" + + " {" + + " \"@type\": \"type.googleapis.com/armeria.grpc.testing.InternalError\"," + + " \"code\": 500," + + " \"message\": \"Internal server error.\"" + + " }," + + " {" + + " \"@type\": \"type.googleapis.com/armeria.grpc.testing.AuthError\"," + + " \"code\": 500," + + " \"message\": \"Auth server error.\"" + + " }" + + " ]" + + '}'); + assertThat(response.trailers()).isEmpty(); + } + + @Test + void plaintextWithCustomStatusMapping() { + final BlockingWebClient client = plaintextResServerWithCustomStatusMapping.webClient().blocking(); + final AggregatedHttpResponse response = + client.prepare() + .post(TestServiceGrpc.getEmptyCallMethod().getFullMethodName()) + .content(MediaType.PROTOBUF, Empty.getDefaultInstance().toByteArray()) + .execute(); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("grpc-code: UNKNOWN, grpc error message"); + assertThat(response.trailers()).isEmpty(); + } }