Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 0599955

Browse files
committed
chore: cleanup wsnet/rtc
Makes code a bit easier to read and uses xerrors for all errors.
1 parent 3582a0d commit 0599955

File tree

3 files changed

+129
-71
lines changed

3 files changed

+129
-71
lines changed

wsnet/cache.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*
136136
if err != nil {
137137
return nil, false, err
138138
}
139+
139140
select {
140141
case <-d.closed:
141142
return nil, false, errors.New("cache closed")

wsnet/rtc.go

Lines changed: 123 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ import (
1818
"github.com/pion/turn/v2"
1919
"github.com/pion/webrtc/v3"
2020
"golang.org/x/net/proxy"
21+
"golang.org/x/xerrors"
2122
)
2223

2324
var (
24-
// ErrMismatchedProtocol occurs when a TURN is requested to a STUN server,
25-
// or a TURN server is requested instead of TURNS.
25+
// ErrMismatchedProtocol occurs when a TURN is requested to a STUN
26+
// server, or a TURN server is requested instead of TURNS.
2627
ErrMismatchedProtocol = errors.New("mismatched protocols")
27-
// ErrInvalidCredentials occurs when invalid credentials are passed to a
28-
// TURN server. This error cannot occur for STUN servers, as they don't accept
29-
// credentials.
28+
// ErrInvalidCredentials occurs when invalid credentials are passed to
29+
// a TURN server. This error cannot occur for STUN servers, as they
30+
// don't accept credentials.
3031
ErrInvalidCredentials = errors.New("invalid credentials")
3132

3233
// Constant for the control channel protocol.
@@ -36,7 +37,7 @@ var (
3637
// DialICEOptions provides options for dialing an ICE server.
3738
type DialICEOptions struct {
3839
Timeout time.Duration
39-
// Whether to ignore TLS errors.
40+
// InsecureSkipVerify determines whether to ignore TLS errors.
4041
InsecureSkipVerify bool
4142
}
4243

@@ -50,52 +51,79 @@ func DialICE(server webrtc.ICEServer, options *DialICEOptions) error {
5051
for _, rawURL := range server.URLs {
5152
err := dialICEURL(server, rawURL, options)
5253
if err != nil {
53-
return err
54+
return xerrors.Errorf("dial ice url: %w", err)
5455
}
5556
}
57+
5658
return nil
5759
}
5860

5961
func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions) error {
60-
url, err := ice.ParseURL(rawURL)
61-
if err != nil {
62-
return err
63-
}
6462
var (
6563
tcpConn net.Conn
6664
udpConn net.PacketConn
67-
turnServerAddr = fmt.Sprintf("%s:%d", url.Host, url.Port)
65+
turnServerAddr string
66+
err error
6867
)
68+
69+
url, err := ice.ParseURL(rawURL)
70+
if err != nil {
71+
return xerrors.Errorf("parse ice url: %w", err)
72+
}
73+
turnServerAddr = fmt.Sprintf("%s:%d", url.Host, url.Port)
74+
6975
switch {
7076
case url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeSTUN:
7177
switch url.Proto {
7278
case ice.ProtoTypeUDP:
7379
udpConn, err = net.ListenPacket("udp4", "0.0.0.0:0")
80+
if err != nil {
81+
return xerrors.Errorf("listen packet udp4: %w", err)
82+
}
83+
7484
case ice.ProtoTypeTCP:
7585
tcpConn, err = net.Dial("tcp4", turnServerAddr)
86+
if err != nil {
87+
return xerrors.Errorf("dial tcp4: %w", err)
88+
}
89+
90+
default:
91+
return xerrors.Errorf("unknown url proto: %q", url.Proto)
7692
}
93+
7794
case url.Scheme == ice.SchemeTypeTURNS || url.Scheme == ice.SchemeTypeSTUNS:
7895
switch url.Proto {
7996
case ice.ProtoTypeUDP:
80-
udpAddr, resErr := net.ResolveUDPAddr("udp4", turnServerAddr)
81-
if resErr != nil {
82-
return resErr
97+
udpAddr, err := net.ResolveUDPAddr("udp4", turnServerAddr)
98+
if err != nil {
99+
return xerrors.Errorf("resolve udp4 addr: %w", err)
83100
}
84-
dconn, dialErr := dtls.Dial("udp4", udpAddr, &dtls.Config{
101+
102+
dconn, err := dtls.Dial("udp4", udpAddr, &dtls.Config{
85103
InsecureSkipVerify: options.InsecureSkipVerify,
86104
})
87-
err = dialErr
105+
if err != nil {
106+
return xerrors.Errorf("dtls dial udp4: %w", err)
107+
}
108+
88109
udpConn = turn.NewSTUNConn(dconn)
110+
89111
case ice.ProtoTypeTCP:
90112
tcpConn, err = tls.Dial("tcp4", turnServerAddr, &tls.Config{
91113
InsecureSkipVerify: options.InsecureSkipVerify,
92114
})
115+
if err != nil {
116+
return xerrors.Errorf("tls dial tcp4: %w", err)
117+
}
118+
119+
default:
120+
return xerrors.Errorf("unknown url proto: %q", url.Proto)
93121
}
94-
}
95122

96-
if err != nil {
97-
return err
123+
default:
124+
return xerrors.Errorf("unknown url scheme: %q", url.Scheme)
98125
}
126+
99127
if tcpConn != nil {
100128
udpConn = turn.NewSTUNConn(tcpConn)
101129
}
@@ -116,45 +144,61 @@ func dialICEURL(server webrtc.ICEServer, rawURL string, options *DialICEOptions)
116144
RTO: options.Timeout,
117145
})
118146
if err != nil {
119-
return err
147+
return xerrors.Errorf("create turn client: %w", err)
120148
}
121149
defer client.Close()
150+
122151
err = client.Listen()
123152
if err != nil {
124-
return err
153+
return xerrors.Errorf("listen turn client: %w", err)
125154
}
126-
// STUN servers are not authenticated with credentials.
127-
// As long as the transport is valid, this should always work.
155+
156+
// STUN servers are not authenticated with credentials. As long as the
157+
// transport is valid, this should always work.
128158
_, err = client.SendBindingRequest()
129159
if err != nil {
130-
// Transport failed to connect.
131-
// https://github.com/pion/turn/blob/8231b69046f562420299916e9fb69cbff4754231/errors.go#L20
132-
if strings.Contains(err.Error(), "retransmissions failed") {
133-
return ErrMismatchedProtocol
160+
// Transport failed to connect. Convert error into a detectable
161+
// one.
162+
if errIsTurnAllRetransmissionsFailed(err) {
163+
err = ErrMismatchedProtocol
134164
}
135-
return fmt.Errorf("binding: %w", err)
165+
166+
return xerrors.Errorf("send binding request: %w", err)
136167
}
168+
137169
if url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS {
138170
// We TURN to validate server credentials are correct.
139171
pc, err := client.Allocate()
140172
if err != nil {
141173
if strings.Contains(err.Error(), "error 400") {
142-
return ErrInvalidCredentials
174+
err = ErrInvalidCredentials
143175
}
176+
144177
// Since TURN and STUN follow the same protocol, they can
145178
// both handshake, but once a tunnel is allocated it will
146179
// fail to transmit.
147-
if strings.Contains(err.Error(), "retransmissions failed") {
148-
return ErrMismatchedProtocol
180+
if errIsTurnAllRetransmissionsFailed(err) {
181+
err = ErrMismatchedProtocol
149182
}
150-
return err
183+
184+
return xerrors.Errorf("turn allocate: %w", err)
151185
}
152186
defer pc.Close()
153187
}
188+
154189
return nil
155190
}
156191

157-
// Generalizes creating a new peer connection with consistent options.
192+
// errIsTurnAllRetransmissionsFailed detects the `errAllRetransmissionsFailed`
193+
// error from pion/turn.
194+
//
195+
// See: https://github.com/pion/turn/blob/8231b69046f562420299916e9fb69cbff4754231/errors.go#L20
196+
func errIsTurnAllRetransmissionsFailed(err error) bool {
197+
return strings.Contains(err.Error(), "retransmissions failed")
198+
}
199+
200+
// newPeerConnection generalizes creating a new peer connection with consistent
201+
// options.
158202
func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc.PeerConnection, error) {
159203
se := webrtc.SettingEngine{}
160204
se.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeUDP4})
@@ -200,7 +244,7 @@ func newPeerConnection(servers []webrtc.ICEServer, dialer proxy.Dialer) (*webrtc
200244
})
201245
}
202246

203-
// Proxies ICE candidates using the protocol to a writer.
247+
// proxyICECandidates proxies ICE candidates using the protocol to a writer.
204248
func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() {
205249
var (
206250
mut sync.Mutex
@@ -220,65 +264,86 @@ func proxyICECandidates(conn *webrtc.PeerConnection, w io.Writer) func() {
220264
}
221265
mut.Lock()
222266
defer mut.Unlock()
267+
223268
if !flushed {
224269
queue = append(queue, i)
225270
return
226271
}
227272

228273
write(i)
229274
})
275+
230276
return func() {
231277
mut.Lock()
232278
defer mut.Unlock()
279+
233280
for _, i := range queue {
234281
write(i)
235282
}
283+
236284
flushed = true
237285
}
238286
}
239287

240-
// Waits for a PeerConnection to hit the open state.
288+
// waitForConnectionOpen waits for a PeerConnection to hit the open state.
241289
func waitForConnectionOpen(ctx context.Context, conn *webrtc.PeerConnection) error {
242290
if conn.ConnectionState() == webrtc.PeerConnectionStateConnected {
243291
return nil
244292
}
245-
var cancel context.CancelFunc
246-
if _, deadlineSet := ctx.Deadline(); deadlineSet {
247-
ctx, cancel = context.WithCancel(ctx)
248-
} else {
249-
ctx, cancel = context.WithTimeout(ctx, time.Second*15)
250-
}
293+
294+
connected := make(chan struct{})
295+
ctx, cancel := ctxDeadlineIfNotSet(ctx, 15*time.Second)
251296
defer cancel()
297+
252298
conn.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
253299
if pcs == webrtc.PeerConnectionStateConnected {
254-
cancel()
300+
close(connected)
255301
}
256302
})
257-
<-ctx.Done()
258-
if ctx.Err() == context.DeadlineExceeded {
259-
return context.DeadlineExceeded
303+
304+
select {
305+
case <-ctx.Done():
306+
return ctx.Err()
307+
case <-connected:
308+
return nil
260309
}
261-
return nil
262310
}
263311

264-
// Waits for a DataChannel to hit the open state.
312+
// waitForDataChannelOpen waits for a DataChannel to hit the open state.
265313
func waitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error {
266-
if channel.ReadyState() == webrtc.DataChannelStateOpen {
314+
switch channel.ReadyState() {
315+
case webrtc.DataChannelStateOpen:
267316
return nil
317+
318+
case webrtc.DataChannelStateClosed,
319+
webrtc.DataChannelStateClosing:
320+
return xerrors.New("channel closed")
268321
}
269-
if channel.ReadyState() != webrtc.DataChannelStateConnecting {
270-
return fmt.Errorf("channel closed")
271-
}
272-
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
273-
defer cancelFunc()
322+
323+
connected := make(chan struct{})
324+
ctx, cancel := ctxDeadlineIfNotSet(ctx, 15*time.Second)
325+
defer cancel()
326+
274327
channel.OnOpen(func() {
275-
cancelFunc()
328+
close(connected)
276329
})
277-
<-ctx.Done()
278-
if ctx.Err() == context.DeadlineExceeded {
330+
331+
select {
332+
case <-ctx.Done():
279333
return ctx.Err()
334+
case <-connected:
335+
return nil
336+
}
337+
}
338+
339+
// ctxDeadlineIfNotSet sets a deadline from the parent context, if and only if
340+
// a deadline does not already exist for the parent context.
341+
func ctxDeadlineIfNotSet(ctx context.Context, deadline time.Duration) (_ctx context.Context, cancel func()) {
342+
if _, ok := ctx.Deadline(); ok {
343+
return context.WithCancel(ctx)
344+
} else {
345+
return context.WithTimeout(ctx, deadline)
280346
}
281-
return nil
282347
}
283348

284349
func stringPtr(s string) *string {

wsnet/rtc_test.go

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package wsnet
22

33
import (
4-
"errors"
54
"fmt"
65
"testing"
76
"time"
87

98
"github.com/pion/ice/v2"
109
"github.com/pion/webrtc/v3"
10+
"github.com/stretchr/testify/assert"
1111
)
1212

1313
func TestDialICE(t *testing.T) {
@@ -26,9 +26,7 @@ func TestDialICE(t *testing.T) {
2626
Timeout: time.Millisecond,
2727
InsecureSkipVerify: true,
2828
})
29-
if err != nil {
30-
t.Error(err)
31-
}
29+
assert.NoError(t, err)
3230
})
3331

3432
t.Run("Protocol mismatch", func(t *testing.T) {
@@ -44,9 +42,7 @@ func TestDialICE(t *testing.T) {
4442
Timeout: time.Millisecond,
4543
InsecureSkipVerify: true,
4644
})
47-
if !errors.Is(err, ErrMismatchedProtocol) {
48-
t.Error(err)
49-
}
45+
assert.ErrorIs(t, err, ErrMismatchedProtocol)
5046
})
5147

5248
t.Run("Invalid auth", func(t *testing.T) {
@@ -62,9 +58,7 @@ func TestDialICE(t *testing.T) {
6258
Timeout: time.Millisecond,
6359
InsecureSkipVerify: true,
6460
})
65-
if !errors.Is(err, ErrInvalidCredentials) {
66-
t.Error(err)
67-
}
61+
assert.ErrorIs(t, err, ErrInvalidCredentials)
6862
})
6963

7064
t.Run("Protocol mismatch public", func(t *testing.T) {
@@ -76,8 +70,6 @@ func TestDialICE(t *testing.T) {
7670
Timeout: time.Millisecond,
7771
InsecureSkipVerify: true,
7872
})
79-
if !errors.Is(err, ErrMismatchedProtocol) {
80-
t.Error(err)
81-
}
73+
assert.ErrorIs(t, err, ErrMismatchedProtocol)
8274
})
8375
}

0 commit comments

Comments
 (0)