Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions wsnet/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*
if err != nil {
return nil, false, err
}

select {
case <-d.closed:
return nil, false, errors.New("cache closed")
Expand Down
185 changes: 125 additions & 60 deletions wsnet/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ import (
"github.com/pion/turn/v2"
"github.com/pion/webrtc/v3"
"golang.org/x/net/proxy"
"golang.org/x/xerrors"
)

var (
// ErrMismatchedProtocol occurs when a TURN is requested to a STUN server,
// or a TURN server is requested instead of TURNS.
// ErrMismatchedProtocol occurs when a TURN is requested to a STUN
// server, or a TURN server is requested instead of TURNS.
ErrMismatchedProtocol = errors.New("mismatched protocols")
// ErrInvalidCredentials occurs when invalid credentials are passed to a
// TURN server. This error cannot occur for STUN servers, as they don't accept
// credentials.
// ErrInvalidCredentials occurs when invalid credentials are passed to
// a TURN server. This error cannot occur for STUN servers, as they
// don't accept credentials.
ErrInvalidCredentials = errors.New("invalid credentials")

// Constant for the control channel protocol.
Expand All @@ -36,7 +37,7 @@ var (
// DialICEOptions provides options for dialing an ICE server.
type DialICEOptions struct {
Timeout time.Duration
// Whether to ignore TLS errors.
// InsecureSkipVerify determines whether to ignore TLS errors.
InsecureSkipVerify bool
}

Expand All @@ -50,52 +51,78 @@ func DialICE(server webrtc.ICEServer, options *DialICEOptions) error {
for _, rawURL := range server.URLs {
err := dialICEURL(server, rawURL, options)
if err != nil {
return err
return xerrors.Errorf("dial ice url: %w", err)
}
}

return nil
}

func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions) error {
var (
tcpConn net.Conn
udpConn net.PacketConn
err error
)

url, err := ice.ParseURL(rawURL)
if err != nil {
return err
return xerrors.Errorf("parse ice url: %w", err)
}
var (
tcpConn net.Conn
udpConn net.PacketConn
turnServerAddr = fmt.Sprintf("%s:%d", url.Host, url.Port)
)
turnServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port)

switch {
case url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeSTUN:
switch url.Proto {
case ice.ProtoTypeUDP:
udpConn, err = net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
return xerrors.Errorf("listen packet udp4: %w", err)
}

case ice.ProtoTypeTCP:
tcpConn, err = net.Dial("tcp4", turnServerAddr)
if err != nil {
return xerrors.Errorf("dial tcp4: %w", err)
}

default:
return xerrors.Errorf("unknown url proto: %q", url.Proto)
}

case url.Scheme == ice.SchemeTypeTURNS || url.Scheme == ice.SchemeTypeSTUNS:
switch url.Proto {
case ice.ProtoTypeUDP:
udpAddr, resErr := net.ResolveUDPAddr("udp4", turnServerAddr)
if resErr != nil {
return resErr
udpAddr, err := net.ResolveUDPAddr("udp4", turnServerAddr)
if err != nil {
return xerrors.Errorf("resolve udp4 addr: %w", err)
}
dconn, dialErr := dtls.Dial("udp4", udpAddr, &dtls.Config{

dconn, err := dtls.Dial("udp4", udpAddr, &dtls.Config{
InsecureSkipVerify: options.InsecureSkipVerify,
})
err = dialErr
if err != nil {
return xerrors.Errorf("dtls dial udp4: %w", err)
}

udpConn = turn.NewSTUNConn(dconn)

case ice.ProtoTypeTCP:
tcpConn, err = tls.Dial("tcp4", turnServerAddr, &tls.Config{
InsecureSkipVerify: options.InsecureSkipVerify,
})
if err != nil {
return xerrors.Errorf("tls dial tcp4: %w", err)
}

default:
return xerrors.Errorf("unknown url proto: %q", url.Proto)
}
}

if err != nil {
return err
default:
return xerrors.Errorf("unknown url scheme: %q", url.Scheme)
}

if tcpConn != nil {
udpConn = turn.NewSTUNConn(tcpConn)
}
Expand All @@ -116,45 +143,61 @@ func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions)
RTO: options.Timeout,
})
if err != nil {
return err
return xerrors.Errorf("create turn client: %w", err)
}
defer client.Close()

err = client.Listen()
if err != nil {
return err
return xerrors.Errorf("listen turn client: %w", err)
}
// STUN servers are not authenticated with credentials.
// As long as the transport is valid, this should always work.

// STUN servers are not authenticated with credentials. As long as the
// transport is valid, this should always work.
_, err = client.SendBindingRequest()
if err != nil {
// Transport failed to connect.
// https://github.com/pion/turn/blob/8231b69046f562420299916e9fb69cbff4754231/errors.go#L20
if strings.Contains(err.Error(), "retransmissions failed") {
return ErrMismatchedProtocol
// Transport failed to connect. Convert error into a detectable
// one.
if errIsTurnAllRetransmissionsFailed(err) {
err = ErrMismatchedProtocol
}
return fmt.Errorf("binding: %w", err)

return xerrors.Errorf("send binding request: %w", err)
}

if url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS {
// We TURN to validate server credentials are correct.
pc, err := client.Allocate()
if err != nil {
if strings.Contains(err.Error(), "error 400") {
return ErrInvalidCredentials
err = ErrInvalidCredentials
}

// Since TURN and STUN follow the same protocol, they can
// both handshake, but once a tunnel is allocated it will
// fail to transmit.
if strings.Contains(err.Error(), "retransmissions failed") {
return ErrMismatchedProtocol
if errIsTurnAllRetransmissionsFailed(err) {
err = ErrMismatchedProtocol
}
return err

return xerrors.Errorf("turn allocate: %w", err)
}
defer pc.Close()
}

return nil
}

// Generalizes creating a new peer connection with consistent options.
// errIsTurnAllRetransmissionsFailed detects the `errAllRetransmissionsFailed`
// error from pion/turn.
//
// See: https://github.com/pion/turn/blob/8231b69046f562420299916e9fb69cbff4754231/errors.go#L20
func errIsTurnAllRetransmissionsFailed(err error) bool {
return strings.Contains(err.Error(), "retransmissions failed")
}

// newPeerConnection generalizes creating a new peer connection with consistent
// options.
func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc.PeerConnection, error) {
se := webrtc.SettingEngine{}
se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4})
Expand Down Expand Up @@ -200,7 +243,7 @@ func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc
})
}

// Proxies ICE candidates using the protocol to a writer.
// proxyICECandidates proxies ICE candidates using the protocol to a writer.
func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() {
var (
mut sync.Mutex
Expand All @@ -220,65 +263,87 @@ func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() {
}
mut.Lock()
defer mut.Unlock()

if !flushed {
queue = append(queue, i)
return
}

write(i)
})

return func() {
mut.Lock()
defer mut.Unlock()

for _, i := range queue {
write(i)
}

flushed = true
}
}

// Waits for a PeerConnection to hit the open state.
// waitForConnectionOpen waits for a PeerConnection to hit the open state.
func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) error {
if conn.ConnectionState() == webrtc.PeerConnectionStateConnected {
state := conn.ConnectionState()
if state == webrtc.PeerConnectionStateConnected {
return nil
} else if state != webrtc.PeerConnectionStateConnecting {
return xerrors.New("connection is not connecting")
}
var cancel context.CancelFunc
if _, deadlineSet := ctx.Deadline(); deadlineSet {
ctx, cancel = context.WithCancel(ctx)
} else {
ctx, cancel = context.WithTimeout(ctx, time.Second*15)
}

connected := make(chan struct{})
ctx, cancel := ctxDeadlineIfNotSet(ctx, 15*time.Second)
defer cancel()

conn.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
if pcs == webrtc.PeerConnectionStateConnected {
cancel()
close(connected)
}
})
<-ctx.Done()
if ctx.Err() == context.DeadlineExceeded {
return context.DeadlineExceeded

select {
case <-ctx.Done():
return ctx.Err()
case <-connected:
return nil
}
return nil
}

// Waits for a DataChannel to hit the open state.
// waitForDataChannelOpen waits for a DataChannel to hit the open state.
func waitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error {
if channel.ReadyState() == webrtc.DataChannelStateOpen {
state := channel.ReadyState()
if state == webrtc.DataChannelStateOpen {
return nil
} else if state != webrtc.DataChannelStateConnecting {
return xerrors.New("channel is not connecting")
}
if channel.ReadyState() != webrtc.DataChannelStateConnecting {
return fmt.Errorf("channel closed")
}
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
defer cancelFunc()

connected := make(chan struct{})
ctx, cancel := ctxDeadlineIfNotSet(ctx, 15*time.Second)
defer cancel()

channel.OnOpen(func() {
cancelFunc()
close(connected)
})
<-ctx.Done()
if ctx.Err() == context.DeadlineExceeded {

select {
case <-ctx.Done():
return ctx.Err()
case <-connected:
return nil
}
return nil
}

// ctxDeadlineIfNotSet sets a deadline from the parent context, if and only if
// a deadline does not already exist for the parent context.
func ctxDeadlineIfNotSet(ctx context.Context, deadline time.Duration) (_ctx context.Context, cancel func()) {
if _, ok := ctx.Deadline(); ok {
return context.WithCancel(ctx)
}

return context.WithTimeout(ctx, deadline)
}

func stringPtr(s string) *string {
Expand Down
18 changes: 5 additions & 13 deletions wsnet/rtc_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package wsnet

import (
"errors"
"fmt"
"testing"
"time"

"github.com/pion/ice/v2"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
)

func TestDialICE(t *testing.T) {
Expand All @@ -26,9 +26,7 @@ func TestDialICE(t *testing.T) {
Timeout: time.Millisecond,
InsecureSkipVerify: true,
})
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
})

t.Run("Protocol mismatch", func(t *testing.T) {
Expand All @@ -44,9 +42,7 @@ func TestDialICE(t *testing.T) {
Timeout: time.Millisecond,
InsecureSkipVerify: true,
})
if !errors.Is(err, ErrMismatchedProtocol) {
t.Error(err)
}
assert.ErrorIs(t, err, ErrMismatchedProtocol)
})

t.Run("Invalid auth", func(t *testing.T) {
Expand All @@ -62,9 +58,7 @@ func TestDialICE(t *testing.T) {
Timeout: time.Millisecond,
InsecureSkipVerify: true,
})
if !errors.Is(err, ErrInvalidCredentials) {
t.Error(err)
}
assert.ErrorIs(t, err, ErrInvalidCredentials)
})

t.Run("Protocol mismatch public", func(t *testing.T) {
Expand All @@ -76,8 +70,6 @@ func TestDialICE(t *testing.T) {
Timeout: time.Millisecond,
InsecureSkipVerify: true,
})
if !errors.Is(err, ErrMismatchedProtocol) {
t.Error(err)
}
assert.ErrorIs(t, err, ErrMismatchedProtocol)
})
}