From 3a54181e6d2647f4653ca022029f6737b09fe81d Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 20 Sep 2023 21:02:47 +0530 Subject: [PATCH 01/22] webrtcprivate: add transport --- p2p/transport/webrtc/connection.go | 7 +- p2p/transport/webrtc/listener.go | 2 +- p2p/transport/webrtc/transport.go | 2 +- p2p/transport/webrtcprivate/listener.go | 276 +++++++++++++ p2p/transport/webrtcprivate/pb/generate.go | 3 + p2p/transport/webrtcprivate/pb/msg.pb.go | 220 ++++++++++ p2p/transport/webrtcprivate/pb/msg.proto | 20 + p2p/transport/webrtcprivate/transport.go | 390 ++++++++++++++++++ p2p/transport/webrtcprivate/transport_test.go | 180 ++++++++ 9 files changed, 1095 insertions(+), 5 deletions(-) create mode 100644 p2p/transport/webrtcprivate/listener.go create mode 100644 p2p/transport/webrtcprivate/pb/generate.go create mode 100644 p2p/transport/webrtcprivate/pb/msg.pb.go create mode 100644 p2p/transport/webrtcprivate/pb/msg.proto create mode 100644 p2p/transport/webrtcprivate/transport.go create mode 100644 p2p/transport/webrtcprivate/transport_test.go diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index fd31f8351a..9ef2f03124 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -44,7 +44,7 @@ type dataChannel struct { type connection struct { pc *webrtc.PeerConnection - transport *WebRTCTransport + transport tpt.Transport scope network.ConnManagementScope closeErr error @@ -66,10 +66,11 @@ type connection struct { cancel context.CancelFunc } -func newConnection( +// NewWebRTCConnection creates a transport.CapableConn from a webrtc.PeerConnection +func NewWebRTCConnection( direction network.Direction, pc *webrtc.PeerConnection, - transport *WebRTCTransport, + transport tpt.Transport, scope network.ConnManagementScope, localPeer peer.ID, diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 0b29bf655d..4932f44a83 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -275,7 +275,7 @@ func (l *listener) setupConnection( // The connection is instantiated before performing the Noise handshake. This is // to handle the case where the remote is faster and attempts to initiate a stream // before the ondatachannel callback can be set. - conn, err := newConnection( + conn, err := NewWebRTCConnection( network.DirInbound, pc, l.transport, diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index dd4028d1f2..9174cde76d 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -392,7 +392,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement // we can only know the remote public key after the noise handshake, // but need to set up the callbacks on the peerconnection - conn, err := newConnection( + conn, err := NewWebRTCConnection( network.DirOutbound, pc, t, diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go new file mode 100644 index 0000000000..77f47eec61 --- /dev/null +++ b/p2p/transport/webrtcprivate/listener.go @@ -0,0 +1,276 @@ +package libp2pwebrtcprivate + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "time" + + "github.com/libp2p/go-libp2p/core/network" + tpt "github.com/libp2p/go-libp2p/core/transport" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "github.com/pion/webrtc/v3" +) + +type listener struct { + t *transport + webrtcConfig webrtc.Configuration + conns chan tpt.CapableConn + closeC chan struct{} +} + +var _ tpt.Listener = &listener{} + +type NetAddr struct{} + +var _ net.Addr = NetAddr{} + +func (n NetAddr) Network() string { + return "libp2p-webrtc" +} + +func (n NetAddr) String() string { + return "/webrtc" +} + +// Accept implements transport.Listener. +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case c := <-l.conns: + return c, nil + case <-l.closeC: + return nil, tpt.ErrListenerClosed + } +} + +// Addr implements transport.Listener. +func (l *listener) Addr() net.Addr { + return NetAddr{} +} + +// Close implements transport.Listener. +func (l *listener) Close() error { + l.t.RemoveListener(l) + close(l.closeC) + return nil +} + +// Multiaddr implements transport.Listener. +func (*listener) Multiaddr() ma.Multiaddr { + return ma.StringCast("/webrtc") +} + +func (l *listener) handleIncoming(s network.Stream) { + ctx, cancel := context.WithTimeout(context.Background(), streamTimeout) + defer cancel() + defer s.Close() + s.SetDeadline(time.Now().Add(streamTimeout)) + + scope, err := l.t.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) + if err != nil { + s.Reset() + log.Debug("failed to create connection scope:", err) + return + } + + settings := webrtc.SettingEngine{} + settings.DetachDataChannels() + api := webrtc.NewAPI(webrtc.WithSettingEngine(settings)) + pc, err := api.NewPeerConnection(l.webrtcConfig) + if err != nil { + s.Reset() + log.Debug("error creating a webrtc.PeerConnection:", err) + return + } + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + // register peerconnection state update callback + connectionState := make(chan webrtc.PeerConnectionState, 1) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateConnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: + // We only use the first state written to connectionState. + select { + case connectionState <- state: + default: + } + } + }) + + // register local ICE Candidate found callback + writeErr := make(chan error, 1) + pc.OnICECandidate(func(candiate *webrtc.ICECandidate) { + if candiate == nil { + return + } + b, err := json.Marshal(candiate.ToJSON()) + if err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("failed to marshal candidate to JSON: %w", err): + default: + } + return + } + data := string(b) + + msg := &pb.Message{ + Type: pb.Message_ICE_CANDIDATE.Enum(), + Data: &data, + } + if err := w.WriteMsg(msg); err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("write candidate failed: %w", err): + default: + } + } + }) + + // de-register candidate callback + defer pc.OnICECandidate(func(_ *webrtc.ICECandidate) {}) + + // read an incoming offer + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + log.Debug("failed to read offer", err) + return + } + if msg.Type == nil || *msg.Type != pb.Message_SDP_OFFER { + s.Reset() + log.Debugf("invalid message: msg.Type expected %s got %s", pb.Message_SDP_OFFER, msg.Type) + return + } + if msg.Data == nil || *msg.Data == "" { + s.Reset() + log.Debugf("invalid message: empty data") + return + } + offer := webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: *msg.Data, + } + if err := pc.SetRemoteDescription(offer); err != nil { + s.Reset() + log.Debug("failed to set remote description: %v", err) + return + } + + // send an answer + answer, err := pc.CreateAnswer(nil) + if err != nil { + s.Reset() + log.Debug("failed to create answer: %v", err) + return + } + + answerMessage := &pb.Message{ + Type: pb.Message_SDP_ANSWER.Enum(), + Data: &answer.SDP, + } + if err := w.WriteMsg(answerMessage); err != nil { + s.Reset() + log.Debug("failed to write answer:", err) + return + } + + if err := pc.SetLocalDescription(answer); err != nil { + s.Reset() + log.Debug("failed to set local description:", err) + return + } + + readErr := make(chan error, 1) + // start a goroutine to read candidates + go func() { + for { + if ctx.Err() != nil { + return + } + + var msg pb.Message + err := r.ReadMsg(&msg) + if err == io.EOF { + return + } + if err != nil { + readErr <- fmt.Errorf("failed to read candidate: %w", err) + return + } + + if msg.Type == nil || *msg.Type != pb.Message_ICE_CANDIDATE { + readErr <- fmt.Errorf("invalid message: msg.Type expected %s got %s", pb.Message_ICE_CANDIDATE, msg.Type) + return + } + // Ignore without erroring on empty message. + // Pion has a case where OnCandidate callback may be called with a nil + // candidate + if msg.Data == nil || *msg.Data == "" { + log.Debugf("received empty candidate from %s", s.Conn().RemotePeer()) + continue + } + + var init webrtc.ICECandidateInit + if err := json.Unmarshal([]byte(*msg.Data), &init); err != nil { + readErr <- fmt.Errorf("failed to unmarshal ice candidate %w", err) + return + } + if err := pc.AddICECandidate(init); err != nil { + readErr <- fmt.Errorf("failed to add ice candidate: %w", err) + return + } + } + }() + + select { + case <-ctx.Done(): + pc.Close() + s.Reset() + log.Error(ctx.Err()) + return + case err := <-writeErr: + pc.Close() + s.Reset() + log.Error(err) + return + case err := <-readErr: + pc.Close() + s.Reset() + log.Error(err) + return + case state := <-connectionState: + switch state { + default: + pc.Close() + s.Reset() + return + case webrtc.PeerConnectionStateConnected: + conn, _ := libp2pwebrtc.NewWebRTCConnection( + network.DirInbound, + pc, + l.t, + scope, + l.t.host.ID(), + ma.StringCast("/webrtc"), + s.Conn().RemotePeer(), + l.t.host.Peerstore().PubKey(s.Conn().RemotePeer()), + ma.StringCast("/webrtc"), + ) + select { + case l.conns <- conn: + default: + s.Reset() + log.Debug("incoming conn queue full: dropping conn from %s", s.Conn().RemotePeer()) + conn.Close() + } + return + } + } +} diff --git a/p2p/transport/webrtcprivate/pb/generate.go b/p2p/transport/webrtcprivate/pb/generate.go new file mode 100644 index 0000000000..657f02bd6a --- /dev/null +++ b/p2p/transport/webrtcprivate/pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc --go_out=. --go_opt=paths=source_relative -I . msg.proto diff --git a/p2p/transport/webrtcprivate/pb/msg.pb.go b/p2p/transport/webrtcprivate/pb/msg.pb.go new file mode 100644 index 0000000000..337b4d7a19 --- /dev/null +++ b/p2p/transport/webrtcprivate/pb/msg.pb.go @@ -0,0 +1,220 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: msg.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Specifies type in `data` field. +type Message_Type int32 + +const ( + // String of `RTCSessionDescription.sdp` + Message_SDP_OFFER Message_Type = 0 + // String of `RTCSessionDescription.sdp` + Message_SDP_ANSWER Message_Type = 1 + // String of `RTCIceCandidate.toJSON()` + Message_ICE_CANDIDATE Message_Type = 2 +) + +// Enum value maps for Message_Type. +var ( + Message_Type_name = map[int32]string{ + 0: "SDP_OFFER", + 1: "SDP_ANSWER", + 2: "ICE_CANDIDATE", + } + Message_Type_value = map[string]int32{ + "SDP_OFFER": 0, + "SDP_ANSWER": 1, + "ICE_CANDIDATE": 2, + } +) + +func (x Message_Type) Enum() *Message_Type { + p := new(Message_Type) + *p = x + return p +} + +func (x Message_Type) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Message_Type) Descriptor() protoreflect.EnumDescriptor { + return file_msg_proto_enumTypes[0].Descriptor() +} + +func (Message_Type) Type() protoreflect.EnumType { + return &file_msg_proto_enumTypes[0] +} + +func (x Message_Type) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Message_Type.Descriptor instead. +func (Message_Type) EnumDescriptor() ([]byte, []int) { + return file_msg_proto_rawDescGZIP(), []int{0, 0} +} + +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Type *Message_Type `protobuf:"varint,1,opt,name=type,proto3,enum=libp2pwebrtcprivate.pb.Message_Type,oneof" json:"type,omitempty"` + Data *string `protobuf:"bytes,2,opt,name=data,proto3,oneof" json:"data,omitempty"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_msg_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_msg_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_msg_proto_rawDescGZIP(), []int{0} +} + +func (x *Message) GetType() Message_Type { + if x != nil && x.Type != nil { + return *x.Type + } + return Message_SDP_OFFER +} + +func (x *Message) GetData() string { + if x != nil && x.Data != nil { + return *x.Data + } + return "" +} + +var File_msg_proto protoreflect.FileDescriptor + +var file_msg_proto_rawDesc = []byte{ + 0x0a, 0x09, 0x6d, 0x73, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x16, 0x6c, 0x69, 0x62, + 0x70, 0x32, 0x70, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, + 0x2e, 0x70, 0x62, 0x22, 0xad, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, + 0x3d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x24, 0x2e, + 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x70, 0x72, 0x69, 0x76, + 0x61, 0x74, 0x65, 0x2e, 0x70, 0x62, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x54, + 0x79, 0x70, 0x65, 0x48, 0x00, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x17, + 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x88, 0x01, 0x01, 0x22, 0x38, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x0d, 0x0a, 0x09, 0x53, 0x44, 0x50, 0x5f, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0e, + 0x0a, 0x0a, 0x53, 0x44, 0x50, 0x5f, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x11, + 0x0a, 0x0d, 0x49, 0x43, 0x45, 0x5f, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, + 0x02, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x64, + 0x61, 0x74, 0x61, 0x42, 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, + 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, + 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x2f, 0x70, + 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_msg_proto_rawDescOnce sync.Once + file_msg_proto_rawDescData = file_msg_proto_rawDesc +) + +func file_msg_proto_rawDescGZIP() []byte { + file_msg_proto_rawDescOnce.Do(func() { + file_msg_proto_rawDescData = protoimpl.X.CompressGZIP(file_msg_proto_rawDescData) + }) + return file_msg_proto_rawDescData +} + +var file_msg_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_msg_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_msg_proto_goTypes = []interface{}{ + (Message_Type)(0), // 0: libp2pwebrtcprivate.pb.Message.Type + (*Message)(nil), // 1: libp2pwebrtcprivate.pb.Message +} +var file_msg_proto_depIdxs = []int32{ + 0, // 0: libp2pwebrtcprivate.pb.Message.type:type_name -> libp2pwebrtcprivate.pb.Message.Type + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_msg_proto_init() } +func file_msg_proto_init() { + if File_msg_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_msg_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_msg_proto_msgTypes[0].OneofWrappers = []interface{}{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_msg_proto_rawDesc, + NumEnums: 1, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_msg_proto_goTypes, + DependencyIndexes: file_msg_proto_depIdxs, + EnumInfos: file_msg_proto_enumTypes, + MessageInfos: file_msg_proto_msgTypes, + }.Build() + File_msg_proto = out.File + file_msg_proto_rawDesc = nil + file_msg_proto_goTypes = nil + file_msg_proto_depIdxs = nil +} diff --git a/p2p/transport/webrtcprivate/pb/msg.proto b/p2p/transport/webrtcprivate/pb/msg.proto new file mode 100644 index 0000000000..3674833ca2 --- /dev/null +++ b/p2p/transport/webrtcprivate/pb/msg.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package libp2pwebrtcprivate.pb; + +option go_package = "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb"; + +message Message { + // Specifies type in `data` field. + enum Type { + // String of `RTCSessionDescription.sdp` + SDP_OFFER = 0; + // String of `RTCSessionDescription.sdp` + SDP_ANSWER = 1; + // String of `RTCIceCandidate.toJSON()` + ICE_CANDIDATE = 2; + } + + optional Type type = 1; + optional string data = 2; +} \ No newline at end of file diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go new file mode 100644 index 0000000000..897f8bb78f --- /dev/null +++ b/p2p/transport/webrtcprivate/transport.go @@ -0,0 +1,390 @@ +package libp2pwebrtcprivate + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" + "github.com/libp2p/go-msgio/pbio" + "github.com/pion/webrtc/v3" + + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" +) + +const ( + name = "webrtcprivate" + maxMsgSize = 4096 + streamTimeout = time.Minute + SignalingProtocol = "/webrtc-signaling" +) + +var log = logging.Logger("webrtcprivate") + +type transport struct { + host host.Host + rcmgr network.ResourceManager + webrtcConfig webrtc.Configuration + + mu sync.Mutex + l *listener +} + +var _ tpt.Transport = &transport{} + +func AddTransport(h host.Host) (*transport, error) { + n, ok := h.Network().(tpt.TransportNetwork) + if !ok { + return nil, fmt.Errorf("%v is not a transport network", h.Network()) + } + + t, err := newTransport(h) + if err != nil { + return nil, err + } + + if err := n.AddTransport(t); err != nil { + return nil, fmt.Errorf("failed to add transport to network: %w", err) + } + + if err := n.Listen(ma.StringCast("/webrtc")); err != nil { + return nil, err + } + + return t, nil +} + +func newTransport(h host.Host) (*transport, error) { + // We use elliptic P-256 since it is widely supported by browsers. + // + // Implementation note: Testing with the browser, + // it seems like Chromium only supports ECDSA P-256 or RSA key signatures in the webrtc TLS certificate. + // We tried using P-228 and P-384 which caused the DTLS handshake to fail with Illegal Parameter + // + // Please refer to this is a list of suggested algorithms for the WebCrypto API. + // The algorithm for generating a certificate for an RTCPeerConnection + // must adhere to the WebCrpyto API. From my observation, + // RSA and ECDSA P-256 is supported on almost all browsers. + // Ed25519 is not present on the list. + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate key for cert: %w", err) + } + cert, err := webrtc.GenerateCertificate(pk) + if err != nil { + return nil, fmt.Errorf("generate certificate: %w", err) + } + config := webrtc.Configuration{ + Certificates: []webrtc.Certificate{*cert}, + } + + return &transport{ + host: h, + rcmgr: h.Network().ResourceManager(), + webrtcConfig: config, + }, nil +} + +var dialMatcher = mafmt.And(mafmt.Base(ma.P_CIRCUIT), mafmt.Base(ma.P_WEBRTC)) + +// CanDial determines if we can dial to an address +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return dialMatcher.Matches(addr) +} + +// Dial implements transport.Transport. +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + relayAddr := getRelayAddr(raddr) + err := t.host.Connect(ctx, peer.AddrInfo{ID: p, Addrs: []ma.Multiaddr{relayAddr}}) + if err != nil { + return nil, fmt.Errorf("failed to open %s stream: %w", SignalingProtocol, err) + } + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) + return nil, err + } + + c, err := t.dialWithScope(ctx, p, scope) + if err != nil { + scope.Done() + log.Debug(err) + return nil, err + } + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + ctx = network.WithUseTransient(ctx, "webrtc private dial") + s, err := t.host.NewStream(ctx, p, SignalingProtocol) + if err != nil { + return nil, fmt.Errorf("error opening stream %s: %w", SignalingProtocol, err) + } + + if err := s.Scope().SetService(name); err != nil { + s.Reset() + return nil, fmt.Errorf("error attaching signaling stream to %s transport: %w", name, err) + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + return nil, fmt.Errorf("error reserving memory for signaling stream: %w", err) + } + defer s.Scope().ReleaseMemory(maxMsgSize) + defer s.Close() + + s.SetDeadline(time.Now().Add(streamTimeout)) + + pc, err := t.connect(ctx, s) + if err != nil { + s.Reset() + return nil, fmt.Errorf("error creating webrtc.PeerConnection: %w", err) + } + return libp2pwebrtc.NewWebRTCConnection( + network.DirOutbound, + pc, + t, + scope, + t.host.ID(), + ma.StringCast("/webrtc"), + p, + t.host.Network().Peerstore().PubKey(p), + ma.StringCast("/webrtc"), + ) +} + +func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { + settings := webrtc.SettingEngine{} + settings.DetachDataChannels() + api := webrtc.NewAPI(webrtc.WithSettingEngine(settings)) + pc, err := api.NewPeerConnection(t.webrtcConfig) + if err != nil { + return nil, fmt.Errorf("error creating peer connection: %w", err) + } + + // Exchange offer and answer with peer + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + // register peerconnection state update callback + connectionState := make(chan webrtc.PeerConnectionState, 1) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateConnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: + // We only use the first state written to connectionState. + select { + case connectionState <- state: + default: + } + } + }) + + // register local ICE Candidate found callback + writeErr := make(chan error, 1) + pc.OnICECandidate(func(candiate *webrtc.ICECandidate) { + // The callback can be called with a nil pointer + if candiate == nil { + return + } + b, err := json.Marshal(candiate.ToJSON()) + if err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("failed to marshal candidate to JSON: %w", err): + default: + } + return + } + data := string(b) + msg := &pb.Message{ + Type: pb.Message_ICE_CANDIDATE.Enum(), + Data: &data, + } + if err = w.WriteMsg(msg); err != nil { + // We only want to write a single error on this channel + select { + case writeErr <- fmt.Errorf("failed to write candidate: %w", err): + default: + } + } + }) + + // de-register candidate callback + defer pc.OnICECandidate(func(_ *webrtc.ICECandidate) {}) + + // We initialise a data channel otherwise the offer will have no ICE components + // https://stackoverflow.com/a/38872920/759687 + var streamID uint16 + dc, err := pc.CreateDataChannel("init", &webrtc.DataChannelInit{ID: &streamID}) + if err != nil { + return nil, fmt.Errorf("failed to create data channel: %w", err) + } + // Ensure that we close *this particular* data channel so that when the remote + // side does AcceptStream this data channel is not used for the new stream. + defer dc.Close() + + // create an offer + offer, err := pc.CreateOffer(nil) + if err != nil { + return nil, fmt.Errorf("failed to create offer: %w", err) + } + offerMessage := &pb.Message{ + Type: pb.Message_SDP_OFFER.Enum(), + Data: &offer.SDP, + } + + // send offer to peer + if err := w.WriteMsg(offerMessage); err != nil { + return nil, fmt.Errorf("failed to write to stream: %w", err) + } + if err := pc.SetLocalDescription(offer); err != nil { + return nil, fmt.Errorf("failed to set local description: %w", err) + } + + // read an incoming answer + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + return nil, fmt.Errorf("failed to read from stream: %w", err) + } + if msg.Type == nil || *msg.Type != pb.Message_SDP_ANSWER { + return nil, fmt.Errorf("invalid message: expected %s, got %s", pb.Message_SDP_ANSWER, msg.Type) + } + if msg.Data == nil || *msg.Data == "" { + return nil, fmt.Errorf("invalid message: empty answer") + } + answer := webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: *msg.Data, + } + if err := pc.SetRemoteDescription(answer); err != nil { + return nil, fmt.Errorf("failed to set remote description: %w", err) + } + + readErr := make(chan error, 1) + ctx, cancel := context.WithTimeout(ctx, streamTimeout) + defer cancel() + // start a goroutine to read candidates + go func() { + for { + if ctx.Err() != nil { + return + } + + var msg pb.Message + err := r.ReadMsg(&msg) + if err == io.EOF { + return + } + if err != nil { + readErr <- fmt.Errorf("read failed: %w", err) + return + } + if msg.Type == nil || *msg.Type != pb.Message_ICE_CANDIDATE { + readErr <- fmt.Errorf("invalid message: expected %s got %s", pb.Message_ICE_CANDIDATE, msg.Type) + return + } + // Ignore without erroring on empty message. + // Pion has a case where OnCandidate callback may be called with a nil + // candidate + if msg.Data == nil || *msg.Data == "" { + log.Debugf("received empty candidate from %s", s.Conn().RemotePeer()) + continue + } + + var init webrtc.ICECandidateInit + if err := json.Unmarshal([]byte(*msg.Data), &init); err != nil { + readErr <- fmt.Errorf("failed to unmarshal ice candidate %w", err) + return + } + if err := pc.AddICECandidate(init); err != nil { + readErr <- fmt.Errorf("failed to add ice candidate: %w", err) + return + } + } + }() + + select { + case <-ctx.Done(): + pc.Close() + return nil, ctx.Err() + case err := <-readErr: + pc.Close() + return nil, err + case state := <-connectionState: + switch state { + default: + pc.Close() + return nil, fmt.Errorf("conn establishment failed, state: %s", state) + case webrtc.PeerConnectionStateConnected: + return pc, nil + } + } +} + +// Listen implements transport.Transport. +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + if _, err := laddr.ValueForProtocol(ma.P_WEBRTC); err != nil { + return nil, fmt.Errorf("invalid listen multiaddr: %s", laddr) + } + t.mu.Lock() + defer t.mu.Unlock() + if t.l != nil { + return nil, errors.New("already listening on /webrtc") + } + + l := &listener{ + t: t, + webrtcConfig: t.webrtcConfig, + conns: make(chan tpt.CapableConn, 8), + closeC: make(chan struct{}), + } + t.l = l + t.host.SetStreamHandler(SignalingProtocol, l.handleIncoming) + return l, nil +} + +func (t *transport) RemoveListener(l *listener) { + t.mu.Lock() + defer t.mu.Unlock() + if t.l == l { + t.l = nil + t.host.RemoveStreamHandler(SignalingProtocol) + } +} + +// Protocols implements transport.Transport. +func (*transport) Protocols() []int { + return []int{ma.P_WEBRTC} +} + +// Proxy implements transport.Transport. +func (*transport) Proxy() bool { + return false +} + +// getRelayAddr removes /webrtc from addr and returns a circuit v2 only address +func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { + first, rest := ma.SplitFunc(addr, func(c ma.Component) bool { + return c.Protocol().Code == ma.P_WEBRTC + }) + // removes /webrtc prefix + _, rest = ma.SplitFirst(rest) + if rest == nil { + return first + } + return first.Encapsulate(rest) +} diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go new file mode 100644 index 0000000000..ccd9fcde1c --- /dev/null +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -0,0 +1,180 @@ +package libp2pwebrtcprivate + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// relayedHost is a webrtc enabled host with a relay reservation +type relayedHost struct { + webrtcHost + // R is the relay host + R host.Host + // Addr is the reachable /webrtc address + Addr ma.Multiaddr +} + +func (r *relayedHost) Close() { + r.R.Close() + r.webrtcHost.Close() +} + +type webrtcHost struct { + host.Host + // T is the webrtc transport used by the host + T *transport +} + +func newWebRTCHost(t *testing.T) *webrtcHost { + as := swarmt.GenSwarm(t) + a := blankhost.NewBlankHost(as) + upg := swarmt.GenUpgrader(t, as, nil) + err := client.AddTransport(a, upg) + require.NoError(t, err) + ta, err := newTransport(a) + require.NoError(t, err) + return &webrtcHost{ + Host: a, + T: ta, + } +} + +func newRelayedHost(t *testing.T) *relayedHost { + rh := blankhost.NewBlankHost(swarmt.GenSwarm(t)) + _, err := relay.New(rh) + require.NoError(t, err) + + ps := swarmt.GenSwarm(t) + p := blankhost.NewBlankHost(ps) + upg := swarmt.GenUpgrader(t, ps, nil) + client.AddTransport(p, upg) + _, err = client.Reserve(context.Background(), p, peer.AddrInfo{ID: rh.ID(), Addrs: rh.Addrs()}) + require.NoError(t, err) + tp, err := newTransport(p) + require.NoError(t, err) + return &relayedHost{ + webrtcHost: webrtcHost{ + Host: p, + T: tp, + }, + R: rh, + Addr: ma.StringCast(fmt.Sprintf("%s/p2p/%s/p2p-circuit/webrtc/", rh.Addrs()[0], rh.ID())), + } +} + +func TestSingleDial(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + defer b.Close() + defer a.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ca, err := a.T.Dial(ctx, b.Addr, b.ID()) + require.NoError(t, err) + + cb, err := l.Accept() + require.NoError(t, err) + sa, err := ca.OpenStream(ctx) + require.NoError(t, err) + sb, err := cb.AcceptStream() + require.NoError(t, err) + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sb.Read(recv) + require.NoError(t, err) + require.Equal(t, "hello world", string(recv[:n])) +} + +func TestMultipleDials(t *testing.T) { + a := newWebRTCHost(t) + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + b := newRelayedHost(t) + defer b.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + ca, err := a.T.Dial(ctx, b.Addr, b.ID()) + assert.NoError(t, err) + + cb, err := l.Accept() + assert.NoError(t, err) + + sa, err := ca.OpenStream(ctx) + assert.NoError(t, err) + sb, err := cb.AcceptStream() + assert.NoError(t, err) + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sb.Read(recv) + assert.NoError(t, err) + assert.Equal(t, "hello world", string(recv[:n])) + wg.Done() + }() + } + wg.Wait() +} + +func TestMultipleDialsAndListeners(t *testing.T) { + var hosts []*webrtcHost + for i := 0; i < 5; i++ { + hosts = append(hosts, newWebRTCHost(t)) + } + var wg sync.WaitGroup + + for i := 0; i < 5; i++ { + for j := 0; j < 5; j++ { + wg.Add(1) + go func(j int) { + b := newRelayedHost(t) + defer b.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + ca, err := hosts[j].T.Dial(ctx, b.Addr, b.ID()) + assert.NoError(t, err) + + cb, err := l.Accept() + assert.NoError(t, err) + + sa, err := ca.OpenStream(ctx) + assert.NoError(t, err) + sb, err := cb.AcceptStream() + assert.NoError(t, err) + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sb.Read(recv) + assert.NoError(t, err) + assert.Equal(t, "hello world", string(recv[:n])) + wg.Done() + }(j) + } + } + wg.Wait() +} From 8f84284d2ca4963a928dbf4d435353786973819d Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 21 Sep 2023 15:41:50 +0530 Subject: [PATCH 02/22] webrtcprivate: fix deadline, limit inflight connection requests interim commit --- p2p/transport/webrtcprivate/listener.go | 55 ++- p2p/transport/webrtcprivate/transport.go | 123 +++-- p2p/transport/webrtcprivate/transport_test.go | 423 ++++++++++++++++-- 3 files changed, 498 insertions(+), 103 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 77f47eec61..26d5fbe6c1 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -18,10 +18,10 @@ import ( ) type listener struct { - t *transport - webrtcConfig webrtc.Configuration - conns chan tpt.CapableConn - closeC chan struct{} + transport *transport + connQueue chan tpt.CapableConn + closeC chan struct{} + inflightQueue chan struct{} } var _ tpt.Listener = &listener{} @@ -41,7 +41,7 @@ func (n NetAddr) String() string { // Accept implements transport.Listener. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case c := <-l.conns: + case c := <-l.connQueue: return c, nil case <-l.closeC: return nil, tpt.ErrListenerClosed @@ -55,7 +55,7 @@ func (l *listener) Addr() net.Addr { // Close implements transport.Listener. func (l *listener) Close() error { - l.t.RemoveListener(l) + l.transport.RemoveListener(l) close(l.closeC) return nil } @@ -66,22 +66,28 @@ func (*listener) Multiaddr() ma.Multiaddr { } func (l *listener) handleIncoming(s network.Stream) { - ctx, cancel := context.WithTimeout(context.Background(), streamTimeout) + select { + case l.inflightQueue <- struct{}{}: + defer func() { <-l.inflightQueue }() + case <-l.closeC: + s.Reset() + return + } + + ctx, cancel := context.WithTimeout(context.Background(), connectTimeout) defer cancel() defer s.Close() - s.SetDeadline(time.Now().Add(streamTimeout)) - scope, err := l.t.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) + s.SetDeadline(time.Now().Add(connectTimeout)) + + scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) if err != nil { s.Reset() log.Debug("failed to create connection scope:", err) return } - settings := webrtc.SettingEngine{} - settings.DetachDataChannels() - api := webrtc.NewAPI(webrtc.WithSettingEngine(settings)) - pc, err := api.NewPeerConnection(l.webrtcConfig) + pc, err := l.transport.NewPeerConnection() if err != nil { s.Reset() log.Debug("error creating a webrtc.PeerConnection:", err) @@ -209,7 +215,7 @@ func (l *listener) handleIncoming(s network.Stream) { readErr <- fmt.Errorf("invalid message: msg.Type expected %s got %s", pb.Message_ICE_CANDIDATE, msg.Type) return } - // Ignore without erroring on empty message. + // Ignore without Debuging on empty message. // Pion has a case where OnCandidate callback may be called with a nil // candidate if msg.Data == nil || *msg.Data == "" { @@ -233,42 +239,45 @@ func (l *listener) handleIncoming(s network.Stream) { case <-ctx.Done(): pc.Close() s.Reset() - log.Error(ctx.Err()) + log.Debug(ctx.Err()) return case err := <-writeErr: pc.Close() s.Reset() - log.Error(err) + log.Debug(err) return case err := <-readErr: pc.Close() s.Reset() - log.Error(err) + log.Debug(err) return case state := <-connectionState: switch state { default: pc.Close() s.Reset() + log.Debugf("connection setup failed, got state: %s", state) return case webrtc.PeerConnectionStateConnected: conn, _ := libp2pwebrtc.NewWebRTCConnection( network.DirInbound, pc, - l.t, + l.transport, scope, - l.t.host.ID(), + l.transport.host.ID(), ma.StringCast("/webrtc"), s.Conn().RemotePeer(), - l.t.host.Peerstore().PubKey(s.Conn().RemotePeer()), + l.transport.host.Peerstore().PubKey(s.Conn().RemotePeer()), ma.StringCast("/webrtc"), ) + // Close the stream before we wait for the connection to be accepted + s.Close() select { - case l.conns <- conn: - default: + case l.connQueue <- conn: + case <-l.closeC: s.Reset() - log.Debug("incoming conn queue full: dropping conn from %s", s.Conn().RemotePeer()) conn.Close() + log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) } return } diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index 897f8bb78f..d206dbadcc 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -15,33 +15,43 @@ import ( logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" + pionlogger "github.com/pion/logging" + "github.com/libp2p/go-libp2p/core/peer" tpt "github.com/libp2p/go-libp2p/core/transport" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" "github.com/libp2p/go-msgio/pbio" "github.com/pion/webrtc/v3" + "go.uber.org/zap/zapcore" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" ) const ( - name = "webrtcprivate" - maxMsgSize = 4096 - streamTimeout = time.Minute - SignalingProtocol = "/webrtc-signaling" + name = "webrtcprivate" + maxMsgSize = 4096 + connectTimeout = time.Minute + SignalingProtocol = "/webrtc-signaling" + disconnectedTimeout = 20 * time.Second + failedTimeout = 30 * time.Second + keepaliveTimeout = 15 * time.Second ) -var log = logging.Logger("webrtcprivate") +var ( + log = logging.Logger("webrtcprivate") + WebRTCAddr = ma.StringCast("/webrtc") +) type transport struct { - host host.Host - rcmgr network.ResourceManager - webrtcConfig webrtc.Configuration + host host.Host + rcmgr network.ResourceManager + webrtcConfig webrtc.Configuration + maxInFlightConnections int - mu sync.Mutex - l *listener + mu sync.Mutex + listener *listener } var _ tpt.Transport = &transport{} @@ -93,9 +103,10 @@ func newTransport(h host.Host) (*transport, error) { } return &transport{ - host: h, - rcmgr: h.Network().ResourceManager(), - webrtcConfig: config, + host: h, + rcmgr: h.Network().ResourceManager(), + webrtcConfig: config, + maxInFlightConnections: 16, }, nil } @@ -108,16 +119,21 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { // Dial implements transport.Transport. func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + // Connect to the peer on the circuit address relayAddr := getRelayAddr(raddr) err := t.host.Connect(ctx, peer.AddrInfo{ID: p, Addrs: []ma.Multiaddr{relayAddr}}) if err != nil { return nil, fmt.Errorf("failed to open %s stream: %w", SignalingProtocol, err) } - scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, true, raddr) if err != nil { log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) return nil, err } + if err := scope.SetPeer(p); err != nil { + return nil, err + } c, err := t.dialWithScope(ctx, p, scope) if err != nil { @@ -129,7 +145,8 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp } func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { - ctx = network.WithUseTransient(ctx, "webrtc private dial") + // Start signaling protocol stream + ctx = network.WithUseTransient(ctx, "webrtcprivate dial") s, err := t.host.NewStream(ctx, p, SignalingProtocol) if err != nil { return nil, fmt.Errorf("error opening stream %s: %w", SignalingProtocol, err) @@ -140,19 +157,23 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. return nil, fmt.Errorf("error attaching signaling stream to %s transport: %w", name, err) } - if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + if err := s.Scope().ReserveMemory(2*maxMsgSize, network.ReservationPriorityAlways); err != nil { s.Reset() return nil, fmt.Errorf("error reserving memory for signaling stream: %w", err) } defer s.Scope().ReleaseMemory(maxMsgSize) defer s.Close() - s.SetDeadline(time.Now().Add(streamTimeout)) + deadline := time.Now().Add(connectTimeout) + if d, ok := ctx.Deadline(); ok && d.Before(deadline) { + deadline = d + } + s.SetDeadline(deadline) - pc, err := t.connect(ctx, s) + pc, err := t.establishPeerConnection(ctx, s) if err != nil { s.Reset() - return nil, fmt.Errorf("error creating webrtc.PeerConnection: %w", err) + return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) } return libp2pwebrtc.NewWebRTCConnection( network.DirOutbound, @@ -167,16 +188,11 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. ) } -func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { - settings := webrtc.SettingEngine{} - settings.DetachDataChannels() - api := webrtc.NewAPI(webrtc.WithSettingEngine(settings)) - pc, err := api.NewPeerConnection(t.webrtcConfig) +func (t *transport) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { + pc, err := t.NewPeerConnection() if err != nil { - return nil, fmt.Errorf("error creating peer connection: %w", err) + return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err) } - - // Exchange offer and answer with peer r := pbio.NewDelimitedReader(s, maxMsgSize) w := pbio.NewDelimitedWriter(s) @@ -210,11 +226,11 @@ func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.Peer return } data := string(b) - msg := &pb.Message{ + msg := pb.Message{ Type: pb.Message_ICE_CANDIDATE.Enum(), Data: &data, } - if err = w.WriteMsg(msg); err != nil { + if err = w.WriteMsg(&msg); err != nil { // We only want to write a single error on this channel select { case writeErr <- fmt.Errorf("failed to write candidate: %w", err): @@ -228,12 +244,11 @@ func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.Peer // We initialise a data channel otherwise the offer will have no ICE components // https://stackoverflow.com/a/38872920/759687 - var streamID uint16 - dc, err := pc.CreateDataChannel("init", &webrtc.DataChannelInit{ID: &streamID}) + dc, err := pc.CreateDataChannel("init", nil) if err != nil { return nil, fmt.Errorf("failed to create data channel: %w", err) } - // Ensure that we close *this particular* data channel so that when the remote + // Ensure that we close this data channel so that when the remote // side does AcceptStream this data channel is not used for the new stream. defer dc.Close() @@ -275,7 +290,7 @@ func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.Peer } readErr := make(chan error, 1) - ctx, cancel := context.WithTimeout(ctx, streamTimeout) + ctx, cancel := context.WithTimeout(ctx, connectTimeout) defer cancel() // start a goroutine to read candidates go func() { @@ -284,7 +299,6 @@ func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.Peer return } - var msg pb.Message err := r.ReadMsg(&msg) if err == io.EOF { return @@ -324,6 +338,9 @@ func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.Peer case err := <-readErr: pc.Close() return nil, err + case err := <-writeErr: + pc.Close() + return nil, err case state := <-connectionState: switch state { default: @@ -342,17 +359,17 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { } t.mu.Lock() defer t.mu.Unlock() - if t.l != nil { + if t.listener != nil { return nil, errors.New("already listening on /webrtc") } l := &listener{ - t: t, - webrtcConfig: t.webrtcConfig, - conns: make(chan tpt.CapableConn, 8), - closeC: make(chan struct{}), + transport: t, + connQueue: make(chan tpt.CapableConn), + inflightQueue: make(chan struct{}, t.maxInFlightConnections), + closeC: make(chan struct{}), } - t.l = l + t.listener = l t.host.SetStreamHandler(SignalingProtocol, l.handleIncoming) return l, nil } @@ -360,8 +377,8 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { func (t *transport) RemoveListener(l *listener) { t.mu.Lock() defer t.mu.Unlock() - if t.l == l { - t.l = nil + if t.listener == l { + t.listener = nil t.host.RemoveStreamHandler(SignalingProtocol) } } @@ -376,6 +393,28 @@ func (*transport) Proxy() bool { return false } +func (t *transport) NewPeerConnection() (*webrtc.PeerConnection, error) { + loggerFactory := pionlogger.NewDefaultLoggerFactory() + logLevel := pionlogger.LogLevelDisabled + switch log.Level() { + case zapcore.DebugLevel: + logLevel = pionlogger.LogLevelDebug + case zapcore.InfoLevel: + logLevel = pionlogger.LogLevelInfo + case zapcore.WarnLevel: + logLevel = pionlogger.LogLevelWarn + case zapcore.ErrorLevel: + logLevel = pionlogger.LogLevelError + } + loggerFactory.DefaultLogLevel = logLevel + s := webrtc.SettingEngine{LoggerFactory: loggerFactory} + s.SetICETimeouts(disconnectedTimeout, failedTimeout, keepaliveTimeout) + s.DetachDataChannels() + s.SetIncludeLoopbackCandidate(true) + api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) + return api.NewPeerConnection(t.webrtcConfig) +} + // getRelayAddr removes /webrtc from addr and returns a circuit v2 only address func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { first, rest := ma.SplitFunc(addr, func(c ma.Component) bool { diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index ccd9fcde1c..efdf491f0e 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -3,7 +3,9 @@ package libp2pwebrtcprivate import ( "context" "fmt" + "os" "sync" + "sync/atomic" "testing" "time" @@ -54,7 +56,9 @@ func newWebRTCHost(t *testing.T) *webrtcHost { func newRelayedHost(t *testing.T) *relayedHost { rh := blankhost.NewBlankHost(swarmt.GenSwarm(t)) - _, err := relay.New(rh) + rr := relay.DefaultResources() + rr.MaxCircuits = 100 + _, err := relay.New(rh, relay.WithResources(rr)) require.NoError(t, err) ps := swarmt.GenSwarm(t) @@ -100,6 +104,9 @@ func TestSingleDial(t *testing.T) { n, err := sb.Read(recv) require.NoError(t, err) require.Equal(t, "hello world", string(recv[:n])) + + ca.Close() + cb.Close() } func TestMultipleDials(t *testing.T) { @@ -113,25 +120,39 @@ func TestMultipleDials(t *testing.T) { defer b.Close() l, err := b.T.Listen(ma.StringCast("/webrtc")) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() ca, err := a.T.Dial(ctx, b.Addr, b.ID()) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } cb, err := l.Accept() - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } sa, err := ca.OpenStream(ctx) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } sb, err := cb.AcceptStream() - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } sa.Write([]byte("hello world")) recv := make([]byte, 24) n, err := sb.Read(recv) - assert.NoError(t, err) - assert.Equal(t, "hello world", string(recv[:n])) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } wg.Done() }() } @@ -139,42 +160,368 @@ func TestMultipleDials(t *testing.T) { } func TestMultipleDialsAndListeners(t *testing.T) { - var hosts []*webrtcHost - for i := 0; i < 5; i++ { - hosts = append(hosts, newWebRTCHost(t)) + var dialHosts []*webrtcHost + const N = 5 + for i := 0; i < N; i++ { + dialHosts = append(dialHosts, newWebRTCHost(t)) + defer dialHosts[i].Close() + } + + var listenHosts []*relayedHost + for i := 0; i < N; i++ { + listenHosts = append(listenHosts, newRelayedHost(t)) + l, err := listenHosts[i].T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + defer listenHosts[i].Close() + defer l.Close() } var wg sync.WaitGroup - for i := 0; i < 5; i++ { - for j := 0; j < 5; j++ { + dialAndPing := func(h *webrtcHost, raddr ma.Multiaddr, p peer.ID) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + ca, err := h.T.Dial(ctx, raddr, p) + if !assert.NoError(t, err) { + return + } + defer ca.Close() + sa, err := ca.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + defer sa.Close() + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sa.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + } + + acceptAndPong := func(r *relayedHost) { + cb, err := r.T.listener.Accept() + if !assert.NoError(t, err) { + return + } + + sb, err := cb.AcceptStream() + if !assert.NoError(t, err) { + return + } + defer sb.Close() + + recv := make([]byte, 24) + n, err := sb.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + sb.Write(recv[:n]) + } + + for i := 0; i < N; i++ { + for j := 0; j < N; j++ { wg.Add(1) - go func(j int) { - b := newRelayedHost(t) - defer b.Close() - - l, err := b.T.Listen(ma.StringCast("/webrtc")) - assert.NoError(t, err) - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - ca, err := hosts[j].T.Dial(ctx, b.Addr, b.ID()) - assert.NoError(t, err) - - cb, err := l.Accept() - assert.NoError(t, err) - - sa, err := ca.OpenStream(ctx) - assert.NoError(t, err) - sb, err := cb.AcceptStream() - assert.NoError(t, err) - sa.Write([]byte("hello world")) - recv := make([]byte, 24) - n, err := sb.Read(recv) - assert.NoError(t, err) - assert.Equal(t, "hello world", string(recv[:n])) + go func(i, j int) { + go dialAndPing(dialHosts[i], listenHosts[j].Addr, listenHosts[j].ID()) + acceptAndPong(listenHosts[j]) wg.Done() - }(j) + }(i, j) } } wg.Wait() } + +func TestDialerCanCreateStreams(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + listener, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + aC := make(chan bool) + go func() { + defer close(aC) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := a.T.Dial(ctx, b.Addr, b.ID()) + if !assert.NoError(t, err) { + return + } + s, err := conn.AcceptStream() + if !assert.NoError(t, err) { + return + } + recv := make([]byte, 24) + n, err := s.Read(recv) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(recv[:n]) + if !assert.NoError(t, err) { + return + } + s.Close() + }() + + bC := make(chan bool) + go func() { + defer close(bC) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + s, err := conn.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + defer s.Close() + + _, err = s.Write([]byte("hello world")) + if !assert.NoError(t, err) { + return + } + + recv := make([]byte, 24) + n, err := s.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + }() + + select { + case <-aC: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + select { + case <-bC: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } +} + +func TestDialerCanCreateStreamsMultiple(t *testing.T) { + count := 5 + a := newWebRTCHost(t) + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, a.ID(), lconn.RemotePeer()) { + return + } + var wg sync.WaitGroup + + for i := 0; i < count; i++ { + stream, err := lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + n, err := stream.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "test", string(buf[:n])) { + return + } + _, err = stream.Write([]byte("test")) + if !assert.NoError(t, err) { + return + } + }() + } + + wg.Wait() + done <- struct{}{} + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + + for i := 0; i < count; i++ { + idx := i + go func() { + stream, err := conn.OpenStream(context.Background()) + if !assert.NoError(t, err) { + return + } + t.Logf("dialer opened stream: %d", idx) + buf := make([]byte, 100) + _, err = stream.Write([]byte("test")) + if !assert.NoError(t, err) { + return + } + n, err := stream.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "test", string(buf[:n])) { + return + } + }() + } + select { + case <-done: + case <-time.After(20 * time.Second): + t.Fatal("timed out") + } +} + +func TestMaxInflightQueue(t *testing.T) { + b := newRelayedHost(t) + defer b.Close() + count := 3 + b.T.maxInFlightConnections = count + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + defer listener.Close() + + var success, failure atomic.Int32 + var wg sync.WaitGroup + for i := 0; i < count+1; i++ { + wg.Add(1) + go func() { + a := newWebRTCHost(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := a.T.Dial(ctx, b.Addr, b.ID()) + if err == nil { + success.Add(1) + } else { + failure.Add(1) + } + wg.Done() + }() + } + wg.Wait() + require.Equal(t, 1, int(failure.Load())) + require.Equal(t, count, int(success.Load())) +} + +func TestRemoteReadsAfterClose(t *testing.T) { + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + + a := newWebRTCHost(t) + + done := make(chan error) + go func() { + lconn, err := listener.Accept() + if err != nil { + done <- err + return + } + stream, err := lconn.AcceptStream() + if err != nil { + done <- err + return + } + _, err = stream.Write([]byte{1, 2, 3, 4}) + if err != nil { + done <- err + return + } + err = stream.Close() + if err != nil { + done <- err + return + } + close(done) + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + // create a stream + stream, err := conn.OpenStream(context.Background()) + + require.NoError(t, err) + // require write and close to complete + require.NoError(t, <-done) + + stream.SetReadDeadline(time.Now().Add(5 * time.Second)) + + buf := make([]byte, 10) + n, err := stream.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 4) +} + +func TestStreamDeadline(t *testing.T) { + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + a := newWebRTCHost(t) + + t.Run("SetReadDeadline", func(t *testing.T) { + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + _, err = lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + + // deadline set to the past + stream.SetReadDeadline(time.Now().Add(-200 * time.Millisecond)) + _, err = stream.Read([]byte{0, 0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + + // future deadline exceeded + stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _, err = stream.Read([]byte{0, 0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + _, err = lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + + stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + time.Sleep(201 * time.Millisecond) + largeBuffer := make([]byte, 2*1024*1024) + _, err = stream.Write(largeBuffer) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) +} From 4f3c12e322f79bb34a11c99a03e7a575fb8ac75b Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 22 Sep 2023 13:20:19 +0530 Subject: [PATCH 03/22] webrtcprivate: integrate connection gater --- p2p/transport/webrtcprivate/listener.go | 58 +++++++++++++++++++----- p2p/transport/webrtcprivate/transport.go | 53 ++++++++++++++++++++-- 2 files changed, 96 insertions(+), 15 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 26d5fbe6c1..a9217f624f 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -9,11 +9,13 @@ import ( "time" "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" tpt "github.com/libp2p/go-libp2p/core/transport" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" "github.com/libp2p/go-msgio/pbio" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/pion/webrtc/v3" ) @@ -259,17 +261,13 @@ func (l *listener) handleIncoming(s network.Stream) { log.Debugf("connection setup failed, got state: %s", state) return case webrtc.PeerConnectionStateConnected: - conn, _ := libp2pwebrtc.NewWebRTCConnection( - network.DirInbound, - pc, - l.transport, - scope, - l.transport.host.ID(), - ma.StringCast("/webrtc"), - s.Conn().RemotePeer(), - l.transport.host.Peerstore().PubKey(s.Conn().RemotePeer()), - ma.StringCast("/webrtc"), - ) + conn, err := l.setupConnection(pc, scope, s.Conn().RemotePeer()) + if err != nil { + pc.Close() + s.Reset() + log.Debug("connection setup with %s failed: %w", s.Conn().RemotePeer(), err) + return + } // Close the stream before we wait for the connection to be accepted s.Close() select { @@ -283,3 +281,41 @@ func (l *listener) handleIncoming(s network.Stream) { } } } + +func (l *listener) setupConnection(pc *webrtc.PeerConnection, scope network.ConnManagementScope, p peer.ID) (tpt.CapableConn, error) { + cp, err := getSelectedCandidate(pc) + if cp == nil || err != nil { + return nil, fmt.Errorf("failed to get selected candidate address, got: %s: %w", cp, err) + } + localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) + if err != nil { + return nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) + } + localAddr = localAddr.Encapsulate(WebRTCAddr) + + remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) + if err != nil { + return nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) + } + remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) + + conn, err := libp2pwebrtc.NewWebRTCConnection( + network.DirInbound, + pc, + l.transport, + scope, + l.transport.host.ID(), + localAddr, + p, + l.transport.host.Peerstore().PubKey(p), // we have the public key from the relayed connection + remoteAddr, + ) + if err != nil { + return nil, fmt.Errorf("failed to create tranport.CapableConn: %w", err) + } + if l.transport.gater != nil && l.transport.gater.InterceptSecured(network.DirOutbound, p, conn) { + conn.Close() + return nil, fmt.Errorf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) + } + return conn, nil +} diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index d206dbadcc..6dbdc3f5fd 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -9,10 +9,12 @@ import ( "errors" "fmt" "io" + "net" "sync" "time" logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" pionlogger "github.com/pion/logging" @@ -27,6 +29,7 @@ import ( ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" + manet "github.com/multiformats/go-multiaddr/net" ) const ( @@ -48,6 +51,7 @@ type transport struct { host host.Host rcmgr network.ResourceManager webrtcConfig webrtc.Configuration + gater connmgr.ConnectionGater maxInFlightConnections int mu sync.Mutex @@ -175,17 +179,45 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. s.Reset() return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) } - return libp2pwebrtc.NewWebRTCConnection( + + cp, err := getSelectedCandidate(pc) + if cp == nil || err != nil { + s.Reset() + return nil, fmt.Errorf("failed to get selected candidate address, got: %s: %w", cp, err) + } + localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) + if err != nil { + return nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) + } + localAddr = localAddr.Encapsulate(WebRTCAddr) + + remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) + if err != nil { + return nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) + } + remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) + + conn, err := libp2pwebrtc.NewWebRTCConnection( network.DirOutbound, pc, t, scope, t.host.ID(), - ma.StringCast("/webrtc"), + localAddr, p, - t.host.Network().Peerstore().PubKey(p), - ma.StringCast("/webrtc"), + t.host.Network().Peerstore().PubKey(p), // we have the pubkey from the relayed connection + remoteAddr, ) + if err != nil { + pc.Close() + return nil, fmt.Errorf("failed to create transport.CapableConn: %w", err) + } + + if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, conn) { + conn.Close() + return nil, fmt.Errorf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) + } + return conn, nil } func (t *transport) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { @@ -427,3 +459,16 @@ func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { } return first.Encapsulate(rest) } + +func getSelectedCandidate(pc *webrtc.PeerConnection) (*webrtc.ICECandidatePair, error) { + if pc.SCTP() == nil { + return nil, errors.New("no sctp transport") + } + if pc.SCTP().Transport() == nil { + return nil, errors.New("no dtls transport") + } + if pc.SCTP().Transport().ICETransport() == nil { + return nil, errors.New("no ice transport") + } + return pc.SCTP().Transport().ICETransport().GetSelectedCandidatePair() +} From 5ab60e27c0f47725d6032b8bb26d76d935e15113 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 22 Sep 2023 13:42:09 +0530 Subject: [PATCH 04/22] webrtcprivate: factor establishConnection out in listener --- p2p/transport/webrtcprivate/listener.go | 100 +++++++++++------------- 1 file changed, 46 insertions(+), 54 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index a9217f624f..33c207aebd 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -3,6 +3,7 @@ package libp2pwebrtcprivate import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -82,19 +83,45 @@ func (l *listener) handleIncoming(s network.Stream) { s.SetDeadline(time.Now().Add(connectTimeout)) - scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) + scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) // we don't have a better remote adress right now if err != nil { s.Reset() log.Debug("failed to create connection scope:", err) return } - pc, err := l.transport.NewPeerConnection() + pc, err := l.establishPeerConnection(ctx, s) + if err != nil { + s.Reset() + log.Debug("failed to establish connection with %s: %s", s.Conn().RemotePeer(), err) + return + } + + conn, err := l.setupConnection(pc, scope, s.Conn().RemotePeer()) if err != nil { + pc.Close() s.Reset() - log.Debug("error creating a webrtc.PeerConnection:", err) + log.Debug("connection setup with %s failed: %w", s.Conn().RemotePeer(), err) return } + // Close the stream before we wait for the connection to be accepted + s.Close() + select { + case l.connQueue <- conn: + case <-l.closeC: + conn.Close() + log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) + } + return +} + +func (l *listener) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { + pc, err := l.transport.NewPeerConnection() + if err != nil { + err = fmt.Errorf("error creating a webrtc.PeerConnection: %w", err) + log.Debug(err) + return nil, err + } r := pbio.NewDelimitedReader(s, maxMsgSize) w := pbio.NewDelimitedWriter(s) @@ -147,36 +174,30 @@ func (l *listener) handleIncoming(s network.Stream) { // read an incoming offer var msg pb.Message if err := r.ReadMsg(&msg); err != nil { - s.Reset() - log.Debug("failed to read offer", err) - return + err = fmt.Errorf("failed to read offer: %w", err) + return nil, err } if msg.Type == nil || *msg.Type != pb.Message_SDP_OFFER { - s.Reset() - log.Debugf("invalid message: msg.Type expected %s got %s", pb.Message_SDP_OFFER, msg.Type) - return + err = fmt.Errorf("invalid message: msg.Type expected %s got %s", pb.Message_SDP_OFFER, msg.Type) + return nil, err } if msg.Data == nil || *msg.Data == "" { - s.Reset() - log.Debugf("invalid message: empty data") - return + err = errors.New("invalid message: empty data") + return nil, err } offer := webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: *msg.Data, } if err := pc.SetRemoteDescription(offer); err != nil { - s.Reset() - log.Debug("failed to set remote description: %v", err) - return + err = fmt.Errorf("failed to set remote description: %w", err) + return nil, err } // send an answer answer, err := pc.CreateAnswer(nil) if err != nil { - s.Reset() - log.Debug("failed to create answer: %v", err) - return + return nil, fmt.Errorf("failed to create answer: %w", err) } answerMessage := &pb.Message{ @@ -184,15 +205,11 @@ func (l *listener) handleIncoming(s network.Stream) { Data: &answer.SDP, } if err := w.WriteMsg(answerMessage); err != nil { - s.Reset() - log.Debug("failed to write answer:", err) - return + return nil, fmt.Errorf("failed to write answer: %w", err) } if err := pc.SetLocalDescription(answer); err != nil { - s.Reset() - log.Debug("failed to set local description:", err) - return + return nil, fmt.Errorf("failed to set local description: %w", err) } readErr := make(chan error, 1) @@ -203,7 +220,6 @@ func (l *listener) handleIncoming(s network.Stream) { return } - var msg pb.Message err := r.ReadMsg(&msg) if err == io.EOF { return @@ -240,44 +256,20 @@ func (l *listener) handleIncoming(s network.Stream) { select { case <-ctx.Done(): pc.Close() - s.Reset() - log.Debug(ctx.Err()) - return + return nil, ctx.Err() case err := <-writeErr: pc.Close() - s.Reset() - log.Debug(err) - return + return nil, fmt.Errorf("error writing candidate: %w", err) case err := <-readErr: pc.Close() - s.Reset() - log.Debug(err) - return + return nil, fmt.Errorf("error reading candidate: %w", err) case state := <-connectionState: switch state { default: pc.Close() - s.Reset() - log.Debugf("connection setup failed, got state: %s", state) - return + return nil, fmt.Errorf("failed to establish webrtc.PeerConnection, state: %s", state) case webrtc.PeerConnectionStateConnected: - conn, err := l.setupConnection(pc, scope, s.Conn().RemotePeer()) - if err != nil { - pc.Close() - s.Reset() - log.Debug("connection setup with %s failed: %w", s.Conn().RemotePeer(), err) - return - } - // Close the stream before we wait for the connection to be accepted - s.Close() - select { - case l.connQueue <- conn: - case <-l.closeC: - s.Reset() - conn.Close() - log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) - } - return + return pc, nil } } } From 10624d12d329811cb903a8f597ce0497d92f7e7e Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 22 Sep 2023 20:16:27 +0530 Subject: [PATCH 05/22] webrtcprivate: setup addresses on connection properly --- p2p/transport/webrtcprivate/listener.go | 17 +------ p2p/transport/webrtcprivate/transport.go | 44 +++++++++++-------- p2p/transport/webrtcprivate/transport_test.go | 33 ++++++++++++++ 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 33c207aebd..e074e56b23 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -16,7 +16,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" "github.com/libp2p/go-msgio/pbio" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" "github.com/pion/webrtc/v3" ) @@ -112,7 +111,6 @@ func (l *listener) handleIncoming(s network.Stream) { conn.Close() log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) } - return } func (l *listener) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { @@ -275,21 +273,10 @@ func (l *listener) establishPeerConnection(ctx context.Context, s network.Stream } func (l *listener) setupConnection(pc *webrtc.PeerConnection, scope network.ConnManagementScope, p peer.ID) (tpt.CapableConn, error) { - cp, err := getSelectedCandidate(pc) - if cp == nil || err != nil { - return nil, fmt.Errorf("failed to get selected candidate address, got: %s: %w", cp, err) - } - localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) - if err != nil { - return nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) - } - localAddr = localAddr.Encapsulate(WebRTCAddr) - - remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) + localAddr, remoteAddr, err := getConnectionAddresses(pc) if err != nil { - return nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) + return nil, fmt.Errorf("failed to get connection addresses: %w", err) } - remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) conn, err := libp2pwebrtc.NewWebRTCConnection( network.DirInbound, diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index 6dbdc3f5fd..cb18188192 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -180,22 +180,11 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) } - cp, err := getSelectedCandidate(pc) - if cp == nil || err != nil { - s.Reset() - return nil, fmt.Errorf("failed to get selected candidate address, got: %s: %w", cp, err) - } - localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) - if err != nil { - return nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) - } - localAddr = localAddr.Encapsulate(WebRTCAddr) - - remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) + localAddr, remoteAddr, err := getConnectionAddresses(pc) if err != nil { - return nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) + s.Reset() + return nil, fmt.Errorf("failed to get connection addresses: %w", err) } - remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) conn, err := libp2pwebrtc.NewWebRTCConnection( network.DirOutbound, @@ -460,15 +449,32 @@ func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { return first.Encapsulate(rest) } -func getSelectedCandidate(pc *webrtc.PeerConnection) (*webrtc.ICECandidatePair, error) { +func getConnectionAddresses(pc *webrtc.PeerConnection) (ma.Multiaddr, ma.Multiaddr, error) { if pc.SCTP() == nil { - return nil, errors.New("no sctp transport") + return nil, nil, errors.New("no sctp transport") } if pc.SCTP().Transport() == nil { - return nil, errors.New("no dtls transport") + return nil, nil, errors.New("no dtls transport") } if pc.SCTP().Transport().ICETransport() == nil { - return nil, errors.New("no ice transport") + return nil, nil, errors.New("no ice transport") + } + cp, err := pc.SCTP().Transport().ICETransport().GetSelectedCandidatePair() + if cp == nil || err != nil { + return nil, nil, fmt.Errorf("invalid candidate pair %s: %w", cp, err) + } + + localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) + if err != nil { + return nil, nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) + } + localAddr = localAddr.Encapsulate(WebRTCAddr) + + remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) + if err != nil { + return nil, nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) } - return pc.SCTP().Transport().ICETransport().GetSelectedCandidatePair() + remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) + + return localAddr, remoteAddr, nil } diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index efdf491f0e..fdf982561b 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -109,6 +109,39 @@ func TestSingleDial(t *testing.T) { cb.Close() } +func TestConnectionAddresses(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + defer b.Close() + defer a.Close() + + l, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ca, err := a.T.Dial(ctx, b.Addr, b.ID()) + require.NoError(t, err) + + cb, err := l.Accept() + require.NoError(t, err) + + // Test connection addresses + require.Equal(t, cb.RemoteMultiaddr(), ca.LocalMultiaddr()) + require.Equal(t, cb.LocalMultiaddr(), ca.RemoteMultiaddr()) + + testAddr := func(addr ma.Multiaddr) { + _, err := addr.ValueForProtocol(ma.P_UDP) + require.NoError(t, err) + _, err = addr.ValueForProtocol(ma.P_WEBRTC) + require.NoError(t, err) + } + testAddr(ca.LocalMultiaddr()) + testAddr(ca.RemoteMultiaddr()) + testAddr(cb.LocalMultiaddr()) + testAddr(cb.RemoteMultiaddr()) +} + func TestMultipleDials(t *testing.T) { a := newWebRTCHost(t) var wg sync.WaitGroup From 0b20ffe52b36b5894f2be5a3439a7d81d05a3a8d Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 22 Sep 2023 22:58:16 +0530 Subject: [PATCH 06/22] webrtcprivate: set negotiated=true for the init data channel --- p2p/transport/webrtcprivate/transport.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index cb18188192..ce9cbea26a 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -265,12 +265,16 @@ func (t *transport) establishPeerConnection(ctx context.Context, s network.Strea // We initialise a data channel otherwise the offer will have no ICE components // https://stackoverflow.com/a/38872920/759687 - dc, err := pc.CreateDataChannel("init", nil) + // We use out-of-band negotiation(negotiated=true), to ensure that this channel doesn't + // get accepted as a stream in AcceptStream on the remote side + negotiated := true + // Any value here is fine since this will be closed on connection establishment. We use 0 since + // it is in line with the handshake channel used in /webrtc-direct stream + var initStreamID uint16 + dc, err := pc.CreateDataChannel("init", &webrtc.DataChannelInit{Negotiated: &negotiated, ID: &initStreamID}) if err != nil { return nil, fmt.Errorf("failed to create data channel: %w", err) } - // Ensure that we close this data channel so that when the remote - // side does AcceptStream this data channel is not used for the new stream. defer dc.Close() // create an offer From 072b38a8967bb49f6faae13e31e83caf5d770344 Mon Sep 17 00:00:00 2001 From: sukun Date: Sat, 23 Sep 2023 17:57:35 +0530 Subject: [PATCH 07/22] webrtcprivate: don't test for address equality When running with docker, the listener might see a different remote port than the port the dialer observes locally and vice versa --- p2p/transport/webrtcprivate/listener.go | 30 ++++----- p2p/transport/webrtcprivate/transport.go | 63 ++++++++++--------- p2p/transport/webrtcprivate/transport_test.go | 4 -- 3 files changed, 43 insertions(+), 54 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index e074e56b23..086fd59803 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -10,7 +10,6 @@ import ( "time" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" tpt "github.com/libp2p/go-libp2p/core/transport" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" @@ -89,19 +88,17 @@ func (l *listener) handleIncoming(s network.Stream) { return } - pc, err := l.establishPeerConnection(ctx, s) + conn, err := l.setupConnection(ctx, s, scope) if err != nil { s.Reset() + scope.Done() log.Debug("failed to establish connection with %s: %s", s.Conn().RemotePeer(), err) return } - conn, err := l.setupConnection(pc, scope, s.Conn().RemotePeer()) - if err != nil { - pc.Close() - s.Reset() - log.Debug("connection setup with %s failed: %w", s.Conn().RemotePeer(), err) - return + if l.transport.gater != nil && l.transport.gater.InterceptSecured(network.DirOutbound, s.Conn().RemotePeer(), conn) { + conn.Close() + log.Debugf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) } // Close the stream before we wait for the connection to be accepted s.Close() @@ -113,7 +110,7 @@ func (l *listener) handleIncoming(s network.Stream) { } } -func (l *listener) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { +func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { pc, err := l.transport.NewPeerConnection() if err != nil { err = fmt.Errorf("error creating a webrtc.PeerConnection: %w", err) @@ -267,14 +264,12 @@ func (l *listener) establishPeerConnection(ctx context.Context, s network.Stream pc.Close() return nil, fmt.Errorf("failed to establish webrtc.PeerConnection, state: %s", state) case webrtc.PeerConnectionStateConnected: - return pc, nil } } -} -func (l *listener) setupConnection(pc *webrtc.PeerConnection, scope network.ConnManagementScope, p peer.ID) (tpt.CapableConn, error) { localAddr, remoteAddr, err := getConnectionAddresses(pc) if err != nil { + pc.Close() return nil, fmt.Errorf("failed to get connection addresses: %w", err) } @@ -285,16 +280,13 @@ func (l *listener) setupConnection(pc *webrtc.PeerConnection, scope network.Conn scope, l.transport.host.ID(), localAddr, - p, - l.transport.host.Peerstore().PubKey(p), // we have the public key from the relayed connection + s.Conn().RemotePeer(), + l.transport.host.Peerstore().PubKey(s.Conn().RemotePeer()), // we have the public key from the relayed connection remoteAddr, ) if err != nil { - return nil, fmt.Errorf("failed to create tranport.CapableConn: %w", err) - } - if l.transport.gater != nil && l.transport.gater.InterceptSecured(network.DirOutbound, p, conn) { - conn.Close() - return nil, fmt.Errorf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) + pc.Close() + return nil, fmt.Errorf("error establishing tpt.CapableConn: %w", err) } return conn, nil } diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index ce9cbea26a..f670ec1f05 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -40,6 +40,7 @@ const ( disconnectedTimeout = 20 * time.Second failedTimeout = 30 * time.Second keepaliveTimeout = 15 * time.Second + maxAcceptQueueLen = 10 ) var ( @@ -174,48 +175,28 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. } s.SetDeadline(deadline) - pc, err := t.establishPeerConnection(ctx, s) + conn, err := t.setupConnection(ctx, s, scope) if err != nil { s.Reset() return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) } - localAddr, remoteAddr, err := getConnectionAddresses(pc) - if err != nil { - s.Reset() - return nil, fmt.Errorf("failed to get connection addresses: %w", err) - } - - conn, err := libp2pwebrtc.NewWebRTCConnection( - network.DirOutbound, - pc, - t, - scope, - t.host.ID(), - localAddr, - p, - t.host.Network().Peerstore().PubKey(p), // we have the pubkey from the relayed connection - remoteAddr, - ) - if err != nil { - pc.Close() - return nil, fmt.Errorf("failed to create transport.CapableConn: %w", err) - } - if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, conn) { conn.Close() + s.Reset() return nil, fmt.Errorf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) } return conn, nil } -func (t *transport) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { +func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + pc, err := t.NewPeerConnection() if err != nil { return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err) } - r := pbio.NewDelimitedReader(s, maxMsgSize) - w := pbio.NewDelimitedWriter(s) // register peerconnection state update callback connectionState := make(chan webrtc.PeerConnectionState, 1) @@ -268,8 +249,8 @@ func (t *transport) establishPeerConnection(ctx context.Context, s network.Strea // We use out-of-band negotiation(negotiated=true), to ensure that this channel doesn't // get accepted as a stream in AcceptStream on the remote side negotiated := true - // Any value here is fine since this will be closed on connection establishment. We use 0 since - // it is in line with the handshake channel used in /webrtc-direct stream + // Any value here is fine since this will be closed on connection establishment. We use 0 as + // it is also used for the /webrtc-direct handshake channel var initStreamID uint16 dc, err := pc.CreateDataChannel("init", &webrtc.DataChannelInit{Negotiated: &negotiated, ID: &initStreamID}) if err != nil { @@ -286,7 +267,6 @@ func (t *transport) establishPeerConnection(ctx context.Context, s network.Strea Type: pb.Message_SDP_OFFER.Enum(), Data: &offer.SDP, } - // send offer to peer if err := w.WriteMsg(offerMessage); err != nil { return nil, fmt.Errorf("failed to write to stream: %w", err) @@ -372,9 +352,30 @@ func (t *transport) establishPeerConnection(ctx context.Context, s network.Strea pc.Close() return nil, fmt.Errorf("conn establishment failed, state: %s", state) case webrtc.PeerConnectionStateConnected: - return pc, nil } } + localAddr, remoteAddr, err := getConnectionAddresses(pc) + if err != nil { + pc.Close() + return nil, fmt.Errorf("failed to get connection addresses: %w", err) + } + + conn, err := libp2pwebrtc.NewWebRTCConnection( + network.DirOutbound, + pc, + t, + scope, + t.host.ID(), + localAddr, + s.Conn().RemotePeer(), + t.host.Network().Peerstore().PubKey(s.Conn().RemotePeer()), // we have the pubkey from the relayed connection + remoteAddr, + ) + if err != nil { + pc.Close() + return nil, fmt.Errorf("failed to create tpt.CapableConn: %w", err) + } + return conn, nil } // Listen implements transport.Transport. @@ -445,7 +446,7 @@ func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { first, rest := ma.SplitFunc(addr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBRTC }) - // removes /webrtc prefix + // remove /webrtc prefix _, rest = ma.SplitFirst(rest) if rest == nil { return first diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index fdf982561b..a175cc3cc1 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -126,10 +126,6 @@ func TestConnectionAddresses(t *testing.T) { cb, err := l.Accept() require.NoError(t, err) - // Test connection addresses - require.Equal(t, cb.RemoteMultiaddr(), ca.LocalMultiaddr()) - require.Equal(t, cb.LocalMultiaddr(), ca.RemoteMultiaddr()) - testAddr := func(addr ma.Multiaddr) { _, err := addr.ValueForProtocol(ma.P_UDP) require.NoError(t, err) From 8bf790f336fa72481f73e43df326181e648a4138 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 00:54:05 +0530 Subject: [PATCH 08/22] webrtcprivate: setup incoming data channel handlers early --- p2p/transport/webrtc/connection.go | 83 ++++++++++++++---------- p2p/transport/webrtc/listener.go | 2 + p2p/transport/webrtc/transport.go | 2 + p2p/transport/webrtcprivate/listener.go | 3 + p2p/transport/webrtcprivate/transport.go | 3 + 5 files changed, 59 insertions(+), 34 deletions(-) diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 9ef2f03124..a52d184aab 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -37,7 +37,7 @@ func (errConnectionTimeout) Error() string { return "connection timeout" } func (errConnectionTimeout) Timeout() bool { return true } func (errConnectionTimeout) Temporary() bool { return false } -type dataChannel struct { +type DetachedDataChannel struct { stream datachannel.ReadWriteCloser channel *webrtc.DataChannel } @@ -60,10 +60,9 @@ type connection struct { streams map[uint16]*stream nextStreamID atomic.Int32 - acceptQueue chan dataChannel - - ctx context.Context - cancel context.CancelFunc + acceptQueue chan DetachedDataChannel + ctx context.Context + cancel context.CancelFunc } // NewWebRTCConnection creates a transport.CapableConn from a webrtc.PeerConnection @@ -79,6 +78,7 @@ func NewWebRTCConnection( remotePeer peer.ID, remoteKey ic.PubKey, remoteMultiaddr ma.Multiaddr, + datachannelQueue chan DetachedDataChannel, ) (*connection, error) { ctx, cancel := context.WithCancel(context.Background()) c := &connection{ @@ -95,8 +95,7 @@ func NewWebRTCConnection( ctx: ctx, cancel: cancel, streams: make(map[uint16]*stream), - - acceptQueue: make(chan dataChannel, maxAcceptQueueLen), + acceptQueue: datachannelQueue, } switch direction { case network.DirInbound: @@ -107,33 +106,14 @@ func NewWebRTCConnection( } pc.OnConnectionStateChange(c.onConnectionStateChange) - pc.OnDataChannel(func(dc *webrtc.DataChannel) { - if c.IsClosed() { - return - } - // Limit the number of streams, since we're not able to actually properly close them. - // See https://github.com/libp2p/specs/issues/575 for details. - if *dc.ID() > maxDataChannelID { - c.Close() - return - } - dc.OnOpen(func() { - rwc, err := dc.Detach() - if err != nil { - log.Warnf("could not detach datachannel: id: %d", *dc.ID()) - return - } - select { - case c.acceptQueue <- dataChannel{rwc, dc}: - default: - log.Warnf("connection busy, rejecting stream") - b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()}) - w := msgio.NewWriter(rwc) - w.WriteMsg(b) - rwc.Close() - } - }) - }) + + // Between the connection establishing and the callback update in the above line, the + // connection may have been closed + state := pc.ConnectionState() + if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + pc.Close() + return nil, errors.New("connection closed") + } return c, nil } @@ -316,3 +296,38 @@ func (c *connection) setRemotePeer(id peer.ID) { func (c *connection) setRemotePublicKey(key ic.PubKey) { c.remoteKey = key } + +// SetupDataChannelQueue sets callback on the peer connection to push incoming +// data channels on to the returned queue after detaching the data channel. +// +// We need to ensure that the data channel is enqueued from the onOpen callback +// to avoid a race condition in pion: https://github.com/pion/webrtc/issues/2586 +func SetupDataChannelQueue(pc *webrtc.PeerConnection, queueLen int) chan DetachedDataChannel { + queue := make(chan DetachedDataChannel, queueLen) + pc.OnDataChannel(func(dc *webrtc.DataChannel) { + // Limit the number of streams, since we're not able to actually properly close them. + // See https://github.com/libp2p/specs/issues/575 for details. + if *dc.ID() > maxDataChannelID { + dc.Close() + return + } + dc.OnOpen(func() { + rwc, err := dc.Detach() + if err != nil { + log.Warnf("could not detach datachannel: id: %d", *dc.ID()) + return + } + select { + case queue <- DetachedDataChannel{rwc, dc}: + default: + log.Warnf("connection busy, rejecting stream") + b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()}) + w := msgio.NewWriter(rwc) + w.WriteMsg(b) + rwc.Close() + } + }) + + }) + return queue +} diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 4932f44a83..bf039a03db 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -229,6 +229,7 @@ func (l *listener) setupConnection( if err != nil { return nil, err } + dataChannelQueue := SetupDataChannelQueue(pc, maxAcceptQueueLen) negotiated, id := handshakeChannelNegotiated, handshakeChannelID rawDatachannel, err := pc.CreateDataChannel("", &webrtc.DataChannelInit{ @@ -285,6 +286,7 @@ func (l *listener) setupConnection( "", // remotePeer nil, // remoteKey remoteMultiaddr, + dataChannelQueue, ) if err != nil { return nil, err diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index 9174cde76d..8b2a722087 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -321,6 +321,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, fmt.Errorf("instantiate peerconnection: %w", err) } + dataChannelQueue := SetupDataChannelQueue(pc, maxAcceptQueueLen) errC := addOnConnectionStateChangeCallback(pc) // We need to set negotiated = true for this channel on both @@ -402,6 +403,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement p, nil, remoteMultiaddrWithoutCerthash, + dataChannelQueue, ) if err != nil { return nil, err diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 086fd59803..3e461ce658 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -117,6 +117,8 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope log.Debug(err) return nil, err } + dataChannelQueue := libp2pwebrtc.SetupDataChannelQueue(pc, maxAcceptQueueLen) + r := pbio.NewDelimitedReader(s, maxMsgSize) w := pbio.NewDelimitedWriter(s) @@ -283,6 +285,7 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope s.Conn().RemotePeer(), l.transport.host.Peerstore().PubKey(s.Conn().RemotePeer()), // we have the public key from the relayed connection remoteAddr, + dataChannelQueue, ) if err != nil { pc.Close() diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index f670ec1f05..cd1aa664ef 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -198,6 +198,8 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err) } + dataChannelQueue := libp2pwebrtc.SetupDataChannelQueue(pc, maxAcceptQueueLen) + // register peerconnection state update callback connectionState := make(chan webrtc.PeerConnectionState, 1) pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { @@ -370,6 +372,7 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope s.Conn().RemotePeer(), t.host.Network().Peerstore().PubKey(s.Conn().RemotePeer()), // we have the pubkey from the relayed connection remoteAddr, + dataChannelQueue, ) if err != nil { pc.Close() From 4622ab20782b857d667b1bf0e3e63983dcf5916f Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 10:41:30 +0530 Subject: [PATCH 09/22] webrtcprivate: fix comments --- p2p/transport/webrtcprivate/listener.go | 6 ++---- p2p/transport/webrtcprivate/transport.go | 9 +++------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 3e461ce658..a08473b126 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -49,24 +49,22 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } } -// Addr implements transport.Listener. +// Addr implements transport.Listener. The returned address always returns libp2p-webrtc:/webrtc func (l *listener) Addr() net.Addr { return NetAddr{} } -// Close implements transport.Listener. func (l *listener) Close() error { l.transport.RemoveListener(l) close(l.closeC) return nil } -// Multiaddr implements transport.Listener. func (*listener) Multiaddr() ma.Multiaddr { return ma.StringCast("/webrtc") } -func (l *listener) handleIncoming(s network.Stream) { +func (l *listener) handleSignalingStream(s network.Stream) { select { case l.inflightQueue <- struct{}{}: defer func() { <-l.inflightQueue }() diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index cd1aa664ef..4248af8854 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -122,7 +122,6 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { return dialMatcher.Matches(addr) } -// Dial implements transport.Transport. func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { // Connect to the peer on the circuit address relayAddr := getRelayAddr(raddr) @@ -381,7 +380,6 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope return conn, nil } -// Listen implements transport.Transport. func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if _, err := laddr.ValueForProtocol(ma.P_WEBRTC); err != nil { return nil, fmt.Errorf("invalid listen multiaddr: %s", laddr) @@ -399,7 +397,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { closeC: make(chan struct{}), } t.listener = l - t.host.SetStreamHandler(SignalingProtocol, l.handleIncoming) + t.host.SetStreamHandler(SignalingProtocol, l.handleSignalingStream) return l, nil } @@ -412,12 +410,10 @@ func (t *transport) RemoveListener(l *listener) { } } -// Protocols implements transport.Transport. func (*transport) Protocols() []int { return []int{ma.P_WEBRTC} } -// Proxy implements transport.Transport. func (*transport) Proxy() bool { return false } @@ -457,7 +453,8 @@ func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { return first.Encapsulate(rest) } -func getConnectionAddresses(pc *webrtc.PeerConnection) (ma.Multiaddr, ma.Multiaddr, error) { +// getConnectionAddresses provides multiaddresses on the two sides of the connection pc +func getConnectionAddresses(pc *webrtc.PeerConnection) (local ma.Multiaddr, remote ma.Multiaddr, err error) { if pc.SCTP() == nil { return nil, nil, errors.New("no sctp transport") } From 144828645c0ef11465da9bb7db02e471b0acd89a Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 10:49:28 +0530 Subject: [PATCH 10/22] webrtcprivate: add rcmgr reservation on listen side --- p2p/transport/webrtcprivate/listener.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index a08473b126..1771de9d90 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -77,8 +77,6 @@ func (l *listener) handleSignalingStream(s network.Stream) { defer cancel() defer s.Close() - s.SetDeadline(time.Now().Add(connectTimeout)) - scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) // we don't have a better remote adress right now if err != nil { s.Reset() @@ -86,6 +84,21 @@ func (l *listener) handleSignalingStream(s network.Stream) { return } + if err := s.Scope().SetService(name); err != nil { + log.Debugf("error attaching stream to /webrtc listener: %s", err) + s.Reset() + return + } + + if err := s.Scope().ReserveMemory(2*maxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("error reserving memory for /webrtc signaling stream: %s", err) + s.Reset() + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(connectTimeout)) + conn, err := l.setupConnection(ctx, s, scope) if err != nil { s.Reset() From 6e464e5e27c55b7e29686535b9c8297aca8ff6dd Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 11:00:22 +0530 Subject: [PATCH 11/22] webrtcprivate: setup conn.ConnState correctly --- p2p/transport/webrtc/connection.go | 19 +++++++++--- p2p/transport/webrtcprivate/transport_test.go | 31 ++++++++++++------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index a52d184aab..4cf853b08d 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -56,6 +56,8 @@ type connection struct { remoteKey ic.PubKey remoteMultiaddr ma.Multiaddr + connectionState network.ConnectionState + m sync.Mutex streams map[uint16]*stream nextStreamID atomic.Int32 @@ -81,6 +83,10 @@ func NewWebRTCConnection( datachannelQueue chan DetachedDataChannel, ) (*connection, error) { ctx, cancel := context.WithCancel(context.Background()) + connectionState := network.ConnectionState{Transport: "webrtc"} + if _, ok := transport.(*WebRTCTransport); ok { + connectionState = network.ConnectionState{Transport: "webrtc-direct"} + } c := &connection{ pc: pc, transport: transport, @@ -92,10 +98,13 @@ func NewWebRTCConnection( remotePeer: remotePeer, remoteKey: remoteKey, remoteMultiaddr: remoteMultiaddr, - ctx: ctx, - cancel: cancel, - streams: make(map[uint16]*stream), - acceptQueue: datachannelQueue, + + connectionState: connectionState, + + ctx: ctx, + cancel: cancel, + streams: make(map[uint16]*stream), + acceptQueue: datachannelQueue, } switch direction { case network.DirInbound: @@ -120,7 +129,7 @@ func NewWebRTCConnection( // ConnState implements transport.CapableConn func (c *connection) ConnState() network.ConnectionState { - return network.ConnectionState{Transport: "webrtc-direct"} + return c.connectionState } // Close closes the underlying peerconnection. diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index a175cc3cc1..74c5f5eadb 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" @@ -109,7 +110,7 @@ func TestSingleDial(t *testing.T) { cb.Close() } -func TestConnectionAddresses(t *testing.T) { +func TestConnectionProperties(t *testing.T) { a := newWebRTCHost(t) b := newRelayedHost(t) defer b.Close() @@ -126,16 +127,24 @@ func TestConnectionAddresses(t *testing.T) { cb, err := l.Accept() require.NoError(t, err) - testAddr := func(addr ma.Multiaddr) { - _, err := addr.ValueForProtocol(ma.P_UDP) - require.NoError(t, err) - _, err = addr.ValueForProtocol(ma.P_WEBRTC) - require.NoError(t, err) - } - testAddr(ca.LocalMultiaddr()) - testAddr(ca.RemoteMultiaddr()) - testAddr(cb.LocalMultiaddr()) - testAddr(cb.RemoteMultiaddr()) + t.Run("Addresses", func(t *testing.T) { + testAddr := func(addr ma.Multiaddr) { + _, err := addr.ValueForProtocol(ma.P_UDP) + require.NoError(t, err) + _, err = addr.ValueForProtocol(ma.P_WEBRTC) + require.NoError(t, err) + } + testAddr(ca.LocalMultiaddr()) + testAddr(ca.RemoteMultiaddr()) + testAddr(cb.LocalMultiaddr()) + testAddr(cb.RemoteMultiaddr()) + }) + + t.Run("ConnectionState", func(t *testing.T) { + require.Equal(t, network.ConnectionState{Transport: "webrtc"}, ca.ConnState()) + require.Equal(t, network.ConnectionState{Transport: "webrtc"}, cb.ConnState()) + }) + } func TestMultipleDials(t *testing.T) { From 8baee51e61e955ac9fb1db144e41bb3d5bc99f2a Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 11:49:47 +0530 Subject: [PATCH 12/22] webrtcprivate: reuse protobuf msg structs --- p2p/transport/webrtcprivate/listener.go | 19 +++++++------------ p2p/transport/webrtcprivate/transport.go | 11 +++++------ 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 1771de9d90..855c67b3a6 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -148,11 +148,11 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope // register local ICE Candidate found callback writeErr := make(chan error, 1) - pc.OnICECandidate(func(candiate *webrtc.ICECandidate) { - if candiate == nil { + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { return } - b, err := json.Marshal(candiate.ToJSON()) + b, err := json.Marshal(candidate.ToJSON()) if err != nil { // We only want to write a single error on this channel select { @@ -193,10 +193,7 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope err = errors.New("invalid message: empty data") return nil, err } - offer := webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: *msg.Data, - } + offer := webrtc.SessionDescription{Type: webrtc.SDPTypeOffer, SDP: *msg.Data} if err := pc.SetRemoteDescription(offer); err != nil { err = fmt.Errorf("failed to set remote description: %w", err) return nil, err @@ -207,15 +204,13 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope if err != nil { return nil, fmt.Errorf("failed to create answer: %w", err) } - - answerMessage := &pb.Message{ + msg = pb.Message{ Type: pb.Message_SDP_ANSWER.Enum(), Data: &answer.SDP, } - if err := w.WriteMsg(answerMessage); err != nil { + if err := w.WriteMsg(&msg); err != nil { return nil, fmt.Errorf("failed to write answer: %w", err) } - if err := pc.SetLocalDescription(answer); err != nil { return nil, fmt.Errorf("failed to set local description: %w", err) } @@ -227,9 +222,9 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope if ctx.Err() != nil { return } - err := r.ReadMsg(&msg) if err == io.EOF { + // remote has done writing return } if err != nil { diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index 4248af8854..b5fafe1cc5 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -214,12 +214,12 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope // register local ICE Candidate found callback writeErr := make(chan error, 1) - pc.OnICECandidate(func(candiate *webrtc.ICECandidate) { + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { // The callback can be called with a nil pointer - if candiate == nil { + if candidate == nil { return } - b, err := json.Marshal(candiate.ToJSON()) + b, err := json.Marshal(candidate.ToJSON()) if err != nil { // We only want to write a single error on this channel select { @@ -264,12 +264,12 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope if err != nil { return nil, fmt.Errorf("failed to create offer: %w", err) } - offerMessage := &pb.Message{ + msg := pb.Message{ Type: pb.Message_SDP_OFFER.Enum(), Data: &offer.SDP, } // send offer to peer - if err := w.WriteMsg(offerMessage); err != nil { + if err := w.WriteMsg(&msg); err != nil { return nil, fmt.Errorf("failed to write to stream: %w", err) } if err := pc.SetLocalDescription(offer); err != nil { @@ -277,7 +277,6 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope } // read an incoming answer - var msg pb.Message if err := r.ReadMsg(&msg); err != nil { return nil, fmt.Errorf("failed to read from stream: %w", err) } From 255fd3eb0e5caf312b0cb879e4435a7018f2556b Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 15:24:38 +0530 Subject: [PATCH 13/22] webrtcprivate: use context for closing listener --- p2p/transport/webrtcprivate/listener.go | 14 ++-- p2p/transport/webrtcprivate/transport.go | 24 ++++-- p2p/transport/webrtcprivate/transport_test.go | 73 +++++++++++++++++++ 3 files changed, 100 insertions(+), 11 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 855c67b3a6..09a4855aab 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -21,8 +21,9 @@ import ( type listener struct { transport *transport connQueue chan tpt.CapableConn - closeC chan struct{} inflightQueue chan struct{} + ctx context.Context + cancel context.CancelFunc } var _ tpt.Listener = &listener{} @@ -41,10 +42,13 @@ func (n NetAddr) String() string { // Accept implements transport.Listener. func (l *listener) Accept() (tpt.CapableConn, error) { + if l.ctx.Err() != nil { + return nil, tpt.ErrListenerClosed + } select { case c := <-l.connQueue: return c, nil - case <-l.closeC: + case <-l.ctx.Done(): return nil, tpt.ErrListenerClosed } } @@ -56,7 +60,7 @@ func (l *listener) Addr() net.Addr { func (l *listener) Close() error { l.transport.RemoveListener(l) - close(l.closeC) + l.cancel() return nil } @@ -68,7 +72,7 @@ func (l *listener) handleSignalingStream(s network.Stream) { select { case l.inflightQueue <- struct{}{}: defer func() { <-l.inflightQueue }() - case <-l.closeC: + case <-l.ctx.Done(): s.Reset() return } @@ -115,7 +119,7 @@ func (l *listener) handleSignalingStream(s network.Stream) { s.Close() select { case l.connQueue <- conn: - case <-l.closeC: + case <-l.ctx.Done(): conn.Close() log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) } diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index b5fafe1cc5..c945f20c5e 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -28,7 +28,6 @@ import ( "go.uber.org/zap/zapcore" ma "github.com/multiformats/go-multiaddr" - mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" ) @@ -115,11 +114,23 @@ func newTransport(h host.Host) (*transport, error) { }, nil } -var dialMatcher = mafmt.And(mafmt.Base(ma.P_CIRCUIT), mafmt.Base(ma.P_WEBRTC)) - // CanDial determines if we can dial to an address func (t *transport) CanDial(addr ma.Multiaddr) bool { - return dialMatcher.Matches(addr) + circuit := false + webrtc := false + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CIRCUIT { + circuit = true + return true + } + // next element after p2p-circuit should be webrtc + if circuit { + webrtc = c.Protocol().Code == ma.P_WEBRTC + return false + } + return true + }) + return circuit && webrtc } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { @@ -388,12 +399,13 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if t.listener != nil { return nil, errors.New("already listening on /webrtc") } - + ctx, cancel := context.WithCancel(context.Background()) l := &listener{ transport: t, connQueue: make(chan tpt.CapableConn), inflightQueue: make(chan struct{}, t.maxInFlightConnections), - closeC: make(chan struct{}), + ctx: ctx, + cancel: cancel, } t.listener = l t.host.SetStreamHandler(SignalingProtocol, l.handleSignalingStream) diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index 74c5f5eadb..a7e10d603e 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -563,3 +563,76 @@ func TestStreamDeadline(t *testing.T) { require.ErrorIs(t, err, os.ErrDeadlineExceeded) }) } + +func TestCanDial(t *testing.T) { + a := newWebRTCHost(t) + defer a.Close() + b := newWebRTCHost(t) + + tests := []struct { + addr ma.Multiaddr + canDial bool + }{ + { + addr: ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1234/p2p/%s/p2p-circuit/", b.ID())), + canDial: false, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1234/p2p/%s/p2p-circuit/webrtc", b.ID())), + canDial: true, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic-v1/p2p/%s/p2p-circuit/", b.ID())), + canDial: false, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic-v1/p2p/%s/p2p-circuit/webrtc/", b.ID())), + canDial: true, + }, + { + addr: ma.StringCast("/ip4/1.2.3.4/tcp/1234/webrtc"), + canDial: false, + }, + { + addr: ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1234/p2p/%s/webrtc/", b.ID())), + canDial: false, + }, + { + addr: ma.StringCast("/ip4/1.2.3.4/tcp/1234/"), + canDial: false, + }, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + require.Equal(t, tt.canDial, a.T.CanDial(tt.addr), "args: %s", tt.addr) + }) + } +} + +func TestCanListenTwice(t *testing.T) { + b := newRelayedHost(t) + defer b.Close() + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + a := newWebRTCHost(t) + defer a.Close() + + ca, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + cb, err := listener.Accept() + require.NoError(t, err) + ca.Close() + cb.Close() + listener.Close() + _, err = listener.Accept() + require.Error(t, err) + + listener, err = b.T.Listen(WebRTCAddr) + require.NoError(t, err) + ca, err = a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + cb, err = listener.Accept() + require.NoError(t, err) + ca.Close() + cb.Close() +} From f41c97c4a19587df546961d8c406533cea914578 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 24 Sep 2023 15:36:55 +0530 Subject: [PATCH 14/22] webrtcprivate: fix Multiple Dialers test to use listeners for dialing --- p2p/transport/webrtcprivate/transport_test.go | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index a7e10d603e..2ebacc4d5f 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -198,24 +198,18 @@ func TestMultipleDials(t *testing.T) { } func TestMultipleDialsAndListeners(t *testing.T) { - var dialHosts []*webrtcHost const N = 5 + var hosts []*relayedHost for i := 0; i < N; i++ { - dialHosts = append(dialHosts, newWebRTCHost(t)) - defer dialHosts[i].Close() - } - - var listenHosts []*relayedHost - for i := 0; i < N; i++ { - listenHosts = append(listenHosts, newRelayedHost(t)) - l, err := listenHosts[i].T.Listen(ma.StringCast("/webrtc")) + hosts = append(hosts, newRelayedHost(t)) + l, err := hosts[i].T.Listen(ma.StringCast("/webrtc")) require.NoError(t, err) - defer listenHosts[i].Close() + defer hosts[i].Close() defer l.Close() } var wg sync.WaitGroup - dialAndPing := func(h *webrtcHost, raddr ma.Multiaddr, p peer.ID) { + dialAndPing := func(h *relayedHost, raddr ma.Multiaddr, p peer.ID) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() ca, err := h.T.Dial(ctx, raddr, p) @@ -263,11 +257,11 @@ func TestMultipleDialsAndListeners(t *testing.T) { } for i := 0; i < N; i++ { - for j := 0; j < N; j++ { + for j := i + 1; j < N; j++ { wg.Add(1) go func(i, j int) { - go dialAndPing(dialHosts[i], listenHosts[j].Addr, listenHosts[j].ID()) - acceptAndPong(listenHosts[j]) + go acceptAndPong(hosts[j]) + dialAndPing(hosts[i], hosts[j].Addr, hosts[j].ID()) wg.Done() }(i, j) } From a2e3df3002b05357d01f35d5bfd672ebb03c2258 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 25 Sep 2023 13:22:09 +0530 Subject: [PATCH 15/22] swarm: integrate webrtc dialing --- core/network/context.go | 7 ++ p2p/net/swarm/dial_ranker.go | 6 +- p2p/net/swarm/swarm_dial.go | 5 -- p2p/net/swarm/swarm_transport.go | 5 +- p2p/test/swarm/swarm_test.go | 79 +++++++++++++++++++ p2p/transport/webrtcprivate/transport.go | 14 ++-- p2p/transport/webrtcprivate/transport_test.go | 4 +- 7 files changed, 106 insertions(+), 14 deletions(-) diff --git a/core/network/context.go b/core/network/context.go index 7fabfb53e0..41c72b99e7 100644 --- a/core/network/context.go +++ b/core/network/context.go @@ -29,6 +29,13 @@ func WithForceDirectDial(ctx context.Context, reason string) context.Context { return context.WithValue(ctx, forceDirectDial, reason) } +// WithoutForceDirectDial constructs a new context with the ForceDirectDial option dropped. +// This is useful in case establishing a direct connection first requires establishing a +// relayed connection e.g. dialing /webrtc addresses. +func WithoutForceDirectDial(ctx context.Context) context.Context { + return context.WithValue(ctx, forceDirectDial, nil) +} + // EXPERIMENTAL // GetForceDirectDial returns true if the force direct dial option is set in the context. func GetForceDirectDial(ctx context.Context) (forceDirect bool, reason string) { diff --git a/p2p/net/swarm/dial_ranker.go b/p2p/net/swarm/dial_ranker.go index 7e58876b91..bd806f8316 100644 --- a/p2p/net/swarm/dial_ranker.go +++ b/p2p/net/swarm/dial_ranker.go @@ -43,7 +43,10 @@ func NoDelayDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { // no additional latency in the vast majority of cases. // // Private and public address groups are dialed in parallel. +// // Dialing relay addresses is delayed by 500 ms, if we have any non-relay alternatives. +// We treat webrtc addresses the same as relay addresses as we need a relay connection to establish a +// webrtc connection. So any available direct addresses are preferred over webrtc addresses. // // Within each group (private, public, relay addresses) we apply the following ranking logic: // @@ -72,7 +75,8 @@ func NoDelayDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { // // We dial lowest ports first as they are more likely to be the listen port. func DefaultDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { - relay, addrs := filterAddrs(addrs, isRelayAddr) + // includes /webrtc addresses too + relay, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_CIRCUIT) }) pvt, addrs := filterAddrs(addrs, manet.IsPrivateAddr) public, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_IP4) || isProtocolAddr(a, ma.P_IP6) }) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 288ad9cc7d..9543ac0db1 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -597,11 +597,6 @@ func isFdConsumingAddr(addr ma.Multiaddr) bool { return err1 == nil || err2 == nil } -func isRelayAddr(addr ma.Multiaddr) bool { - _, err := addr.ValueForProtocol(ma.P_CIRCUIT) - return err == nil -} - // filterLowPriorityAddresses removes addresses inplace for which we have a better alternative // 1. If a /quic-v1 address is present, filter out /quic and /webtransport address on the same 2-tuple: // QUIC v1 is preferred over the deprecated QUIC draft-29, and given the choice, we prefer using diff --git a/p2p/net/swarm/swarm_transport.go b/p2p/net/swarm/swarm_transport.go index 924f0384aa..e36b7d0925 100644 --- a/p2p/net/swarm/swarm_transport.go +++ b/p2p/net/swarm/swarm_transport.go @@ -27,7 +27,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { } return nil } - if isRelayAddr(a) { + if isProtocolAddr(a, ma.P_WEBRTC) { + return s.transports.m[ma.P_WEBRTC] + } + if isProtocolAddr(a, ma.P_CIRCUIT) { return s.transports.m[ma.P_CIRCUIT] } for _, t := range s.transports.m { diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 9874431441..30c7d43387 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -14,6 +14,7 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -243,3 +244,81 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { return false }, 5*time.Second, 100*time.Millisecond) } + +func TestDialPeerWebRTC(t *testing.T) { + h1, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relay.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + _, err = libp2pwebrtcprivate.AddTransport(h1, nil) + require.NoError(t, err) + _, err = libp2pwebrtcprivate.AddTransport(h2, nil) + require.NoError(t, err) + + webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc/p2p/" + h2.ID().String()) + relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) + + h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL) + + // swarm.DialPeer should connect over transient connections + conn1, err := h1.Network().DialPeer(context.Background(), h2.ID()) + require.NoError(t, err) + require.NotNil(t, conn1) + require.Condition(t, func() bool { + _, err1 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) + _, err2 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) + return err1 == nil && err2 != nil + }) + + // should connect to webrtc address + ctx := network.WithForceDirectDial(context.Background(), "test") + conn, err := h1.Network().DialPeer(ctx, h2.ID()) + require.NoError(t, err) + require.NotNil(t, conn) + require.Condition(t, func() bool { + _, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) + _, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) + return err1 != nil && err2 == nil + }) + + done := make(chan struct{}) + h2.SetStreamHandler("test-addr", func(s network.Stream) { + s.Conn().LocalMultiaddr() + _, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) + assert.Error(t, err1) + _, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) + assert.NoError(t, err2) + s.Reset() + close(done) + }) + + s, err := h1.NewStream(context.Background(), h2.ID(), "test-addr") + require.NoError(t, err) + s.Write([]byte("test")) + <-done +} diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index c945f20c5e..67d16909bf 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -60,13 +60,13 @@ type transport struct { var _ tpt.Transport = &transport{} -func AddTransport(h host.Host) (*transport, error) { +func AddTransport(h host.Host, gater connmgr.ConnectionGater) (*transport, error) { n, ok := h.Network().(tpt.TransportNetwork) if !ok { return nil, fmt.Errorf("%v is not a transport network", h.Network()) } - t, err := newTransport(h) + t, err := newTransport(h, gater) if err != nil { return nil, err } @@ -82,7 +82,7 @@ func AddTransport(h host.Host) (*transport, error) { return t, nil } -func newTransport(h host.Host) (*transport, error) { +func newTransport(h host.Host, gater connmgr.ConnectionGater) (*transport, error) { // We use elliptic P-256 since it is widely supported by browsers. // // Implementation note: Testing with the browser, @@ -111,6 +111,7 @@ func newTransport(h host.Host) (*transport, error) { rcmgr: h.Network().ResourceManager(), webrtcConfig: config, maxInFlightConnections: 16, + gater: gater, }, nil } @@ -136,6 +137,11 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { // Connect to the peer on the circuit address relayAddr := getRelayAddr(raddr) + // We drop the ForceDirectDial option as we need a relayed connection before we can + // setup a direct connection + ctx = network.WithoutForceDirectDial(ctx) + // We need this for the signaling stream + ctx = network.WithUseTransient(ctx, "webrtcprivate dial") err := t.host.Connect(ctx, peer.AddrInfo{ID: p, Addrs: []ma.Multiaddr{relayAddr}}) if err != nil { return nil, fmt.Errorf("failed to open %s stream: %w", SignalingProtocol, err) @@ -160,8 +166,6 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp } func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { - // Start signaling protocol stream - ctx = network.WithUseTransient(ctx, "webrtcprivate dial") s, err := t.host.NewStream(ctx, p, SignalingProtocol) if err != nil { return nil, fmt.Errorf("error opening stream %s: %w", SignalingProtocol, err) diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index 2ebacc4d5f..5d3e69ae11 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -47,7 +47,7 @@ func newWebRTCHost(t *testing.T) *webrtcHost { upg := swarmt.GenUpgrader(t, as, nil) err := client.AddTransport(a, upg) require.NoError(t, err) - ta, err := newTransport(a) + ta, err := newTransport(a, nil) require.NoError(t, err) return &webrtcHost{ Host: a, @@ -68,7 +68,7 @@ func newRelayedHost(t *testing.T) *relayedHost { client.AddTransport(p, upg) _, err = client.Reserve(context.Background(), p, peer.AddrInfo{ID: rh.ID(), Addrs: rh.Addrs()}) require.NoError(t, err) - tp, err := newTransport(p) + tp, err := newTransport(p, nil) require.NoError(t, err) return &relayedHost{ webrtcHost: webrtcHost{ From 212159b75a5e7047c58f827c36e698677c78a0a9 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 25 Sep 2023 15:18:37 +0530 Subject: [PATCH 16/22] libp2p: provide option for enabling webrtcprivate --- config/config.go | 8 ++++++++ options.go | 8 ++++++++ p2p/test/swarm/swarm_test.go | 8 ++------ p2p/transport/webrtcprivate/transport.go | 11 ++++++++--- p2p/transport/webrtcprivate/transport_test.go | 4 ++-- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/config/config.go b/config/config.go index 8be5a43999..eaab0ee168 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,7 @@ import ( relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" "github.com/prometheus/client_golang/prometheus" ma "github.com/multiformats/go-multiaddr" @@ -128,6 +129,9 @@ type Config struct { DialRanker network.DialRanker SwarmOpts []swarm.Option + + WebRTCPrivate bool + WebRTCStunServers []string } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -208,6 +212,7 @@ func (cfg *Config) addTransports(h host.Host) error { fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), + fx.Provide(func() []string { return cfg.WebRTCStunServers }), } fxopts = append(fxopts, cfg.Transports...) if cfg.Insecure { @@ -284,6 +289,9 @@ func (cfg *Config) addTransports(h host.Host) error { if cfg.Relay { fxopts = append(fxopts, fx.Invoke(circuitv2.AddTransport)) } + if cfg.WebRTCPrivate { + fxopts = append(fxopts, fx.Invoke(libp2pwebrtcprivate.AddTransport)) + } app := fx.New(fxopts...) if err := app.Err(); err != nil { h.Close() diff --git a/options.go b/options.go index 1a1e9d3982..550d359bb2 100644 --- a/options.go +++ b/options.go @@ -598,3 +598,11 @@ func SwarmOpts(opts ...swarm.Option) Option { return nil } } + +func EnableWebRTCPrivate(stunServers []string) Option { + return func(cfg *Config) error { + cfg.WebRTCPrivate = true + cfg.WebRTCStunServers = stunServers + return nil + } +} diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 30c7d43387..5fdafb2913 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -14,7 +14,6 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" - libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -249,12 +248,14 @@ func TestDialPeerWebRTC(t *testing.T) { h1, err := libp2p.New( libp2p.NoListenAddrs, libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), ) require.NoError(t, err) h2, err := libp2p.New( libp2p.NoListenAddrs, libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), ) require.NoError(t, err) @@ -275,11 +276,6 @@ func TestDialPeerWebRTC(t *testing.T) { _, err = client.Reserve(context.Background(), h2, relay1info) require.NoError(t, err) - _, err = libp2pwebrtcprivate.AddTransport(h1, nil) - require.NoError(t, err) - _, err = libp2pwebrtcprivate.AddTransport(h2, nil) - require.NoError(t, err) - webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc/p2p/" + h2.ID().String()) relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index 67d16909bf..d11fa8b6ab 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -60,13 +60,13 @@ type transport struct { var _ tpt.Transport = &transport{} -func AddTransport(h host.Host, gater connmgr.ConnectionGater) (*transport, error) { +func AddTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []string) (*transport, error) { n, ok := h.Network().(tpt.TransportNetwork) if !ok { return nil, fmt.Errorf("%v is not a transport network", h.Network()) } - t, err := newTransport(h, gater) + t, err := newTransport(h, gater, stunServers) if err != nil { return nil, err } @@ -82,7 +82,7 @@ func AddTransport(h host.Host, gater connmgr.ConnectionGater) (*transport, error return t, nil } -func newTransport(h host.Host, gater connmgr.ConnectionGater) (*transport, error) { +func newTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []string) (*transport, error) { // We use elliptic P-256 since it is widely supported by browsers. // // Implementation note: Testing with the browser, @@ -102,8 +102,13 @@ func newTransport(h host.Host, gater connmgr.ConnectionGater) (*transport, error if err != nil { return nil, fmt.Errorf("generate certificate: %w", err) } + servers := make([]webrtc.ICEServer, len(stunServers)) + for i := 0; i < len(stunServers); i++ { + servers[i] = webrtc.ICEServer{URLs: []string{stunServers[i]}} + } config := webrtc.Configuration{ Certificates: []webrtc.Certificate{*cert}, + ICEServers: servers, } return &transport{ diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index 5d3e69ae11..f9d889cee2 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -47,7 +47,7 @@ func newWebRTCHost(t *testing.T) *webrtcHost { upg := swarmt.GenUpgrader(t, as, nil) err := client.AddTransport(a, upg) require.NoError(t, err) - ta, err := newTransport(a, nil) + ta, err := newTransport(a, nil, nil) require.NoError(t, err) return &webrtcHost{ Host: a, @@ -68,7 +68,7 @@ func newRelayedHost(t *testing.T) *relayedHost { client.AddTransport(p, upg) _, err = client.Reserve(context.Background(), p, peer.AddrInfo{ID: rh.ID(), Addrs: rh.Addrs()}) require.NoError(t, err) - tp, err := newTransport(p, nil) + tp, err := newTransport(p, nil, nil) require.NoError(t, err) return &relayedHost{ webrtcHost: webrtcHost{ From 7e077d39c3fd80886330252f966c0847e51d678b Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 26 Sep 2023 11:38:32 +0530 Subject: [PATCH 17/22] libp2p: use webrtc.iceserver for configuration --- config/config.go | 7 +++++-- options.go | 2 +- p2p/transport/webrtcprivate/transport.go | 10 +++------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/config/config.go b/config/config.go index eaab0ee168..69f2936612 100644 --- a/config/config.go +++ b/config/config.go @@ -35,6 +35,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" + "github.com/pion/webrtc/v3" "github.com/prometheus/client_golang/prometheus" ma "github.com/multiformats/go-multiaddr" @@ -66,6 +67,8 @@ type Security struct { Constructor interface{} } +type ICEServer = webrtc.ICEServer + // Config describes a set of settings for a libp2p node // // This is *not* a stable interface. Use the options defined in the root @@ -131,7 +134,7 @@ type Config struct { SwarmOpts []swarm.Option WebRTCPrivate bool - WebRTCStunServers []string + WebRTCStunServers []ICEServer } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -212,7 +215,7 @@ func (cfg *Config) addTransports(h host.Host) error { fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), - fx.Provide(func() []string { return cfg.WebRTCStunServers }), + fx.Provide(func() []ICEServer { return cfg.WebRTCStunServers }), } fxopts = append(fxopts, cfg.Transports...) if cfg.Insecure { diff --git a/options.go b/options.go index 550d359bb2..4f9ac2597b 100644 --- a/options.go +++ b/options.go @@ -599,7 +599,7 @@ func SwarmOpts(opts ...swarm.Option) Option { } } -func EnableWebRTCPrivate(stunServers []string) Option { +func EnableWebRTCPrivate(stunServers []config.ICEServer) Option { return func(cfg *Config) error { cfg.WebRTCPrivate = true cfg.WebRTCStunServers = stunServers diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index d11fa8b6ab..63d562863d 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -60,7 +60,7 @@ type transport struct { var _ tpt.Transport = &transport{} -func AddTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []string) (*transport, error) { +func AddTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []webrtc.ICEServer) (*transport, error) { n, ok := h.Network().(tpt.TransportNetwork) if !ok { return nil, fmt.Errorf("%v is not a transport network", h.Network()) @@ -82,7 +82,7 @@ func AddTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []stri return t, nil } -func newTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []string) (*transport, error) { +func newTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []webrtc.ICEServer) (*transport, error) { // We use elliptic P-256 since it is widely supported by browsers. // // Implementation note: Testing with the browser, @@ -102,13 +102,9 @@ func newTransport(h host.Host, gater connmgr.ConnectionGater, stunServers []stri if err != nil { return nil, fmt.Errorf("generate certificate: %w", err) } - servers := make([]webrtc.ICEServer, len(stunServers)) - for i := 0; i < len(stunServers); i++ { - servers[i] = webrtc.ICEServer{URLs: []string{stunServers[i]}} - } config := webrtc.Configuration{ Certificates: []webrtc.Certificate{*cert}, - ICEServers: servers, + ICEServers: stunServers, } return &transport{ From 13081dd46c8498070f1256a6aeb949206382925e Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 2 Oct 2023 21:21:12 +0530 Subject: [PATCH 18/22] holepunch: trigger holepunching for peers with webrtc addresses --- p2p/protocol/holepunch/holepunch_test.go | 55 ++++++++++++++++++++++++ p2p/protocol/holepunch/holepuncher.go | 15 +++++++ p2p/protocol/holepunch/svc.go | 38 ++++++++++++++++ 3 files changed, 108 insertions(+) diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 29d589cd7a..3275dc49cc 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" @@ -511,3 +512,57 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, require.NoError(t, err) return h, hps } + +func TestWebRTCDirectConnect(t *testing.T) { + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relayv2.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + + h1, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + libp2p.EnableHolePunching(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + ) + require.NoError(t, err) + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc") + relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/") + h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL) + + err = h1.Connect(context.Background(), peer.AddrInfo{ID: h2.ID()}) + require.NoError(t, err) + require.Eventually( + t, + func() bool { + for _, c := range h1.Network().ConnsToPeer(h2.ID()) { + if !c.Stat().Transient { + return true + } + } + return false + }, + 5*time.Second, + 100*time.Millisecond, + ) +} diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index b651bd7822..18b950f9e3 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -108,6 +108,9 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { // short-circuit hole punching if a direct dial works. // attempt a direct connection ONLY if we have a public address for the remote peer for _, a := range hp.host.Peerstore().Addrs(rp) { + // Here we consider /webrtc addresses as relay addresses and skip them as they're + // also holepunched. We will dial the /webrtc addresses along with other addresses + // obtained in DCUtR if manet.IsPublicAddr(a) && !isRelayAddress(a) { forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching") dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) @@ -136,6 +139,7 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { if err != nil { log.Debugw("hole punching failed", "peer", rp, "error", err) hp.tracer.ProtocolError(rp, err) + hp.maybeDialWebRTC(rp) return err } synTime := rtt / 2 @@ -171,6 +175,17 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { return fmt.Errorf("all retries for hole punch with peer %s failed", rp) } +func (hp *holePuncher) maybeDialWebRTC(p peer.ID) { + addrs := hp.host.Peerstore().Addrs(p) + for _, a := range addrs { + if _, err := a.ValueForProtocol(ma.P_WEBRTC); err == nil { + ctx := network.WithForceDirectDial(hp.ctx, "webrtc holepunch") + hp.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + return + } + } +} + // initiateHolePunch opens a new hole punching coordination stream, // exchanges the addresses and measures the RTT. func (hp *holePuncher) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, []ma.Multiaddr, time.Duration, error) { diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go index 47bf434fb1..4b19e5f811 100644 --- a/p2p/protocol/holepunch/svc.go +++ b/p2p/protocol/holepunch/svc.go @@ -84,6 +84,8 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, return nil, err } } + s.host.Network().Notify(s) + s.tracer.Start() s.refCount.Add(1) @@ -283,3 +285,39 @@ func (s *Service) DirectConnect(p peer.ID) error { s.holePuncherMx.Unlock() return holePuncher.DirectConnect(p) } + +var _ network.Notifiee = &Service{} + +func (s *Service) Connected(_ network.Network, conn network.Conn) { + // Dial /webrtc address if it's a relay connection to a browser node + if conn.Stat().Direction == network.DirOutbound && conn.Stat().Transient { + s.refCount.Add(1) + go func() { + defer s.refCount.Done() + select { + // waiting for Identify here will allow us to access the peer's public and observed addresses + // that we can dial to for a hole punch. + case <-s.ids.IdentifyWait(conn): + case <-s.ctx.Done(): + return + } + p := conn.RemotePeer() + // Peer supports DCUtR, let it trigger holepunch + if protos, err := s.host.Peerstore().SupportsProtocols(p, Protocol); err == nil && len(protos) > 0 { + return + } + // No DCUtR support, connect with peer over /webrtc + for _, addr := range s.host.Peerstore().Addrs(p) { + if _, err := addr.ValueForProtocol(ma.P_WEBRTC); err == nil { + ctx := network.WithForceDirectDial(s.ctx, "webrtc holepunch") + s.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + return + } + } + }() + } +} + +func (*Service) Disconnected(_ network.Network, v network.Conn) {} +func (*Service) Listen(n network.Network, a ma.Multiaddr) {} +func (*Service) ListenClose(n network.Network, a ma.Multiaddr) {} From 05f5d3398095e417b84a799404f0b6476b6dca70 Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 3 Oct 2023 18:07:21 +0530 Subject: [PATCH 19/22] webrtcprivate: integrate transport integration tests Some failing tests are skipped. They're failing because they require a different testing strategy. A webrtcprivate dial also requires a relay dial, causing some of the expectations on mocks to be incorrectly setup --- p2p/test/transport/gating_test.go | 19 +++++++- p2p/test/transport/rcmgr_test.go | 4 +- p2p/test/transport/transport_test.go | 60 ++++++++++++++++++++++++- p2p/transport/webrtc/listener.go | 12 ++--- p2p/transport/webrtcprivate/listener.go | 12 ++++- 5 files changed, 96 insertions(+), 11 deletions(-) diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..c914b04d31 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -83,6 +83,10 @@ func TestInterceptSecuredOutgoing(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + } + ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -92,7 +96,6 @@ func TestInterceptSecuredOutgoing(t *testing.T) { defer h1.Close() defer h2.Close() require.Len(t, h2.Addrs(), 1) - require.Len(t, h2.Addrs(), 1) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -104,6 +107,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) { require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) }), ) + err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) @@ -117,6 +121,9 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -153,6 +160,9 @@ func TestInterceptAccept(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -198,6 +208,10 @@ func TestInterceptSecuredIncoming(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } + ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -231,6 +245,9 @@ func TestInterceptUpgradedIncoming(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) diff --git a/p2p/test/transport/rcmgr_test.go b/p2p/test/transport/rcmgr_test.go index 20f34de799..1ea2a0a69e 100644 --- a/p2p/test/transport/rcmgr_test.go +++ b/p2p/test/transport/rcmgr_test.go @@ -24,7 +24,9 @@ func TestResourceManagerIsUsed(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { for _, testDialer := range []bool{true, false} { t.Run(tc.Name+fmt.Sprintf(" test_dialer=%v", testDialer), func(t *testing.T) { - + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + } var reservedMemory, releasedMemory atomic.Int32 defer func() { require.Equal(t, reservedMemory.Load(), releasedMemory.Load()) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index a7e98a0d85..16e45e6a5f 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -25,12 +25,14 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" "github.com/libp2p/go-libp2p/p2p/net/swarm" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" - "github.com/multiformats/go-multiaddr" + ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -152,6 +154,56 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "WebRTCPrivate", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + // NoListenAddrs helps ensure that we are not listening for TCP, QUIC etc. We do need + // those transports to dial the relay for signaling stream + libp2pOpts = append(libp2pOpts, libp2p.EnableWebRTCPrivate(nil), libp2p.EnableRelay(), libp2p.NoListenAddrs) + + if !opts.NoListen { + r, err := libp2p.New( + libp2p.EnableRelayService(), + libp2p.ForceReachabilityPublic(), + libp2p.Transport(libp2pquic.NewTransport), + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1")) + require.NoError(t, err) + libp2pOpts = append( + libp2pOpts, + libp2p.AddrsFactory(func(_ []ma.Multiaddr) []ma.Multiaddr { + raddrs := r.Addrs() + addrs := make([]ma.Multiaddr, len(raddrs)) + for i := 0; i < len(raddrs); i++ { + + addrs[i] = ma.StringCast(fmt.Sprintf("%s/p2p/%s/p2p-circuit/webrtc/", raddrs[i], r.ID())) + } + return addrs + })) + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + _, err = client.Reserve(context.Background(), h, peer.AddrInfo{ID: r.ID(), Addrs: r.Addrs()}) + require.NoError(t, err) + return &webrtcHost{Host: h, r: r} + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return &webrtcHost{Host: h} + }, + }, +} + +type webrtcHost struct { + host.Host + r host.Host +} + +func (h *webrtcHost) Close() error { + h.Host.Close() + if h.r != nil { + h.r.Close() + } + return nil } func TestPing(t *testing.T) { @@ -656,6 +708,10 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { + if strings.Contains(tc.Name, "WebRTCPrivate") { + t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + } + h1 := tc.HostGenerator(t, TransportTestCaseOpts{}) h2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) defer h1.Close() @@ -673,7 +729,7 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { ai := &peer.AddrInfo{ ID: bogusPeerId, - Addrs: []multiaddr.Multiaddr{h1.Addrs()[0]}, + Addrs: []ma.Multiaddr{h1.Addrs()[0]}, } // Try connecting with the bogus peer ID diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index bf039a03db..bec3dbba35 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -25,14 +25,14 @@ import ( "go.uber.org/zap/zapcore" ) -type connMultiaddrs struct { - local, remote ma.Multiaddr +type ConnMultiaddrs struct { + Local, Remote ma.Multiaddr } -var _ network.ConnMultiaddrs = &connMultiaddrs{} +var _ network.ConnMultiaddrs = &ConnMultiaddrs{} -func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } -func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } +func (c *ConnMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.Local } +func (c *ConnMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.Remote } const ( candidateSetupTimeout = 20 * time.Second @@ -158,7 +158,7 @@ func (l *listener) handleCandidate(ctx context.Context, candidate udpmux.Candida } if l.transport.gater != nil { localAddr, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) - if !l.transport.gater.InterceptAccept(&connMultiaddrs{local: localAddr, remote: remoteMultiaddr}) { + if !l.transport.gater.InterceptAccept(&ConnMultiaddrs{Local: localAddr, Remote: remoteMultiaddr}) { // The connection attempt is rejected before we can send the client an error. // This means that the connection attempt will time out. return nil, errors.New("connection gated") diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 09a4855aab..133e53fd73 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -103,6 +103,15 @@ func (l *listener) handleSignalingStream(s network.Stream) { s.SetDeadline(time.Now().Add(connectTimeout)) + if l.transport.gater != nil { + localAddr := s.Conn().LocalMultiaddr().Encapsulate(WebRTCAddr) + remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr) + if !l.transport.gater.InterceptAccept(&libp2pwebrtc.ConnMultiaddrs{Local: localAddr, Remote: remoteAddr}) { + log.Debug("gater disallowed accepting connection from %s at %s", s.Conn().RemotePeer(), remoteAddr) + s.Reset() + } + } + conn, err := l.setupConnection(ctx, s, scope) if err != nil { s.Reset() @@ -111,7 +120,7 @@ func (l *listener) handleSignalingStream(s network.Stream) { return } - if l.transport.gater != nil && l.transport.gater.InterceptSecured(network.DirOutbound, s.Conn().RemotePeer(), conn) { + if l.transport.gater != nil && l.transport.gater.InterceptSecured(network.DirInbound, s.Conn().RemotePeer(), conn) { conn.Close() log.Debugf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) } @@ -126,6 +135,7 @@ func (l *listener) handleSignalingStream(s network.Stream) { } func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { + pc, err := l.transport.NewPeerConnection() if err != nil { err = fmt.Errorf("error creating a webrtc.PeerConnection: %w", err) From 4a77baee42b0bd6c52009149a54b566c39544181 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 11 Oct 2023 01:31:09 +0530 Subject: [PATCH 20/22] webrtcprivate: fix bug with gater intercept secured --- p2p/transport/webrtcprivate/listener.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 133e53fd73..f356746d85 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -120,7 +120,7 @@ func (l *listener) handleSignalingStream(s network.Stream) { return } - if l.transport.gater != nil && l.transport.gater.InterceptSecured(network.DirInbound, s.Conn().RemotePeer(), conn) { + if l.transport.gater != nil && !l.transport.gater.InterceptSecured(network.DirInbound, s.Conn().RemotePeer(), conn) { conn.Close() log.Debugf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) } From 62497721e0c343a29e6a78269c46ab45c798296a Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 12 Oct 2023 15:17:13 +0530 Subject: [PATCH 21/22] webrtcprivate: fix rcmgr bugs on listener --- p2p/transport/webrtcprivate/listener.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index f356746d85..97807e3ef5 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -81,12 +81,16 @@ func (l *listener) handleSignalingStream(s network.Stream) { defer cancel() defer s.Close() - scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) // we don't have a better remote adress right now + scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, true, ma.StringCast("/webrtc")) // we don't have a better remote adress right now if err != nil { s.Reset() log.Debug("failed to create connection scope:", err) return } + if err := scope.SetPeer(s.Conn().RemotePeer()); err != nil { + log.Debugf("resource manager blocked incoming conn from peer %s: %s", s.Conn().RemotePeer(), err) + return + } if err := s.Scope().SetService(name); err != nil { log.Debugf("error attaching stream to /webrtc listener: %s", err) From bc90626fbcf7afaa35f3616838b1823489dcf375 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 16 Oct 2023 22:10:44 +0530 Subject: [PATCH 22/22] webrtcprivate: advertise /webrtc addresses --- p2p/host/autorelay/relay_finder.go | 2 +- p2p/host/basic/basic_host.go | 30 ++++++++++++++++++++ p2p/protocol/holepunch/holepuncher.go | 5 +++- p2p/protocol/holepunch/svc.go | 5 +++- p2p/test/basichost/basic_host_test.go | 41 +++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 3 deletions(-) diff --git a/p2p/host/autorelay/relay_finder.go b/p2p/host/autorelay/relay_finder.go index ef79950b7b..de041b9ad9 100644 --- a/p2p/host/autorelay/relay_finder.go +++ b/p2p/host/autorelay/relay_finder.go @@ -726,7 +726,7 @@ func (rf *relayFinder) relayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { // only keep private addrs from the original addr set for _, addr := range addrs { - if manet.IsPrivateAddr(addr) { + if !manet.IsPublicAddr(addr) { raddrs = append(raddrs, addr) } } diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 6c3ba53e5b..5659da48ec 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -27,6 +27,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/protocol/ping" + libp2pwebrtcprivate "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/prometheus/client_golang/prometheus" @@ -801,9 +802,38 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { addrs[i] = addrWithCerthash } } + + // Append webrtc addresses to circuit-v2 addresses + hasWebRTCPrivate := false + for _, addr := range addrs { + if addr.Equal(libp2pwebrtcprivate.WebRTCAddr) { + hasWebRTCPrivate = true + break + } + } + if hasWebRTCPrivate { + for _, addr := range addrs { + if _, err := addr.ValueForProtocol(ma.P_CIRCUIT); err == nil { + if isBrowserDialableAddr(addr) { + addrs = append(addrs, addr.Encapsulate(libp2pwebrtcprivate.WebRTCAddr)) + } + } + } + } return addrs } +var browserProtocols = []int{ma.P_WEBTRANSPORT, ma.P_WEBRTC_DIRECT, ma.P_WSS} + +func isBrowserDialableAddr(addr ma.Multiaddr) bool { + for _, p := range browserProtocols { + if _, err := addr.ValueForProtocol(p); err == nil { + return true + } + } + return false +} + // NormalizeMultiaddr returns a multiaddr suitable for equality checks. // If the multiaddr is a webtransport component, it removes the certhashes. func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index 18b950f9e3..cdcba85186 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -180,7 +180,10 @@ func (hp *holePuncher) maybeDialWebRTC(p peer.ID) { for _, a := range addrs { if _, err := a.ValueForProtocol(ma.P_WEBRTC); err == nil { ctx := network.WithForceDirectDial(hp.ctx, "webrtc holepunch") - hp.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + err := hp.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + if err != nil { + log.Debugf("holepunch attempt to %s over /webrtc failed: %s", p, err) + } return } } diff --git a/p2p/protocol/holepunch/svc.go b/p2p/protocol/holepunch/svc.go index 4b19e5f811..1796ec9701 100644 --- a/p2p/protocol/holepunch/svc.go +++ b/p2p/protocol/holepunch/svc.go @@ -310,7 +310,10 @@ func (s *Service) Connected(_ network.Network, conn network.Conn) { for _, addr := range s.host.Peerstore().Addrs(p) { if _, err := addr.ValueForProtocol(ma.P_WEBRTC); err == nil { ctx := network.WithForceDirectDial(s.ctx, "webrtc holepunch") - s.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + err := s.host.Connect(ctx, peer.AddrInfo{ID: p}) // address is already in peerstore + if err != nil { + log.Debugf("holepunch attempt to %s over /webrtc failed: %s", p, err) + } return } } diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index d00dd9d5dc..0d621884d7 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -10,6 +10,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/host/autorelay" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" ma "github.com/multiformats/go-multiaddr" @@ -158,3 +159,43 @@ func TestNewStreamTransientConnection(t *testing.T) { <-done <-done } + +func TestWebRTCPrivateAddressAdvertisement(t *testing.T) { + r, err := libp2p.New( + // We need a public address for the relay + libp2p.AddrsFactory(func(addrs []ma.Multiaddr) []ma.Multiaddr { + return append(addrs, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport")) + }), + libp2p.EnableRelayService(), + libp2p.ForceReachabilityPublic(), + ) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: r.ID(), + Addrs: r.Addrs(), + } + + h, err := libp2p.New( + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), + libp2p.EnableRelay(), + libp2p.EnableWebRTCPrivate(nil), + libp2p.EnableAutoRelayWithStaticRelays( + []peer.AddrInfo{relay1info}, + autorelay.WithBootDelay(0), + ), + libp2p.ForceReachabilityPrivate(), + ) + require.NoError(t, err) + + require.Eventually(t, func() bool { + for _, a := range h.Addrs() { + _, rerr := a.ValueForProtocol(ma.P_CIRCUIT) + _, werr := a.ValueForProtocol(ma.P_WEBRTC) + if rerr == nil && werr == nil { + return true + } + } + return false + }, 5*time.Second, 50*time.Millisecond) +}