Skip to content

Commit

Permalink
core: implement Channel State API for RoundRobin
Browse files Browse the repository at this point in the history
  • Loading branch information
dapengzhang0 committed Aug 10, 2017
1 parent 04fd4bc commit e5ef92c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
33 changes: 28 additions & 5 deletions core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
package io.grpc.util;

import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
Expand All @@ -36,6 +38,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -119,12 +122,12 @@ public void handleResolvedAddressGroups(
subchannel.shutdown();
}

updatePicker(getAggregatedError());
updateBalancingState(getAggregatedState(), getAggregatedError());
}

@Override
public void handleNameResolutionError(Status error) {
updatePicker(error);
updateBalancingState(TRANSIENT_FAILURE, error);
}

@Override
Expand All @@ -136,7 +139,7 @@ public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo s
subchannel.requestConnection();
}
getSubchannelStateInfoRef(subchannel).set(stateInfo);
updatePicker(getAggregatedError());
updateBalancingState(getAggregatedState(), getAggregatedError());
}

@Override
Expand All @@ -149,9 +152,9 @@ public void shutdown() {
/**
* Updates picker with the list of active subchannels (state == READY).
*/
private void updatePicker(@Nullable Status error) {
private void updateBalancingState(ConnectivityState state, Status error) {
List<Subchannel> activeList = filterNonFailingSubchannels(getSubchannels());
helper.updatePicker(new Picker(activeList, error));
helper.updateBalancingState(state, new Picker(activeList, error));
}

/**
Expand Down Expand Up @@ -197,6 +200,26 @@ private Status getAggregatedError() {
return status;
}

private ConnectivityState getAggregatedState() {
Set<ConnectivityState> states = EnumSet.noneOf(ConnectivityState.class);
for (Subchannel subchannel : getSubchannels()) {
states.add(getSubchannelStateInfoRef(subchannel).get().getState());
}
if (states.contains(READY)) {
return READY;
}
if (states.contains(CONNECTING)) {
return CONNECTING;
}
if (states.contains(IDLE)) {
// This subchannel IDLE is not because of channel IDLE_TIMEOUT, in which case LB is already
// shutdown.
// RRLB will request connection immediately on subchannel IDLE.
return CONNECTING;
}
return TRANSIENT_FAILURE;
}

@VisibleForTesting
Collection<Subchannel> getSubchannels() {
return subchannels.values();
Expand Down
43 changes: 33 additions & 10 deletions core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package io.grpc.util;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.STATE_INFO;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
Expand Down Expand Up @@ -77,6 +80,8 @@ public class RoundRobinLoadBalancerTest {
@Captor
private ArgumentCaptor<Picker> pickerCaptor;
@Captor
private ArgumentCaptor<ConnectivityState> stateCaptor;
@Captor
private ArgumentCaptor<EquivalentAddressGroup> eagCaptor;
@Mock
private Helper mockHelper;
Expand Down Expand Up @@ -131,8 +136,11 @@ public void pickAfterResolved() throws Exception {
verify(subchannel, never()).shutdown();
}

verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture());
verify(mockHelper, times(2))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());

assertEquals(CONNECTING, stateCaptor.getAllValues().get(0));
assertEquals(READY, stateCaptor.getAllValues().get(1));
assertThat(pickerCaptor.getValue().getList()).containsExactly(readySubchannel);

verifyNoMoreInteractions(mockHelper);
Expand Down Expand Up @@ -176,7 +184,7 @@ public Subchannel answer(InvocationOnMock invocation) throws Throwable {

InOrder inOrder = inOrder(mockHelper);

inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture());
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
Picker picker = pickerCaptor.getValue();
assertNull(picker.getStatus());
assertThat(picker.getList()).containsExactly(removedSubchannel, oldSubchannel);
Expand Down Expand Up @@ -206,7 +214,7 @@ public Subchannel answer(InvocationOnMock invocation) throws Throwable {

verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
any(Attributes.class));
inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture());
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());

picker = pickerCaptor.getValue();
assertNull(picker.getStatus());
Expand Down Expand Up @@ -239,12 +247,12 @@ public void pickAfterStateChange() throws Exception {
AtomicReference<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
STATE_INFO);

inOrder.verify(mockHelper).updatePicker(isA(Picker.class));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(Picker.class));
assertThat(subchannelStateInfo.get()).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));

loadBalancer.handleSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture());
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertNull(pickerCaptor.getValue().getStatus());
assertThat(subchannelStateInfo.get()).isEqualTo(
ConnectivityStateInfo.forNonError(ConnectivityState.READY));
Expand All @@ -254,12 +262,12 @@ public void pickAfterStateChange() throws Exception {
ConnectivityStateInfo.forTransientFailure(error));
assertThat(subchannelStateInfo.get()).isEqualTo(
ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture());
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertNull(pickerCaptor.getValue().getStatus());

loadBalancer.handleSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(ConnectivityState.IDLE));
inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture());
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertNull(pickerCaptor.getValue().getStatus());
assertThat(subchannelStateInfo.get()).isEqualTo(
ConnectivityStateInfo.forNonError(ConnectivityState.IDLE));
Expand Down Expand Up @@ -300,7 +308,7 @@ public void pickerEmptyList() throws Exception {
public void nameResolutionErrorWithNoChannels() throws Exception {
Status error = Status.NOT_FOUND.withDescription("nameResolutionError");
loadBalancer.handleNameResolutionError(error);
verify(mockHelper).updatePicker(pickerCaptor.capture());
verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertNull(pickResult.getSubchannel());
assertEquals(error, pickResult.getStatus());
Expand All @@ -316,7 +324,13 @@ public void nameResolutionErrorWithActiveChannels() throws Exception {

verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
any(Attributes.class));
verify(mockHelper, times(3)).updatePicker(pickerCaptor.capture());
verify(mockHelper, times(3))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());

Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
assertEquals(CONNECTING, stateIterator.next());
assertEquals(READY, stateIterator.next());
assertEquals(TRANSIENT_FAILURE, stateIterator.next());

LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertEquals(readySubchannel, pickResult.getSubchannel());
Expand Down Expand Up @@ -346,19 +360,28 @@ public void subchannelStateIsolation() throws Exception {
loadBalancer
.handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));

verify(mockHelper, times(6)).updatePicker(pickerCaptor.capture());
verify(mockHelper, times(6))
.updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
Iterator<Picker> pickers = pickerCaptor.getAllValues().iterator();
// The picker is incrementally updated as subchannels become READY
assertEquals(CONNECTING, stateIterator.next());
assertThat(pickers.next().getList()).isEmpty();
assertEquals(READY, stateIterator.next());
assertThat(pickers.next().getList()).containsExactly(sc1);
assertEquals(READY, stateIterator.next());
assertThat(pickers.next().getList()).containsExactly(sc1, sc2);
assertEquals(READY, stateIterator.next());
assertThat(pickers.next().getList()).containsExactly(sc1, sc2, sc3);
// The IDLE subchannel is dropped from the picker, but a reconnection is requested
assertEquals(READY, stateIterator.next());
assertThat(pickers.next().getList()).containsExactly(sc1, sc3);
verify(sc2, times(2)).requestConnection();
// The failing subchannel is dropped from the picker, with no requested reconnect
assertEquals(READY, stateIterator.next());
assertThat(pickers.next().getList()).containsExactly(sc1);
verify(sc3, times(1)).requestConnection();
assertThat(stateIterator.hasNext()).isFalse();
assertThat(pickers.hasNext()).isFalse();
}

Expand Down

0 comments on commit e5ef92c

Please sign in to comment.