/
fake_grpclb.go
169 lines (162 loc) · 5.87 KB
/
fake_grpclb.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// This file is for testing only. Runs a fake grpclb balancer server.
// The name of the service to load balance for and the addresses
// of that service are provided by command line flags.
package main
import (
"flag"
"net"
"strconv"
"strings"
"time"
"google.golang.org/grpc"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/alts"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
var (
port = flag.Int("port", 10000, "Port to listen on.")
backendAddrs = flag.String("backend_addrs", "", "Comma separated list of backend IP/port addresses.")
useALTS = flag.Bool("use_alts", false, "Listen on ALTS credentials.")
useTLS = flag.Bool("use_tls", false, "Listen on TLS credentials, using a test certificate.")
shortStream = flag.Bool("short_stream", false, "End the balancer stream immediately after sending the first server list.")
serviceName = flag.String("service_name", "UNSET", "Name of the service being load balanced for.")
)
type loadBalancerServer struct {
serverListResponse *lbpb.LoadBalanceResponse
}
func (l *loadBalancerServer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) error {
grpclog.Info("Begin handling new BalancerLoad request.")
var lbReq *lbpb.LoadBalanceRequest
var err error
if lbReq, err = stream.Recv(); err != nil {
grpclog.Errorf("Error receiving LoadBalanceRequest: %v", err)
return err
}
grpclog.Info("LoadBalancerRequest received.")
initialReq := lbReq.GetInitialRequest()
if initialReq == nil {
grpclog.Info("Expected first request to be an InitialRequest. Got: %v", lbReq)
return status.Error(codes.Unknown, "First request not an InitialRequest")
}
// gRPC clients targeting foo.bar.com:443 can sometimes include the ":443" suffix in
// their requested names; handle this case. TODO: make 443 configurable?
var cleanedName string
var requestedNamePortNumber string
if cleanedName, requestedNamePortNumber, err = net.SplitHostPort(initialReq.Name); err != nil {
cleanedName = initialReq.Name
} else {
if requestedNamePortNumber != "443" {
grpclog.Info("Bad requested service name port number: %v.", requestedNamePortNumber)
return status.Error(codes.Unknown, "Bad requested service name port number")
}
}
if cleanedName != *serviceName {
grpclog.Info("Expected requested service name: %v. Got: %v", *serviceName, initialReq.Name)
return status.Error(codes.NotFound, "Bad requested service name")
}
if err := stream.Send(&lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
InitialResponse: &lbpb.InitialLoadBalanceResponse{},
},
}); err != nil {
grpclog.Errorf("Error sending initial LB response: %v", err)
return status.Error(codes.Unknown, "Error sending initial response")
}
grpclog.Info("Send LoadBalanceResponse: %v", l.serverListResponse)
if err := stream.Send(l.serverListResponse); err != nil {
grpclog.Errorf("Error sending LB response: %v", err)
return status.Error(codes.Unknown, "Error sending response")
}
if *shortStream {
return nil
}
for {
grpclog.Info("Send LoadBalanceResponse: %v", l.serverListResponse)
if err := stream.Send(l.serverListResponse); err != nil {
grpclog.Errorf("Error sending LB response: %v", err)
return status.Error(codes.Unknown, "Error sending response")
}
time.Sleep(10 * time.Second)
}
}
func main() {
flag.Parse()
var opts []grpc.ServerOption
if *useTLS {
certFile := testdata.Path("server1.pem")
keyFile := testdata.Path("server1.key")
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
opts = append(opts, grpc.Creds(creds))
} else if *useALTS {
altsOpts := alts.DefaultServerOptions()
altsTC := alts.NewServerCreds(altsOpts)
opts = append(opts, grpc.Creds(altsTC))
}
var serverList []*lbpb.Server
if len(*backendAddrs) == 0 {
serverList = make([]*lbpb.Server, 0)
} else {
rawBackendAddrs := strings.Split(*backendAddrs, ",")
serverList = make([]*lbpb.Server, len(rawBackendAddrs))
for i := range rawBackendAddrs {
rawIP, rawPort, err := net.SplitHostPort(rawBackendAddrs[i])
if err != nil {
grpclog.Fatalf("Failed to parse --backend_addrs[%d]=%v, error: %v", i, rawBackendAddrs[i], err)
}
ip := net.ParseIP(rawIP)
if ip == nil {
grpclog.Fatalf("Failed to parse ip: %v", rawIP)
}
numericPort, err := strconv.Atoi(rawPort)
if err != nil {
grpclog.Fatalf("Failed to convert port %v to int", rawPort)
}
grpclog.Infof("Adding backend ip: %v, port: %d", ip.String(), numericPort)
serverList[i] = &lbpb.Server{
IpAddress: ip,
Port: int32(numericPort),
}
}
}
serverListResponse := &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: &lbpb.ServerList{
Servers: serverList,
},
},
}
server := grpc.NewServer(opts...)
grpclog.Infof("Begin listening on %d.", *port)
lis, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
grpclog.Fatalf("Failed to listen on port %v: %v", *port, err)
}
lbpb.RegisterLoadBalancerServer(server, &loadBalancerServer{
serverListResponse: serverListResponse,
})
server.Serve(lis)
}