-
Notifications
You must be signed in to change notification settings - Fork 1
/
websocket.go
118 lines (105 loc) · 2.72 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package transport
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)
// Websocket is a Transport implementation that uses the websocket
// protocol.
type Websocket struct {
*stream
conn *websocket.Conn
}
// WebsocketOptions contains options for the websocket transport.
type WebsocketOptions struct {
// Context used to close the connection.
Context context.Context
// URL of the websocket endpoint.
URL string
// HTTPClient is the HTTP client to use. If nil, http.DefaultClient is
// used.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers to be included in the
// websocket handshake request.
HTTPHeader http.Header
// Timeout is the timeout for the websocket requests. Default is 60s.
Timout time.Duration
// ErrorCh is an optional channel used to report errors.
ErrorCh chan error
}
// NewWebsocket creates a new Websocket instance.
func NewWebsocket(opts WebsocketOptions) (*Websocket, error) {
if opts.URL == "" {
return nil, errors.New("URL cannot be empty")
}
if opts.Context == nil {
return nil, errors.New("context cannot be nil")
}
if opts.Timout == 0 {
opts.Timout = 60 * time.Second
}
conn, _, err := websocket.Dial(opts.Context, opts.URL, &websocket.DialOptions{ //nolint:bodyclose
HTTPClient: opts.HTTPClient,
HTTPHeader: opts.HTTPHeader,
})
if err != nil {
return nil, fmt.Errorf("failed to dial websocket: %w", err)
}
i := &Websocket{
stream: &stream{
ctx: opts.Context,
errCh: opts.ErrorCh,
timeout: opts.Timout,
},
conn: conn,
}
i.onClose = i.close
i.stream.initStream()
go i.readerRoutine()
go i.writerRoutine()
return i, nil
}
func (ws *Websocket) readerRoutine() {
// The background context is used here because closing context will
// cause the nhooyr.io/websocket package to close a connection with
// a close code of 1008 (policy violation) which is not what we want.
ctx := context.Background()
for {
res := rpcResponse{}
if err := wsjson.Read(ctx, ws.conn, &res); err != nil {
if ws.ctx.Err() != nil || errors.As(err, &websocket.CloseError{}) {
return
}
if ws.errCh != nil {
ws.errCh <- fmt.Errorf("websocket reading error: %w", err)
}
continue
}
ws.readerCh <- res
}
}
func (ws *Websocket) writerRoutine() {
for {
select {
case <-ws.ctx.Done():
return
case req := <-ws.writerCh:
if err := wsjson.Write(ws.ctx, ws.conn, req); err != nil {
if ws.errCh != nil {
ws.errCh <- fmt.Errorf("websocket writing error: %w", err)
}
continue
}
}
}
}
func (ws *Websocket) close() {
err := ws.conn.Close(websocket.StatusNormalClosure, "")
if err != nil && ws.errCh != nil {
ws.errCh <- fmt.Errorf("websocket closing error: %w", err)
}
}