forked from brocaar/chirpstack-application-server
/
pool.go
128 lines (108 loc) · 3.32 KB
/
pool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package nsclient
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"sync"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/brocaar/loraserver/api/ns"
)
// Pool defines the network-server client pool.
type Pool interface {
Get(hostname string, caCert, tlsCert, tlsKey []byte) (ns.NetworkServerServiceClient, error)
}
type client struct {
client ns.NetworkServerServiceClient
clientConn *grpc.ClientConn
caCert []byte
tlsCert []byte
tlsKey []byte
}
type pool struct {
sync.RWMutex
clients map[string]client
}
// NewPool creates a Pool.
func NewPool() Pool {
return &pool{
clients: make(map[string]client),
}
}
// Get returns a NetworkServerClient for the given server (hostname:ip).
func (p *pool) Get(hostname string, caCert, tlsCert, tlsKey []byte) (ns.NetworkServerServiceClient, error) {
defer p.Unlock()
p.Lock()
var connect bool
c, ok := p.clients[hostname]
if !ok {
connect = true
}
// if the connection exists in the map, but when the certificates changed
// try to cloe the connection and re-connect
if ok && (!bytes.Equal(c.caCert, caCert) || !bytes.Equal(c.tlsCert, tlsCert) || !bytes.Equal(c.tlsKey, tlsKey)) {
c.clientConn.Close()
delete(p.clients, hostname)
connect = true
}
if connect {
clientConn, nsClient, err := p.createClient(hostname, caCert, tlsCert, tlsKey)
if err != nil {
return nil, errors.Wrap(err, "create network-server api client error")
}
c = client{
client: nsClient,
clientConn: clientConn,
caCert: caCert,
tlsCert: tlsCert,
tlsKey: tlsKey,
}
p.clients[hostname] = c
}
return c.client, nil
}
func (p *pool) createClient(hostname string, caCert, tlsCert, tlsKey []byte) (*grpc.ClientConn, ns.NetworkServerServiceClient, error) {
logrusEntry := log.NewEntry(log.StandardLogger())
logrusOpts := []grpc_logrus.Option{
grpc_logrus.WithLevels(grpc_logrus.DefaultCodeToLevel),
}
nsOpts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithUnaryInterceptor(
grpc_logrus.UnaryClientInterceptor(logrusEntry, logrusOpts...),
),
grpc.WithStreamInterceptor(
grpc_logrus.StreamClientInterceptor(logrusEntry, logrusOpts...),
),
}
if len(caCert) == 0 && len(tlsCert) == 0 && len(tlsKey) == 0 {
nsOpts = append(nsOpts, grpc.WithInsecure())
log.WithField("server", hostname).Warning("creating insecure network-server client")
} else {
log.WithField("server", hostname).Info("creating network-server client")
cert, err := tls.X509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, nil, errors.Wrap(err, "load x509 keypair error")
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, nil, errors.Wrap(err, "append ca cert to pool error")
}
nsOpts = append(nsOpts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
})))
}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
nsClient, err := grpc.DialContext(ctx, hostname, nsOpts...)
if err != nil {
return nil, nil, errors.Wrap(err, "dial network-server api error")
}
return nsClient, ns.NewNetworkServerServiceClient(nsClient), nil
}