Skip to content

Commit

Permalink
Add ServerErrorHandler to inject dependencies in annotations (#5446)
Browse files Browse the repository at this point in the history
Motivation:
- It seems good to directly inject ServerErrorHandler via a bean.
- #5440

Modifications:
- Introduced `Optional<List<ServerErrorHandler>>` serverErrorHandlers in
the `armeriaServer` method to enable the injection of ServerErrorHandler
beans.
- Updated the `configureServerWithArmeriaSettings` method to incorporate
error handlers.

Result:
- Users can now define custom ServerErrorHandler beans and utilize an
Armeria server that has been pre-configured with these user-defined
error handlers.
- Closes #5443
  • Loading branch information
kth496 committed Apr 5, 2024
1 parent b1a1044 commit 4ca7e89
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import com.linecorp.armeria.common.metric.MeterIdPrefixFunction;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServerErrorHandler;
import com.linecorp.armeria.server.encoding.EncodingService;
import com.linecorp.armeria.server.metric.MetricCollectingService;
import com.linecorp.armeria.server.metric.MetricCollectingServiceBuilder;
Expand Down Expand Up @@ -110,6 +111,7 @@ public static void configureServerWithArmeriaSettings(
MeterIdPrefixFunction meterIdPrefixFunction,
List<MetricCollectingServiceConfigurator> metricCollectingServiceConfigurators,
List<DependencyInjector> dependencyInjectors,
List<ServerErrorHandler> serverErrorHandlers,
BeanFactory beanFactory) {

requireNonNull(server, "server");
Expand Down Expand Up @@ -203,6 +205,7 @@ public static void configureServerWithArmeriaSettings(
if (settings.isEnableAutoInjection()) {
server.dependencyInjector(SpringDependencyInjector.of(beanFactory), false);
}
serverErrorHandlers.forEach(server::errorHandler);
}

private static void configureInternalService(ServerBuilder server, InternalServiceId serviceId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.linecorp.armeria.common.metric.MeterIdPrefixFunction;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServerErrorHandler;
import com.linecorp.armeria.server.ServerPort;
import com.linecorp.armeria.server.docs.DocService;
import com.linecorp.armeria.server.healthcheck.HealthCheckService;
Expand Down Expand Up @@ -73,6 +74,7 @@ public Server armeriaServer(
Optional<List<ArmeriaServerConfigurator>> armeriaServerConfigurators,
Optional<List<Consumer<ServerBuilder>>> armeriaServerBuilderConsumers,
Optional<List<DependencyInjector>> dependencyInjectors,
Optional<List<ServerErrorHandler>> serverErrorHandlers,
BeanFactory beanFactory) {

if (!armeriaServerConfigurators.isPresent() &&
Expand All @@ -98,6 +100,7 @@ public Server armeriaServer(
MeterIdPrefixFunction.ofDefault("armeria.server")),
metricCollectingServiceConfigurators.orElse(ImmutableList.of()),
dependencyInjectors.orElse(ImmutableList.of()),
serverErrorHandlers.orElse(ImmutableList.of()),
beanFactory);

return serverBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.metric.MeterIdPrefixFunction;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerErrorHandler;
import com.linecorp.armeria.server.ServerPort;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.annotation.ExceptionHandlerFunction;
Expand Down Expand Up @@ -170,6 +171,26 @@ public MetricCollectingServiceConfigurator metricCollectingServiceConfigurator()
return (statusCode >= 200 && statusCode < 400) || statusCode == 404;
});
}

@Bean
public ServerErrorHandler serverErrorHandler1() {
return (ctx, cause) -> {
if (cause instanceof ArithmeticException) {
return HttpResponse.of("ArithmeticException was handled by serverErrorHandler!");
}
return null;
};
}

@Bean
public ServerErrorHandler serverErrorHandler2() {
return (ctx, cause) -> {
if (cause instanceof IllegalStateException) {
return HttpResponse.of("IllegalStateException was handled by serverErrorHandler!");
}
return null;
};
}
}

public static class IllegalArgumentExceptionHandler implements ExceptionHandlerFunction {
Expand Down Expand Up @@ -221,6 +242,21 @@ public AggregatedHttpResponse getV2() {
public JsonNode post(@RequestObject JsonNode jsonNode) {
return jsonNode;
}

@Get("/unhandled1")
public AggregatedHttpResponse unhandled1() throws Exception {
throw new ArithmeticException();
}

@Get("/unhandled2")
public AggregatedHttpResponse unhandled2() throws Exception {
throw new IllegalStateException();
}

@Get("/unhandled3")
public AggregatedHttpResponse unhandled3() throws Exception {
throw new IllegalAccessException();
}
}

public static class HelloGrpcService extends TestServiceImplBase {
Expand Down Expand Up @@ -294,7 +330,7 @@ void testAnnotatedService() throws Exception {
@Test
void testThriftService() throws Exception {
final TestService.Iface client = ThriftClients.newClient(newUrl("h1c") + "/thrift",
TestService.Iface.class);
TestService.Iface.class);
assertThat(client.hello("world")).isEqualTo("hello world");

final WebClient webClient = WebClient.of(newUrl("h1c"));
Expand All @@ -314,7 +350,7 @@ void testThriftService() throws Exception {
@Test
void testGrpcService() throws Exception {
final TestServiceBlockingStub client = GrpcClients.newClient(newUrl("h2c") + '/',
TestServiceBlockingStub.class);
TestServiceBlockingStub.class);
final HelloRequest request = HelloRequest.newBuilder()
.setName("world")
.build();
Expand Down Expand Up @@ -394,4 +430,38 @@ void testHealthCheckService() throws Exception {
res = response.aggregate().get();
assertThat(res.status()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE);
}

/**
* When a ServerErrorHandler @Bean is present,
* Server.config().errorHandler() does not register a DefaultServerErrorHandler.
* Since DefaultServerErrorHandler is not public, test were forced to compare toString.
* Needs to be improved.
*/
@Test
void testServerErrorHandlerRegistration() {
assertThat(server.config().errorHandler().toString()).isNotEqualTo("INSTANCE");
}

@Test
void testServerErrorHandler() throws Exception {
final WebClient client = WebClient.of(newUrl("h1c"));

// ArithmeticException will be handled by serverErrorHandler
final HttpResponse response1 = client.get("/annotated/unhandled1");
final AggregatedHttpResponse res1 = response1.aggregate().join();
assertThat(res1.status()).isEqualTo(HttpStatus.OK);
assertThat(res1.contentUtf8()).isEqualTo("ArithmeticException was handled by serverErrorHandler!");

// IllegalStateException will be handled by serverErrorHandler
final HttpResponse response2 = client.get("/annotated/unhandled2");
final AggregatedHttpResponse res2 = response2.aggregate().join();
assertThat(res2.status()).isEqualTo(HttpStatus.OK);
assertThat(res2.contentUtf8()).isEqualTo("IllegalStateException was handled by serverErrorHandler!");

// IllegalAccessException will be handled by DefaultServerErrorHandler which is used as the
// final fallback when all customized handlers return null
final HttpResponse response3 = client.get("/annotated/unhandled3");
final AggregatedHttpResponse res3 = response3.aggregate().join();
assertThat(res3.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import com.linecorp.armeria.server.Route;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServerErrorHandler;
import com.linecorp.armeria.server.ServerPort;
import com.linecorp.armeria.spring.ArmeriaServerConfigurator;
import com.linecorp.armeria.spring.ArmeriaSettings;
Expand Down Expand Up @@ -167,6 +168,7 @@ public WebServer getWebServer(HttpHandler httpHandler) {
meterIdPrefixFunctionOrDefault(),
findBeans(MetricCollectingServiceConfigurator.class),
findBeans(DependencyInjector.class),
findBeans(ServerErrorHandler.class),
beanFactory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
Expand Down Expand Up @@ -62,12 +63,14 @@
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.metric.PrometheusMeterRegistries;
import com.linecorp.armeria.internal.common.util.PortUtil;
import com.linecorp.armeria.internal.testing.MockAddressResolverGroup;
import com.linecorp.armeria.server.HttpStatusException;
import com.linecorp.armeria.server.ServerErrorHandler;
import com.linecorp.armeria.server.annotation.Get;
import com.linecorp.armeria.server.annotation.Param;
import com.linecorp.armeria.server.healthcheck.HealthChecker;
Expand Down Expand Up @@ -567,4 +570,24 @@ void testManagementPort() throws JsonProcessingException {
.isEqualTo("/hello/foo");
}
}

@Test
void testServerErrorHandlerRegistration() {
beanFactory.registerBeanDefinition("armeriaSettings", new RootBeanDefinition(ArmeriaSettings.class));
registerInternalServices(beanFactory);

// Add ServerErrorHandler @Bean which handles all exceptions and returns 200 with empty string content.
final ServerErrorHandler handler = (ctx, req) -> HttpResponse.of("");
final BeanDefinition rbd2 = new RootBeanDefinition(ServerErrorHandler.class, () -> handler);
beanFactory.registerBeanDefinition("serverErrorHandler", rbd2);

final ArmeriaReactiveWebServerFactory factory = factory();
runServer(factory, (req, res) -> {
throw new IllegalArgumentException(); // Always raise exception handler
}, server -> {
final WebClient client = httpClient(server);
final AggregatedHttpResponse res1 = client.post("/hello", "hello").aggregate().join();
assertThat(res1.status()).isEqualTo(com.linecorp.armeria.common.HttpStatus.OK);
});
}
}

0 comments on commit 4ca7e89

Please sign in to comment.