From 34926016d9ddfaa34dbe3c64bd2c80cc6ac2cfff Mon Sep 17 00:00:00 2001 From: Adrian Lanzafame Date: Wed, 20 Jun 2018 21:42:13 +1000 Subject: [PATCH 1/2] add ability to distinguish rpc errors This allows for users of the package to determine whether a returned error is from the gorpc package or their code. --- call.go | 1 - client.go | 47 +++++++++++++++--------- errors.go | 84 ++++++++++++++++++++++++++++++++++++++++++ server.go | 38 +++++++++++-------- server_test.go | 99 ++++++++++++++++++++++++++++++++++++++++---------- 5 files changed, 216 insertions(+), 53 deletions(-) create mode 100644 errors.go diff --git a/call.go b/call.go index 4af07e6..1528f40 100644 --- a/call.go +++ b/call.go @@ -26,7 +26,6 @@ type Call struct { errorMu sync.Mutex Error error // After completion, the error status. - } func newCall(ctx context.Context, dest peer.ID, svcName, svcMethod string, args interface{}, reply interface{}, done chan *Call) *Call { diff --git a/client.go b/client.go index aea24de..184440c 100644 --- a/client.go +++ b/client.go @@ -2,7 +2,6 @@ package rpc import ( "context" - "errors" "io" "sync" @@ -227,17 +226,21 @@ func checkMatchingLengths(l ...int) bool { // makeCall decides if a call can be performed. If it's a local // call it will use the configured server if set. func (c *Client) makeCall(call *Call) { - logger.Debugf("makeCall: %s.%s", + logger.Debugf( + "makeCall: %s.%s", call.SvcID.Name, - call.SvcID.Method) + call.SvcID.Method, + ) // Handle local RPC calls if call.Dest == "" || call.Dest == c.host.ID() { - logger.Debugf("local call: %s.%s", - call.SvcID.Name, call.SvcID.Method) + logger.Debugf( + "local call: %s.%s", + call.SvcID.Name, + call.SvcID.Method, + ) if c.server == nil { - err := errors.New( - "Cannot make local calls: server not set") + err := &ClientError{"Cannot make local calls: server not set"} call.doneWithError(err) return } @@ -263,7 +266,7 @@ func (c *Client) send(call *Call) { s, err := c.host.NewStream(call.ctx, call.Dest, c.protocol) if err != nil { - call.doneWithError(err) + call.doneWithError(NewClientError(err)) return } defer s.Close() @@ -271,21 +274,25 @@ func (c *Client) send(call *Call) { sWrap := wrapStream(s) - logger.Debugf("sending RPC %s.%s to %s", call.SvcID.Name, - call.SvcID.Method, call.Dest) + logger.Debugf( + "sending RPC %s.%s to %s", + call.SvcID.Name, + call.SvcID.Method, + call.Dest, + ) if err := sWrap.enc.Encode(call.SvcID); err != nil { - call.doneWithError(err) + call.doneWithError(NewClientError(err)) s.Reset() return } if err := sWrap.enc.Encode(call.Args); err != nil { - call.doneWithError(err) + call.doneWithError(NewClientError(err)) s.Reset() return } if err := sWrap.w.Flush(); err != nil { - call.doneWithError(err) + call.doneWithError(NewClientError(err)) s.Reset() return } @@ -294,24 +301,28 @@ func (c *Client) send(call *Call) { // receiveResponse reads a response to an RPC call func receiveResponse(s *streamWrap, call *Call) { - logger.Debugf("waiting response for %s.%s to %s", call.SvcID.Name, - call.SvcID.Method, call.Dest) + logger.Debugf( + "waiting response for %s.%s to %s", + call.SvcID.Name, + call.SvcID.Method, + call.Dest, + ) var resp Response if err := s.dec.Decode(&resp); err != nil { - call.doneWithError(err) + call.doneWithError(NewClientError(err)) s.stream.Reset() return } defer call.done() if e := resp.Error; e != "" { - call.setError(errors.New(e)) + call.setError(ResponseError(resp.ErrType, e)) } // Even on error we sent the reply so it needs to be // read if err := s.dec.Decode(call.Reply); err != nil && err != io.EOF { - call.setError(err) + call.setError(NewClientError(err)) } return } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..7cf4c6a --- /dev/null +++ b/errors.go @@ -0,0 +1,84 @@ +package rpc + +import "errors" + +// ResponseErr is an enum type for providing error type +// information over the wire between rpc server and client. +type ResponseErr int + +const ( + // NonRPCErr is an error that hasn't arisen from the gorpc package. + NonRPCErr ResponseErr = iota + // ServerErr is an error that has arisen on the server side. + ServerErr + // ClientErr is an error that has arisen on the client side. + ClientErr +) + +// ServerError indicates that error originated in server +// specific code. +type ServerError struct { + msg string +} + +func (s *ServerError) Error() string { + return s.msg +} + +// NewServerError wraps an error in the ServerError type. +func NewServerError(err error) error { + return &ServerError{err.Error()} +} + +// ClientError indicates that error originated in client +// specific code. +type ClientError struct { + msg string +} + +func (c *ClientError) Error() string { + return c.msg +} + +// NewClientError wraps an error in the ClientError type. +func NewClientError(err error) error { + return &ClientError{err.Error()} +} + +// ResponseError converts an ResponseErr and error message string +// into the appropriate error type. +func ResponseError(errType ResponseErr, errMsg string) error { + switch errType { + case ServerErr: + return &ServerError{errMsg} + case ClientErr: + return &ClientError{errMsg} + default: + return errors.New(errMsg) + } +} + +// ResponseErrorType determines whether an error is of either +// ServerError or ClientError type and returns the appropriate +// ResponseErr value. +func ResponseErrorType(err error) ResponseErr { + switch err.(type) { + case *ServerError: + return ServerErr + case *ClientError: + return ClientErr + default: + return NonRPCErr + } +} + +// IsRPCError returns whether an error is either a ServerError +// or ClientError. +func IsRPCError(err error) bool { + switch err.(type) { + case *ServerError, *ClientError: + return true + default: + return false + } +} diff --git a/server.go b/server.go index 07253ba..6571c5a 100644 --- a/server.go +++ b/server.go @@ -96,6 +96,7 @@ type ServiceID struct { type Response struct { Service ServiceID Error string // error, if any. + ErrType ResponseErr } // Server is an LibP2P RPC server. It can register services which comply to the @@ -128,7 +129,7 @@ func NewServer(h host.Host, p protocol.ID) *Server { err := s.handle(sWrap) if err != nil { logger.Error("error handling RPC:", err) - resp := &Response{ServiceID{}, err.Error()} + resp := &Response{ServiceID{}, err.Error(), ResponseErrorType(err)} sendResponse(sWrap, resp, nil) } }) @@ -152,14 +153,14 @@ func (server *Server) handle(s *streamWrap) error { err := s.dec.Decode(&svcID) if err != nil { - return err + return NewServerError(err) } logger.Debugf("RPC ServiceID is %s.%s", svcID.Name, svcID.Method) service, mtype, err := server.getService(svcID) if err != nil { - return err + return NewServerError(err) } // Decode the argument value. @@ -172,7 +173,7 @@ func (server *Server) handle(s *streamWrap) error { } // argv guaranteed to be a pointer now. if err = s.dec.Decode(argv.Interface()); err != nil { - return err + return NewServerError(err) } if argIsValue { argv = argv.Elem() @@ -216,7 +217,7 @@ func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, if errInter != nil { errmsg = errInter.(error).Error() } - resp := &Response{svcID, errmsg} + resp := &Response{svcID, errmsg, NonRPCErr} return sendResponse(sWrap, resp, replyv.Interface()) } @@ -245,7 +246,7 @@ func (server *Server) Call(call *Call) error { var argv, replyv reflect.Value service, mtype, err := server.getService(call.SvcID) if err != nil { - return err + return NewServerError(err) } // Use the context value from the call directly @@ -257,7 +258,9 @@ func (server *Server) Call(call *Call) error { if reflect.TypeOf(call.Args).Kind() != reflect.Ptr { return fmt.Errorf( "%s.%s is being called with the wrong arg type", - call.SvcID.Name, call.SvcID.Method) + call.SvcID.Name, + call.SvcID.Method, + ) } argv = reflect.New(mtype.ArgType.Elem()) argv.Elem().Set(reflect.ValueOf(call.Args).Elem()) @@ -265,7 +268,9 @@ func (server *Server) Call(call *Call) error { if reflect.TypeOf(call.Args).Kind() == reflect.Ptr { return fmt.Errorf( "%s.%s is being called with the wrong arg type", - call.SvcID.Name, call.SvcID.Method) + call.SvcID.Name, + call.SvcID.Method, + ) } argv = reflect.New(mtype.ArgType) argv.Elem().Set(reflect.ValueOf(call.Args)) @@ -282,11 +287,14 @@ func (server *Server) Call(call *Call) error { // Call service and respond function := mtype.method.Func // Invoke the method, providing a new value for the reply. - returnValues := function.Call([]reflect.Value{ - service.rcvr, - ctxv, // context - argv, // argument - replyv}) // reply + returnValues := function.Call( + []reflect.Value{ + service.rcvr, + ctxv, // context + argv, // argument + replyv, + }, + ) // reply creplyv := reflect.ValueOf(call.Reply) creplyv.Elem().Set(replyv.Elem()) @@ -306,12 +314,12 @@ func (server *Server) getService(id ServiceID) (*service, *methodType, error) { server.mu.RUnlock() if service == nil { err := errors.New("rpc: can't find service " + id.Name) - return nil, nil, err + return nil, nil, NewServerError(err) } mtype := service.method[id.Method] if mtype == nil { err := errors.New("rpc: can't find method " + id.Method) - return nil, nil, err + return nil, nil, NewServerError(err) } return service, mtype, nil } diff --git a/server_test.go b/server_test.go index 68ad4f4..af6473e 100644 --- a/server_test.go +++ b/server_test.go @@ -210,26 +210,87 @@ func TestErrorResponse(t *testing.T) { var arith Arith s.Register(&arith) - var r int - // test remote - c := NewClientWithServer(h2, "rpc", s) - err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) - if err == nil || err.Error() != "an error" { - t.Error("expected different error") - } - if r != 42 { - t.Error("response should be set even on error") - } + t.Run("remote", func(t *testing.T) { + var r int + c := NewClientWithServer(h2, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err == nil || err.Error() != "an error" { + t.Error("expected different error") + } + if r != 42 { + t.Error("response should be set even on error") + } + }) - // test local - c = NewClientWithServer(h1, "rpc", s) - err = c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) - if err == nil || err.Error() != "an error" { - t.Error("expected different error") - } - if r != 42 { - t.Error("response should be set even on error") - } + t.Run("local", func(t *testing.T) { + var r int + c := NewClientWithServer(h1, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err == nil || err.Error() != "an error" { + t.Error("expected different error") + } + if r != 42 { + t.Error("response should be set even on error") + } + }) +} + +func TestNonRPCError(t *testing.T) { + h1, h2 := makeRandomNodes() + defer h1.Close() + defer h2.Close() + + s := NewServer(h1, "rpc") + var arith Arith + s.Register(&arith) + + t.Run("local non rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h1, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err != nil { + if IsRPCError(err) { + t.Log(err) + t.Error("expected non rpc error") + } + } + }) + + t.Run("local rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h1, "rpc", s) + err := c.Call(h1.ID(), "Arith", "ThisIsNotAMethod", &Args{1, 2}, &r) + if err != nil { + if !IsRPCError(err) { + t.Log(err) + t.Error("expected rpc error") + } + } + }) + + t.Run("remote non rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h2, "rpc", s) + err := c.Call(h1.ID(), "Arith", "GimmeError", &Args{1, 2}, &r) + if err != nil { + if IsRPCError(err) { + t.Log(err) + t.Error("expected non rpc error") + } + } + }) + + t.Run("remote rpc error", func(t *testing.T) { + var r int + c := NewClientWithServer(h2, "rpc", s) + err := c.Call(h1.ID(), "Arith", "ThisIsNotAMethod", &Args{1, 2}, &r) + if err != nil { + if !IsRPCError(err) { + t.Log(err) + t.Error("expected rpc error") + } + } + }) } func testCallContext(t *testing.T, servHost, clientHost host.Host, dest peer.ID) { From f4d75b00d4809fab2b376f19805b68672cef4116 Mon Sep 17 00:00:00 2001 From: Adrian Lanzafame Date: Wed, 27 Jun 2018 15:02:45 +1000 Subject: [PATCH 2/2] unexport internal types/funcs --- client.go | 16 +++++----- errors.go | 96 +++++++++++++++++++++++++++++++++---------------------- server.go | 18 +++++------ 3 files changed, 75 insertions(+), 55 deletions(-) diff --git a/client.go b/client.go index 184440c..2e4f551 100644 --- a/client.go +++ b/client.go @@ -240,7 +240,7 @@ func (c *Client) makeCall(call *Call) { call.SvcID.Method, ) if c.server == nil { - err := &ClientError{"Cannot make local calls: server not set"} + err := &clientError{"Cannot make local calls: server not set"} call.doneWithError(err) return } @@ -266,7 +266,7 @@ func (c *Client) send(call *Call) { s, err := c.host.NewStream(call.ctx, call.Dest, c.protocol) if err != nil { - call.doneWithError(NewClientError(err)) + call.doneWithError(newClientError(err)) return } defer s.Close() @@ -281,18 +281,18 @@ func (c *Client) send(call *Call) { call.Dest, ) if err := sWrap.enc.Encode(call.SvcID); err != nil { - call.doneWithError(NewClientError(err)) + call.doneWithError(newClientError(err)) s.Reset() return } if err := sWrap.enc.Encode(call.Args); err != nil { - call.doneWithError(NewClientError(err)) + call.doneWithError(newClientError(err)) s.Reset() return } if err := sWrap.w.Flush(); err != nil { - call.doneWithError(NewClientError(err)) + call.doneWithError(newClientError(err)) s.Reset() return } @@ -309,20 +309,20 @@ func receiveResponse(s *streamWrap, call *Call) { ) var resp Response if err := s.dec.Decode(&resp); err != nil { - call.doneWithError(NewClientError(err)) + call.doneWithError(newClientError(err)) s.stream.Reset() return } defer call.done() if e := resp.Error; e != "" { - call.setError(ResponseError(resp.ErrType, e)) + call.setError(responseError(resp.ErrType, e)) } // Even on error we sent the reply so it needs to be // read if err := s.dec.Decode(call.Reply); err != nil && err != io.EOF { - call.setError(NewClientError(err)) + call.setError(newClientError(err)) } return } diff --git a/errors.go b/errors.go index 7cf4c6a..6e87dd0 100644 --- a/errors.go +++ b/errors.go @@ -2,81 +2,101 @@ package rpc import "errors" -// ResponseErr is an enum type for providing error type +// responseErr is an enum type for providing error type // information over the wire between rpc server and client. -type ResponseErr int +type responseErr int const ( - // NonRPCErr is an error that hasn't arisen from the gorpc package. - NonRPCErr ResponseErr = iota - // ServerErr is an error that has arisen on the server side. - ServerErr - // ClientErr is an error that has arisen on the client side. - ClientErr + // nonRPCErr is an error that hasn't arisen from the gorpc package. + nonRPCErr responseErr = iota + // serverErr is an error that has arisen on the server side. + serverErr + // clientErr is an error that has arisen on the client side. + clientErr ) -// ServerError indicates that error originated in server +// serverError indicates that error originated in server // specific code. -type ServerError struct { +type serverError struct { msg string } -func (s *ServerError) Error() string { +func (s *serverError) Error() string { return s.msg } -// NewServerError wraps an error in the ServerError type. -func NewServerError(err error) error { - return &ServerError{err.Error()} +// newServerError wraps an error in the serverError type. +func newServerError(err error) error { + return &serverError{err.Error()} } -// ClientError indicates that error originated in client +// clientError indicates that error originated in client // specific code. -type ClientError struct { +type clientError struct { msg string } -func (c *ClientError) Error() string { +func (c *clientError) Error() string { return c.msg } -// NewClientError wraps an error in the ClientError type. -func NewClientError(err error) error { - return &ClientError{err.Error()} +// newClientError wraps an error in the clientError type. +func newClientError(err error) error { + return &clientError{err.Error()} } -// ResponseError converts an ResponseErr and error message string +// responseError converts an responseErr and error message string // into the appropriate error type. -func ResponseError(errType ResponseErr, errMsg string) error { +func responseError(errType responseErr, errMsg string) error { switch errType { - case ServerErr: - return &ServerError{errMsg} - case ClientErr: - return &ClientError{errMsg} + case serverErr: + return &serverError{errMsg} + case clientErr: + return &clientError{errMsg} default: return errors.New(errMsg) } } -// ResponseErrorType determines whether an error is of either -// ServerError or ClientError type and returns the appropriate -// ResponseErr value. -func ResponseErrorType(err error) ResponseErr { +// responseErrorType determines whether an error is of either +// serverError or clientError type and returns the appropriate +// responseErr value. +func responseErrorType(err error) responseErr { switch err.(type) { - case *ServerError: - return ServerErr - case *ClientError: - return ClientErr + case *serverError: + return serverErr + case *clientError: + return clientErr default: - return NonRPCErr + return nonRPCErr } } -// IsRPCError returns whether an error is either a ServerError -// or ClientError. +// IsRPCError returns whether an error is either a serverError +// or clientError. func IsRPCError(err error) bool { switch err.(type) { - case *ServerError, *ClientError: + case *serverError, *clientError: + return true + default: + return false + } +} + +// IsServerError returns whether an error is serverError. +func IsServerError(err error) bool { + switch err.(type) { + case *serverError: + return true + default: + return false + } +} + +// IsClientError returns whether an error is clientError. +func IsClientError(err error) bool { + switch err.(type) { + case *clientError: return true default: return false diff --git a/server.go b/server.go index 6571c5a..fade0cf 100644 --- a/server.go +++ b/server.go @@ -96,7 +96,7 @@ type ServiceID struct { type Response struct { Service ServiceID Error string // error, if any. - ErrType ResponseErr + ErrType responseErr } // Server is an LibP2P RPC server. It can register services which comply to the @@ -129,7 +129,7 @@ func NewServer(h host.Host, p protocol.ID) *Server { err := s.handle(sWrap) if err != nil { logger.Error("error handling RPC:", err) - resp := &Response{ServiceID{}, err.Error(), ResponseErrorType(err)} + resp := &Response{ServiceID{}, err.Error(), responseErrorType(err)} sendResponse(sWrap, resp, nil) } }) @@ -153,14 +153,14 @@ func (server *Server) handle(s *streamWrap) error { err := s.dec.Decode(&svcID) if err != nil { - return NewServerError(err) + return newServerError(err) } logger.Debugf("RPC ServiceID is %s.%s", svcID.Name, svcID.Method) service, mtype, err := server.getService(svcID) if err != nil { - return NewServerError(err) + return newServerError(err) } // Decode the argument value. @@ -173,7 +173,7 @@ func (server *Server) handle(s *streamWrap) error { } // argv guaranteed to be a pointer now. if err = s.dec.Decode(argv.Interface()); err != nil { - return NewServerError(err) + return newServerError(err) } if argIsValue { argv = argv.Elem() @@ -217,7 +217,7 @@ func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, if errInter != nil { errmsg = errInter.(error).Error() } - resp := &Response{svcID, errmsg, NonRPCErr} + resp := &Response{svcID, errmsg, nonRPCErr} return sendResponse(sWrap, resp, replyv.Interface()) } @@ -246,7 +246,7 @@ func (server *Server) Call(call *Call) error { var argv, replyv reflect.Value service, mtype, err := server.getService(call.SvcID) if err != nil { - return NewServerError(err) + return newServerError(err) } // Use the context value from the call directly @@ -314,12 +314,12 @@ func (server *Server) getService(id ServiceID) (*service, *methodType, error) { server.mu.RUnlock() if service == nil { err := errors.New("rpc: can't find service " + id.Name) - return nil, nil, NewServerError(err) + return nil, nil, newServerError(err) } mtype := service.method[id.Method] if mtype == nil { err := errors.New("rpc: can't find method " + id.Method) - return nil, nil, NewServerError(err) + return nil, nil, newServerError(err) } return service, mtype, nil }