diff --git a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java index 5267169a37a..039761f0511 100644 --- a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java +++ b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java @@ -75,6 +75,7 @@ import com.linecorp.armeria.common.grpc.GrpcSerializationFormats; import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageDeframer; import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer; +import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.unsafe.grpc.GrpcUnsafeBufferUtil; import io.grpc.CallCredentials; @@ -418,7 +419,8 @@ public T build(Class clientType) { option(INTERCEPTORS.newValue(clientInterceptors)); } if (exceptionHandler != null) { - option(EXCEPTION_HANDLER.newValue(exceptionHandler)); + option(EXCEPTION_HANDLER.newValue(new UnwrappingGrpcExceptionHandleFunction(exceptionHandler.orElse( + GrpcExceptionHandlerFunction.of())))); } final Object client; diff --git a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java index 8d7a632988d..09df88b849c 100644 --- a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java +++ b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java @@ -37,7 +37,7 @@ import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer; import com.linecorp.armeria.internal.client.grpc.NullCallCredentials; import com.linecorp.armeria.internal.client.grpc.NullGrpcClientStubFactory; -import com.linecorp.armeria.internal.common.grpc.GrpcStatus; +import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.unsafe.grpc.GrpcUnsafeBufferUtil; import io.grpc.CallCredentials; @@ -174,8 +174,8 @@ public final class GrpcClientOptions { * to a gRPC {@link Status}. */ public static final ClientOption EXCEPTION_HANDLER = - ClientOption.define("EXCEPTION_HANDLER", - (ctx, cause, metadata) -> GrpcStatus.fromThrowable(cause)); + ClientOption.define("EXCEPTION_HANDLER", new UnwrappingGrpcExceptionHandleFunction( + GrpcExceptionHandlerFunction.of())); /** * Sets whether to respect the marshaller specified in gRPC {@link MethodDescriptor}. diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java new file mode 100644 index 00000000000..275a0e023bd --- /dev/null +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.grpc; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; + +import com.google.protobuf.InvalidProtocolBufferException; + +import com.linecorp.armeria.client.UnprocessedRequestException; +import com.linecorp.armeria.client.circuitbreaker.FailFastException; +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.ContentTooLargeException; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.TimeoutException; +import com.linecorp.armeria.common.stream.ClosedStreamException; +import com.linecorp.armeria.server.RequestTimeoutException; + +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2Exception; + +enum DefaultGrpcExceptionHandlerFunction implements GrpcExceptionHandlerFunction { + INSTANCE; + + /** + * Converts the {@link Throwable} to a {@link Status}, taking into account exceptions specific to Armeria as + * well and the protocol package. + */ + @Override + public Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + final Status s = Status.fromThrowable(cause); + if (s.getCode() != Code.UNKNOWN) { + return s; + } + + if (cause instanceof ClosedSessionException || cause instanceof ClosedChannelException) { + // ClosedChannelException is used any time the Netty channel is closed. Proper error + // processing requires remembering the error that occurred before this one and using it + // instead. + return s; + } + if (cause instanceof ClosedStreamException || cause instanceof RequestTimeoutException) { + return Status.CANCELLED.withCause(cause); + } + if (cause instanceof InvalidProtocolBufferException) { + return Status.INVALID_ARGUMENT.withCause(cause); + } + if (cause instanceof UnprocessedRequestException || + cause instanceof IOException || + cause instanceof FailFastException) { + return Status.UNAVAILABLE.withCause(cause); + } + if (cause instanceof Http2Exception) { + if (cause instanceof Http2Exception.StreamException && + ((Http2Exception.StreamException) cause).error() == Http2Error.CANCEL) { + return Status.CANCELLED; + } + return Status.INTERNAL.withCause(cause); + } + if (cause instanceof TimeoutException) { + return Status.DEADLINE_EXCEEDED.withCause(cause); + } + if (cause instanceof ContentTooLargeException) { + return Status.RESOURCE_EXHAUSTED.withCause(cause); + } + return s; + } +} diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java index 8bf029af4c5..c7853f05844 100644 --- a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java @@ -39,6 +39,14 @@ static GrpcExceptionHandlerFunctionBuilder builder() { return new GrpcExceptionHandlerFunctionBuilder(); } + /** + * Returns the default {@link GrpcExceptionHandlerFunction}. + */ + @UnstableApi + static GrpcExceptionHandlerFunction of() { + return DefaultGrpcExceptionHandlerFunction.INSTANCE; + } + /** * Maps the specified {@link Throwable} to a gRPC {@link Status}, * and mutates the specified {@link Metadata}. diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java index ee94a142869..87aa120d312 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java @@ -248,10 +248,9 @@ public void start(Listener responseListener, Metadata metadata) { prepareHeaders(compressor, metadata, remainingNanos); final BiFunction errorResponseFactory = - (unused, cause) -> HttpResponse.ofFailure( - GrpcStatus.fromThrowable(exceptionHandler, ctx, cause, metadata) - .withDescription(cause.getMessage()) - .asRuntimeException()); + (unused, cause) -> HttpResponse.ofFailure(exceptionHandler.apply(ctx, cause, metadata) + .withDescription(cause.getMessage()) + .asRuntimeException()); final HttpResponse res = initContextAndExecuteWithFallback( httpClient, ctx, endpointGroup, HttpResponse::of, errorResponseFactory); @@ -455,7 +454,7 @@ public void onNext(DeframedMessage message) { }); } catch (Throwable t) { final Metadata metadata = new Metadata(); - close(GrpcStatus.fromThrowable(exceptionHandler, ctx, t, metadata), metadata); + close(exceptionHandler.apply(ctx, t, metadata), metadata); } } @@ -512,7 +511,7 @@ private void prepareHeaders(Compressor compressor, Metadata metadata, long remai private void closeWhenListenerThrows(Throwable t) { final Metadata metadata = new Metadata(); - closeWhenEos(GrpcStatus.fromThrowable(exceptionHandler, ctx, t, metadata), metadata); + closeWhenEos(exceptionHandler.apply(ctx, t, metadata), metadata); } private void closeWhenEos(Status status, Metadata metadata) { diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcStatus.java b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcStatus.java index b1615c33597..d0f55873083 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcStatus.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcStatus.java @@ -31,11 +31,7 @@ package com.linecorp.armeria.internal.common.grpc; -import static java.util.Objects.requireNonNull; - -import java.io.IOException; import java.net.HttpURLConnection; -import java.nio.channels.ClosedChannelException; import java.util.Base64; import org.slf4j.Logger; @@ -44,34 +40,19 @@ import com.google.common.base.Strings; import com.google.protobuf.InvalidProtocolBufferException; -import com.linecorp.armeria.client.UnprocessedRequestException; -import com.linecorp.armeria.client.circuitbreaker.FailFastException; -import com.linecorp.armeria.common.ClosedSessionException; -import com.linecorp.armeria.common.ContentTooLargeException; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpStatus; -import com.linecorp.armeria.common.RequestContext; -import com.linecorp.armeria.common.TimeoutException; -import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; -import com.linecorp.armeria.common.grpc.GrpcStatusFunction; import com.linecorp.armeria.common.grpc.StackTraceElementProto; import com.linecorp.armeria.common.grpc.StatusCauseException; import com.linecorp.armeria.common.grpc.ThrowableProto; -import com.linecorp.armeria.common.grpc.protocol.ArmeriaStatusException; import com.linecorp.armeria.common.grpc.protocol.DeframedMessage; import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames; import com.linecorp.armeria.common.grpc.protocol.StatusMessageEscaper; -import com.linecorp.armeria.common.stream.ClosedStreamException; import com.linecorp.armeria.common.stream.StreamMessage; -import com.linecorp.armeria.common.util.Exceptions; -import com.linecorp.armeria.server.RequestTimeoutException; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.Status.Code; -import io.netty.handler.codec.http2.Http2Error; -import io.netty.handler.codec.http2.Http2Exception; /** * Utilities for handling {@link Status} in Armeria. @@ -80,134 +61,6 @@ public final class GrpcStatus { private static final Logger logger = LoggerFactory.getLogger(GrpcStatus.class); - /** - * Converts the {@link Throwable} to a {@link Status}, taking into account exceptions specific to Armeria as - * well and the protocol package. - */ - public static Status fromThrowable(Throwable t) { - t = peelAndUnwrap(requireNonNull(t, "t")); - return statusFromThrowable(t); - } - - /** - * Converts the {@link Throwable} to a {@link Status}. - * If the specified {@code statusFunction} returns {@code null}, - * the built-in exception mapping rule, which takes into account exceptions specific to Armeria as well - * and the protocol package, is used by default. - */ - public static Status fromThrowable(@Nullable GrpcStatusFunction statusFunction, RequestContext ctx, - Throwable t, Metadata metadata) { - final GrpcExceptionHandlerFunction exceptionHandler = - statusFunction != null ? statusFunction::apply : null; - return fromThrowable(exceptionHandler, ctx, t, metadata); - } - - /** - * Converts the {@link Throwable} to a {@link Status}. - * If the specified {@link GrpcExceptionHandlerFunction} returns {@code null}, - * the built-in exception mapping rule, which takes into account exceptions specific to Armeria as well - * and the protocol package, is used by default. - */ - public static Status fromThrowable(@Nullable GrpcExceptionHandlerFunction exceptionHandler, - RequestContext ctx, Throwable t, Metadata metadata) { - t = peelAndUnwrap(requireNonNull(t, "t")); - - if (exceptionHandler != null) { - final Status status = exceptionHandler.apply(ctx, t, metadata); - if (status != null) { - return status; - } - } - - return statusFromThrowable(t); - } - - private static Status statusFromThrowable(Throwable t) { - final Status s = Status.fromThrowable(t); - if (s.getCode() != Code.UNKNOWN) { - return s; - } - - if (t instanceof ClosedSessionException || t instanceof ClosedChannelException) { - // ClosedChannelException is used any time the Netty channel is closed. Proper error - // processing requires remembering the error that occurred before this one and using it - // instead. - return s; - } - if (t instanceof ClosedStreamException || t instanceof RequestTimeoutException) { - return Status.CANCELLED.withCause(t); - } - if (t instanceof InvalidProtocolBufferException) { - return Status.INVALID_ARGUMENT.withCause(t); - } - if (t instanceof UnprocessedRequestException || - t instanceof IOException || - t instanceof FailFastException) { - return Status.UNAVAILABLE.withCause(t); - } - if (t instanceof Http2Exception) { - if (t instanceof Http2Exception.StreamException && - ((Http2Exception.StreamException) t).error() == Http2Error.CANCEL) { - return Status.CANCELLED; - } - return Status.INTERNAL.withCause(t); - } - if (t instanceof TimeoutException) { - return Status.DEADLINE_EXCEEDED.withCause(t); - } - if (t instanceof ContentTooLargeException) { - return Status.RESOURCE_EXHAUSTED.withCause(t); - } - return s; - } - - /** - * Converts the specified {@link Status} to a new user-specified {@link Status} - * using the specified {@link GrpcStatusFunction}. - * Returns the given {@link Status} as is if the {@link GrpcStatusFunction} returns {@code null}. - */ - public static Status fromStatusFunction(@Nullable GrpcStatusFunction statusFunction, - RequestContext ctx, Status status, Metadata metadata) { - final GrpcExceptionHandlerFunction exceptionHandler = - statusFunction != null ? statusFunction::apply : null; - return fromExceptionHandler(exceptionHandler, ctx, status, metadata); - } - - /** - * Converts the specified {@link Status} to a new user-specified {@link Status} - * using the specified {@link GrpcExceptionHandlerFunction}. - * Returns the given {@link Status} as is if the {@link GrpcExceptionHandlerFunction} returns {@code null}. - */ - public static Status fromExceptionHandler(@Nullable GrpcExceptionHandlerFunction exceptionHandler, - RequestContext ctx, Status status, Metadata metadata) { - requireNonNull(status, "status"); - - if (exceptionHandler != null) { - final Throwable cause = status.getCause(); - if (cause != null) { - final Throwable unwrapped = peelAndUnwrap(cause); - final Status newStatus = exceptionHandler.apply(ctx, unwrapped, metadata); - if (newStatus != null) { - return newStatus; - } - } - } - return status; - } - - private static Throwable peelAndUnwrap(Throwable t) { - t = Exceptions.peel(t); - Throwable cause = t; - while (cause != null) { - if (cause instanceof ArmeriaStatusException) { - t = StatusExceptionConverter.toGrpc((ArmeriaStatusException) cause); - break; - } - cause = cause.getCause(); - } - return t; - } - /** * Maps gRPC {@link Status} to {@link HttpStatus}. If there is no matched rule for the specified * {@link Status}, the mapping rules defined in upstream Google APIs diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java index 4896eca16c8..f8c28fb6c5a 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java @@ -44,7 +44,6 @@ public final class HttpStreamDeframer extends ArmeriaMessageDeframer { private final RequestContext ctx; private final DecompressorRegistry decompressorRegistry; private final TransportStatusListener transportStatusListener; - @Nullable private final GrpcExceptionHandlerFunction exceptionHandler; @Nullable @@ -56,7 +55,7 @@ public HttpStreamDeframer( DecompressorRegistry decompressorRegistry, RequestContext ctx, TransportStatusListener transportStatusListener, - @Nullable GrpcExceptionHandlerFunction exceptionHandler, + GrpcExceptionHandlerFunction exceptionHandler, int maxMessageLength, boolean grpcWebText, boolean server) { super(maxMessageLength, ctx.alloc(), grpcWebText); this.ctx = requireNonNull(ctx, "ctx"); @@ -121,9 +120,8 @@ public void processHeaders(HttpHeaders headers, StreamDecoderOutput extends ServerCall { @Nullable private final Executor blockingExecutor; - @Nullable private final GrpcExceptionHandlerFunction exceptionHandler; // Only set once. @@ -149,7 +148,7 @@ protected AbstractServerCall(HttpRequest req, @Nullable GrpcJsonMarshaller jsonMarshaller, boolean unsafeWrapRequestBuffers, ResponseHeaders defaultHeaders, - @Nullable GrpcExceptionHandlerFunction exceptionHandler, + GrpcExceptionHandlerFunction exceptionHandler, @Nullable Executor blockingExecutor, boolean autoCompression, boolean useMethodMarshaller) { @@ -214,16 +213,23 @@ public final void close(Throwable exception) { public final void close(Throwable exception, boolean cancelled) { exception = Exceptions.peel(exception); final Metadata metadata = generateMetadataFromThrowable(exception); - final Status status = GrpcStatus.fromThrowable(exceptionHandler, ctx, exception, metadata); + final Status status = exceptionHandler.apply(ctx, exception, metadata); close(new ServerStatusAndMetadata(status, metadata, false, cancelled), exception); } @Override public final void close(Status status, Metadata metadata) { + if (status.getCause() == null) { + close(new ServerStatusAndMetadata(status, metadata, false)); + return; + } + Status newStatus = exceptionHandler.apply(ctx, status.getCause(), metadata); + assert newStatus != null; + if (status.getDescription() != null) { + newStatus = newStatus.withDescription(status.getDescription()); + } final ServerStatusAndMetadata statusAndMetadata = - new ServerStatusAndMetadata(GrpcStatus.fromExceptionHandler(exceptionHandler, ctx, - status, metadata), - metadata, false); + new ServerStatusAndMetadata(newStatus, metadata, false); close(statusAndMetadata); } diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/FramedGrpcService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/FramedGrpcService.java index a683deb928a..3bd5ef85ba3 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/FramedGrpcService.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/FramedGrpcService.java @@ -60,7 +60,6 @@ import com.linecorp.armeria.common.logging.RequestLogProperty; import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.common.util.TimeoutMode; -import com.linecorp.armeria.internal.common.grpc.GrpcStatus; import com.linecorp.armeria.internal.common.grpc.MetadataUtil; import com.linecorp.armeria.internal.common.grpc.TimeoutHeaderUtil; import com.linecorp.armeria.internal.server.grpc.AbstractServerCall; @@ -241,8 +240,7 @@ protected HttpResponse doPost(ServiceRequestContext ctx, HttpRequest req) throws return HttpResponse.of( (ResponseHeaders) AbstractServerCall.statusToTrailers( ctx, defaultHeaders.get(serializationFormat).toBuilder(), - GrpcStatus.fromThrowable(exceptionHandler, ctx, e, metadata), - metadata)); + exceptionHandler.apply(ctx, e, metadata), metadata)); } } else { if (Boolean.TRUE.equals(ctx.attr(AbstractUnframedGrpcService.IS_UNFRAMED_GRPC))) { diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java index 5eb3b521a6c..431980bc37d 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java @@ -51,6 +51,7 @@ import com.linecorp.armeria.common.grpc.GrpcStatusFunction; import com.linecorp.armeria.common.grpc.protocol.AbstractMessageDeframer; import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer; +import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.HttpServiceWithRoutes; import com.linecorp.armeria.server.Server; @@ -996,16 +997,16 @@ public GrpcService build() { registryBuilder.addService(grpcHealthCheckService.bindService(), null, ImmutableList.of()); } - final GrpcExceptionHandlerFunction grpcExceptionHandler; + GrpcExceptionHandlerFunction grpcExceptionHandler; if (exceptionMappingsBuilder != null) { - grpcExceptionHandler = exceptionMappingsBuilder.build(); + grpcExceptionHandler = exceptionMappingsBuilder.build().orElse(GrpcExceptionHandlerFunction.of()); + } else if (exceptionHandler != null) { + grpcExceptionHandler = exceptionHandler.orElse(GrpcExceptionHandlerFunction.of()); } else { - grpcExceptionHandler = exceptionHandler; - } - - if (grpcExceptionHandler != null) { - registryBuilder.setDefaultExceptionHandler(grpcExceptionHandler); + grpcExceptionHandler = GrpcExceptionHandlerFunction.of(); } + grpcExceptionHandler = new UnwrappingGrpcExceptionHandleFunction(grpcExceptionHandler); + registryBuilder.setDefaultExceptionHandler(grpcExceptionHandler); if (interceptors != null) { final HandlerRegistry.Builder newRegistryBuilder = new HandlerRegistry.Builder(); diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java index 6044ef4c969..47b8d63e9c0 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java @@ -71,6 +71,7 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.internal.common.ReflectiveDependencyInjector; +import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.internal.server.annotation.AnnotationUtil; import com.linecorp.armeria.internal.server.annotation.DecoratorAnnotationUtil; import com.linecorp.armeria.internal.server.annotation.DecoratorAnnotationUtil.DecoratorAndOrder; @@ -281,7 +282,8 @@ private static void putGrpcExceptionHandlerIfPresent( grpcExceptionHandler.ifPresent(exceptionHandler -> { GrpcExceptionHandlerFunction grpcExceptionHandler0 = exceptionHandler; if (defaultExceptionHandler != null) { - grpcExceptionHandler0 = exceptionHandler.orElse(defaultExceptionHandler); + grpcExceptionHandler0 = new UnwrappingGrpcExceptionHandleFunction( + exceptionHandler.orElse(defaultExceptionHandler)); } grpcExceptionHandlersBuilder.put(methodDefinition, grpcExceptionHandler0); }); diff --git a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java index fd7ffb4992b..15a38c3bf5d 100644 --- a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java @@ -34,7 +34,9 @@ import com.linecorp.armeria.client.Clients; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.CommonPools; +import com.linecorp.armeria.common.ContentTooLargeException; import com.linecorp.armeria.common.SerializationFormat; +import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.common.grpc.GrpcSerializationFormats; import com.linecorp.armeria.internal.common.grpc.TestServiceImpl; import com.linecorp.armeria.server.ServerBuilder; @@ -47,6 +49,9 @@ import io.grpc.ClientInterceptor; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.PrototypeMarshaller; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; import testing.grpc.Messages.Payload; import testing.grpc.Messages.SimpleRequest; import testing.grpc.Messages.SimpleResponse; @@ -282,4 +287,24 @@ public O parse(InputStream inputStream) { return next.newCall(methodDescriptor, callOptions); } } + + @Test + void useDefaultGrpcExceptionHandlerFunctionAsFallback() { + final GrpcExceptionHandlerFunction noopExceptionHandler = (ctx, cause, metadata) -> null; + final GrpcExceptionHandlerFunction exceptionHandler = + GrpcExceptionHandlerFunction.builder() + .on(ContentTooLargeException.class, noopExceptionHandler) + .build(); + final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri()) + .maxResponseLength(1) + .exceptionHandler(exceptionHandler) + .build(TestServiceBlockingStub.class); + + // Fallback exception handler expected to return RESOURCE_EXHAUSTED for the ContentTooLargeException + assertThatThrownBy(() -> client.unaryCall(SimpleRequest.getDefaultInstance())) + .isInstanceOf(StatusRuntimeException.class) + .extracting(e -> ((StatusRuntimeException) e).getStatus()) + .extracting(Status::getCode) + .isEqualTo(Code.RESOURCE_EXHAUSTED); + } } diff --git a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java index f11120abf6e..fec8017a51b 100644 --- a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java @@ -77,6 +77,7 @@ import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.grpc.GrpcCallOptions; +import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.common.grpc.GrpcSerializationFormats; import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames; import com.linecorp.armeria.common.logging.RequestLog; @@ -84,7 +85,6 @@ import com.linecorp.armeria.common.util.ThreadFactories; import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.grpc.GrpcLogUtil; -import com.linecorp.armeria.internal.common.grpc.GrpcStatus; import com.linecorp.armeria.internal.common.grpc.MetadataUtil; import com.linecorp.armeria.internal.common.grpc.StreamRecorder; import com.linecorp.armeria.internal.common.grpc.TestServiceImpl; @@ -746,7 +746,9 @@ void cancelAfterBegin() throws Exception { requestObserver.onError(new RuntimeException()); responseObserver.awaitCompletion(); assertThat(responseObserver.getValues()).isEmpty(); - assertThat(GrpcStatus.fromThrowable(responseObserver.getError()).getCode()).isEqualTo(Code.CANCELLED); + assertThat(GrpcExceptionHandlerFunction.of() + .apply(null, responseObserver.getError(), null) + .getCode()).isEqualTo(Code.CANCELLED); final RequestLog log = requestLogQueue.take(); assertThat(log.isComplete()).isTrue(); @@ -780,7 +782,9 @@ void cancelAfterFirstResponse() throws Exception { requestObserver.onError(new RuntimeException()); responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS); assertThat(responseObserver.getValues()).hasSize(1); - assertThat(GrpcStatus.fromThrowable(responseObserver.getError()).getCode()).isEqualTo(Code.CANCELLED); + assertThat(GrpcExceptionHandlerFunction.of() + .apply(null, responseObserver.getError(), null) + .getCode()).isEqualTo(Code.CANCELLED); checkRequestLog((rpcReq, rpcRes, grpcStatus) -> { assertThat(rpcReq.params()).containsExactly(request); @@ -1413,7 +1417,9 @@ void deadlineExceededServerStreaming() throws Exception { recorder.awaitCompletion(); assertThat(recorder.getError()).isNotNull(); - assertThat(GrpcStatus.fromThrowable(recorder.getError()).getCode()) + assertThat(GrpcExceptionHandlerFunction.of() + .apply(null, recorder.getError(), null) + .getCode()) .isEqualTo(Status.DEADLINE_EXCEEDED.getCode()); checkRequestLogError((headers, rpcReq, cause) -> { @@ -1611,8 +1617,12 @@ void statusCodeAndMessage() throws Exception { final ArgumentCaptor captor = ArgumentCaptor.forClass(Throwable.class); verify(responseObserver, timeout(operationTimeoutMillis())).onError(captor.capture()); - assertThat(GrpcStatus.fromThrowable(captor.getValue()).getCode()).isEqualTo(Status.UNKNOWN.getCode()); - assertThat(GrpcStatus.fromThrowable(captor.getValue()).getDescription()).isEqualTo(errorMessage); + assertThat(GrpcExceptionHandlerFunction.of() + .apply(null, captor.getValue(), null) + .getCode()).isEqualTo(Status.UNKNOWN.getCode()); + assertThat(GrpcExceptionHandlerFunction.of() + .apply(null, captor.getValue(), null) + .getDescription()).isEqualTo(errorMessage); verifyNoMoreInteractions(responseObserver); checkRequestLog((rpcReq, rpcRes, grpcStatus) -> { diff --git a/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java b/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java new file mode 100644 index 00000000000..17a20ceb32f --- /dev/null +++ b/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.grpc; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import com.google.protobuf.InvalidProtocolBufferException; + +import com.linecorp.armeria.client.circuitbreaker.CircuitBreaker; +import com.linecorp.armeria.client.circuitbreaker.FailFastException; + +import io.grpc.Status; + +class DefaultGrpcExceptionHandlerFunctionTest { + + @Test + void failFastExceptionToUnavailableCode() { + assertThat(GrpcExceptionHandlerFunction + .of() + .apply(null, new FailFastException(CircuitBreaker.ofDefaultName()), null) + .getCode()).isEqualTo(Status.Code.UNAVAILABLE); + } + + @Test + void invalidProtocolBufferExceptionToInvalidArgumentCode() { + assertThat(GrpcExceptionHandlerFunction + .of() + .apply(null, new InvalidProtocolBufferException("Failed to parse message"), null) + .getCode()).isEqualTo(Status.Code.INVALID_ARGUMENT); + } +} diff --git a/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java b/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java index 6590ae146b4..5fc512e8858 100644 --- a/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java @@ -25,7 +25,6 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; -import com.linecorp.armeria.internal.common.grpc.GrpcStatus; import com.linecorp.armeria.server.ServiceRequestContext; import io.grpc.Metadata; @@ -88,20 +87,21 @@ void sortExceptionHandler() { B2Exception.class, B1Exception.class); - final GrpcExceptionHandlerFunction exceptionHandler = builder.build(); - Status status = GrpcStatus.fromThrowable(exceptionHandler, ctx, new A3Exception(), new Metadata()); + final GrpcExceptionHandlerFunction exceptionHandler = builder.build().orElse( + GrpcExceptionHandlerFunction.of()); + Status status = exceptionHandler.apply(ctx, new A3Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.UNAUTHENTICATED); - status = GrpcStatus.fromThrowable(exceptionHandler, ctx, new A2Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, new A2Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.UNIMPLEMENTED); - status = GrpcStatus.fromThrowable(exceptionHandler, ctx, new A1Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, new A1Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.RESOURCE_EXHAUSTED); - status = GrpcStatus.fromThrowable(exceptionHandler, ctx, new B2Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, new B2Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.NOT_FOUND); - status = GrpcStatus.fromThrowable(exceptionHandler, ctx, new B1Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, new B1Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.UNAUTHENTICATED); } @@ -111,21 +111,23 @@ void mapStatus() { GrpcExceptionHandlerFunction .builder() .on(A2Exception.class, (ctx, throwable, metadata) -> Status.PERMISSION_DENIED) + .on(A1Exception.class, (ctx1, cause, metadata) -> Status.DEADLINE_EXCEEDED) .build(); for (Throwable ex : ImmutableList.of(new A2Exception(), new A3Exception())) { - final Status status = Status.UNKNOWN.withCause(ex); final Metadata metadata = new Metadata(); - final Status newStatus = GrpcStatus.fromExceptionHandler(exceptionHandler, ctx, status, metadata); + final Status newStatus = exceptionHandler.apply(ctx, ex, metadata); assertThat(newStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); assertThat(newStatus.getCause()).isEqualTo(ex); assertThat(metadata.keys()).isEmpty(); } - final Status status = Status.DEADLINE_EXCEEDED.withCause(new A1Exception()); + final A1Exception cause = new A1Exception(); final Metadata metadata = new Metadata(); - final Status newStatus = GrpcStatus.fromExceptionHandler(exceptionHandler, ctx, status, metadata); - assertThat(newStatus).isSameAs(status); + final Status newStatus = exceptionHandler.apply(ctx, cause, metadata); + + assertThat(newStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(newStatus.getCause()).isEqualTo(cause); assertThat(metadata.keys()).isEmpty(); } @@ -140,17 +142,17 @@ void mapStatusAndMetadata() { }) .build(); - final Status status = Status.UNKNOWN.withCause(new B1Exception()); - + final B1Exception cause = new B1Exception(); final Metadata metadata1 = new Metadata(); - final Status newStatus1 = GrpcStatus.fromExceptionHandler(exceptionHandler, ctx, status, metadata1); + final Status newStatus1 = exceptionHandler.apply(ctx, cause, metadata1); assertThat(newStatus1.getCode()).isEqualTo(Code.ABORTED); assertThat(metadata1.get(TEST_KEY)).isEqualTo("B1Exception"); assertThat(metadata1.keys()).containsOnly(TEST_KEY.name()); final Metadata metadata2 = new Metadata(); metadata2.put(TEST_KEY2, "test"); - final Status newStatus2 = GrpcStatus.fromExceptionHandler(exceptionHandler, ctx, status, metadata2); + final Status newStatus2 = exceptionHandler.apply(ctx, cause, metadata2); + assertThat(newStatus2.getCode()).isEqualTo(Code.ABORTED); assertThat(metadata2.get(TEST_KEY)).isEqualTo("B1Exception"); assertThat(metadata2.keys()).containsOnly(TEST_KEY.name(), TEST_KEY2.name()); diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/GrpcStatusTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/GrpcStatusTest.java index 9d1c2cba024..74d20a17cb3 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/GrpcStatusTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/GrpcStatusTest.java @@ -36,10 +36,6 @@ import org.junit.jupiter.api.Test; import com.google.api.gax.grpc.GrpcStatusCode; -import com.google.protobuf.InvalidProtocolBufferException; - -import com.linecorp.armeria.client.circuitbreaker.CircuitBreaker; -import com.linecorp.armeria.client.circuitbreaker.FailFastException; import io.grpc.Status; @@ -52,18 +48,4 @@ void grpcCodeToHttpStatus() { .isEqualTo(GrpcStatusCode.of(code).getCode().getHttpStatusCode()); } } - - @Test - void failFastExceptionToUnavailableCode() { - assertThat(GrpcStatus.fromThrowable(new FailFastException(CircuitBreaker.ofDefaultName())) - .getCode()) - .isEqualTo(Status.Code.UNAVAILABLE); - } - - @Test - void invalidProtocolBufferExceptionToInvalidArgumentCode() { - assertThat(GrpcStatus.fromThrowable(new InvalidProtocolBufferException("Failed to parse message")) - .getCode()) - .isEqualTo(Status.Code.INVALID_ARGUMENT); - } } diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java index da0c678c60f..d3ad2170874 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java @@ -31,6 +31,7 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.common.grpc.protocol.ArmeriaStatusException; import com.linecorp.armeria.common.grpc.protocol.DeframedMessage; import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames; @@ -57,7 +58,9 @@ void setUp() { final ServiceRequestContext ctx = ServiceRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); final TransportStatusListener statusListener = (status, metadata) -> statusRef.set(status); deframer = new HttpStreamDeframer(DecompressorRegistry.getDefaultInstance(), ctx, statusListener, - null, Integer.MAX_VALUE, false, true); + new UnwrappingGrpcExceptionHandleFunction( + GrpcExceptionHandlerFunction.of()), Integer.MAX_VALUE, + false, true); } @Test diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/MetadataUtilTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/MetadataUtilTest.java index c7725577897..c33bb0b2357 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/MetadataUtilTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/MetadataUtilTest.java @@ -49,8 +49,8 @@ class MetadataUtilTest { private static final Metadata.Key TEST_BIN_KEY = Metadata.Key.of("testBinary-bin", Metadata.BINARY_BYTE_MARSHALLER); - private static final ThrowableProto THROWABLE_PROTO = - GrpcStatus.serializeThrowable(new RuntimeException("test")); + private static final ThrowableProto THROWABLE_PROTO = GrpcStatus.serializeThrowable( + new RuntimeException("test")); @Test void fillHeadersTest() { diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java index 653762e3529..889ad6f781f 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java @@ -34,6 +34,7 @@ import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.Service; import com.linecorp.armeria.server.ServiceRequestContext; @@ -365,7 +366,9 @@ private synchronized void dispatchChunk() { } } catch (Throwable e) { failure = e; - if (GrpcStatus.fromThrowable(e).getCode() == Status.CANCELLED.getCode()) { + if (GrpcExceptionHandlerFunction.of() + .apply(ServiceRequestContext.current(), e, new Metadata()) + .getCode() == Status.CANCELLED.getCode()) { // Stream was cancelled by client, responseStream.onError() might be called already or // will be called soon by inbounding StreamObserver. chunks.clear(); diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java index 40b8bf77066..71cd9ef3428 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.io.IOException; import java.util.Objects; import java.util.concurrent.BlockingDeque; import java.util.concurrent.LinkedBlockingDeque; @@ -81,6 +82,17 @@ protected void configure(ServerBuilder sb) throws Exception { } }; + @RegisterExtension + static final ServerExtension serverWithDefaultExceptionHandler = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.requestTimeoutMillis(5000) + .service(GrpcService.builder() + .addService(new TestServiceIOException()) + .build()); + } + }; + @Test void classAndMethodHaveMultipleExceptionHandlers() { final TestServiceBlockingStub client = @@ -430,6 +442,25 @@ void solelyAddedMethod() { assertThat(client.unaryCall(fifthRequest)).isNotNull(); } + @Test + void defaultGrpcExceptionHandlerConvertIOExceptionToUnavailable() { + final TestServiceBlockingStub client = + GrpcClients.newClient(serverWithDefaultExceptionHandler.httpUri(), + TestServiceBlockingStub.class); + + final SimpleRequest globalRequest = SimpleRequest + .newBuilder() + .setNestedRequest(NestedRequest + .newBuilder() + .setNestedPayload("global") + .build()) + .build(); + assertThatThrownBy(() -> client.unaryCall(globalRequest)) + .isInstanceOfSatisfying(StatusRuntimeException.class, e -> { + assertThat(e.getStatus().getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + }); + } + private static class FirstGrpcExceptionHandler implements GrpcExceptionHandlerFunction { @Override @@ -562,4 +593,12 @@ public void unaryCall(SimpleRequest request, StreamObserver resp responseObserver.onCompleted(); } } + + // TestServiceIOException has DefaultGRPCExceptionHandlerFunction as fallback exception handler + private static class TestServiceIOException extends TestServiceImpl { + @Override + public void unaryCall(SimpleRequest request, StreamObserver responseObserver) { + responseObserver.onError(new IOException()); + } + } }