/
transport.go
241 lines (209 loc) · 5.36 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
package websocket
import (
"encoding/json"
"fmt"
"log"
"math"
"sync"
"time"
"github.com/beowulf-foundation/beowulf-go/types"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
var (
ErrShutdown = errors.New("connection is shut down")
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
)
type Transport struct {
conn *websocket.Conn
reqMutex sync.Mutex
requestID uint64
pending map[uint64]*callRequest
callbackMutex sync.Mutex
callbackID uint64
callbacks map[uint64]func(args json.RawMessage)
closing bool // user has called Close
shutdown bool // server has told us to stop
mutex sync.Mutex
}
// Represent an async call
type callRequest struct {
Error error // after completion, the error status.
Done chan bool // strobes when call is complete.
Reply *json.RawMessage // reply message
}
func NewTransport(url string) (*Transport, error) {
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
return nil, err
}
client := &Transport{
conn: ws,
pending: make(map[uint64]*callRequest),
callbacks: make(map[uint64]func(args json.RawMessage)),
}
go ping(ws)
go client.input()
return client, nil
}
func (caller *Transport) Call(method string, args []interface{}, reply interface{}, scid string) error {
caller.reqMutex.Lock()
defer caller.reqMutex.Unlock()
caller.mutex.Lock()
if caller.closing || caller.shutdown {
caller.mutex.Unlock()
return ErrShutdown
}
// increase request id
if caller.requestID == math.MaxUint64 {
caller.requestID = 0
}
caller.requestID++
seq := caller.requestID
c := &callRequest{
Done: make(chan bool, 1),
}
caller.pending[seq] = c
caller.mutex.Unlock()
request := types.RPCRequest{
Method: method,
JSON: "2.0",
ID: caller.requestID,
Params: args,
}
// send Json Rcp request
if err := caller.WriteJSON(request); err != nil {
caller.mutex.Lock()
delete(caller.pending, seq)
caller.mutex.Unlock()
return err
}
// wait for the call to complete
<-c.Done
if c.Error != nil {
return c.Error
}
if c.Reply != nil {
if err := json.Unmarshal(*c.Reply, reply); err != nil {
return err
}
}
return nil
}
func (caller *Transport) SetCallback(api string, method string, notice func(args json.RawMessage)) error {
var ans map[string]interface{}
// increase callback id
caller.callbackMutex.Lock()
if caller.callbackID == math.MaxUint64 {
caller.callbackID = 0
}
//caller.callbackID++
caller.callbackID = caller.requestID + 1
caller.callbacks[caller.callbackID] = notice
caller.callbackMutex.Unlock()
return caller.Call("call", []interface{}{api, method, []interface{}{caller.callbackID}}, ans, "")
}
func (caller *Transport) input() {
caller.conn.SetPongHandler(func(string) error { _ = caller.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
_, message, err := caller.conn.ReadMessage()
if err != nil {
caller.stop(err)
return
}
var response types.RPCResponse
if err := json.Unmarshal(message, &response); err != nil {
caller.stop(err)
return
} else {
if call, ok := caller.pending[response.ID]; ok {
caller.onCallResponse(response, call)
} else {
//the message is not a pending call, but probably a callback notice
var incoming types.RPCIncoming
if err := json.Unmarshal(message, &incoming); err != nil {
caller.stop(err)
return
}
if _, ok := caller.callbacks[incoming.ID]; ok {
if err := caller.onNotice(incoming); err != nil {
caller.stop(err)
return
}
} else {
log.Printf("protocol error: unknown message received: %+v\n", incoming)
log.Printf("Answer: %+v\n", string(message))
}
}
}
}
}
// Return pending clients and shutdown the client
func (caller *Transport) stop(err error) {
caller.reqMutex.Lock()
caller.shutdown = true
for _, call := range caller.pending {
call.Error = err
call.Done <- true
}
caller.reqMutex.Unlock()
}
// Call response handler
func (caller *Transport) onCallResponse(response types.RPCResponse, call *callRequest) {
caller.mutex.Lock()
delete(caller.pending, response.ID)
if response.Error != nil {
call.Error = response.Error
}
call.Reply = response.Result
call.Done <- true
caller.mutex.Unlock()
}
// Incoming notice handler
func (caller *Transport) onNotice(incoming types.RPCIncoming) error {
notice := caller.callbacks[incoming.ID]
if notice == nil {
return fmt.Errorf("callback %d is not registered", incoming.ID)
}
// invoke callback
notice(incoming.Result)
return nil
}
// Close calls the underlying web socket Close method. If the connection is already
// shutting down, ErrShutdown is returned.
func (caller *Transport) Close() error {
caller.mutex.Lock()
if caller.closing {
caller.mutex.Unlock()
return ErrShutdown
}
caller.closing = true
caller.mutex.Unlock()
return caller.conn.Close()
}
func ping(ws *websocket.Conn) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
<-ticker.C
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
log.Println("ping:", err)
}
}
}
func (caller *Transport) WriteJSON(v interface{}) error {
w, err := caller.conn.NextWriter(1)
if err != nil {
return err
}
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
err1 := enc.Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}