From 827760e7734884ae6cbbf975d06f08aa07af4933 Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Fri, 13 Oct 2017 13:26:36 -0700 Subject: [PATCH] util: improve scalabiltiy of RR load balancer In relative order of importance: * Don't acquire lock when picking subchannel. * Use O(1) lookup for updating channel state * Use non synchronized ref instead of AtomicReference * Dont store size in picker. * make class final * remove test that was not valid --- .../util/RoundRobinLoadBalancerFactory.java | 82 +++++++++++-------- .../grpc/util/RoundRobinLoadBalancerTest.java | 34 +++----- 2 files changed, 57 insertions(+), 59 deletions(-) diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java index 6b936534146..b0c7cd5b2b1 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java @@ -37,7 +37,6 @@ import io.grpc.Status; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; @@ -45,9 +44,8 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A {@link LoadBalancer} that provides round-robin load balancing mechanism over the @@ -56,11 +54,23 @@ * what is then balanced across. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") -public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { +public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { + private static final RoundRobinLoadBalancerFactory INSTANCE = new RoundRobinLoadBalancerFactory(); - private RoundRobinLoadBalancerFactory() { + private RoundRobinLoadBalancerFactory() {} + + /** + * A lighter weight Reference than AtomicReference. + */ + @VisibleForTesting + static final class Ref { + T value; + + Ref(T value) { + this.value = value; + } } /** @@ -76,15 +86,15 @@ public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { } @VisibleForTesting - static class RoundRobinLoadBalancer extends LoadBalancer { + static final class RoundRobinLoadBalancer extends LoadBalancer { + @VisibleForTesting + static final Attributes.Key> STATE_INFO = + Attributes.Key.of("state-info"); + private final Helper helper; private final Map subchannels = new HashMap(); - @VisibleForTesting - static final Attributes.Key> STATE_INFO = - Attributes.Key.of("state-info"); - RoundRobinLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); } @@ -106,12 +116,12 @@ public void handleResolvedAddressGroups( // NB(lukaszx0): because attributes are immutable we can't set new value for the key // after creation but since we can mutate the values we leverge that and set // AtomicReference which will allow mutating state info for given channel. - .set(STATE_INFO, new AtomicReference( - ConnectivityStateInfo.forNonError(IDLE))) + .set( + STATE_INFO, new Ref(ConnectivityStateInfo.forNonError(IDLE))) .build(); - Subchannel subchannel = checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), - "subchannel"); + Subchannel subchannel = + checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), "subchannel"); subchannels.put(addressGroup, subchannel); subchannel.requestConnection(); } @@ -132,13 +142,13 @@ public void handleNameResolutionError(Status error) { @Override public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - if (!subchannels.containsValue(subchannel)) { + if (subchannels.get(subchannel.getAddresses()) != subchannel) { return; } if (stateInfo.getState() == IDLE) { subchannel.requestConnection(); } - getSubchannelStateInfoRef(subchannel).set(stateInfo); + getSubchannelStateInfoRef(subchannel).value = stateInfo; updateBalancingState(getAggregatedState(), getAggregatedError()); } @@ -164,7 +174,7 @@ private static List filterNonFailingSubchannels( Collection subchannels) { List readySubchannels = new ArrayList(subchannels.size()); for (Subchannel subchannel : subchannels) { - if (getSubchannelStateInfoRef(subchannel).get().getState() == READY) { + if (getSubchannelStateInfoRef(subchannel).value.getState() == READY) { readySubchannels.add(subchannel); } } @@ -176,7 +186,7 @@ private static List filterNonFailingSubchannels( * remove all attributes. */ private static Set stripAttrs(List groupList) { - Set addrs = new HashSet(); + Set addrs = new HashSet(groupList.size()); for (EquivalentAddressGroup group : groupList) { addrs.add(new EquivalentAddressGroup(group.getAddresses())); } @@ -191,7 +201,7 @@ private static Set stripAttrs(List states = EnumSet.noneOf(ConnectivityState.class); for (Subchannel subchannel : getSubchannels()) { - states.add(getSubchannelStateInfoRef(subchannel).get().getState()); + states.add(getSubchannelStateInfoRef(subchannel).value.getState()); } if (states.contains(READY)) { return READY; @@ -225,7 +235,7 @@ Collection getSubchannels() { return subchannels.values(); } - private static AtomicReference getSubchannelStateInfoRef( + private static Ref getSubchannelStateInfoRef( Subchannel subchannel) { return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); } @@ -239,22 +249,23 @@ private static Set setsDifference(Set a, Set b) { @VisibleForTesting static final class Picker extends SubchannelPicker { + private static final AtomicIntegerFieldUpdater indexUpdater = + AtomicIntegerFieldUpdater.newUpdater(Picker.class, "index"); + @Nullable private final Status status; private final List list; - private final int size; - @GuardedBy("this") - private int index = 0; + @SuppressWarnings("unused") + private volatile int index = -1; // start off at -1 so the address on first use is 0. Picker(List list, @Nullable Status status) { - this.list = Collections.unmodifiableList(list); - this.size = list.size(); + this.list = list; this.status = status; } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - if (size > 0) { + if (list.size() > 0) { return PickResult.withSubchannel(nextSubchannel()); } @@ -266,17 +277,18 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } private Subchannel nextSubchannel() { - if (size == 0) { + if (list.isEmpty()) { throw new NoSuchElementException(); } - synchronized (this) { - Subchannel val = list.get(index); - index++; - if (index >= size) { - index = 0; - } - return val; + int size = list.size(); + + int i = indexUpdater.incrementAndGet(this); + if (i >= size) { + int oldi = i; + i %= size; + indexUpdater.compareAndSet(this, oldi, i); } + return list.get(i); } @VisibleForTesting diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index e730f21dd10..0f26c76363a 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -48,13 +48,13 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.Status; import io.grpc.util.RoundRobinLoadBalancerFactory.Picker; +import io.grpc.util.RoundRobinLoadBalancerFactory.Ref; import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer; import java.net.SocketAddress; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -98,7 +98,9 @@ public void setUp() { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - subchannels.put(eag, mock(Subchannel.class)); + Subchannel sc = mock(Subchannel.class); + when(sc.getAddresses()).thenReturn(eag); + subchannels.put(eag, sc); } when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class))) @@ -155,7 +157,7 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { for (Subchannel subchannel : Lists.newArrayList(removedSubchannel, oldSubchannel, newSubchannel)) { when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO, - new AtomicReference( + new Ref( ConnectivityStateInfo.forNonError(READY))).build()); } @@ -223,44 +225,28 @@ public Subchannel answer(InvocationOnMock invocation) throws Throwable { verifyNoMoreInteractions(mockHelper); } - @Test - public void pickAfterStateChangeBeforeResolution() throws Exception { - loadBalancer.handleSubchannelState(mockSubchannel, - ConnectivityStateInfo.forNonError(READY)); - verifyNoMoreInteractions(mockSubchannel); - verifyNoMoreInteractions(mockHelper); - } - - @Test - public void pickAfterStateChangeAndResolutionError() throws Exception { - loadBalancer.handleSubchannelState(mockSubchannel, - ConnectivityStateInfo.forNonError(READY)); - verifyNoMoreInteractions(mockSubchannel); - verifyNoMoreInteractions(mockHelper); - } - @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); - AtomicReference subchannelStateInfo = subchannel.getAttributes().get( + Ref subchannelStateInfo = subchannel.getAttributes().get( STATE_INFO); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(Picker.class)); - assertThat(subchannelStateInfo.get()).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); + assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertNull(pickerCaptor.getValue().getStatus()); - assertThat(subchannelStateInfo.get()).isEqualTo( + assertThat(subchannelStateInfo.value).isEqualTo( ConnectivityStateInfo.forNonError(READY)); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(subchannelStateInfo.get()).isEqualTo( + assertThat(subchannelStateInfo.value).isEqualTo( ConnectivityStateInfo.forTransientFailure(error)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertNull(pickerCaptor.getValue().getStatus()); @@ -269,7 +255,7 @@ public void pickAfterStateChange() throws Exception { ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertNull(pickerCaptor.getValue().getStatus()); - assertThat(subchannelStateInfo.get()).isEqualTo( + assertThat(subchannelStateInfo.value).isEqualTo( ConnectivityStateInfo.forNonError(IDLE)); verify(subchannel, times(2)).requestConnection();