/
transport.go
131 lines (111 loc) · 3.75 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package noise
import (
"context"
"net"
"github.com/libp2p/go-libp2p/core/canonicallog"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/sec"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
manet "github.com/multiformats/go-multiaddr/net"
)
// ID is the protocol ID for noise
const ID = "/noise"
const maxProtoNum = 100
type Transport struct {
protocolID protocol.ID
localID peer.ID
privateKey crypto.PrivKey
muxers []protocol.ID
}
var _ sec.SecureTransport = &Transport{}
// New creates a new Noise transport using the given private key as its
// libp2p identity key.
func New(id protocol.ID, privkey crypto.PrivKey, muxers []tptu.StreamMuxer) (*Transport, error) {
localID, err := peer.IDFromPrivateKey(privkey)
if err != nil {
return nil, err
}
muxerIDs := make([]protocol.ID, 0, len(muxers))
for _, m := range muxers {
muxerIDs = append(muxerIDs, m.ID)
}
return &Transport{
protocolID: id,
localID: localID,
privateKey: privkey,
muxers: muxerIDs,
}, nil
}
// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
responderEDH := newTransportEDH(t)
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, responderEDH, false, p != "")
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
canonicallog.LogPeerStatus(100, p, addr, "handshake_failure", "noise", "err", err.Error())
}
}
return SessionWithConnState(c, responderEDH.MatchMuxers(false)), err
}
// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
initiatorEDH := newTransportEDH(t)
c, err := newSecureSession(t, ctx, insecure, p, nil, initiatorEDH, nil, true, true)
if err != nil {
return c, err
}
return SessionWithConnState(c, initiatorEDH.MatchMuxers(true)), err
}
func (t *Transport) WithSessionOptions(opts ...SessionOption) (*SessionTransport, error) {
st := &SessionTransport{t: t, protocolID: t.protocolID}
for _, opt := range opts {
if err := opt(st); err != nil {
return nil, err
}
}
return st, nil
}
func (t *Transport) ID() protocol.ID {
return t.protocolID
}
func matchMuxers(initiatorMuxers, responderMuxers []protocol.ID) protocol.ID {
for _, initMuxer := range initiatorMuxers {
for _, respMuxer := range responderMuxers {
if initMuxer == respMuxer {
return initMuxer
}
}
}
return ""
}
type transportEarlyDataHandler struct {
transport *Transport
receivedMuxers []protocol.ID
}
var _ EarlyDataHandler = &transportEarlyDataHandler{}
func newTransportEDH(t *Transport) *transportEarlyDataHandler {
return &transportEarlyDataHandler{transport: t}
}
func (i *transportEarlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions {
return &pb.NoiseExtensions{
StreamMuxers: protocol.ConvertToStrings(i.transport.muxers),
}
}
func (i *transportEarlyDataHandler) Received(_ context.Context, _ net.Conn, extension *pb.NoiseExtensions) error {
// Discard messages with size or the number of protocols exceeding extension limit for security.
if extension != nil && len(extension.StreamMuxers) <= maxProtoNum {
i.receivedMuxers = protocol.ConvertFromStrings(extension.GetStreamMuxers())
}
return nil
}
func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) protocol.ID {
if isInitiator {
return matchMuxers(i.transport.muxers, i.receivedMuxers)
}
return matchMuxers(i.receivedMuxers, i.transport.muxers)
}