Skip to content

Commit

Permalink
fix: Create a pool of Channels for each target.
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelmenzella-google committed Sep 10, 2020
1 parent 61b1981 commit 36b472b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,27 @@
package com.google.cloud.pubsublite.internal;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.flogger.GoogleLogger;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.util.Deque;
import java.util.LinkedList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/** A ChannelCache creates and stores default channels for use with api methods. */
public class ChannelCache {
private static final GoogleLogger log = GoogleLogger.forEnclosingClass();

private final Function<String, ManagedChannel> channelFactory;
private final ConcurrentHashMap<String, ManagedChannel> channels = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Deque<ManagedChannel>> channels =
new ConcurrentHashMap<>();

private static final int NUMBER_OF_CHANNELS_PER_TARGET = 10;
private static final String NUMBER_OF_CHANNELS_PER_TARGET_VM_OVERRIDE =
"google.cloud.pubsublite.channelCacheSize";

public ChannelCache() {
this(ChannelCache::newChannel);
Expand All @@ -40,20 +50,45 @@ public ChannelCache() {
}

@VisibleForTesting
void onShutdown() {
synchronized void onShutdown() {
channels.forEachValue(
channels.size(),
channel -> {
channels -> {
try {
channel.shutdownNow().awaitTermination(60, TimeUnit.SECONDS);
for (ManagedChannel channel : channels) {
channel.shutdownNow().awaitTermination(60, TimeUnit.SECONDS);
}
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}

public Channel get(String target) {
return channels.computeIfAbsent(target, channelFactory);
public synchronized Channel get(String target) {
Deque<ManagedChannel> channelQueue = channels.computeIfAbsent(target, this::newChannels);
ManagedChannel channel = channelQueue.removeFirst();
channelQueue.addLast(channel);
return channel;
}

private Deque<ManagedChannel> newChannels(String target) {
int numberOfChannels = NUMBER_OF_CHANNELS_PER_TARGET;
String numberOfChannelsOverride = System.getProperty(NUMBER_OF_CHANNELS_PER_TARGET_VM_OVERRIDE);
if (numberOfChannelsOverride != null && !numberOfChannelsOverride.isEmpty()) {
try {
numberOfChannels = Integer.parseInt((numberOfChannelsOverride));
} catch (NumberFormatException e) {
log.atSevere().log(
"Unable to parse override for number of channels per target: %s",
numberOfChannelsOverride);
}
}

Deque<ManagedChannel> channels = new LinkedList<>();
for (int i = 0; i < numberOfChannels; i++) {
channels.add(channelFactory.apply(target));
}
return channels;
}

private static ManagedChannel newChannel(String target) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package com.google.cloud.pubsublite.internal;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;

import io.grpc.Channel;
import io.grpc.ManagedChannel;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Function;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -34,7 +34,6 @@

@RunWith(JUnit4.class)
public class ChannelCacheTest {
@Mock ManagedChannel mockChannel;
@Mock Function<String, ManagedChannel> channelFactory;

@Before
Expand All @@ -44,14 +43,22 @@ public void setUp() {

@Test
public void reusesChannels() {
when(channelFactory.apply(any())).thenReturn(mockChannel);
when(channelFactory.apply("abc"))
.thenAnswer(
(target) -> {
ManagedChannel channel = mock(ManagedChannel.class);
when(channel.shutdownNow()).thenReturn(channel);
return channel;
});
ChannelCache cache = new ChannelCache(channelFactory);
Channel chan1 = cache.get("abc");
Channel chan2 = cache.get("abc");
assertThat(chan1).isEqualTo(chan2);
verify(channelFactory, times(1)).apply("abc");
when(mockChannel.shutdownNow()).thenReturn(mockChannel);

// Only 10 Channels should be created.
Set<Channel> channels = new HashSet<>();
for (int i = 0; i < 20; i++) {
channels.add(cache.get("abc"));
}

assertThat(channels.size()).isEqualTo(10);
cache.onShutdown();
verify(mockChannel, times(1)).shutdownNow();
}
}

0 comments on commit 36b472b

Please sign in to comment.