diff --git a/client.go b/client.go index a8c01e0..d3aa897 100644 --- a/client.go +++ b/client.go @@ -67,7 +67,7 @@ func (e *ErrClient) Unwrap(err error) error { type clientResponse struct { Jsonrpc string `json:"jsonrpc"` Result json.RawMessage `json:"result"` - ID int64 `json:"id"` + ID requestID `json:"id"` Error *respError `json:"error,omitempty"` } @@ -170,7 +170,7 @@ func httpClient(ctx context.Context, addr string, namespace string, outs []inter return clientResponse{}, xerrors.Errorf("unmarshaling response: %w", err) } - if resp.ID != *cr.req.ID { + if cr.req.ID.actual != resp.ID.actual { return clientResponse{}, xerrors.New("request and response id didn't match") } @@ -240,7 +240,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs [] req: request{ Jsonrpc: "2.0", Method: wsCancel, - Params: []param{{v: reflect.ValueOf(*cr.req.ID)}}, + Params: []param{{v: reflect.ValueOf(cr.req.ID.actual)}}, }, } select { @@ -498,7 +498,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) req := request{ Jsonrpc: "2.0", - ID: &id, + ID: requestID{id}, Method: fn.client.namespace + "." + fn.name, Params: params, } @@ -528,7 +528,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) return fn.processError(fmt.Errorf("sendRequest failed: %w", err)) } - if resp.ID != *req.ID { + if req.ID.actual != resp.ID.actual { return fn.processError(xerrors.New("request and response id didn't match")) } diff --git a/handler.go b/handler.go index feb4342..a873ba3 100644 --- a/handler.go +++ b/handler.go @@ -37,12 +37,44 @@ type rpcHandler struct { type request struct { Jsonrpc string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` + ID requestID `json:"id,omitempty"` Method string `json:"method"` Params []param `json:"params"` Meta map[string]string `json:"meta,omitempty"` } +type requestID struct { + actual interface{} // nil, int64, or string +} + +func (r *requestID) UnmarshalJSON(data []byte) error { + switch data[0] { + case 'n': // null + case '"': // string + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + r.actual = s + default: // number + var n int64 + if err := json.Unmarshal(data, &n); err != nil { + return err + } + r.actual = n + } + return nil +} + +func (r requestID) MarshalJSON() ([]byte, error) { + switch r.actual.(type) { + case nil, int64, string: + return json.Marshal(r.actual) + default: + return nil, fmt.Errorf("unexpected ID type: %T", r.actual) + } +} + // Limit request size. Ideally this limit should be specific for each field // in the JSON request but as a simple defensive measure we just limit the // entire HTTP body. @@ -64,7 +96,7 @@ func (e *respError) Error() string { type response struct { Jsonrpc string `json:"jsonrpc"` Result interface{} `json:"result,omitempty"` - ID int64 `json:"id"` + ID requestID `json:"id"` Error *respError `json:"error,omitempty"` } @@ -109,7 +141,7 @@ func (s *RPCServer) register(namespace string, r interface{}) { // Handle type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error) -type chanOut func(reflect.Value, int64) error +type chanOut func(reflect.Value, requestID) error func (s *RPCServer) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) { wf := func(cb func(io.Writer)) { @@ -262,7 +294,7 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ stats.Record(ctx, metrics.RPCRequestError.M(1)) return } - if req.ID == nil { + if req.ID.actual == nil { return // notification } @@ -270,7 +302,7 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ resp := response{ Jsonrpc: "2.0", - ID: *req.ID, + ID: req.ID, } if handler.errOut != -1 { @@ -302,7 +334,7 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ // sending channel messages before this rpc call returns //noinspection GoNilness // already checked above - err = chOut(callResult[handler.valOut], *req.ID) + err = chOut(callResult[handler.valOut], req.ID) if err == nil { return // channel goroutine handles responding } diff --git a/rpc_test.go b/rpc_test.go index 3701079..65d8cba 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -1,6 +1,7 @@ package jsonrpc import ( + "bytes" "context" "encoding/json" "errors" @@ -8,6 +9,7 @@ import ( "io" "io/ioutil" "net" + "net/http" "net/http/httptest" "reflect" "strconv" @@ -360,6 +362,81 @@ func TestRPCHttpClient(t *testing.T) { closer() } +func TestRPCCustomHttpClient(t *testing.T) { + // setup server + serverHandler := &SimpleServerHandler{} + rpcServer := NewServer() + rpcServer.Register("SimpleServerHandler", serverHandler) + testServ := httptest.NewServer(rpcServer) + defer testServ.Close() + + // setup custom client + addr := "http://" + testServ.Listener.Addr().String() + doReq := func(reqStr string) string { + hreq, err := http.NewRequest("POST", addr, bytes.NewReader([]byte(reqStr))) + require.NoError(t, err) + + hreq.Header = http.Header{} + hreq.Header.Set("Content-Type", "application/json") + + httpResp, err := testServ.Client().Do(hreq) + defer httpResp.Body.Close() + + respBytes, err := ioutil.ReadAll(httpResp.Body) + require.NoError(t, err) + + return string(respBytes) + } + + // Add(2) + reqStr := `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":100}"` + respBytes := doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","id":100}`+"\n", string(respBytes)) + require.Equal(t, 2, serverHandler.n) + + // Add(-3546) error + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[-3546],"id":1010102}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","id":1010102,"error":{"code":1,"message":"test"}}`+"\n", string(respBytes)) + require.Equal(t, 2, serverHandler.n) + + // AddGet(3) + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.AddGet","params":[3],"id":0}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","result":5,"id":0}`+"\n", string(respBytes)) + require.Equal(t, 5, serverHandler.n) + + // StringMatch("0", 0, 0) + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"0","I":0},0],"id":1}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","result":{"S":"0","I":0,"Ok":true},"id":1}`+"\n", string(respBytes)) + require.Equal(t, 5, serverHandler.n) + + // StringMatch("5", 0, 5) error + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"5","I":0},5],"id":2}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","id":2,"error":{"code":1,"message":":("}}`+"\n", string(respBytes)) + require.Equal(t, 5, serverHandler.n) + + // StringMatch("8", 8, 8) error + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"8","I":8},8],"id":3}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","result":{"S":"8","I":8,"Ok":true},"id":3}`+"\n", string(respBytes)) + require.Equal(t, 5, serverHandler.n) + + // Add(int) string ID + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":"100"}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","id":"100"}`+"\n", string(respBytes)) + require.Equal(t, 7, serverHandler.n) + + // Add(int) random string ID + reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":"OpenRPC says this can be whatever you want"}"` + respBytes = doReq(reqStr) + require.Equal(t, `{"jsonrpc":"2.0","id":"OpenRPC says this can be whatever you want"}`+"\n", string(respBytes)) + require.Equal(t, 9, serverHandler.n) +} + type CtxHandler struct { lk sync.Mutex diff --git a/server.go b/server.go index 30d8f1f..6331f63 100644 --- a/server.go +++ b/server.go @@ -100,13 +100,13 @@ func rpcError(wf func(func(io.Writer)), req *request, code int, err error) { log.Warnf("rpc error: %s", err) - if req.ID == nil { // notification + if req.ID.actual == nil { // notification return } resp := response{ Jsonrpc: "2.0", - ID: *req.ID, + ID: req.ID, Error: &respError{ Code: code, Message: err.Error(), diff --git a/websocket.go b/websocket.go index 44f4f25..67f1c28 100644 --- a/websocket.go +++ b/websocket.go @@ -22,7 +22,7 @@ const chClose = "xrpc.ch.close" type frame struct { // common Jsonrpc string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` + ID requestID `json:"id,omitempty"` Meta map[string]string `json:"meta,omitempty"` // request @@ -35,7 +35,7 @@ type frame struct { } type outChanReg struct { - reqID int64 + reqID requestID chID uint64 ch reflect.Value @@ -66,16 +66,16 @@ type wsConn struct { // Client related // inflight are requests we've sent to the remote - inflight map[int64]clientRequest + inflight map[interface{}]clientRequest // chanHandlers is a map of client-side channel handlers - chanHandlers map[uint64]func(m []byte, ok bool) + chanHandlers map[interface{}]func(m []byte, ok bool) // //// // Server related // handling are the calls we handle - handling map[int64]context.CancelFunc + handling map[interface{}]context.CancelFunc handlingLk sync.Mutex spawnOutChanHandlerOnce sync.Once @@ -227,7 +227,7 @@ func (c *wsConn) handleOutChans() { if err := c.sendRequest(request{ Jsonrpc: "2.0", - ID: nil, // notification + ID: requestID{nil}, // notification Method: chClose, Params: []param{{v: reflect.ValueOf(id)}}, }); err != nil { @@ -239,7 +239,7 @@ func (c *wsConn) handleOutChans() { // forward message if err := c.sendRequest(request{ Jsonrpc: "2.0", - ID: nil, // notification + ID: requestID{nil}, // notification Method: chValue, Params: []param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}}, }); err != nil { @@ -250,7 +250,7 @@ func (c *wsConn) handleOutChans() { } // handleChanOut registers output channel for forwarding to client -func (c *wsConn) handleChanOut(ch reflect.Value, req int64) error { +func (c *wsConn) handleChanOut(ch reflect.Value, req requestID) error { c.spawnOutChanHandlerOnce.Do(func() { go c.handleOutChans() }) @@ -279,21 +279,21 @@ func (c *wsConn) handleChanOut(ch reflect.Value, req int64) error { // This should also probably be a single goroutine, // Note that not doing this should be fine for now as long as we are using // contexts correctly (cancelling when async functions are no longer is use) -func (c *wsConn) handleCtxAsync(actx context.Context, id int64) { +func (c *wsConn) handleCtxAsync(actx context.Context, id requestID) { <-actx.Done() if err := c.sendRequest(request{ Jsonrpc: "2.0", Method: wsCancel, - Params: []param{{v: reflect.ValueOf(id)}}, + Params: []param{{v: reflect.ValueOf(id.actual)}}, }); err != nil { - log.Warnw("failed to send request", "method", wsCancel, "id", id, "error", err.Error()) + log.Warnw("failed to send request", "method", wsCancel, "id", id.actual, "error", err.Error()) } } // cancelCtx is a built-in rpc which handles context cancellation over rpc func (c *wsConn) cancelCtx(req frame) { - if req.ID != nil { + if req.ID.actual != nil { log.Warnf("%s call with ID set, won't respond", wsCancel) } @@ -317,15 +317,20 @@ func (c *wsConn) cancelCtx(req frame) { // // func (c *wsConn) handleChanMessage(frame frame) { - var chid uint64 + var chid requestID if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) return } - hnd, ok := c.chanHandlers[chid] + if chid.actual == nil { + log.Errorf("xrpc.ch.val: no handler ID") + return + } + + hnd, ok := c.chanHandlers[chid.actual] if !ok { - log.Errorf("xrpc.ch.val: handler %d not found", chid) + log.Errorf("xrpc.ch.val: handler %d not found", chid.actual) return } @@ -333,25 +338,30 @@ func (c *wsConn) handleChanMessage(frame frame) { } func (c *wsConn) handleChanClose(frame frame) { - var chid uint64 + var chid requestID if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) return } - hnd, ok := c.chanHandlers[chid] + if chid.actual == nil { + log.Errorf("xrpc.ch.val: no handler ID") + return + } + + hnd, ok := c.chanHandlers[chid.actual] if !ok { log.Errorf("xrpc.ch.val: handler %d not found", chid) return } - delete(c.chanHandlers, chid) + delete(c.chanHandlers, chid.actual) hnd(nil, false) } func (c *wsConn) handleResponse(frame frame) { - req, ok := c.inflight[*frame.ID] + req, ok := c.inflight[frame.ID.actual] if !ok { log.Error("client got unknown ID in response") return @@ -359,24 +369,29 @@ func (c *wsConn) handleResponse(frame frame) { if req.retCh != nil && frame.Result != nil { // output is channel - var chid uint64 + var chid requestID if err := json.Unmarshal(frame.Result, &chid); err != nil { log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result)) return } + if chid.actual == nil { + log.Errorf("xrpc.ch.val: no handler ID") + return + } + var chanCtx context.Context - chanCtx, c.chanHandlers[chid] = req.retCh() - go c.handleCtxAsync(chanCtx, *frame.ID) + chanCtx, c.chanHandlers[chid.actual] = req.retCh() + go c.handleCtxAsync(chanCtx, frame.ID) } req.ready <- clientResponse{ Jsonrpc: frame.Jsonrpc, Result: frame.Result, - ID: *frame.ID, + ID: frame.ID, Error: frame.Error, } - delete(c.inflight, *frame.ID) + delete(c.inflight, frame.ID.actual) } func (c *wsConn) handleCall(ctx context.Context, frame frame) { @@ -403,11 +418,11 @@ func (c *wsConn) handleCall(ctx context.Context, frame frame) { cancel() } } - if frame.ID != nil { + if frame.ID.actual != nil { nextWriter = c.nextWriter c.handlingLk.Lock() - c.handling[*frame.ID] = cancel + c.handling[frame.ID.actual] = cancel c.handlingLk.Unlock() done = func(keepctx bool) { @@ -416,7 +431,7 @@ func (c *wsConn) handleCall(ctx context.Context, frame frame) { if !keepctx { cancel() - delete(c.handling, *frame.ID) + delete(c.handling, frame.ID.actual) } } } @@ -448,7 +463,7 @@ func (c *wsConn) closeInFlight() { for id, req := range c.inflight { req.ready <- clientResponse{ Jsonrpc: "2.0", - ID: id, + ID: requestID{id}, Error: &respError{ Message: "handler: websocket connection closed", Code: 2, @@ -462,8 +477,8 @@ func (c *wsConn) closeInFlight() { } c.handlingLk.Unlock() - c.inflight = map[int64]clientRequest{} - c.handling = map[int64]context.CancelFunc{} + c.inflight = map[interface{}]clientRequest{} + c.handling = map[interface{}]context.CancelFunc{} } func (c *wsConn) closeChans() { @@ -558,9 +573,9 @@ func (c *wsConn) tryReconnect(ctx context.Context) bool { func (c *wsConn) handleWsConn(ctx context.Context) { c.incoming = make(chan io.Reader) - c.inflight = map[int64]clientRequest{} - c.handling = map[int64]context.CancelFunc{} - c.chanHandlers = map[uint64]func(m []byte, ok bool){} + c.inflight = map[interface{}]clientRequest{} + c.handling = map[interface{}]context.CancelFunc{} + c.chanHandlers = map[interface{}]func(m []byte, ok bool){} c.pongs = make(chan struct{}, 1) c.registerCh = make(chan outChanReg) @@ -628,11 +643,11 @@ func (c *wsConn) handleWsConn(ctx context.Context) { } case req := <-c.requests: c.writeLk.Lock() - if req.req.ID != nil { + if req.req.ID.actual != nil { if c.incomingErr != nil { // No conn?, immediate fail req.ready <- clientResponse{ Jsonrpc: "2.0", - ID: *req.req.ID, + ID: req.req.ID, Error: &respError{ Message: "handler: websocket connection closed", Code: 2, @@ -641,7 +656,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) { c.writeLk.Unlock() break } - c.inflight[*req.req.ID] = req + c.inflight[req.req.ID.actual] = req } c.writeLk.Unlock() if err := c.sendRequest(req.req); err != nil {