From 1e0d82e9f00e430f9d6c903b814d1aea0a6566e8 Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue, 5 Sep 2023 16:27:51 -0400 Subject: [PATCH] balancer/leastrequest: Cache atomic load and also add concurrent rpc test (#6602) --- balancer/leastrequest/balancer_test.go | 55 ++++++++++++++++++++++++++ balancer/leastrequest/leastrequest.go | 9 ++--- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/balancer/leastrequest/balancer_test.go b/balancer/leastrequest/balancer_test.go index 39bf1b94abd..44bb21c9e9f 100644 --- a/balancer/leastrequest/balancer_test.go +++ b/balancer/leastrequest/balancer_test.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "testing" "time" @@ -455,3 +456,57 @@ func (s) TestLeastRequestPersistsCounts(t *testing.T) { t.Fatalf("addr count (-got:, +want): %v", diff) } } + +// TestConcurrentRPCs tests concurrent RPCs on the least request balancer. It +// configures a channel with a least request balancer as the top level balancer, +// and makes 100 RPCs asynchronously. This makes sure no race conditions happen +// in this scenario. +func (s) TestConcurrentRPCs(t *testing.T) { + addresses := setupBackends(t) + + mr := manual.NewBuilderWithScheme("lr-e2e") + defer mr.Close() + + // Configure least request as top level balancer of channel. + lrscJSON := ` +{ + "loadBalancingConfig": [ + { + "least_request_experimental": { + "choiceCount": 2 + } + } + ] +}` + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON) + firstTwoAddresses := []resolver.Address{ + {Addr: addresses[0]}, + {Addr: addresses[1]}, + } + mr.InitialState(resolver.State{ + Addresses: firstTwoAddresses, + ServiceConfig: sc, + }) + + cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.Dial() failed: %v", err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + testServiceClient := testgrpc.NewTestServiceClient(cc) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 5; j++ { + testServiceClient.EmptyCall(ctx, &testpb.Empty{}) + } + }() + } + wg.Wait() + +} diff --git a/balancer/leastrequest/leastrequest.go b/balancer/leastrequest/leastrequest.go index 6ef86dc267e..3289f2869f8 100644 --- a/balancer/leastrequest/leastrequest.go +++ b/balancer/leastrequest/leastrequest.go @@ -155,15 +155,14 @@ type picker struct { func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { var pickedSC *scWithRPCCount + var pickedSCNumRPCs int32 for i := 0; i < int(p.choiceCount); i++ { index := grpcranduint32() % uint32(len(p.subConns)) sc := p.subConns[index] - if pickedSC == nil { - pickedSC = &sc - continue - } - if sc.numRPCs.Load() < pickedSC.numRPCs.Load() { + n := sc.numRPCs.Load() + if pickedSC == nil || n < pickedSCNumRPCs { pickedSC = &sc + pickedSCNumRPCs = n } } // "The counter for a subchannel should be atomically incremented by one