forked from hyperledger/fabric
-
Notifications
You must be signed in to change notification settings - Fork 0
/
connections.go
134 lines (111 loc) · 3.79 KB
/
connections.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
Copyright IBM Corp. 2017 All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/
package cluster
import (
"bytes"
"crypto/x509"
"sync"
"github.com/hyperledger/fabric/common/metrics"
"github.com/pkg/errors"
"google.golang.org/grpc"
)
// RemoteVerifier verifies the connection to the remote host
type RemoteVerifier func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
//go:generate mockery -dir . -name SecureDialer -case underscore -output ./mocks/
// SecureDialer connects to a remote address
type SecureDialer interface {
Dial(address string, verifyFunc RemoteVerifier) (*grpc.ClientConn, error)
}
// ConnectionMapper maps certificates to connections
type ConnectionMapper interface {
Lookup(cert []byte) (*grpc.ClientConn, bool)
Put(cert []byte, conn *grpc.ClientConn)
Remove(cert []byte)
Size() int
}
// ConnectionStore stores connections to remote nodes
type ConnectionStore struct {
lock sync.RWMutex
Connections ConnectionMapper
dialer SecureDialer
}
// NewConnectionStore creates a new ConnectionStore with the given SecureDialer
func NewConnectionStore(dialer SecureDialer, tlsConnectionCount metrics.Gauge) *ConnectionStore {
connMapping := &ConnectionStore{
Connections: &connMapperReporter{
ConnectionMapper: make(ConnByCertMap),
tlsConnectionCountMetrics: tlsConnectionCount,
},
dialer: dialer,
}
return connMapping
}
// verifyHandshake returns a predicate that verifies that the remote node authenticates
// itself with the given TLS certificate
func (c *ConnectionStore) verifyHandshake(endpoint string, certificate []byte) RemoteVerifier {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if bytes.Equal(certificate, rawCerts[0]) {
return nil
}
return errors.Errorf("certificate presented by %s doesn't match any authorized certificate", endpoint)
}
}
// Disconnect closes the gRPC connection that is mapped to the given certificate
func (c *ConnectionStore) Disconnect(expectedServerCert []byte) {
c.lock.Lock()
defer c.lock.Unlock()
conn, connected := c.Connections.Lookup(expectedServerCert)
if !connected {
return
}
conn.Close()
c.Connections.Remove(expectedServerCert)
}
// Connection obtains a connection to the given endpoint and expects the given server certificate
// to be presented by the remote node
func (c *ConnectionStore) Connection(endpoint string, expectedServerCert []byte) (*grpc.ClientConn, error) {
c.lock.RLock()
conn, alreadyConnected := c.Connections.Lookup(expectedServerCert)
c.lock.RUnlock()
if alreadyConnected {
return conn, nil
}
// Else, we need to connect to the remote endpoint
return c.connect(endpoint, expectedServerCert)
}
// connect connects to the given endpoint and expects the given TLS server certificate
// to be presented at the time of authentication
func (c *ConnectionStore) connect(endpoint string, expectedServerCert []byte) (*grpc.ClientConn, error) {
c.lock.Lock()
defer c.lock.Unlock()
// Check again to see if some other goroutine has already connected while
// we were waiting on the lock
conn, alreadyConnected := c.Connections.Lookup(expectedServerCert)
if alreadyConnected {
return conn, nil
}
v := c.verifyHandshake(endpoint, expectedServerCert)
conn, err := c.dialer.Dial(endpoint, v)
if err != nil {
return nil, err
}
c.Connections.Put(expectedServerCert, conn)
return conn, nil
}
type connMapperReporter struct {
tlsConnectionCountMetrics metrics.Gauge
ConnectionMapper
}
func (cmg *connMapperReporter) Put(cert []byte, conn *grpc.ClientConn) {
cmg.ConnectionMapper.Put(cert, conn)
cmg.reportSize()
}
func (cmg *connMapperReporter) Remove(cert []byte) {
cmg.ConnectionMapper.Remove(cert)
cmg.reportSize()
}
func (cmg *connMapperReporter) reportSize() {
cmg.tlsConnectionCountMetrics.Set(float64(cmg.ConnectionMapper.Size()))
}