diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java index a8b03315fed..cfb989e5bc8 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java @@ -49,6 +49,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Queue; +import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; @@ -110,12 +111,14 @@ static final class RoundRobinLoadBalancer extends LoadBalancer { private final Helper helper; private final Map subchannels = new HashMap(); + private final Random random; @Nullable private StickinessState stickinessState; RoundRobinLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); + this.random = new Random(); } @Override @@ -211,7 +214,12 @@ public void shutdown() { */ private void updateBalancingState(ConnectivityState state, Status error) { List activeList = filterNonFailingSubchannels(getSubchannels()); - helper.updateBalancingState(state, new Picker(activeList, error, stickinessState)); + // initialize the Picker to a random start index to ensure that a high frequency of Picker + // churn does not skew subchannel selection. + int startIndex = activeList.isEmpty() ? 0 : random.nextInt(activeList.size()); + helper.updateBalancingState( + state, + new Picker(activeList, error, startIndex, stickinessState)); } /** @@ -388,14 +396,15 @@ static final class Picker extends SubchannelPicker { @Nullable private final RoundRobinLoadBalancer.StickinessState stickinessState; @SuppressWarnings("unused") - private volatile int index = -1; // start off at -1 so the address on first use is 0. + private volatile int index; Picker( - List list, @Nullable Status status, + List list, @Nullable Status status, int startIndex, @Nullable RoundRobinLoadBalancer.StickinessState stickinessState) { this.list = list; this.status = status; this.stickinessState = stickinessState; + this.index = startIndex - 1; } @Override diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 257cdef9e53..e6d2dd23ee3 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -272,14 +272,19 @@ public void pickAfterStateChange() throws Exception { verifyNoMoreInteractions(mockHelper); } + private Subchannel nextSubchannel(Subchannel current, List allSubChannels) { + return allSubChannels.get((allSubChannels.indexOf(current) + 1) % allSubChannels.size()); + } + @Test public void pickerRoundRobin() throws Exception { Subchannel subchannel = mock(Subchannel.class); Subchannel subchannel1 = mock(Subchannel.class); Subchannel subchannel2 = mock(Subchannel.class); - Picker picker = new Picker(Collections.unmodifiableList(Lists.newArrayList( - subchannel, subchannel1, subchannel2)), null /* status */, null /* stickinessState */); + Picker picker = new Picker(Collections.unmodifiableList(Lists.newArrayList( + subchannel, subchannel1, subchannel2)), null /* status */, 0 /* startIndex */, + null /* stickinessState */); assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2); @@ -292,7 +297,7 @@ public void pickerRoundRobin() throws Exception { @Test public void pickerEmptyList() throws Exception { Picker picker = - new Picker(Lists.newArrayList(), Status.UNKNOWN, null /* stickinessState */); + new Picker(Lists.newArrayList(), Status.UNKNOWN, 0, null /* stickinessState */); assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(Status.UNKNOWN, @@ -395,14 +400,16 @@ public void noStickinessEnabled_withStickyHeader() { headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); - Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); - Subchannel sc3 = subchannelIterator.next(); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(sc3, picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + List allSubchannels = picker.getList(); + Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); + Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel(); + Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel(); + Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel(); + + assertEquals(nextSubchannel(sc1, allSubchannels), sc2); + assertEquals(nextSubchannel(sc2, allSubchannels), sc3); + assertEquals(nextSubchannel(sc3, allSubchannels), sc1); + assertEquals(sc4, sc1); assertNull(loadBalancer.getStickinessMapForTest()); } @@ -423,15 +430,17 @@ public void stickinessEnabled_withoutStickyHeader() { doReturn(new Metadata()).when(mockArgs).getHeaders(); - Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); - Subchannel sc3 = subchannelIterator.next(); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(sc3, picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + List allSubchannels = picker.getList(); + Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); + Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel(); + Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel(); + Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel(); + + assertEquals(nextSubchannel(sc1, allSubchannels), sc2); + assertEquals(nextSubchannel(sc2, allSubchannels), sc3); + assertEquals(nextSubchannel(sc3, allSubchannels), sc1); + assertEquals(sc4, sc1); verify(mockArgs, times(4)).getHeaders(); assertNotNull(loadBalancer.getStickinessMapForTest()); assertThat(loadBalancer.getStickinessMapForTest()).isEmpty(); @@ -456,7 +465,7 @@ public void stickinessEnabled_withStickyHeader() { headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); - Subchannel sc1 = loadBalancer.getSubchannels().iterator().next(); + Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); @@ -488,21 +497,24 @@ public void stickinessEnabled_withDifferentStickyHeaders() { Metadata headerWithStickinessValue2 = new Metadata(); headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2"); - Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); + List allSubchannels = picker.getList(); doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc1a = picker.pickSubchannel(mockArgs).getSubchannel(); doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders(); - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc2a = picker.pickSubchannel(mockArgs).getSubchannel(); doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc1b = picker.pickSubchannel(mockArgs).getSubchannel(); doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders(); - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc2b = picker.pickSubchannel(mockArgs).getSubchannel(); + + assertEquals(sc1a, sc1b); + assertEquals(sc2a, sc2b); + assertEquals(nextSubchannel(sc1a, allSubchannels), sc2a); + assertEquals(nextSubchannel(sc1b, allSubchannels), sc2b); verify(mockArgs, atLeast(4)).getHeaders(); assertNotNull(loadBalancer.getStickinessMapForTest()); @@ -528,12 +540,10 @@ public void stickiness_goToTransientFailure_pick_backToReady() { headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); - Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); + List allSubchannels = picker.getList(); // first pick - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); // go to transient failure loadBalancer @@ -542,8 +552,9 @@ public void stickiness_goToTransientFailure_pick_backToReady() { verify(mockHelper, times(5)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); picker = pickerCaptor.getValue(); + // second pick - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel(); // go back to ready loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); @@ -551,9 +562,10 @@ public void stickiness_goToTransientFailure_pick_backToReady() { verify(mockHelper, times(6)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); picker = pickerCaptor.getValue(); - // third pick - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + // third pick + Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel(); + assertEquals(sc2, sc3); verify(mockArgs, atLeast(3)).getHeaders(); assertNotNull(loadBalancer.getStickinessMapForTest()); assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1); @@ -578,12 +590,10 @@ public void stickiness_goToTransientFailure_backToReady_pick() { headerWithStickinessValue1.put(stickinessKey, "my-sticky-value"); doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); - Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); + List allSubchannels = picker.getList(); // first pick - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); // go to transient failure loadBalancer @@ -595,8 +605,9 @@ public void stickiness_goToTransientFailure_backToReady_pick() { verify(mockHelper, times(5)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); picker = pickerCaptor.getValue(); + // second pick with a different stickiness value - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel(); // go back to ready loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); @@ -605,8 +616,10 @@ public void stickiness_goToTransientFailure_backToReady_pick() { verify(mockHelper, times(6)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); picker = pickerCaptor.getValue(); + // third pick with my-sticky-value1 - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel(); + assertEquals(sc1, sc3); verify(mockArgs, atLeast(3)).getHeaders(); assertNotNull(loadBalancer.getStickinessMapForTest()); @@ -632,18 +645,17 @@ public void stickiness_oneSubchannelShutdown() { headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); - Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); + List allSubchannels = Lists.newArrayList(picker.getList()); - assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); loadBalancer .handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value); - assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(nextSubchannel(sc1, allSubchannels), + picker.pickSubchannel(mockArgs).getSubchannel()); assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1); verify(mockArgs, atLeast(2)).getHeaders(); }