diff --git a/graphql-kickstart-spring-support/src/main/java/graphql/kickstart/spring/error/GraphQLErrorStartupListener.java b/graphql-kickstart-spring-support/src/main/java/graphql/kickstart/spring/error/GraphQLErrorStartupListener.java index be4787e4..bfee83e1 100644 --- a/graphql-kickstart-spring-support/src/main/java/graphql/kickstart/spring/error/GraphQLErrorStartupListener.java +++ b/graphql-kickstart-spring-support/src/main/java/graphql/kickstart/spring/error/GraphQLErrorStartupListener.java @@ -18,9 +18,11 @@ public GraphQLErrorStartupListener(ErrorHandlerSupplier errorHandlerSupplier, bo @Override public void onApplicationEvent(@NonNull ApplicationReadyEvent event) { - ConfigurableApplicationContext context = event.getApplicationContext(); - GraphQLErrorHandler errorHandler = new GraphQLErrorHandlerFactory().create(context, exceptionHandlersEnabled); - context.getBeanFactory().registerSingleton(errorHandler.getClass().getCanonicalName(), errorHandler); - errorHandlerSupplier.setErrorHandler(errorHandler); + if (!errorHandlerSupplier.isPresent()) { + ConfigurableApplicationContext context = event.getApplicationContext(); + GraphQLErrorHandler errorHandler = new GraphQLErrorHandlerFactory().create(context, exceptionHandlersEnabled); + context.getBeanFactory().registerSingleton(errorHandler.getClass().getCanonicalName(), errorHandler); + errorHandlerSupplier.setErrorHandler(errorHandler); + } } } diff --git a/graphql-kickstart-spring-support/src/test/java/graphql/kickstart/spring/error/GraphQLErrorStartupListenerTest.java b/graphql-kickstart-spring-support/src/test/java/graphql/kickstart/spring/error/GraphQLErrorStartupListenerTest.java new file mode 100644 index 00000000..5287c894 --- /dev/null +++ b/graphql-kickstart-spring-support/src/test/java/graphql/kickstart/spring/error/GraphQLErrorStartupListenerTest.java @@ -0,0 +1,36 @@ +package graphql.kickstart.spring.error; + +import graphql.kickstart.execution.error.GraphQLErrorHandler; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.context.event.ApplicationReadyEvent; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; + +public class GraphQLErrorStartupListenerTest { + + @Test + void error_handler_is_not_overridden_when_present() { + GraphQLErrorHandler expectedErrorHandler = Mockito.mock(GraphQLErrorHandler.class); + ErrorHandlerSupplier errorHandlerSupplier = new ErrorHandlerSupplier(expectedErrorHandler); + GraphQLErrorStartupListener graphQLErrorStartupListener = new GraphQLErrorStartupListener(errorHandlerSupplier, false); + graphQLErrorStartupListener.onApplicationEvent(getApplicationReadyEvent()); + Assertions.assertThat(errorHandlerSupplier.get()).isEqualTo(expectedErrorHandler); + } + + @Test + void error_handler_is_set_when_not_present() { + ErrorHandlerSupplier errorHandlerSupplier = new ErrorHandlerSupplier(null); + GraphQLErrorStartupListener graphQLErrorStartupListener = new GraphQLErrorStartupListener(errorHandlerSupplier, false); + graphQLErrorStartupListener.onApplicationEvent(getApplicationReadyEvent()); + Assertions.assertThat(errorHandlerSupplier.get()).isNotNull(); + } + + private ApplicationReadyEvent getApplicationReadyEvent() { + AnnotationConfigWebApplicationContext annotationConfigWebApplicationContext = new AnnotationConfigWebApplicationContext(); + annotationConfigWebApplicationContext.refresh(); + return new ApplicationReadyEvent(new SpringApplication(), new String[0], annotationConfigWebApplicationContext); + } + +}