-
Notifications
You must be signed in to change notification settings - Fork 84
/
wsconn.go
651 lines (567 loc) · 19.5 KB
/
wsconn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
// This code is available on the terms of the project LICENSE.md file,
// also available online at https://blueoakcouncil.org/license/1.0.0.
package comms
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"decred.org/dcrdex/dex"
"decred.org/dcrdex/dex/msgjson"
"github.com/gorilla/websocket"
)
const (
// bufferSize is buffer size for a websocket connection's read channel.
readBuffSize = 128
// The maximum time in seconds to write to a connection.
writeWait = time.Second * 3
// reconnectInterval is the initial and increment between reconnect tries.
reconnectInterval = 5 * time.Second
// maxReconnectInterval is the maximum allowed reconnect interval.
maxReconnectInterval = time.Minute
// DefaultResponseTimeout is the default timeout for responses after a
// request is successfully sent.
DefaultResponseTimeout = 30 * time.Second
)
// ConnectionStatus represents the current status of the websocket connection.
type ConnectionStatus uint32
const (
Disconnected ConnectionStatus = iota
Connected
InvalidCert
)
// String gives a human readable string for each connection status.
func (cs ConnectionStatus) String() string {
switch cs {
case Disconnected:
return "disconnected"
case Connected:
return "connected"
case InvalidCert:
return "invalid certificate"
default:
return "unknown status"
}
}
// invalidCertRegexp is a regexp that helps check for non-typed x509 errors
// caused by or related to an invalid cert.
var invalidCertRegexp = regexp.MustCompile(".*(unknown authority|not standards compliant|not trusted)")
// IsErrorInvalidCert checks if the provided error is one of the different
// variant of an invalid cert error returned from the x509 package or is
// ErrInvalidCert.
func IsErrorInvalidCert(err error) bool {
var invalidCert x509.CertificateInvalidError
var unknownCertAuth x509.UnknownAuthorityError
return errors.Is(err, ErrInvalidCert) || errors.As(err, &invalidCert) ||
errors.As(err, &unknownCertAuth) || invalidCertRegexp.MatchString(err.Error())
}
// ErrInvalidCert is the error returned when attempting to use an invalid cert
// to set up a ws connection.
var ErrInvalidCert = fmt.Errorf("invalid certificate")
// ErrCertRequired is the error returned when a ws connection fails because no
// cert was provided.
var ErrCertRequired = fmt.Errorf("certificate required")
// WsConn is an interface for a websocket client.
type WsConn interface {
NextID() uint64
IsDown() bool
Send(msg *msgjson.Message) error
Request(msg *msgjson.Message, respHandler func(*msgjson.Message)) error
RequestWithTimeout(msg *msgjson.Message, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) error
Connect(ctx context.Context) (*sync.WaitGroup, error)
MessageSource() <-chan *msgjson.Message
}
// When the DEX sends a request to the client, a responseHandler is created
// to wait for the response.
type responseHandler struct {
expiration *time.Timer
f func(*msgjson.Message)
abort func() // only to be run at most once, and not if f ran
}
// WsCfg is the configuration struct for initializing a WsConn.
type WsCfg struct {
// URL is the websocket endpoint URL.
URL string
// The maximum time in seconds to wait for a ping from the server. This
// should be larger than the server's ping interval to allow for network
// latency.
PingWait time.Duration
// The server's certificate.
Cert []byte
// ReconnectSync runs the needed reconnection synchronization after
// a reconnect.
ReconnectSync func()
// ConnectEventFunc runs whenever connection status changes.
//
// NOTE: Disconnect event notifications may lag behind actual
// disconnections.
ConnectEventFunc func(ConnectionStatus)
// Logger is the logger for the WsConn.
Logger dex.Logger
// NetDialContext specifies an optional dialer context to use.
NetDialContext func(context.Context, string, string) (net.Conn, error)
}
// wsConn represents a client websocket connection.
type wsConn struct {
// 64-bit atomic variables first. See
// https://golang.org/pkg/sync/atomic/#pkg-note-BUG.
rID uint64
cancel context.CancelFunc
wg sync.WaitGroup
log dex.Logger
cfg *WsCfg
tlsCfg *tls.Config
readCh chan *msgjson.Message
wsMtx sync.Mutex
ws *websocket.Conn
connectionStatus uint32 // atomic
reqMtx sync.RWMutex
respHandlers map[uint64]*responseHandler
reconnectCh chan struct{} // trigger for immediate reconnect
}
// NewWsConn creates a client websocket connection.
func NewWsConn(cfg *WsCfg) (WsConn, error) {
if cfg.PingWait < 0 {
return nil, fmt.Errorf("ping wait cannot be negative")
}
var tlsConfig *tls.Config
if len(cfg.Cert) > 0 {
uri, err := url.Parse(cfg.URL)
if err != nil {
return nil, fmt.Errorf("error parsing URL: %w", err)
}
rootCAs, _ := x509.SystemCertPool()
if rootCAs == nil {
rootCAs = x509.NewCertPool()
}
if ok := rootCAs.AppendCertsFromPEM(cfg.Cert); !ok {
return nil, ErrInvalidCert
}
tlsConfig = &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
ServerName: uri.Hostname(),
}
}
return &wsConn{
cfg: cfg,
log: cfg.Logger,
tlsCfg: tlsConfig,
readCh: make(chan *msgjson.Message, readBuffSize),
respHandlers: make(map[uint64]*responseHandler),
reconnectCh: make(chan struct{}, 1),
}, nil
}
// IsDown indicates if the connection is known to be down.
func (conn *wsConn) IsDown() bool {
return atomic.LoadUint32(&conn.connectionStatus) != uint32(Connected)
}
// setConnectionStatus updates the connection's status and runs the
// ConnectEventFunc in case of a change.
func (conn *wsConn) setConnectionStatus(status ConnectionStatus) {
oldStatus := atomic.SwapUint32(&conn.connectionStatus, uint32(status))
statusChange := oldStatus != uint32(status)
if statusChange && conn.cfg.ConnectEventFunc != nil {
conn.cfg.ConnectEventFunc(status)
}
}
// connect attempts to establish a websocket connection.
func (conn *wsConn) connect(ctx context.Context) error {
dialer := &websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: conn.tlsCfg,
}
if conn.cfg.NetDialContext != nil {
dialer.NetDialContext = conn.cfg.NetDialContext
} else {
dialer.Proxy = http.ProxyFromEnvironment
}
ws, _, err := dialer.DialContext(ctx, conn.cfg.URL, nil)
if err != nil {
var e x509.HostnameError // No need to retry...
if IsErrorInvalidCert(err) || errors.As(err, &e) {
conn.setConnectionStatus(InvalidCert)
if conn.tlsCfg == nil {
return ErrCertRequired
}
return ErrInvalidCert
}
conn.setConnectionStatus(Disconnected)
return err
}
// Set the initial read deadline for the first ping. Subsequent read
// deadlines are set in the ping handler.
err = ws.SetReadDeadline(time.Now().Add(conn.cfg.PingWait))
if err != nil {
conn.log.Errorf("set read deadline failed: %v", err)
return err
}
ws.SetPingHandler(func(string) error {
now := time.Now()
// Set the deadline for the next ping.
err := ws.SetReadDeadline(now.Add(conn.cfg.PingWait))
if err != nil {
conn.log.Errorf("set read deadline failed: %v", err)
return err
}
// Respond with a pong.
err = ws.WriteControl(websocket.PongMessage, []byte{}, now.Add(writeWait))
if err != nil {
// read loop handles reconnect
conn.log.Errorf("pong write error: %v", err)
return err
}
return nil
})
conn.wsMtx.Lock()
// If keepAlive called connect, the wsConn's current websocket.Conn may need
// to be closed depending on the error that triggered the reconnect.
if conn.ws != nil {
conn.close()
}
conn.ws = ws
conn.wsMtx.Unlock()
conn.setConnectionStatus(Connected)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.read(ctx)
}()
return nil
}
func (conn *wsConn) close() {
// Attempt to send a close message in case the connection is still live.
msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye")
_ = conn.ws.WriteControl(websocket.CloseMessage, msg,
time.Now().Add(50*time.Millisecond)) // ignore any error
// Forcibly close the underlying connection.
conn.ws.Close()
}
// read fetches and parses incoming messages for processing. This should be
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
reconnect := func() {
conn.setConnectionStatus(Disconnected)
conn.reconnectCh <- struct{}{}
}
for {
msg := new(msgjson.Message)
// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()
// The read itself does not require locking since only this goroutine
// uses read functions that are not safe for concurrent use.
err := ws.ReadJSON(msg)
// Drop the read error on context cancellation.
if ctx.Err() != nil {
return
}
if err != nil {
// Read timeout should flag the connection as down asap.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
conn.log.Errorf("Read timeout on connection to %s.", conn.cfg.URL)
reconnect()
return
}
var mErr *json.UnmarshalTypeError
if errors.As(err, &mErr) {
// JSON decode errors are not fatal, log and proceed.
conn.log.Errorf("json decode error: %v", mErr)
continue
}
// TODO: Now that wsConn goroutines have contexts that are canceled
// on shutdown, we do not have to infer the source and severity of
// the error; just reconnect in ALL other cases, and remove the
// following legacy checks.
// Expected close errors (1000 and 1001) ... but if the server
// closes we still want to reconnect. (???)
if websocket.IsCloseError(err, websocket.CloseGoingAway,
websocket.CloseNormalClosure) ||
strings.Contains(err.Error(), "websocket: close sent") {
reconnect()
return
}
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "read" {
if strings.Contains(opErr.Err.Error(),
"use of closed network connection") {
conn.log.Errorf("read quitting: %v", err)
reconnect()
return
}
}
// Log all other errors and trigger a reconnection.
conn.log.Errorf("read error (%v), attempting reconnection", err)
reconnect()
// Successful reconnect via connect() will start read() again.
return
}
// If the message is a response, find the handler.
if msg.Type == msgjson.Response {
handler := conn.respHandler(msg.ID)
if handler == nil {
conn.log.Errorf("unhandled response with error msg: %v", handleUnknownResponse(msg))
continue
}
// Run handlers in a goroutine so that other messages can be
// received. Include the handler goroutines in the WaitGroup to
// allow them to complete if the connection master desires.
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
handler.f(msg)
}()
continue
}
conn.readCh <- msg
}
}
// keepAlive maintains an active websocket connection by reconnecting when
// the established connection is broken. This should be run as a goroutine.
func (conn *wsConn) keepAlive(ctx context.Context) {
rcInt := reconnectInterval
for {
select {
case <-conn.reconnectCh:
// Prioritize context cancellation even if there are reconnect
// requests.
if ctx.Err() != nil {
return
}
conn.log.Infof("Attempting to reconnect to %s...", conn.cfg.URL)
err := conn.connect(ctx)
if err != nil {
conn.log.Errorf("Reconnect failed. Scheduling reconnect to %s in %.1f seconds.",
conn.cfg.URL, rcInt.Seconds())
time.AfterFunc(rcInt, func() {
conn.reconnectCh <- struct{}{}
})
// Increment the wait up to PingWait.
if rcInt < maxReconnectInterval {
rcInt += reconnectInterval
}
continue
}
conn.log.Info("Successfully reconnected.")
rcInt = reconnectInterval
// Synchronize after a reconnection.
if conn.cfg.ReconnectSync != nil {
conn.cfg.ReconnectSync()
}
case <-ctx.Done():
return
}
}
}
// NextID returns the next request id.
func (conn *wsConn) NextID() uint64 {
return atomic.AddUint64(&conn.rID, 1)
}
// Connect connects the client. Any error encountered during the initial
// connection will be returned. An auto-(re)connect goroutine will be started,
// even on error. To terminate it, use Stop() or cancel the context.
func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
var ctxInternal context.Context
ctxInternal, conn.cancel = context.WithCancel(ctx)
err := conn.connect(ctxInternal)
if err != nil {
// If the certificate is invalid or missing, do not start the reconnect
// loop, and return an error with no WaitGroup.
if errors.Is(err, ErrInvalidCert) || errors.Is(err, ErrCertRequired) {
conn.cancel()
conn.wg.Wait() // probably a no-op
close(conn.readCh)
return nil, err
}
// The read loop would normally trigger keepAlive, but it wasn't started
// on account of a connect error.
conn.log.Errorf("Initial connection failed, starting reconnect loop: %v", err)
time.AfterFunc(5*time.Second, func() {
conn.reconnectCh <- struct{}{}
})
}
conn.wg.Add(2)
go func() {
defer conn.wg.Done()
conn.keepAlive(ctxInternal)
}()
go func() {
defer conn.wg.Done()
<-ctxInternal.Done()
conn.setConnectionStatus(Disconnected)
conn.wsMtx.Lock()
if conn.ws != nil {
conn.log.Debug("Sending close 1000 (normal) message.")
conn.close()
}
conn.wsMtx.Unlock()
// Run the expire funcs so request callers don't hang.
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
for id, h := range conn.respHandlers {
delete(conn.respHandlers, id)
// Since we are holding reqMtx and deleting the handler, no need to
// check if expiration fired (see logReq), but good to stop it.
h.expiration.Stop()
h.abort()
}
close(conn.readCh) // signal to MessageSource receivers that the wsConn is dead
}()
return &conn.wg, err
}
// Stop can be used to close the connection and all of the goroutines started by
// Connect. Alternatively, the context passed to Connect may be canceled.
func (conn *wsConn) Stop() {
conn.cancel()
}
func (conn *wsConn) SendRaw(b []byte) error {
conn.wsMtx.Lock()
defer conn.wsMtx.Unlock()
return conn.ws.WriteMessage(websocket.TextMessage, b)
}
// Send pushes outgoing messages over the websocket connection. Sending of the
// message is synchronous, so a nil error guarantees that the message was
// successfully sent. A non-nil error may indicate that the connection is known
// to be down, the message failed to marshall to JSON, or writing to the
// websocket link failed.
func (conn *wsConn) Send(msg *msgjson.Message) error {
if conn.IsDown() {
return fmt.Errorf("cannot send on a broken connection")
}
// Marshal the Message first so that we don't send junk to the peer even if
// it fails to marshal completely, which gorilla/websocket.WriteJSON does.
b, err := json.Marshal(msg)
if err != nil {
conn.log.Errorf("Failed to marshal message: %v", err)
return err
}
conn.wsMtx.Lock()
defer conn.wsMtx.Unlock()
err = conn.ws.SetWriteDeadline(time.Now().Add(writeWait))
if err != nil {
conn.log.Errorf("Send: failed to set write deadline: %v", err)
return err
}
err = conn.ws.WriteMessage(websocket.TextMessage, b)
if err != nil {
conn.log.Errorf("Send: WriteMessage error: %v", err)
return err
}
return nil
}
// Request sends the Request-type msgjson.Message to the server and does not
// wait for a response, but records a callback function to run when a response
// is received. A response must be received within DefaultResponseTimeout of the
// request, after which the response handler expires and any late response will
// be ignored. To handle expiration or to set the timeout duration, use
// RequestWithTimeout. Sending of the request is synchronous, so a nil error
// guarantees that the request message was successfully sent.
func (conn *wsConn) Request(msg *msgjson.Message, f func(*msgjson.Message)) error {
return conn.RequestWithTimeout(msg, f, DefaultResponseTimeout, func() {})
}
// RequestWithTimeout sends the Request-type message and does not wait for a
// response, but records a callback function to run when a response is received.
// If the server responds within expireTime of the request, the response handler
// is called, otherwise the expire function is called. If the response handler
// is called, it is guaranteed that the response Message.ID is equal to the
// request Message.ID. Sending of the request is synchronous, so a nil error
// guarantees that the request message was successfully sent and that either the
// response handler or expire function will be run; a non-nil error guarantees
// that neither function will run.
//
// For example, to wait on a response or timeout:
//
// errChan := make(chan error, 1)
// err := conn.RequestWithTimeout(reqMsg, func(msg *msgjson.Message) {
// errChan <- msg.UnmarshalResult(responseStructPointer)
// }, timeout, func() {
// errChan <- fmt.Errorf("timed out waiting for '%s' response.", route)
// })
// if err != nil {
// return err // request error
// }
// return <-errChan // timeout or response error
func (conn *wsConn) RequestWithTimeout(msg *msgjson.Message, f func(*msgjson.Message), expireTime time.Duration, expire func()) error {
if msg.Type != msgjson.Request {
return fmt.Errorf("Message is not a request: %v", msg.Type)
}
// Register the response and expire handlers for this request.
conn.logReq(msg.ID, f, expireTime, expire)
err := conn.Send(msg)
if err != nil {
// Neither expire nor the handler should run. Stop the expire timer
// created by logReq and delete the response handler it added. The
// caller receives a non-nil error to deal with it.
conn.log.Errorf("(*wsConn).Request(route '%s') Send error (%v), unregistering msg ID %d handler",
msg.Route, err, msg.ID)
conn.respHandler(msg.ID) // drop the responseHandler logged by logReq that is no longer necessary
}
return err
}
func (conn *wsConn) expire(id uint64) bool {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
_, removed := conn.respHandlers[id]
delete(conn.respHandlers, id)
return removed
}
// logReq stores the response handler in the respHandlers map. Requests to the
// client are associated with a response handler.
func (conn *wsConn) logReq(id uint64, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
doExpire := func() {
// Delete the response handler, and call the provided expire function if
// (*wsLink).respHandler has not already retrieved the handler function
// for execution.
if conn.expire(id) {
expire()
}
}
conn.respHandlers[id] = &responseHandler{
expiration: time.AfterFunc(expireTime, doExpire),
f: respHandler,
abort: expire,
}
}
// respHandler extracts the response handler for the provided request ID if it
// exists, else nil. If the handler exists, it will be deleted from the map.
func (conn *wsConn) respHandler(id uint64) *responseHandler {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
cb, ok := conn.respHandlers[id]
if ok {
cb.expiration.Stop()
delete(conn.respHandlers, id)
}
return cb
}
// MessageSource returns the connection's read source. The returned chan will
// receive requests and notifications from the server, but not responses, which
// have handlers associated with their request. The same channel is returned on
// each call, so there must only be one receiver. When the connection is
// shutdown, the channel will be closed.
func (conn *wsConn) MessageSource() <-chan *msgjson.Message {
return conn.readCh
}
// handleUnknownResponse extracts the error message sent for a response without
// a handler.
func handleUnknownResponse(msg *msgjson.Message) error {
resp, err := msg.Response()
if err != nil {
return err
}
return resp.Error
}