/
stream_layer.go
98 lines (82 loc) · 1.96 KB
/
stream_layer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package log
import (
"bytes"
"crypto/tls"
"fmt"
"net"
"time"
"github.com/hashicorp/raft"
)
var _ raft.StreamLayer = (*StreamLayer)(nil)
// StreamLayer implements the raft.StreamLayer interface.
type StreamLayer struct {
ln net.Listener
serverTLSConfig *tls.Config
peerTLSConfig *tls.Config
}
// NewStreamLayer instanciates a StreamLayer, and checks
// that it satisfies raft.StreamLayer interface.
func NewStreamLayer(
ln net.Listener,
serverTLSConfig,
peerTLSConfig *tls.Config,
) *StreamLayer {
return &StreamLayer{
ln: ln,
serverTLSConfig: serverTLSConfig,
peerTLSConfig: peerTLSConfig,
}
}
const RaftRPC = 1
// Dial makes outgoing connections to other
// services in the Raft cluster.
func (s *StreamLayer) Dial(
addr raft.ServerAddress,
timeout time.Duration,
) (net.Conn, error) {
dialer := &net.Dialer{Timeout: timeout}
var conn, err = dialer.Dial("tcp", string(addr))
if err != nil {
return nil, err
}
// identify to mux this is a raft rpc
_, err = conn.Write([]byte{byte(RaftRPC)})
if err != nil {
return nil, err
}
if s.peerTLSConfig != nil {
conn = tls.Client(conn, s.peerTLSConfig)
}
return conn, err
}
// Accept is the mirror of Dial, it accepts the
// incoming connection and reads the byte that
// identifies the connection and then creates
// a server side TLS connection.
func (s *StreamLayer) Accept() (net.Conn, error) {
conn, err := s.ln.Accept()
if err != nil {
return nil, err
}
b := make([]byte, 1)
_, err = conn.Read(b)
if err != nil {
return nil, err
}
if !bytes.Equal([]byte{byte(RaftRPC)}, b) {
return nil, fmt.Errorf("not a raft rpc")
}
if s.serverTLSConfig != nil {
return tls.Server(conn, s.serverTLSConfig), nil
}
return conn, nil
}
// Close implements the closing method for
// the raft.StreamLayer interface.
func (s *StreamLayer) Close() error {
return s.ln.Close()
}
// Addr returns the listener's address.
func (s *StreamLayer) Addr() net.Addr {
return s.ln.Addr()
}