Skip to content

Commit

Permalink
xds: Use weighted_target LB provider in wrr_locality (#9195)
Browse files Browse the repository at this point in the history
Fixes a bug where WrrLocalityLoadBalancer would use the endpoint picking policy provider instead of WeightedTargetLoadBalancerProvider.

Also adds a test to fake control plane integration test that caught this bug. The test scaffolding is also updated to have the test server echo all client headers back in the response.

The test load balancer in the test is an almost straight copy of: https://github.com/grpc/grpc-java/blob/master/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java
  • Loading branch information
temawi committed May 27, 2022
1 parent 52bac7e commit 61604ac
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 17 deletions.
10 changes: 9 additions & 1 deletion xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java
Expand Up @@ -18,10 +18,12 @@

import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME;

import com.google.common.base.MoreObjects;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerRegistry;
import io.grpc.Status;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.util.GracefulSwitchLoadBalancer;
Expand All @@ -43,9 +45,15 @@ final class WrrLocalityLoadBalancer extends LoadBalancer {
private final XdsLogger logger;
private final Helper helper;
private final GracefulSwitchLoadBalancer switchLb;
private final LoadBalancerRegistry lbRegistry;

WrrLocalityLoadBalancer(Helper helper) {
this(helper, LoadBalancerRegistry.getDefaultRegistry());
}

WrrLocalityLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) {
this.helper = checkNotNull(helper, "helper");
this.lbRegistry = lbRegistry;
switchLb = new GracefulSwitchLoadBalancer(helper);
logger = XdsLogger.withLogId(
InternalLogId.allocate("xds-wrr-locality-lb", helper.getAuthority()));
Expand Down Expand Up @@ -88,7 +96,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
.setAttributes(resolvedAddresses.getAttributes().toBuilder()
.discard(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS).build()).build();

switchLb.switchTo(wrrLocalityConfig.childPolicy.getProvider());
switchLb.switchTo(lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME));
switchLb.handleResolvedAddresses(
resolvedAddresses.toBuilder()
.setLoadBalancingPolicyConfig(new WeightedTargetConfig(weightedPolicySelections))
Expand Down
129 changes: 129 additions & 0 deletions xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java
Expand Up @@ -17,23 +17,30 @@

package io.grpc.xds;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS;
import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS;
import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS;
import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS;
import static org.junit.Assert.assertEquals;

import com.github.xds.type.v3.TypedStruct;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.Any;
import com.google.protobuf.Message;
import com.google.protobuf.Struct;
import com.google.protobuf.UInt32Value;
import com.google.protobuf.Value;
import io.envoyproxy.envoy.config.cluster.v3.Cluster;
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy;
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy;
import io.envoyproxy.envoy.config.core.v3.Address;
import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource;
import io.envoyproxy.envoy.config.core.v3.ConfigSource;
import io.envoyproxy.envoy.config.core.v3.HealthStatus;
import io.envoyproxy.envoy.config.core.v3.SocketAddress;
import io.envoyproxy.envoy.config.core.v3.TrafficDirection;
import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig;
import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment;
import io.envoyproxy.envoy.config.endpoint.v3.Endpoint;
import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint;
Expand All @@ -53,12 +60,27 @@
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager;
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter;
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds;
import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.LoadBalancerRegistry;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.NameResolverRegistry;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.protobuf.SimpleRequest;
Expand Down Expand Up @@ -100,6 +122,7 @@ public class FakeControlPlaneXdsIntegrationTest {
private Server controlPlane;
private XdsTestControlPlaneService controlPlaneService;
private XdsNameResolverProvider nameResolverProvider;
private MetadataLoadBalancerProvider metadataLoadBalancerProvider;

protected int testServerPort = 0;
protected int controlPlaneServicePort;
Expand Down Expand Up @@ -135,10 +158,13 @@ public class FakeControlPlaneXdsIntegrationTest {
*/
@Before
public void setUp() throws Exception {
ClientXdsClient.enableCustomLbConfig = true;
startControlPlane();
nameResolverProvider = XdsNameResolverProvider.createForTest(SCHEME,
defaultBootstrapOverride());
NameResolverRegistry.getDefaultRegistry().register(nameResolverProvider);
metadataLoadBalancerProvider = new MetadataLoadBalancerProvider();
LoadBalancerRegistry.getDefaultRegistry().register(metadataLoadBalancerProvider);
}

@After
Expand All @@ -156,6 +182,7 @@ public void tearDown() throws Exception {
}
}
NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider);
LoadBalancerRegistry.getDefaultRegistry().deregister(metadataLoadBalancerProvider);
}

@Test
Expand Down Expand Up @@ -186,7 +213,108 @@ serverHostName, clientListener(serverHostName)
assertEquals(goldenResponse, blockingStub.unaryRpc(request));
}

@Test
public void pingPong_metadataLoadBalancer() throws Exception {
String tcpListenerName = SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT;
String serverHostName = "test-server";
controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of(
tcpListenerName, serverListener(tcpListenerName),
serverHostName, clientListener(serverHostName)
));
startServer(defaultBootstrapOverride());
controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS,
ImmutableMap.of(RDS_NAME, rds(serverHostName)));

// Use the LoadBalancingPolicy to configure a custom LB that adds a header to server calls.
Policy metadataLbPolicy = Policy.newBuilder().setTypedExtensionConfig(
TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack(
TypedStruct.newBuilder().setTypeUrl("type.googleapis.com/test.MetadataLoadBalancer")
.setValue(Struct.newBuilder()
.putFields("metadataKey", Value.newBuilder().setStringValue("foo").build())
.putFields("metadataValue", Value.newBuilder().setStringValue("bar").build()))
.build()))).build();
Policy wrrLocalityPolicy = Policy.newBuilder()
.setTypedExtensionConfig(TypedExtensionConfig.newBuilder().setTypedConfig(
Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy(
LoadBalancingPolicy.newBuilder().addPolicies(metadataLbPolicy)).build()))).build();
controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS,
ImmutableMap.<String, Message>of(CLUSTER_NAME, cds().toBuilder().setLoadBalancingPolicy(
LoadBalancingPolicy.newBuilder()
.addPolicies(wrrLocalityPolicy)).build()));

InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0);
controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS,
ImmutableMap.<String, Message>of(EDS_NAME, eds(edsInetSocketAddress.getHostName(),
edsInetSocketAddress.getPort())));
ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + serverHostName,
InsecureChannelCredentials.create()).build();
ResponseHeaderClientInterceptor responseHeaderInterceptor
= new ResponseHeaderClientInterceptor();

// We add an interceptor to catch the response headers from the server.
blockingStub = SimpleServiceGrpc.newBlockingStub(channel)
.withInterceptors(responseHeaderInterceptor);
SimpleRequest request = SimpleRequest.newBuilder()
.build();
SimpleResponse goldenResponse = SimpleResponse.newBuilder()
.setResponseMessage("Hi, xDS!")
.build();
assertEquals(goldenResponse, blockingStub.unaryRpc(request));

// Make sure we got back the header we configured the LB with.
assertThat(responseHeaderInterceptor.reponseHeaders.get(
Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo("bar");
}

// Captures response headers from the server.
private class ResponseHeaderClientInterceptor implements ClientInterceptor {
Metadata reponseHeaders;

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions, Channel next) {

return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
@Override
public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
super.start(new ForwardingClientCallListener<RespT>() {
@Override
protected ClientCall.Listener<RespT> delegate() {
return responseListener;
}

@Override
public void onHeaders(Metadata headers) {
reponseHeaders = headers;
}
}, headers);
}
};
}
}

private void startServer(Map<String, ?> bootstrapOverride) throws Exception {
ServerInterceptor metadataInterceptor = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) {
logger.fine("Received following metadata: " + requestHeaders);

return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) {
@Override
public void sendHeaders(Metadata responseHeaders) {
responseHeaders.merge(requestHeaders);
super.sendHeaders(responseHeaders);
}

@Override
public void close(Status status, Metadata trailers) {
super.close(status, trailers);
}
}, requestHeaders);
}
};

SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl =
new SimpleServiceGrpc.SimpleServiceImplBase() {
@Override
Expand All @@ -202,6 +330,7 @@ public void unaryRpc(
XdsServerBuilder serverBuilder = XdsServerBuilder.forPort(
0, InsecureServerCredentials.create())
.addService(simpleServiceImpl)
.intercept(metadataInterceptor)
.overrideBootstrapForTest(bootstrapOverride);
server = serverBuilder.build().start();
testServerPort = server.getPort();
Expand Down

0 comments on commit 61604ac

Please sign in to comment.