diff --git a/grpc-client-utils/build.gradle.kts b/grpc-client-utils/build.gradle.kts index aeb4e84..886c438 100644 --- a/grpc-client-utils/build.gradle.kts +++ b/grpc-client-utils/build.gradle.kts @@ -22,6 +22,7 @@ dependencies { testImplementation("org.junit.jupiter:junit-jupiter:5.7.0") testImplementation("org.mockito:mockito-core:3.4.4") + testRuntimeOnly("io.grpc:grpc-netty:1.36.0") } tasks.test { diff --git a/grpc-client-utils/src/main/java/org/hypertrace/core/grpcutils/client/GrpcChannelRegistry.java b/grpc-client-utils/src/main/java/org/hypertrace/core/grpcutils/client/GrpcChannelRegistry.java new file mode 100644 index 0000000..72de4ba --- /dev/null +++ b/grpc-client-utils/src/main/java/org/hypertrace/core/grpcutils/client/GrpcChannelRegistry.java @@ -0,0 +1,34 @@ +package org.hypertrace.core.grpcutils.client; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class GrpcChannelRegistry { + private static final Logger LOG = LoggerFactory.getLogger(GrpcChannelRegistry.class); + private final Map channelMap = new ConcurrentHashMap<>(); + private volatile boolean isShutdown = false; + + public ManagedChannel forAddress(String host, int port) { + assert !this.isShutdown; + String channelId = this.getChannelId(host, port); + return this.channelMap.computeIfAbsent(channelId, unused -> this.buildNewChannel(host, port)); + } + + private ManagedChannel buildNewChannel(String host, int port) { + LOG.info("Creating new channel for {}:{}", host, port); + return ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); + } + + private String getChannelId(String host, int port) { + return host + ":" + port; + } + + public void shutdown() { + channelMap.values().forEach(ManagedChannel::shutdown); + this.isShutdown = true; + } +} diff --git a/grpc-client-utils/src/test/java/org/hypertrace/core/grpcutils/client/GrpcChannelRegistryTest.java b/grpc-client-utils/src/test/java/org/hypertrace/core/grpcutils/client/GrpcChannelRegistryTest.java new file mode 100644 index 0000000..26d2a74 --- /dev/null +++ b/grpc-client-utils/src/test/java/org/hypertrace/core/grpcutils/client/GrpcChannelRegistryTest.java @@ -0,0 +1,53 @@ +package org.hypertrace.core.grpcutils.client; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.Channel; +import io.grpc.ManagedChannel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class GrpcChannelRegistryTest { + + GrpcChannelRegistry channelRegistry; + + @BeforeEach + void beforeEach() { + this.channelRegistry = new GrpcChannelRegistry(); + } + + @Test + void createsNewChannelsAsRequested() { + assertNotNull(this.channelRegistry.forAddress("foo", 1000)); + } + + @Test + void reusesChannelsForDuplicateRequests() { + Channel firstChannel = this.channelRegistry.forAddress("foo", 1000); + assertSame(firstChannel, this.channelRegistry.forAddress("foo", 1000)); + assertNotSame(firstChannel, this.channelRegistry.forAddress("foo", 1001)); + assertNotSame(firstChannel, this.channelRegistry.forAddress("bar", 1000)); + } + + @Test + void shutdownAllChannelsOnShutdown() { + ManagedChannel firstChannel = this.channelRegistry.forAddress("foo", 1000); + ManagedChannel secondChannel = this.channelRegistry.forAddress("foo", 1002); + assertFalse(firstChannel.isShutdown()); + assertFalse(secondChannel.isShutdown()); + this.channelRegistry.shutdown(); + assertTrue(firstChannel.isShutdown()); + assertTrue(secondChannel.isShutdown()); + } + + @Test + void throwsIfNewChannelRequestedAfterShutdown() { + this.channelRegistry.shutdown(); + assertThrows(AssertionError.class, () -> this.channelRegistry.forAddress("foo", 1000)); + } +}