Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message ID can be int, string, or null as per OpenRPC spec #48

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (e *ErrClient) Unwrap(err error) error {
type clientResponse struct {
Jsonrpc string `json:"jsonrpc"`
Result json.RawMessage `json:"result"`
ID int64 `json:"id"`
ID requestID `json:"id"`
Error *respError `json:"error,omitempty"`
}

Expand Down Expand Up @@ -170,7 +170,7 @@ func httpClient(ctx context.Context, addr string, namespace string, outs []inter
return clientResponse{}, xerrors.Errorf("unmarshaling response: %w", err)
}

if resp.ID != *cr.req.ID {
if cr.req.ID.actual != resp.ID.actual {
return clientResponse{}, xerrors.New("request and response id didn't match")
}

Expand Down Expand Up @@ -240,7 +240,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
req: request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(*cr.req.ID)}},
Params: []param{{v: reflect.ValueOf(cr.req.ID.actual)}},
},
}
select {
Expand Down Expand Up @@ -498,7 +498,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)

req := request{
Jsonrpc: "2.0",
ID: &id,
ID: requestID{id},
Method: fn.client.namespace + "." + fn.name,
Params: params,
}
Expand Down Expand Up @@ -528,7 +528,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
}

if resp.ID != *req.ID {
if req.ID.actual != resp.ID.actual {
return fn.processError(xerrors.New("request and response id didn't match"))
}

Expand Down
44 changes: 38 additions & 6 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,44 @@ type rpcHandler struct {

type request struct {
Jsonrpc string `json:"jsonrpc"`
ID *int64 `json:"id,omitempty"`
ID requestID `json:"id,omitempty"`
Method string `json:"method"`
Params []param `json:"params"`
Meta map[string]string `json:"meta,omitempty"`
}

type requestID struct {
actual interface{} // nil, int64, or string
}

func (r *requestID) UnmarshalJSON(data []byte) error {
switch data[0] {
case 'n': // null
case '"': // string
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
r.actual = s
default: // number
var n int64
if err := json.Unmarshal(data, &n); err != nil {
return err
}
r.actual = n
}
return nil
}

func (r requestID) MarshalJSON() ([]byte, error) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add a sanity check here that "actual" is one of the three allowed Go types, just as an extra sanity check since UarshalJSON does it too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @mvdan, added and it picked up a bug in my code! can you sanity check my latest commit to see if it's how you would have done it?

Copy link

@mvdan mvdan Apr 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM. You could simplify it a little bit, since you don't need to separate the json.Marshal calls:

	switch r.actual.(type) {
	case nil, int64, string:
		return json.Marshal(r.actual)
	default:
		return nil, fmt.Errorf("unexpected ID type: %T", r.actual)
	}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 thanks & done

switch r.actual.(type) {
case nil, int64, string:
return json.Marshal(r.actual)
default:
return nil, fmt.Errorf("unexpected ID type: %T", r.actual)
}
}

// Limit request size. Ideally this limit should be specific for each field
// in the JSON request but as a simple defensive measure we just limit the
// entire HTTP body.
Expand All @@ -64,7 +96,7 @@ func (e *respError) Error() string {
type response struct {
Jsonrpc string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
ID int64 `json:"id"`
ID requestID `json:"id"`
Error *respError `json:"error,omitempty"`
}

Expand Down Expand Up @@ -109,7 +141,7 @@ func (s *RPCServer) register(namespace string, r interface{}) {
// Handle

type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
type chanOut func(reflect.Value, int64) error
type chanOut func(reflect.Value, requestID) error

func (s *RPCServer) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
wf := func(cb func(io.Writer)) {
Expand Down Expand Up @@ -262,15 +294,15 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
if req.ID == nil {
if req.ID.actual == nil {
return // notification
}

///////////////////

resp := response{
Jsonrpc: "2.0",
ID: *req.ID,
ID: req.ID,
}

if handler.errOut != -1 {
Expand Down Expand Up @@ -302,7 +334,7 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ
// sending channel messages before this rpc call returns

//noinspection GoNilness // already checked above
err = chOut(callResult[handler.valOut], *req.ID)
err = chOut(callResult[handler.valOut], req.ID)
if err == nil {
return // channel goroutine handles responding
}
Expand Down
77 changes: 77 additions & 0 deletions rpc_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package jsonrpc

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
Expand Down Expand Up @@ -360,6 +362,81 @@ func TestRPCHttpClient(t *testing.T) {
closer()
}

func TestRPCCustomHttpClient(t *testing.T) {
// setup server
serverHandler := &SimpleServerHandler{}
rpcServer := NewServer()
rpcServer.Register("SimpleServerHandler", serverHandler)
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

// setup custom client
addr := "http://" + testServ.Listener.Addr().String()
doReq := func(reqStr string) string {
hreq, err := http.NewRequest("POST", addr, bytes.NewReader([]byte(reqStr)))
require.NoError(t, err)

hreq.Header = http.Header{}
hreq.Header.Set("Content-Type", "application/json")

httpResp, err := testServ.Client().Do(hreq)
defer httpResp.Body.Close()

respBytes, err := ioutil.ReadAll(httpResp.Body)
require.NoError(t, err)

return string(respBytes)
}

// Add(2)
reqStr := `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":100}"`
respBytes := doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":100}`+"\n", string(respBytes))
require.Equal(t, 2, serverHandler.n)

// Add(-3546) error
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[-3546],"id":1010102}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":1010102,"error":{"code":1,"message":"test"}}`+"\n", string(respBytes))
require.Equal(t, 2, serverHandler.n)

// AddGet(3)
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.AddGet","params":[3],"id":0}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","result":5,"id":0}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// StringMatch("0", 0, 0)
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"0","I":0},0],"id":1}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","result":{"S":"0","I":0,"Ok":true},"id":1}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// StringMatch("5", 0, 5) error
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"5","I":0},5],"id":2}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":2,"error":{"code":1,"message":":("}}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// StringMatch("8", 8, 8) error
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"8","I":8},8],"id":3}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","result":{"S":"8","I":8,"Ok":true},"id":3}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// Add(int) string ID
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":"100"}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":"100"}`+"\n", string(respBytes))
require.Equal(t, 7, serverHandler.n)

// Add(int) random string ID
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":"OpenRPC says this can be whatever you want"}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":"OpenRPC says this can be whatever you want"}`+"\n", string(respBytes))
require.Equal(t, 9, serverHandler.n)
}

type CtxHandler struct {
lk sync.Mutex

Expand Down
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func rpcError(wf func(func(io.Writer)), req *request, code int, err error) {

log.Warnf("rpc error: %s", err)

if req.ID == nil { // notification
if req.ID.actual == nil { // notification
return
}

resp := response{
Jsonrpc: "2.0",
ID: *req.ID,
ID: req.ID,
Error: &respError{
Code: code,
Message: err.Error(),
Expand Down