Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grpclb: Support server list expiration #962

Merged
merged 9 commits into from Nov 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 65 additions & 15 deletions grpclb/grpclb.go
Expand Up @@ -40,6 +40,7 @@ import (
"errors"
"fmt"
"sync"
"time"

"golang.org/x/net/context"
"google.golang.org/grpc"
Expand Down Expand Up @@ -93,16 +94,17 @@ type addrInfo struct {
}

type balancer struct {
r naming.Resolver
mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher
addrCh chan []grpc.Address
rbs []remoteBalancerInfo
addrs []*addrInfo
next int
waitCh chan struct{}
done bool
r naming.Resolver
mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher
addrCh chan []grpc.Address
rbs []remoteBalancerInfo
addrs []*addrInfo
next int
waitCh chan struct{}
done bool
expTimer *time.Timer
}

func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
Expand Down Expand Up @@ -180,14 +182,39 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
return nil
}

func (b *balancer) serverListExpire(seq int) {
b.mu.Lock()
defer b.mu.Unlock()
// TODO: gRPC interanls do not clear the connections when the server list is stale.
// This means RPCs will keep using the existing server list until b receives new
// server list even though the list is expired. Revisit this behavior later.
if b.done || seq < b.seq {
return
}
b.next = 0
b.addrs = nil
// Ask grpc internals to close all the corresponding connections.
b.addrCh <- nil
}

func convertDuration(d *lbpb.Duration) time.Duration {
if d == nil {
return 0
}
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}

func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
if l == nil {
return
}
servers := l.GetServers()
expiration := convertDuration(l.GetExpirationInterval())
var (
sl []*addrInfo
addrs []grpc.Address
)
for _, s := range servers {
// TODO: Support ExpirationInterval
md := metadata.Pairs("lb-token", s.LoadBalanceToken)
addr := grpc.Address{
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
Expand All @@ -209,11 +236,20 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
b.next = 0
b.addrs = sl
b.addrCh <- addrs
if b.expTimer != nil {
b.expTimer.Stop()
b.expTimer = nil
}
if expiration > 0 {
b.expTimer = time.AfterFunc(expiration, func() {
b.serverListExpire(seq)
})
}
}
return
}

func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) {
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
Expand All @@ -226,8 +262,6 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool)
b.mu.Unlock()
return
}
b.seq++
seq := b.seq
b.mu.Unlock()
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
Expand Down Expand Up @@ -260,6 +294,14 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool)
if err != nil {
break
}
b.mu.Lock()
if b.done || seq < b.seq {
b.mu.Unlock()
return
}
b.seq++ // tick when receiving a new list of servers.
seq = b.seq
b.mu.Unlock()
if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList, seq)
}
Expand Down Expand Up @@ -326,10 +368,15 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
return
}
b.mu.Lock()
b.seq++ // tick when getting a new balancer address
seq := b.seq
b.next = 0
b.mu.Unlock()
go func(cc *grpc.ClientConn) {
lbc := lbpb.NewLoadBalancerClient(cc)
for {
if retry := b.callRemoteBalancer(lbc); !retry {
if retry := b.callRemoteBalancer(lbc, seq); !retry {
cc.Close()
return
}
Expand Down Expand Up @@ -497,6 +544,9 @@ func (b *balancer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
b.done = true
if b.expTimer != nil {
b.expTimer.Stop()
}
if b.waitCh != nil {
close(b.waitCh)
}
Expand Down
124 changes: 109 additions & 15 deletions grpclb/grpclb_test.go
Expand Up @@ -162,14 +162,16 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error {
}

type remoteBalancer struct {
servers *lbpb.ServerList
done chan struct{}
sls []*lbpb.ServerList
intervals []time.Duration
done chan struct{}
}

func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer {
func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{
servers: servers,
done: make(chan struct{}),
sls: sls,
intervals: intervals,
done: make(chan struct{}),
}
}

Expand All @@ -186,13 +188,16 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer)
if err := stream.Send(resp); err != nil {
return err
}
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: b.servers,
},
}
if err := stream.Send(resp); err != nil {
return err
for k, v := range b.sls {
time.Sleep(b.intervals[k])
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: v,
},
}
if err := stream.Send(resp); err != nil {
return err
}
}
<-b.done
return nil
Expand Down Expand Up @@ -268,7 +273,9 @@ func TestGRPCLB(t *testing.T) {
sl := &lbpb.ServerList{
Servers: bes,
}
ls := newRemoteBalancer(sl)
sls := []*lbpb.ServerList{sl}
intervals := []time.Duration{0}
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand Down Expand Up @@ -343,7 +350,9 @@ func TestDropRequest(t *testing.T) {
sl := &lbpb.ServerList{
Servers: bes,
}
ls := newRemoteBalancer(sl)
sls := []*lbpb.ServerList{sl}
intervals := []time.Duration{0}
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand Down Expand Up @@ -413,7 +422,9 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
sl := &lbpb.ServerList{
Servers: bes,
}
ls := newRemoteBalancer(sl)
sls := []*lbpb.ServerList{sl}
intervals := []time.Duration{0}
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand All @@ -439,3 +450,86 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
}
cc.Close()
}

func TestServerExpiration(t *testing.T) {
// Start a backend.
beLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen %v", err)
}
beAddr := strings.Split(beLis.Addr().String(), ":")
bePort, err := strconv.Atoi(beAddr[1])
backends := startBackends(t, besn, beLis)
defer stopBackends(backends)

// Start a load balancer.
lbLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to create the listener for the load balancer %v", err)
}
lbCreds := &serverNameCheckCreds{
sn: lbsn,
}
lb := grpc.NewServer(grpc.Creds(lbCreds))
if err != nil {
t.Fatalf("Failed to generate the port number %v", err)
}
be := &lbpb.Server{
IpAddress: []byte(beAddr[0]),
Port: int32(bePort),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
exp := &lbpb.Duration{
Seconds: 0,
Nanos: 100000000, // 100ms
}
var sls []*lbpb.ServerList
sl := &lbpb.ServerList{
Servers: bes,
ExpirationInterval: exp,
}
sls = append(sls, sl)
sl = &lbpb.ServerList{
Servers: bes,
}
sls = append(sls, sl)
var intervals []time.Duration
intervals = append(intervals, 0)
intervals = append(intervals, 500*time.Millisecond)
ls := newRemoteBalancer(sls, intervals)
lbpb.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
}()
defer func() {
ls.stop()
lb.Stop()
}()
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
addr: lbLis.Addr().String(),
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
helloC := hwpb.NewGreeterClient(cc)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
}
// Sleep and wake up when the first server list gets expired.
time.Sleep(150 * time.Millisecond)
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
}
// A non-failfast rpc should be succeeded after the second server list is received from
// the remote load balancer.
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
}
cc.Close()
}