-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.go
324 lines (283 loc) · 8.19 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
package tsrpc
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/dgdts/tsrpc/codec"
)
// Call represents an RPC call.
type Call struct {
Seq uint64 // Sequence number of the call
ServiceMethod string // Service method name
Args interface{} // Arguments for the method
Reply interface{} // Reply from the method
Error error // Error status of the call
Done chan *Call // Channel for indicating call completion
}
// done sends the call completion signal on the Done channel.
func (c *Call) done() {
c.Done <- c
}
// Client represents an RPC client.
type Client struct {
cc codec.Codec // Codec for encoding and decoding
opt *Option // Options for the client
sending sync.Mutex // Mutex for sending requests
header codec.Header // Header for RPC requests
mu sync.Mutex // Mutex for concurrent access to the client
seq uint64 // Sequence number for calls
pending map[uint64]*Call // Map of pending calls
closing bool // Indicates if the client is closing
shutdown bool // Indicates if the client is shut down
}
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shutdown")
// Close closes the client connection.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing {
return ErrShutdown
}
c.closing = true
return c.cc.Close()
}
// IsAvaliable checks if the client is available for making requests.
func (c *Client) IsAvaliable() bool {
c.mu.Lock()
defer c.mu.Unlock()
return !c.shutdown && !c.closing
}
// resisterCall registers a call and assigns it a sequence number.
func (c *Client) resisterCall(call *Call) (uint64, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing || c.shutdown {
return 0, ErrShutdown
}
call.Seq = c.seq
c.pending[call.Seq] = call
c.seq++
return call.Seq, nil
}
// removeCall removes a call from the pending call map based on its sequence number.
func (c *Client) removeCall(seq uint64) *Call {
c.mu.Lock()
defer c.mu.Unlock()
call := c.pending[seq]
delete(c.pending, seq)
return call
}
// terminateCalls terminates all pending calls with the given error.
func (c *Client) terminateCalls(err error) {
c.sending.Lock()
defer c.sending.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
c.shutdown = true
for _, call := range c.pending {
call.Error = err
call.done()
}
}
// receive handles incoming RPC responses.
func (c *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = c.cc.ReadHeader(&h); err != nil {
break
}
call := c.removeCall(h.Seq)
switch {
case call == nil:
err = c.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = c.cc.ReadBody(nil)
call.done()
default:
err = c.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
c.terminateCalls(err)
}
// NewClient creates a new RPC client with the given connection and options.
func NewClient(conn net.Conn, opt *Option) (*Client, error) {
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
}
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: option error:", err)
_ = conn.Close()
return nil, err
}
if err := json.NewDecoder(conn).Decode(opt); err != nil {
log.Println("rpc client: option err:", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), opt), nil
}
// newClientCodec creates a new client with the given codec and options.
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1,
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
// parseOptions parses and validates the provided options.
func parseOptions(opts ...*Option) (*Option, error) {
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
// Dial creates a new client connection to the specified network address.
func Dial(network string, address string, opts ...*Option) (client *Client, err error) {
return dialTimeout(NewClient, network, address, opts...)
}
// send sends an RPC request using the provided call.
func (c *Client) send(call *Call) {
c.sending.Lock()
defer c.sending.Unlock()
seq, err := c.resisterCall(call)
if err != nil {
call.Error = err
call.done()
return
}
c.header.ServiceMethod = call.ServiceMethod
c.header.Seq = call.Seq
c.header.Error = ""
if err := c.cc.Write(&c.header, call.Args); err != nil {
call := c.removeCall(seq)
if call != nil {
call.Error = err
call.done()
}
}
}
// Go invokes an RPC call asynchronously and returns a Call structure representing the invocation.
func (c *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
c.send(call)
return call
}
// Call invokes an RPC call synchronously.
func (c *Client) Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error {
call := c.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
c.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}
// clientResult represents the result of a client dialing operation.
type clientResult struct {
client *Client
err error
}
// newClientFunc is a function type for creating a new client with a connection and options.
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
// dialTimeout dials a connection with the given timeout and options.
func dialTimeout(f newClientFunc, network string, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}
// NewHTTPClient creates a new HTTP RPC client with the given connection and options.
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected {
return NewClient(conn, opt)
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
return nil, err
}
// DialHTTP creates a new HTTP RPC client connection to the specified network address.
func DialHTTP(network string, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewHTTPClient, network, address, opts...)
}
// XDial creates a new RPC client connection based on the given protocol and address.
func XDial(rpcAddr string, opts ...*Option) (*Client, error) {
parts := strings.Split(rpcAddr, "@")
if len(parts) != 2 {
return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
}
protocol, addr := parts[0], parts[1]
switch protocol {
case "http":
return DialHTTP("tcp", addr, opts...)
default:
return Dial(protocol, addr, opts...)
}
}