diff --git a/spring-cloud-circuitbreaker-resilience4j/src/main/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerFactory.java b/spring-cloud-circuitbreaker-resilience4j/src/main/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerFactory.java index 9a6ca07..e43aac7 100644 --- a/spring-cloud-circuitbreaker-resilience4j/src/main/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerFactory.java +++ b/spring-cloud-circuitbreaker-resilience4j/src/main/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerFactory.java @@ -54,6 +54,8 @@ public class Resilience4JCircuitBreakerFactory extends private ExecutorService executorService = Executors.newCachedThreadPool(); + private Function groupExecutorServiceFactory = group -> Executors.newCachedThreadPool(); + private ConcurrentHashMap executorServices = new ConcurrentHashMap<>(); private Map> circuitBreakerCustomizers = new HashMap<>(); @@ -110,6 +112,14 @@ public void configureExecutorService(ExecutorService executorService) { this.executorService = executorService; } + /** + * configure GroupExecutorService. + * @param groupFactory GroupExecutorService Factory + */ + public void configureGroupExecutorService(Function groupFactory) { + this.groupExecutorServiceFactory = groupFactory; + } + @Override public org.springframework.cloud.client.circuitbreaker.CircuitBreaker create(String id) { Assert.hasText(id, "A CircuitBreaker must have an id."); @@ -121,8 +131,7 @@ public org.springframework.cloud.client.circuitbreaker.CircuitBreaker create(Str public org.springframework.cloud.client.circuitbreaker.CircuitBreaker create(String id, String groupName) { Assert.hasText(id, "A CircuitBreaker must have an id."); Assert.hasText(groupName, "A CircuitBreaker must have a group name."); - final ExecutorService groupExecutorService = executorServices.computeIfAbsent(groupName, - group -> Executors.newCachedThreadPool()); + final ExecutorService groupExecutorService = executorServices.computeIfAbsent(groupName, groupExecutorServiceFactory); Resilience4JCircuitBreaker resilience4JCircuitBreaker = create(id, groupName, groupExecutorService); return tryObservedCircuitBreaker(resilience4JCircuitBreaker); } diff --git a/spring-cloud-circuitbreaker-resilience4j/src/test/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerTest.java b/spring-cloud-circuitbreaker-resilience4j/src/test/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerTest.java index bf01957..0a98f7f 100644 --- a/spring-cloud-circuitbreaker-resilience4j/src/test/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerTest.java +++ b/spring-cloud-circuitbreaker-resilience4j/src/test/java/org/springframework/cloud/circuitbreaker/resilience4j/Resilience4JCircuitBreakerTest.java @@ -16,19 +16,21 @@ package org.springframework.cloud.circuitbreaker.resilience4j; -import java.util.concurrent.TimeUnit; - import io.github.resilience4j.bulkhead.BulkheadRegistry; import io.github.resilience4j.bulkhead.ThreadPoolBulkheadRegistry; import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry; import io.github.resilience4j.timelimiter.TimeLimiterRegistry; +import io.micrometer.core.instrument.util.NamedThreadFactory; import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Test; - import org.springframework.cloud.client.circuitbreaker.CircuitBreaker; import org.springframework.cloud.client.circuitbreaker.NoFallbackAvailableException; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -230,4 +232,48 @@ public void runWithDisabledTimeLimiter() { })).isEqualTo("foobar"); } + /** + * Run the test with grouping and specify thread pool. + */ + @Test + public void runWithCustomGroupThreadPool() { + Resilience4JCircuitBreakerFactory factory = new Resilience4JCircuitBreakerFactory(CircuitBreakerRegistry.ofDefaults(), + TimeLimiterRegistry.ofDefaults(), null); + String groupName = "groupFoo"; + + // configure GroupExecutorService + factory.configureGroupExecutorService(group -> new ContextThreadPoolExecutor(groupName)); + + CircuitBreaker cb = factory.create("foo", groupName); + assertThat(cb.run(() -> Thread.currentThread().getName())).startsWith(groupName); + } + + /** + * Run tests without grouping and specify thread pool. + */ + @Test + public void runWithCustomNormalThreadPool() { + Resilience4JCircuitBreakerFactory factory = new Resilience4JCircuitBreakerFactory(CircuitBreakerRegistry.ofDefaults(), + TimeLimiterRegistry.ofDefaults(), null); + String threadPoolName = "demo-"; + + // configure ExecutorService + factory.configureExecutorService(new ContextThreadPoolExecutor(threadPoolName)); + + CircuitBreaker cb = factory.create("foo"); + assertThat(cb.run(() -> Thread.currentThread().getName())).startsWith(threadPoolName); + } + + static class ContextThreadPoolExecutor extends ThreadPoolExecutor { + + /** + * example ContextThreadPoolExecutor + * @param threadPoolName fixed threadPoolName + */ + public ContextThreadPoolExecutor(String threadPoolName) { + super(2, 5, 10, TimeUnit.SECONDS, new ArrayBlockingQueue<>(1024)); + this.setThreadFactory(new NamedThreadFactory(threadPoolName)); + } + } + }