/
caller.go
159 lines (145 loc) · 3.55 KB
/
caller.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package wscaller
import (
"context"
"errors"
"log"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/ccpaging/rpc"
"github.com/gorilla/websocket"
)
type caller interface {
// Call sends a request of rpc to aria2 daemon
Call(method string, params, reply interface{}) (err error)
Close() error
}
type websocketCaller struct {
conn *websocket.Conn
sendChan chan *sendRequest
cancel context.CancelFunc
wg *sync.WaitGroup
once sync.Once
timeout time.Duration
}
type sendRequest struct {
cancel context.CancelFunc
request *rpc.ClientRequest
reply interface{}
}
var reqid = func() func() uint64 {
var id = uint64(time.Now().UnixNano())
return func() uint64 {
return atomic.AddUint64(&id, 1)
}
}()
func NewWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) {
var header = http.Header{}
conn, _, err := websocket.DefaultDialer.Dial(uri, header)
if err != nil {
return nil, err
}
sendChan := make(chan *sendRequest, 16)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(ctx)
w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout}
processor := NewResponseProcessor()
wg.Add(1)
go func() { // routine:recv
defer wg.Done()
defer cancel()
for {
select {
case <-ctx.Done():
return
default:
}
var resp websocketResponse
if err := conn.ReadJSON(&resp); err != nil {
select {
case <-ctx.Done():
return
default:
}
log.Printf("conn.ReadJSON|err:%v", err.Error())
return
}
if resp.Id == nil { // RPC notifications
if notifier != nil {
switch resp.Method {
case "aria2.onDownloadStart":
notifier.OnDownloadStart(resp.Params)
case "aria2.onDownloadPause":
notifier.OnDownloadPause(resp.Params)
case "aria2.onDownloadStop":
notifier.OnDownloadStop(resp.Params)
case "aria2.onDownloadComplete":
notifier.OnDownloadComplete(resp.Params)
case "aria2.onDownloadError":
notifier.OnDownloadError(resp.Params)
case "aria2.onBtDownloadComplete":
notifier.OnBtDownloadComplete(resp.Params)
default:
log.Printf("unexpected notification: %s", resp.Method)
}
}
continue
}
processor.Process(resp.ClientResponse)
}
}()
wg.Add(1)
go func() { // routine:send
defer wg.Done()
defer cancel()
defer w.conn.Close()
for {
select {
case <-ctx.Done():
if err := w.conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
log.Printf("sending websocket close message: %v", err)
}
return
case req := <-sendChan:
processor.Add(req.request.Id, func(resp rpc.ClientResponse) error {
err := resp.Decode(req.reply)
req.cancel()
return err
})
w.conn.SetWriteDeadline(time.Now().Add(timeout))
w.conn.WriteJSON(req.request)
}
}
}()
return w, nil
}
func (w *websocketCaller) Close() (err error) {
w.once.Do(func() {
w.cancel()
w.wg.Wait()
})
return
}
func (w websocketCaller) Call(method string, params, reply interface{}) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), w.timeout)
defer cancel()
select {
case w.sendChan <- &sendRequest{cancel: cancel, request: &rpc.ClientRequest{
Version: "2.0",
Method: method,
Params: params,
Id: rpc.ReqId(),
}, reply: reply}:
default:
return errors.New("sending channel blocking")
}
select {
case <-ctx.Done():
if err := ctx.Err(); err == context.DeadlineExceeded {
return err
}
}
return
}