-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
streamlayer.go
382 lines (318 loc) · 9.46 KB
/
streamlayer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
package raft
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
fmt "fmt"
"io"
"math/big"
mathrand "math/rand"
"net"
"net/url"
"sync"
"time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/raft"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/vault/cluster"
)
// TLSKey is a single TLS keypair in the Keyring
type TLSKey struct {
// ID is a unique identifier for this Key
ID string `json:"id"`
// KeyType defines the algorighm used to generate the private keys
KeyType string `json:"key_type"`
// AppliedIndex is the earliest known raft index that safely contains this
// key.
AppliedIndex uint64 `json:"applied_index"`
// CertBytes is the marshaled certificate.
CertBytes []byte `json:"cluster_cert"`
// KeyParams is the marshaled private key.
KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"`
// CreatedTime is the time this key was generated. This value is useful in
// determining when the next rotation should be.
CreatedTime time.Time `json:"created_time"`
parsedCert *x509.Certificate
parsedKey *ecdsa.PrivateKey
}
// TLSKeyring is the set of keys that raft uses for network communication.
// Only one key is used to dial at a time but both keys will be used to accept
// connections.
type TLSKeyring struct {
// Keys is the set of available key pairs
Keys []*TLSKey `json:"keys"`
// AppliedIndex is the earliest known raft index that safely contains the
// latest key in the keyring.
AppliedIndex uint64 `json:"applied_index"`
// Term is an incrementing identifier value used to quickly determine if two
// states of the keyring are different.
Term uint64 `json:"term"`
// ActiveKeyID is the key ID to track the active key in the keyring. Only
// the active key is used for dialing.
ActiveKeyID string `json:"active_key_id"`
}
// GetActive returns the active key.
func (k *TLSKeyring) GetActive() *TLSKey {
if k.ActiveKeyID == "" {
return nil
}
for _, key := range k.Keys {
if key.ID == k.ActiveKeyID {
return key
}
}
return nil
}
func GenerateTLSKey(reader io.Reader) (*TLSKey, error) {
key, err := ecdsa.GenerateKey(elliptic.P521(), reader)
if err != nil {
return nil, err
}
host, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
host = fmt.Sprintf("raft-%s", host)
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
// 30 years ought to be enough for anybody
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
if err != nil {
return nil, errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
}
return &TLSKey{
ID: host,
KeyType: certutil.PrivateKeyTypeP521,
CertBytes: certBytes,
KeyParams: &certutil.ClusterKeyParams{
Type: certutil.PrivateKeyTypeP521,
X: key.PublicKey.X,
Y: key.PublicKey.Y,
D: key.D,
},
CreatedTime: time.Now(),
}, nil
}
var (
// Make sure raftLayer satisfies the raft.StreamLayer interface
_ raft.StreamLayer = (*raftLayer)(nil)
// Make sure raftLayer satisfies the cluster.Handler and cluster.Client
// interfaces
_ cluster.Handler = (*raftLayer)(nil)
_ cluster.Client = (*raftLayer)(nil)
)
// RaftLayer implements the raft.StreamLayer interface,
// so that we can use a single RPC layer for Raft and Vault
type raftLayer struct {
// Addr is the listener address to return
addr net.Addr
// connCh is used to accept connections
connCh chan net.Conn
// Tracks if we are closed
closed bool
closeCh chan struct{}
closeLock sync.Mutex
logger log.Logger
dialerFunc func(string, time.Duration) (net.Conn, error)
// TLS config
keyring *TLSKeyring
clusterListener cluster.ClusterHook
}
// NewRaftLayer creates a new raftLayer object. It parses the TLS information
// from the network config.
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) {
clusterAddr := clusterListener.Addr()
if clusterAddr == nil {
return nil, errors.New("no raft addr found")
}
{
// Test the advertised address to make sure it's not an unspecified IP
u := url.URL{
Host: clusterAddr.String(),
}
ip := net.ParseIP(u.Hostname())
if ip != nil && ip.IsUnspecified() {
return nil, fmt.Errorf("cannot use unspecified IP with raft storage: %s", clusterAddr.String())
}
}
layer := &raftLayer{
addr: clusterAddr,
connCh: make(chan net.Conn),
closeCh: make(chan struct{}),
logger: logger,
clusterListener: clusterListener,
}
if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
return nil, err
}
return layer, nil
}
func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error {
// Fast path a noop update
if l.keyring != nil && l.keyring.Term == keyring.Term {
return nil
}
for _, key := range keyring.Keys {
switch {
case key.KeyParams == nil:
return errors.New("no raft cluster key params found")
case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil:
return errors.New("failed to parse raft cluster key")
case key.KeyParams.Type != certutil.PrivateKeyTypeP521:
return errors.New("failed to find valid raft cluster key type")
case len(key.CertBytes) == 0:
return errors.New("no cluster cert found")
}
parsedCert, err := x509.ParseCertificate(key.CertBytes)
if err != nil {
return errwrap.Wrapf("error parsing raft cluster certificate: {{err}}", err)
}
key.parsedCert = parsedCert
key.parsedKey = &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P521(),
X: key.KeyParams.X,
Y: key.KeyParams.Y,
},
D: key.KeyParams.D,
}
}
if keyring.GetActive() == nil {
return errors.New("expected one active key to be present in the keyring")
}
l.keyring = keyring
return nil
}
func (l *raftLayer) ServerName() string {
key := l.keyring.GetActive()
if key == nil {
return ""
}
return key.parsedCert.Subject.CommonName
}
func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate {
key := l.keyring.GetActive()
if key == nil {
return nil
}
return key.parsedCert
}
func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
for _, subj := range requestInfo.AcceptableCAs {
for _, key := range l.keyring.Keys {
if bytes.Equal(subj, key.parsedCert.RawIssuer) {
localCert := make([]byte, len(key.CertBytes))
copy(localCert, key.CertBytes)
return &tls.Certificate{
Certificate: [][]byte{localCert},
PrivateKey: key.parsedKey,
Leaf: key.parsedCert,
}, nil
}
}
}
return nil, nil
}
func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if l.keyring == nil {
return nil, errors.New("got raft connection but no local cert")
}
for _, key := range l.keyring.Keys {
if clientHello.ServerName == key.ID {
localCert := make([]byte, len(key.CertBytes))
copy(localCert, key.CertBytes)
return &tls.Certificate{
Certificate: [][]byte{localCert},
PrivateKey: key.parsedKey,
Leaf: key.parsedCert,
}, nil
}
}
return nil, nil
}
// CALookup returns the CA to use when validating this connection.
func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) {
ret := make([]*x509.Certificate, len(l.keyring.Keys))
for i, key := range l.keyring.Keys {
ret[i] = key.parsedCert
}
return ret, nil
}
// Stop shuts down the raft layer.
func (l *raftLayer) Stop() error {
l.Close()
return nil
}
// Handoff is used to hand off a connection to the
// RaftLayer. This allows it to be Accept()'ed
func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error {
l.closeLock.Lock()
closed := l.closed
l.closeLock.Unlock()
if closed {
return errors.New("raft is shutdown")
}
wg.Add(1)
go func() {
defer wg.Done()
select {
case l.connCh <- conn:
case <-l.closeCh:
case <-ctx.Done():
case <-quit:
}
}()
return nil
}
// Accept is used to return connection which are
// dialed to be used with the Raft layer
func (l *raftLayer) Accept() (net.Conn, error) {
select {
case conn := <-l.connCh:
return conn, nil
case <-l.closeCh:
return nil, fmt.Errorf("Raft RPC layer closed")
}
}
// Close is used to stop listening for Raft connections
func (l *raftLayer) Close() error {
l.closeLock.Lock()
defer l.closeLock.Unlock()
if !l.closed {
l.closed = true
close(l.closeCh)
}
return nil
}
// Addr is used to return the address of the listener
func (l *raftLayer) Addr() net.Addr {
return l.addr
}
// Dial is used to create a new outgoing connection
func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN)
return dialFunc(string(address), timeout)
}