Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,15 @@
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

/**
* A {@link LoadBalancer} that provides round-robin load balancing mechanism over the
Expand All @@ -56,11 +54,23 @@
* what is then balanced across.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771")
public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory {
public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory {

private static final RoundRobinLoadBalancerFactory INSTANCE =
new RoundRobinLoadBalancerFactory();

private RoundRobinLoadBalancerFactory() {
private RoundRobinLoadBalancerFactory() {}

/**
* A lighter weight Reference than AtomicReference.
*/
@VisibleForTesting
static final class Ref<T> {
T value;

Ref(T value) {
this.value = value;
}
}

/**
Expand All @@ -76,15 +86,15 @@ public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) {
}

@VisibleForTesting
static class RoundRobinLoadBalancer extends LoadBalancer {
static final class RoundRobinLoadBalancer extends LoadBalancer {
@VisibleForTesting
static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.of("state-info");

private final Helper helper;
private final Map<EquivalentAddressGroup, Subchannel> subchannels =
new HashMap<EquivalentAddressGroup, Subchannel>();

@VisibleForTesting
static final Attributes.Key<AtomicReference<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.of("state-info");

RoundRobinLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
}
Expand All @@ -106,12 +116,12 @@ public void handleResolvedAddressGroups(
// NB(lukaszx0): because attributes are immutable we can't set new value for the key
// after creation but since we can mutate the values we leverge that and set
// AtomicReference which will allow mutating state info for given channel.
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(IDLE)))
.set(
STATE_INFO, new Ref<ConnectivityStateInfo>(ConnectivityStateInfo.forNonError(IDLE)))
.build();

Subchannel subchannel = checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs),
"subchannel");
Subchannel subchannel =
checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), "subchannel");
subchannels.put(addressGroup, subchannel);
subchannel.requestConnection();
}
Expand All @@ -132,13 +142,13 @@ public void handleNameResolutionError(Status error) {

@Override
public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
if (!subchannels.containsValue(subchannel)) {
if (subchannels.get(subchannel.getAddresses()) != subchannel) {
return;
}
if (stateInfo.getState() == IDLE) {
subchannel.requestConnection();
}
getSubchannelStateInfoRef(subchannel).set(stateInfo);
getSubchannelStateInfoRef(subchannel).value = stateInfo;
updateBalancingState(getAggregatedState(), getAggregatedError());
}

Expand All @@ -164,7 +174,7 @@ private static List<Subchannel> filterNonFailingSubchannels(
Collection<Subchannel> subchannels) {
List<Subchannel> readySubchannels = new ArrayList<Subchannel>(subchannels.size());
for (Subchannel subchannel : subchannels) {
if (getSubchannelStateInfoRef(subchannel).get().getState() == READY) {
if (getSubchannelStateInfoRef(subchannel).value.getState() == READY) {
readySubchannels.add(subchannel);
}
}
Expand All @@ -176,7 +186,7 @@ private static List<Subchannel> filterNonFailingSubchannels(
* remove all attributes.
*/
private static Set<EquivalentAddressGroup> stripAttrs(List<EquivalentAddressGroup> groupList) {
Set<EquivalentAddressGroup> addrs = new HashSet<EquivalentAddressGroup>();
Set<EquivalentAddressGroup> addrs = new HashSet<EquivalentAddressGroup>(groupList.size());
for (EquivalentAddressGroup group : groupList) {
addrs.add(new EquivalentAddressGroup(group.getAddresses()));
}
Expand All @@ -191,7 +201,7 @@ private static Set<EquivalentAddressGroup> stripAttrs(List<EquivalentAddressGrou
private Status getAggregatedError() {
Status status = null;
for (Subchannel subchannel : getSubchannels()) {
ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).get();
ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value;
if (stateInfo.getState() != TRANSIENT_FAILURE) {
return null;
}
Expand All @@ -203,7 +213,7 @@ private Status getAggregatedError() {
private ConnectivityState getAggregatedState() {
Set<ConnectivityState> states = EnumSet.noneOf(ConnectivityState.class);
for (Subchannel subchannel : getSubchannels()) {
states.add(getSubchannelStateInfoRef(subchannel).get().getState());
states.add(getSubchannelStateInfoRef(subchannel).value.getState());
}
if (states.contains(READY)) {
return READY;
Expand All @@ -225,7 +235,7 @@ Collection<Subchannel> getSubchannels() {
return subchannels.values();
}

private static AtomicReference<ConnectivityStateInfo> getSubchannelStateInfoRef(
private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef(
Subchannel subchannel) {
return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
}
Expand All @@ -239,22 +249,23 @@ private static <T> Set<T> setsDifference(Set<T> a, Set<T> b) {

@VisibleForTesting
static final class Picker extends SubchannelPicker {
private static final AtomicIntegerFieldUpdater<Picker> indexUpdater =
AtomicIntegerFieldUpdater.newUpdater(Picker.class, "index");

@Nullable
private final Status status;
private final List<Subchannel> list;
private final int size;
@GuardedBy("this")
private int index = 0;
@SuppressWarnings("unused")
private volatile int index = -1; // start off at -1 so the address on first use is 0.

Picker(List<Subchannel> list, @Nullable Status status) {
this.list = Collections.unmodifiableList(list);
this.size = list.size();
this.list = list;
this.status = status;
}

@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
if (size > 0) {
if (list.size() > 0) {
return PickResult.withSubchannel(nextSubchannel());
}

Expand All @@ -266,17 +277,18 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
}

private Subchannel nextSubchannel() {
if (size == 0) {
if (list.isEmpty()) {
throw new NoSuchElementException();
}
synchronized (this) {
Subchannel val = list.get(index);
index++;
if (index >= size) {
index = 0;
}
return val;
int size = list.size();

int i = indexUpdater.incrementAndGet(this);
if (i >= size) {
int oldi = i;
i %= size;
indexUpdater.compareAndSet(this, oldi, i);
}
return list.get(i);
}

@VisibleForTesting
Expand Down
34 changes: 10 additions & 24 deletions core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.Status;
import io.grpc.util.RoundRobinLoadBalancerFactory.Picker;
import io.grpc.util.RoundRobinLoadBalancerFactory.Ref;
import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -98,7 +98,9 @@ public void setUp() {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
subchannels.put(eag, mock(Subchannel.class));
Subchannel sc = mock(Subchannel.class);
when(sc.getAddresses()).thenReturn(eag);
subchannels.put(eag, sc);
}

when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
Expand Down Expand Up @@ -155,7 +157,7 @@ public void pickAfterResolvedUpdatedHosts() throws Exception {
for (Subchannel subchannel : Lists.newArrayList(removedSubchannel, oldSubchannel,
newSubchannel)) {
when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO,
new AtomicReference<ConnectivityStateInfo>(
new Ref<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(READY))).build());
}

Expand Down Expand Up @@ -223,44 +225,28 @@ public Subchannel answer(InvocationOnMock invocation) throws Throwable {
verifyNoMoreInteractions(mockHelper);
}

@Test
public void pickAfterStateChangeBeforeResolution() throws Exception {
loadBalancer.handleSubchannelState(mockSubchannel,
ConnectivityStateInfo.forNonError(READY));
verifyNoMoreInteractions(mockSubchannel);
verifyNoMoreInteractions(mockHelper);
}

@Test
public void pickAfterStateChangeAndResolutionError() throws Exception {
loadBalancer.handleSubchannelState(mockSubchannel,
ConnectivityStateInfo.forNonError(READY));
verifyNoMoreInteractions(mockSubchannel);
verifyNoMoreInteractions(mockHelper);
}

@Test
public void pickAfterStateChange() throws Exception {
InOrder inOrder = inOrder(mockHelper);
loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
AtomicReference<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
STATE_INFO);

inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(Picker.class));
assertThat(subchannelStateInfo.get()).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));
assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));

loadBalancer.handleSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertNull(pickerCaptor.getValue().getStatus());
assertThat(subchannelStateInfo.get()).isEqualTo(
assertThat(subchannelStateInfo.value).isEqualTo(
ConnectivityStateInfo.forNonError(READY));

Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯");
loadBalancer.handleSubchannelState(subchannel,
ConnectivityStateInfo.forTransientFailure(error));
assertThat(subchannelStateInfo.get()).isEqualTo(
assertThat(subchannelStateInfo.value).isEqualTo(
ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertNull(pickerCaptor.getValue().getStatus());
Expand All @@ -269,7 +255,7 @@ public void pickAfterStateChange() throws Exception {
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertNull(pickerCaptor.getValue().getStatus());
assertThat(subchannelStateInfo.get()).isEqualTo(
assertThat(subchannelStateInfo.value).isEqualTo(
ConnectivityStateInfo.forNonError(IDLE));

verify(subchannel, times(2)).requestConnection();
Expand Down