From 9b2c1fc6763809919bc79cffa2fc97e24599f185 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 9 Sep 2021 13:03:21 -0500 Subject: [PATCH 1/3] chore: cleanup wsnet/rtc Makes code a bit easier to read and uses xerrors for all errors. --- wsnet/cache.go | 1 + wsnet/rtc.go | 181 +++++++++++++++++++++++++++++++--------------- wsnet/rtc_test.go | 18 ++--- 3 files changed, 129 insertions(+), 71 deletions(-) diff --git a/wsnet/cache.go b/wsnet/cache.go index b16950ca..3f03bab4 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -136,6 +136,7 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (* if err != nil { return nil, false, err } + select { case <-d.closed: return nil, false, errors.New("cache closed") diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 32a089a2..b3a81426 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -18,15 +18,16 @@ import ( "github.com/pion/turn/v2" "github.com/pion/webrtc/v3" "golang.org/x/net/proxy" + "golang.org/x/xerrors" ) var ( - // ErrMismatchedProtocol occurs when a TURN is requested to a STUN server, - // or a TURN server is requested instead of TURNS. + // ErrMismatchedProtocol occurs when a TURN is requested to a STUN + // server, or a TURN server is requested instead of TURNS. ErrMismatchedProtocol = errors.New("mismatched protocols") - // ErrInvalidCredentials occurs when invalid credentials are passed to a - // TURN server. This error cannot occur for STUN servers, as they don't accept - // credentials. + // ErrInvalidCredentials occurs when invalid credentials are passed to + // a TURN server. This error cannot occur for STUN servers, as they + // don't accept credentials. ErrInvalidCredentials = errors.New("invalid credentials") // Constant for the control channel protocol. @@ -36,7 +37,7 @@ var ( // DialICEOptions provides options for dialing an ICE server. type DialICEOptions struct { Timeout time.Duration - // Whether to ignore TLS errors. + // InsecureSkipVerify determines whether to ignore TLS errors. InsecureSkipVerify bool } @@ -50,52 +51,79 @@ func DialICE(server webrtc.ICEServer, options *DialICEOptions) error { for _, rawURL := range server.URLs { err := dialICEURL(server, rawURL, options) if err != nil { - return err + return xerrors.Errorf("dial ice url: %w", err) } } + return nil } func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions) error { - url, err := ice.ParseURL(rawURL) - if err != nil { - return err - } var ( tcpConn net.Conn udpConn net.PacketConn - turnServerAddr = fmt.Sprintf("%s:%d", url.Host, url.Port) + turnServerAddr string + err error ) + + url, err := ice.ParseURL(rawURL) + if err != nil { + return xerrors.Errorf("parse ice url: %w", err) + } + turnServerAddr = fmt.Sprintf("%s:%d", url.Host, url.Port) + switch { case url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeSTUN: switch url.Proto { case ice.ProtoTypeUDP: udpConn, err = net.ListenPacket("udp4", "0.0.0.0:0") + if err != nil { + return xerrors.Errorf("listen packet udp4: %w", err) + } + case ice.ProtoTypeTCP: tcpConn, err = net.Dial("tcp4", turnServerAddr) + if err != nil { + return xerrors.Errorf("dial tcp4: %w", err) + } + + default: + return xerrors.Errorf("unknown url proto: %q", url.Proto) } + case url.Scheme == ice.SchemeTypeTURNS || url.Scheme == ice.SchemeTypeSTUNS: switch url.Proto { case ice.ProtoTypeUDP: - udpAddr, resErr := net.ResolveUDPAddr("udp4", turnServerAddr) - if resErr != nil { - return resErr + udpAddr, err := net.ResolveUDPAddr("udp4", turnServerAddr) + if err != nil { + return xerrors.Errorf("resolve udp4 addr: %w", err) } - dconn, dialErr := dtls.Dial("udp4", udpAddr, &dtls.Config{ + + dconn, err := dtls.Dial("udp4", udpAddr, &dtls.Config{ InsecureSkipVerify: options.InsecureSkipVerify, }) - err = dialErr + if err != nil { + return xerrors.Errorf("dtls dial udp4: %w", err) + } + udpConn = turn.NewSTUNConn(dconn) + case ice.ProtoTypeTCP: tcpConn, err = tls.Dial("tcp4", turnServerAddr, &tls.Config{ InsecureSkipVerify: options.InsecureSkipVerify, }) + if err != nil { + return xerrors.Errorf("tls dial tcp4: %w", err) + } + + default: + return xerrors.Errorf("unknown url proto: %q", url.Proto) } - } - if err != nil { - return err + default: + return xerrors.Errorf("unknown url scheme: %q", url.Scheme) } + if tcpConn != nil { udpConn = turn.NewSTUNConn(tcpConn) } @@ -116,45 +144,61 @@ func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions) RTO: options.Timeout, }) if err != nil { - return err + return xerrors.Errorf("create turn client: %w", err) } defer client.Close() + err = client.Listen() if err != nil { - return err + return xerrors.Errorf("listen turn client: %w", err) } - // STUN servers are not authenticated with credentials. - // As long as the transport is valid, this should always work. + + // STUN servers are not authenticated with credentials. As long as the + // transport is valid, this should always work. _, err = client.SendBindingRequest() if err != nil { - // Transport failed to connect. - // https://github.com/pion/turn/blob/8231b69046f562420299916e9fb69cbff4754231/errors.go#L20 - if strings.Contains(err.Error(), "retransmissions failed") { - return ErrMismatchedProtocol + // Transport failed to connect. Convert error into a detectable + // one. + if errIsTurnAllRetransmissionsFailed(err) { + err = ErrMismatchedProtocol } - return fmt.Errorf("binding: %w", err) + + return xerrors.Errorf("send binding request: %w", err) } + if url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS { // We TURN to validate server credentials are correct. pc, err := client.Allocate() if err != nil { if strings.Contains(err.Error(), "error 400") { - return ErrInvalidCredentials + err = ErrInvalidCredentials } + // Since TURN and STUN follow the same protocol, they can // both handshake, but once a tunnel is allocated it will // fail to transmit. - if strings.Contains(err.Error(), "retransmissions failed") { - return ErrMismatchedProtocol + if errIsTurnAllRetransmissionsFailed(err) { + err = ErrMismatchedProtocol } - return err + + return xerrors.Errorf("turn allocate: %w", err) } defer pc.Close() } + return nil } -// Generalizes creating a new peer connection with consistent options. +// errIsTurnAllRetransmissionsFailed detects the `errAllRetransmissionsFailed` +// error from pion/turn. +// +// See: https://github.com/pion/turn/blob/8231b69046f562420299916e9fb69cbff4754231/errors.go#L20 +func errIsTurnAllRetransmissionsFailed(err error) bool { + return strings.Contains(err.Error(), "retransmissions failed") +} + +// newPeerConnection generalizes creating a new peer connection with consistent +// options. func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc.PeerConnection, error) { se := webrtc.SettingEngine{} se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4}) @@ -200,7 +244,7 @@ func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc }) } -// Proxies ICE candidates using the protocol to a writer. +// proxyICECandidates proxies ICE candidates using the protocol to a writer. func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() { var ( mut sync.Mutex @@ -220,6 +264,7 @@ func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() { } mut.Lock() defer mut.Unlock() + if !flushed { queue = append(queue, i) return @@ -227,58 +272,78 @@ func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() { write(i) }) + return func() { mut.Lock() defer mut.Unlock() + for _, i := range queue { write(i) } + flushed = true } } -// Waits for a PeerConnection to hit the open state. +// waitForConnectionOpen waits for a PeerConnection to hit the open state. func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) error { if conn.ConnectionState() == webrtc.PeerConnectionStateConnected { return nil } - var cancel context.CancelFunc - if _, deadlineSet := ctx.Deadline(); deadlineSet { - ctx, cancel = context.WithCancel(ctx) - } else { - ctx, cancel = context.WithTimeout(ctx, time.Second*15) - } + + connected := make(chan struct{}) + ctx, cancel := ctxDeadlineIfNotSet(ctx, 15*time.Second) defer cancel() + conn.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { if pcs == webrtc.PeerConnectionStateConnected { - cancel() + close(connected) } }) - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - return context.DeadlineExceeded + + select { + case <-ctx.Done(): + return ctx.Err() + case <-connected: + return nil } - return nil } -// Waits for a DataChannel to hit the open state. +// waitForDataChannelOpen waits for a DataChannel to hit the open state. func waitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error { - if channel.ReadyState() == webrtc.DataChannelStateOpen { + switch channel.ReadyState() { + case webrtc.DataChannelStateOpen: return nil + + case webrtc.DataChannelStateClosed, + webrtc.DataChannelStateClosing: + return xerrors.New("channel closed") } - if channel.ReadyState() != webrtc.DataChannelStateConnecting { - return fmt.Errorf("channel closed") - } - ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) - defer cancelFunc() + + connected := make(chan struct{}) + ctx, cancel := ctxDeadlineIfNotSet(ctx, 15*time.Second) + defer cancel() + channel.OnOpen(func() { - cancelFunc() + close(connected) }) - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + + select { + case <-ctx.Done(): return ctx.Err() + case <-connected: + return nil } - return nil +} + +// ctxDeadlineIfNotSet sets a deadline from the parent context, if and only if +// a deadline does not already exist for the parent context. +func ctxDeadlineIfNotSet(ctx context.Context, deadline time.Duration) (_ctx context.Context, cancel func()) { + if _, ok := ctx.Deadline(); ok { + return context.WithCancel(ctx) + } + + return context.WithTimeout(ctx, deadline) } func stringPtr(s string) *string { diff --git a/wsnet/rtc_test.go b/wsnet/rtc_test.go index 73d1af2f..7ed3fb80 100644 --- a/wsnet/rtc_test.go +++ b/wsnet/rtc_test.go @@ -1,13 +1,13 @@ package wsnet import ( - "errors" "fmt" "testing" "time" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/assert" ) func TestDialICE(t *testing.T) { @@ -26,9 +26,7 @@ func TestDialICE(t *testing.T) { Timeout: time.Millisecond, InsecureSkipVerify: true, }) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) }) t.Run("Protocol mismatch", func(t *testing.T) { @@ -44,9 +42,7 @@ func TestDialICE(t *testing.T) { Timeout: time.Millisecond, InsecureSkipVerify: true, }) - if !errors.Is(err, ErrMismatchedProtocol) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrMismatchedProtocol) }) t.Run("Invalid auth", func(t *testing.T) { @@ -62,9 +58,7 @@ func TestDialICE(t *testing.T) { Timeout: time.Millisecond, InsecureSkipVerify: true, }) - if !errors.Is(err, ErrInvalidCredentials) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrInvalidCredentials) }) t.Run("Protocol mismatch public", func(t *testing.T) { @@ -76,8 +70,6 @@ func TestDialICE(t *testing.T) { Timeout: time.Millisecond, InsecureSkipVerify: true, }) - if !errors.Is(err, ErrMismatchedProtocol) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrMismatchedProtocol) }) } From 3b3bcf29325ff8bfef1f9489556a901ee1e76454 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 9 Sep 2021 14:15:40 -0500 Subject: [PATCH 2/3] fixup! chore: cleanup wsnet/rtc --- wsnet/rtc.go | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/wsnet/rtc.go b/wsnet/rtc.go index b3a81426..9be95a2c 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -60,17 +60,16 @@ func DialICE(server webrtc.ICEServer, options *DialICEOptions) error { func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions) error { var ( - tcpConn net.Conn - udpConn net.PacketConn - turnServerAddr string - err error + tcpConn net.Conn + udpConn net.PacketConn + err error ) url, err := ice.ParseURL(rawURL) if err != nil { return xerrors.Errorf("parse ice url: %w", err) } - turnServerAddr = fmt.Sprintf("%s:%d", url.Host, url.Port) + turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port) switch { case url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeSTUN: @@ -311,13 +310,11 @@ func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) err // waitForDataChannelOpen waits for a DataChannel to hit the open state. func waitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error { - switch channel.ReadyState() { - case webrtc.DataChannelStateOpen: + state := channel.ReadyState() + if state == webrtc.DataChannelStateOpen { return nil - - case webrtc.DataChannelStateClosed, - webrtc.DataChannelStateClosing: - return xerrors.New("channel closed") + } else if state != webrtc.DataChannelStateConnecting { + return xerrors.New("channel is not connecting") } connected := make(chan struct{}) From ee2214f86b6de945fec42a6203fcbcbced1226a2 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 9 Sep 2021 14:17:47 -0500 Subject: [PATCH 3/3] fixup! chore: cleanup wsnet/rtc --- wsnet/rtc.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/wsnet/rtc.go b/wsnet/rtc.go index 9be95a2c..f0c40651 100644 --- a/wsnet/rtc.go +++ b/wsnet/rtc.go @@ -286,8 +286,11 @@ func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() { // waitForConnectionOpen waits for a PeerConnection to hit the open state. func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) error { - if conn.ConnectionState() == webrtc.PeerConnectionStateConnected { + state := conn.ConnectionState() + if state == webrtc.PeerConnectionStateConnected { return nil + } else if state != webrtc.PeerConnectionStateConnecting { + return xerrors.New("connection is not connecting") } connected := make(chan struct{})