diff --git a/call.go b/call.go index 4af07e6..1528f40 100644 --- a/call.go +++ b/call.go @@ -26,7 +26,6 @@ type Call struct { errorMu sync.Mutex Error error // After completion, the error status. - } func newCall(ctx context.Context, dest peer.ID, svcName, svcMethod string, args interface{}, reply interface{}, done chan *Call) *Call { diff --git a/client.go b/client.go index aea24de..2e4f551 100644 --- a/client.go +++ b/client.go @@ -2,7 +2,6 @@ package rpc import ( "context" - "errors" "io" "sync" @@ -227,17 +226,21 @@ func checkMatchingLengths(l ...int) bool { // makeCall decides if a call can be performed. If it's a local // call it will use the configured server if set. func (c *Client) makeCall(call *Call) { - logger.Debugf("makeCall: %s.%s", + logger.Debugf( + "makeCall: %s.%s", call.SvcID.Name, - call.SvcID.Method) + call.SvcID.Method, + ) // Handle local RPC calls if call.Dest == "" || call.Dest == c.host.ID() { - logger.Debugf("local call: %s.%s", - call.SvcID.Name, call.SvcID.Method) + logger.Debugf( + "local call: %s.%s", + call.SvcID.Name, + call.SvcID.Method, + ) if c.server == nil { - err := errors.New( - "Cannot make local calls: server not set") + err := &clientError{"Cannot make local calls: server not set"} call.doneWithError(err) return } @@ -263,7 +266,7 @@ func (c *Client) send(call *Call) { s, err := c.host.NewStream(call.ctx, call.Dest, c.protocol) if err != nil { - call.doneWithError(err) + call.doneWithError(newClientError(err)) return } defer s.Close() @@ -271,21 +274,25 @@ func (c *Client) send(call *Call) { sWrap := wrapStream(s) - logger.Debugf("sending RPC %s.%s to %s", call.SvcID.Name, - call.SvcID.Method, call.Dest) + logger.Debugf( + "sending RPC %s.%s to %s", + call.SvcID.Name, + call.SvcID.Method, + call.Dest, + ) if err := sWrap.enc.Encode(call.SvcID); err != nil { - call.doneWithError(err) + call.doneWithError(newClientError(err)) s.Reset() return } if err := sWrap.enc.Encode(call.Args); err != nil { - call.doneWithError(err) + call.doneWithError(newClientError(err)) s.Reset() return } if err := sWrap.w.Flush(); err != nil { - call.doneWithError(err) + call.doneWithError(newClientError(err)) s.Reset() return } @@ -294,24 +301,28 @@ func (c *Client) send(call *Call) { // receiveResponse reads a response to an RPC call func receiveResponse(s *streamWrap, call *Call) { - logger.Debugf("waiting response for %s.%s to %s", call.SvcID.Name, - call.SvcID.Method, call.Dest) + logger.Debugf( + "waiting response for %s.%s to %s", + call.SvcID.Name, + call.SvcID.Method, + call.Dest, + ) var resp Response if err := s.dec.Decode(&resp); err != nil { - call.doneWithError(err) + call.doneWithError(newClientError(err)) s.stream.Reset() return } defer call.done() if e := resp.Error; e != "" { - call.setError(errors.New(e)) + call.setError(responseError(resp.ErrType, e)) } // Even on error we sent the reply so it needs to be // read if err := s.dec.Decode(call.Reply); err != nil && err != io.EOF { - call.setError(err) + call.setError(newClientError(err)) } return } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..6e87dd0 --- /dev/null +++ b/errors.go @@ -0,0 +1,104 @@ +package rpc + +import "errors" + +// responseErr is an enum type for providing error type +// information over the wire between rpc server and client. +type responseErr int + +const ( + // nonRPCErr is an error that hasn't arisen from the gorpc package. + nonRPCErr responseErr = iota + // serverErr is an error that has arisen on the server side. + serverErr + // clientErr is an error that has arisen on the client side. + clientErr +) + +// serverError indicates that error originated in server +// specific code. +type serverError struct { + msg string +} + +func (s *serverError) Error() string { + return s.msg +} + +// newServerError wraps an error in the serverError type. +func newServerError(err error) error { + return &serverError{err.Error()} +} + +// clientError indicates that error originated in client +// specific code. +type clientError struct { + msg string +} + +func (c *clientError) Error() string { + return c.msg +} + +// newClientError wraps an error in the clientError type. +func newClientError(err error) error { + return &clientError{err.Error()} +} + +// responseError converts an responseErr and error message string +// into the appropriate error type. +func responseError(errType responseErr, errMsg string) error { + switch errType { + case serverErr: + return &serverError{errMsg} + case clientErr: + return &clientError{errMsg} + default: + return errors.New(errMsg) + } +} + +// responseErrorType determines whether an error is of either +// serverError or clientError type and returns the appropriate +// responseErr value. +func responseErrorType(err error) responseErr { + switch err.(type) { + case *serverError: + return serverErr + case *clientError: + return clientErr + default: + return nonRPCErr + } +} + +// IsRPCError returns whether an error is either a serverError +// or clientError. +func IsRPCError(err error) bool { + switch err.(type) { + case *serverError, *clientError: + return true + default: + return false + } +} + +// IsServerError returns whether an error is serverError. +func IsServerError(err error) bool { + switch err.(type) { + case *serverError: + return true + default: + return false + } +} + +// IsClientError returns whether an error is clientError. +func IsClientError(err error) bool { + switch err.(type) { + case *clientError: + return true + default: + return false + } +} diff --git a/server.go b/server.go index 07253ba..fade0cf 100644 --- a/server.go +++ b/server.go @@ -96,6 +96,7 @@ type ServiceID struct { type Response struct { Service ServiceID Error string // error, if any. + ErrType responseErr } // Server is an LibP2P RPC server. It can register services which comply to the @@ -128,7 +129,7 @@ func NewServer(h host.Host, p protocol.ID) *Server { err := s.handle(sWrap) if err != nil { logger.Error("error handling RPC:", err) - resp := &Response{ServiceID{}, err.Error()} + resp := &Response{ServiceID{}, err.Error(), responseErrorType(err)} sendResponse(sWrap, resp, nil) } }) @@ -152,14 +153,14 @@ func (server *Server) handle(s *streamWrap) error { err := s.dec.Decode(&svcID) if err != nil { - return err + return newServerError(err) } logger.Debugf("RPC ServiceID is %s.%s", svcID.Name, svcID.Method) service, mtype, err := server.getService(svcID) if err != nil { - return err + return newServerError(err) } // Decode the argument value. @@ -172,7 +173,7 @@ func (server *Server) handle(s *streamWrap) error { } // argv guaranteed to be a pointer now. if err = s.dec.Decode(argv.Interface()); err != nil { - return err + return newServerError(err) } if argIsValue { argv = argv.Elem() @@ -216,7 +217,7 @@ func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, if errInter != nil { errmsg = errInter.(error).Error() } - resp := &Response{svcID, errmsg} + resp := &Response{svcID, errmsg, nonRPCErr} return sendResponse(sWrap, resp, replyv.Interface()) } @@ -245,7 +246,7 @@ func (server *Server) Call(call *Call) error { var argv, replyv reflect.Value service, mtype, err := server.getService(call.SvcID) if err != nil { - return err + return newServerError(err) } // Use the context value from the call directly @@ -257,7 +258,9 @@ func (server *Server) Call(call *Call) error { if reflect.TypeOf(call.Args).Kind() != reflect.Ptr { return fmt.Errorf( "%s.%s is being called with the wrong arg type", - call.SvcID.Name, call.SvcID.Method) + call.SvcID.Name, + call.SvcID.Method, + ) } argv = reflect.New(mtype.ArgType.Elem()) argv.Elem().Set(reflect.ValueOf(call.Args).Elem()) @@ -265,7 +268,9 @@ func (server *Server) Call(call *Call) error { if reflect.TypeOf(call.Args).Kind() == reflect.Ptr { return fmt.Errorf( "%s.%s is being called with the wrong arg type", - call.SvcID.Name, call.SvcID.Method) + call.SvcID.Name, + call.SvcID.Method, + ) } argv = reflect.New(mtype.ArgType) argv.Elem().Set(reflect.ValueOf(call.Args)) @@ -282,11 +287,14 @@ func (server *Server) Call(call *Call) error { // Call service and respond function := mtype.method.Func // Invoke the method, providing a new value for the reply. - returnValues := function.Call([]reflect.Value{ - service.rcvr, - ctxv, // context - argv, // argument - replyv}) // reply + returnValues := function.Call( + []reflect.Value{ + service.rcvr, + ctxv, // context + argv, // argument + replyv, + }, + ) // reply creplyv := reflect.ValueOf(call.Reply) creplyv.Elem().Set(replyv.Elem()) @@ -306,12 +314,12 @@ func (server *Server) getService(id ServiceID) (*service, *methodType, error) { server.mu.RUnlock() if service == nil { err := errors.New("rpc: can't find service " + id.Name) - return nil, nil, err + return nil, nil, newServerError(err) } mtype := service.method[id.Method] if mtype == nil { err := errors.New("rpc: can't find method " + id.Method) - return nil, nil, err + return nil, nil, newServerError(err) } return service, mtype, nil } diff --git a/server_test.go b/server_test.go index 68ad4f4..af6473e 100644 --- a/server_test.go +++ b/server_test.go @@ -210,26 +210,87 @@ func TestErrorResponse(t *testing.T) { var arith Arith s.Register(&arith) - var r int - // test remote - c := NewClientWithServer(h2, "rpc", s) - err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) - if err == nil || err.Error() != "an error" { - t.Error("expected different error") - } - if r != 42 { - t.Error("response should be set even on error") - } + t.Run("remote", func(t *testing.T) { + var r int + c := NewClientWithServer(h2, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err == nil || err.Error() != "an error" { + t.Error("expected different error") + } + if r != 42 { + t.Error("response should be set even on error") + } + }) - // test local - c = NewClientWithServer(h1, "rpc", s) - err = c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) - if err == nil || err.Error() != "an error" { - t.Error("expected different error") - } - if r != 42 { - t.Error("response should be set even on error") - } + t.Run("local", func(t *testing.T) { + var r int + c := NewClientWithServer(h1, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err == nil || err.Error() != "an error" { + t.Error("expected different error") + } + if r != 42 { + t.Error("response should be set even on error") + } + }) +} + +func TestNonRPCError(t *testing.T) { + h1, h2 := makeRandomNodes() + defer h1.Close() + defer h2.Close() + + s := NewServer(h1, "rpc") + var arith Arith + s.Register(&arith) + + t.Run("local non rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h1, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err != nil { + if IsRPCError(err) { + t.Log(err) + t.Error("expected non rpc error") + } + } + }) + + t.Run("local rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h1, "rpc", s) + err := c.Call(h1.ID(), "Arith", "ThisIsNotAMethod", &Args{1, 2}, &r) + if err != nil { + if !IsRPCError(err) { + t.Log(err) + t.Error("expected rpc error") + } + } + }) + + t.Run("remote non rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h2, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err != nil { + if IsRPCError(err) { + t.Log(err) + t.Error("expected non rpc error") + } + } + }) + + t.Run("remote rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h2, "rpc", s) + err := c.Call(h1.ID(), "Arith", "ThisIsNotAMethod", &Args{1, 2}, &r) + if err != nil { + if !IsRPCError(err) { + t.Log(err) + t.Error("expected rpc error") + } + } + }) } func testCallContext(t *testing.T, servHost, clientHost host.Host, dest peer.ID) {