Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
Expand Down Expand Up @@ -110,12 +111,14 @@ static final class RoundRobinLoadBalancer extends LoadBalancer {
private final Helper helper;
private final Map<EquivalentAddressGroup, Subchannel> subchannels =
new HashMap<EquivalentAddressGroup, Subchannel>();
private final Random random;

@Nullable
private StickinessState stickinessState;

RoundRobinLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
this.random = new Random();
}

@Override
Expand Down Expand Up @@ -211,7 +214,12 @@ public void shutdown() {
*/
private void updateBalancingState(ConnectivityState state, Status error) {
List<Subchannel> activeList = filterNonFailingSubchannels(getSubchannels());
helper.updateBalancingState(state, new Picker(activeList, error, stickinessState));
// initialize the Picker to a random start index to ensure that a high frequency of Picker
// churn does not skew subchannel selection.
int startIndex = activeList.isEmpty() ? 0 : random.nextInt(activeList.size());
helper.updateBalancingState(
state,
new Picker(activeList, error, startIndex, stickinessState));
}

/**
Expand Down Expand Up @@ -388,14 +396,15 @@ static final class Picker extends SubchannelPicker {
@Nullable
private final RoundRobinLoadBalancer.StickinessState stickinessState;
@SuppressWarnings("unused")
private volatile int index = -1; // start off at -1 so the address on first use is 0.
private volatile int index;

Picker(
List<Subchannel> list, @Nullable Status status,
List<Subchannel> list, @Nullable Status status, int startIndex,
@Nullable RoundRobinLoadBalancer.StickinessState stickinessState) {
this.list = list;
this.status = status;
this.stickinessState = stickinessState;
this.index = startIndex - 1;
}

@Override
Expand Down
102 changes: 57 additions & 45 deletions core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,19 @@ public void pickAfterStateChange() throws Exception {
verifyNoMoreInteractions(mockHelper);
}

private Subchannel nextSubchannel(Subchannel current, List<Subchannel> allSubChannels) {
return allSubChannels.get((allSubChannels.indexOf(current) + 1) % allSubChannels.size());
}

@Test
public void pickerRoundRobin() throws Exception {
Subchannel subchannel = mock(Subchannel.class);
Subchannel subchannel1 = mock(Subchannel.class);
Subchannel subchannel2 = mock(Subchannel.class);

Picker picker = new Picker(Collections.unmodifiableList(Lists.<Subchannel>newArrayList(
subchannel, subchannel1, subchannel2)), null /* status */, null /* stickinessState */);
Picker picker = new Picker(Collections.unmodifiableList(Lists.newArrayList(
subchannel, subchannel1, subchannel2)), null /* status */, 0 /* startIndex */,
null /* stickinessState */);

assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2);

Expand All @@ -292,7 +297,7 @@ public void pickerRoundRobin() throws Exception {
@Test
public void pickerEmptyList() throws Exception {
Picker picker =
new Picker(Lists.<Subchannel>newArrayList(), Status.UNKNOWN, null /* stickinessState */);
new Picker(Lists.<Subchannel>newArrayList(), Status.UNKNOWN, 0, null /* stickinessState */);

assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(Status.UNKNOWN,
Expand Down Expand Up @@ -395,14 +400,16 @@ public void noStickinessEnabled_withStickyHeader() {
headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();

Iterator<Subchannel> subchannelIterator = loadBalancer.getSubchannels().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
Subchannel sc3 = subchannelIterator.next();
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc3, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
List<Subchannel> allSubchannels = picker.getList();
Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel();

assertEquals(nextSubchannel(sc1, allSubchannels), sc2);
assertEquals(nextSubchannel(sc2, allSubchannels), sc3);
assertEquals(nextSubchannel(sc3, allSubchannels), sc1);
assertEquals(sc4, sc1);

assertNull(loadBalancer.getStickinessMapForTest());
}
Expand All @@ -423,15 +430,17 @@ public void stickinessEnabled_withoutStickyHeader() {

doReturn(new Metadata()).when(mockArgs).getHeaders();

Iterator<Subchannel> subchannelIterator = loadBalancer.getSubchannels().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
Subchannel sc3 = subchannelIterator.next();
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc3, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
List<Subchannel> allSubchannels = picker.getList();

Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel();

assertEquals(nextSubchannel(sc1, allSubchannels), sc2);
assertEquals(nextSubchannel(sc2, allSubchannels), sc3);
assertEquals(nextSubchannel(sc3, allSubchannels), sc1);
assertEquals(sc4, sc1);
verify(mockArgs, times(4)).getHeaders();
assertNotNull(loadBalancer.getStickinessMapForTest());
assertThat(loadBalancer.getStickinessMapForTest()).isEmpty();
Expand All @@ -456,7 +465,7 @@ public void stickinessEnabled_withStickyHeader() {
headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();

Subchannel sc1 = loadBalancer.getSubchannels().iterator().next();
Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Expand Down Expand Up @@ -488,21 +497,24 @@ public void stickinessEnabled_withDifferentStickyHeaders() {
Metadata headerWithStickinessValue2 = new Metadata();
headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2");

Iterator<Subchannel> subchannelIterator = loadBalancer.getSubchannels().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
List<Subchannel> allSubchannels = picker.getList();

doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc1a = picker.pickSubchannel(mockArgs).getSubchannel();

doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc2a = picker.pickSubchannel(mockArgs).getSubchannel();

doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc1b = picker.pickSubchannel(mockArgs).getSubchannel();

doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc2b = picker.pickSubchannel(mockArgs).getSubchannel();

assertEquals(sc1a, sc1b);
assertEquals(sc2a, sc2b);
assertEquals(nextSubchannel(sc1a, allSubchannels), sc2a);
assertEquals(nextSubchannel(sc1b, allSubchannels), sc2b);

verify(mockArgs, atLeast(4)).getHeaders();
assertNotNull(loadBalancer.getStickinessMapForTest());
Expand All @@ -528,12 +540,10 @@ public void stickiness_goToTransientFailure_pick_backToReady() {
headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();

Iterator<Subchannel> subchannelIterator = loadBalancer.getSubchannels().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
List<Subchannel> allSubchannels = picker.getList();

// first pick
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();

// go to transient failure
loadBalancer
Expand All @@ -542,18 +552,20 @@ public void stickiness_goToTransientFailure_pick_backToReady() {
verify(mockHelper, times(5))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
picker = pickerCaptor.getValue();

// second pick
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();

// go back to ready
loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));

verify(mockHelper, times(6))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
picker = pickerCaptor.getValue();
// third pick
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());

// third pick
Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
assertEquals(sc2, sc3);
verify(mockArgs, atLeast(3)).getHeaders();
assertNotNull(loadBalancer.getStickinessMapForTest());
assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
Expand All @@ -578,12 +590,10 @@ public void stickiness_goToTransientFailure_backToReady_pick() {
headerWithStickinessValue1.put(stickinessKey, "my-sticky-value");
doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();

Iterator<Subchannel> subchannelIterator = loadBalancer.getSubchannels().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
List<Subchannel> allSubchannels = picker.getList();

// first pick
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();

// go to transient failure
loadBalancer
Expand All @@ -595,8 +605,9 @@ public void stickiness_goToTransientFailure_backToReady_pick() {
verify(mockHelper, times(5))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
picker = pickerCaptor.getValue();

// second pick with a different stickiness value
assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();

// go back to ready
loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
Expand All @@ -605,8 +616,10 @@ public void stickiness_goToTransientFailure_backToReady_pick() {
verify(mockHelper, times(6))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
picker = pickerCaptor.getValue();

// third pick with my-sticky-value1
assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
assertEquals(sc1, sc3);

verify(mockArgs, atLeast(3)).getHeaders();
assertNotNull(loadBalancer.getStickinessMapForTest());
Expand All @@ -632,18 +645,17 @@ public void stickiness_oneSubchannelShutdown() {
headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();

Iterator<Subchannel> subchannelIterator = loadBalancer.getSubchannels().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
List<Subchannel> allSubchannels = Lists.newArrayList(picker.getList());

assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();

loadBalancer
.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN));

assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);

assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(nextSubchannel(sc1, allSubchannels),
picker.pickSubchannel(mockArgs).getSubchannel());
assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
verify(mockArgs, atLeast(2)).getHeaders();
}
Expand Down