Skip to content

Commit

Permalink
Merge f4d75b0 into ae156c4
Browse files Browse the repository at this point in the history
  • Loading branch information
lanzafame committed Jun 27, 2018
2 parents ae156c4 + f4d75b0 commit 962492b
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 53 deletions.
1 change: 0 additions & 1 deletion call.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ type Call struct {

errorMu sync.Mutex
Error error // After completion, the error status.

}

func newCall(ctx context.Context, dest peer.ID, svcName, svcMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
Expand Down
47 changes: 29 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rpc

import (
"context"
"errors"
"io"
"sync"

Expand Down Expand Up @@ -227,17 +226,21 @@ func checkMatchingLengths(l ...int) bool {
// makeCall decides if a call can be performed. If it's a local
// call it will use the configured server if set.
func (c *Client) makeCall(call *Call) {
logger.Debugf("makeCall: %s.%s",
logger.Debugf(
"makeCall: %s.%s",
call.SvcID.Name,
call.SvcID.Method)
call.SvcID.Method,
)

// Handle local RPC calls
if call.Dest == "" || call.Dest == c.host.ID() {
logger.Debugf("local call: %s.%s",
call.SvcID.Name, call.SvcID.Method)
logger.Debugf(
"local call: %s.%s",
call.SvcID.Name,
call.SvcID.Method,
)
if c.server == nil {
err := errors.New(
"Cannot make local calls: server not set")
err := &clientError{"Cannot make local calls: server not set"}
call.doneWithError(err)
return
}
Expand All @@ -263,29 +266,33 @@ func (c *Client) send(call *Call) {

s, err := c.host.NewStream(call.ctx, call.Dest, c.protocol)
if err != nil {
call.doneWithError(err)
call.doneWithError(newClientError(err))
return
}
defer s.Close()
go call.watchContextWithStream(s)

sWrap := wrapStream(s)

logger.Debugf("sending RPC %s.%s to %s", call.SvcID.Name,
call.SvcID.Method, call.Dest)
logger.Debugf(
"sending RPC %s.%s to %s",
call.SvcID.Name,
call.SvcID.Method,
call.Dest,
)
if err := sWrap.enc.Encode(call.SvcID); err != nil {
call.doneWithError(err)
call.doneWithError(newClientError(err))
s.Reset()
return
}
if err := sWrap.enc.Encode(call.Args); err != nil {
call.doneWithError(err)
call.doneWithError(newClientError(err))
s.Reset()
return
}

if err := sWrap.w.Flush(); err != nil {
call.doneWithError(err)
call.doneWithError(newClientError(err))
s.Reset()
return
}
Expand All @@ -294,24 +301,28 @@ func (c *Client) send(call *Call) {

// receiveResponse reads a response to an RPC call
func receiveResponse(s *streamWrap, call *Call) {
logger.Debugf("waiting response for %s.%s to %s", call.SvcID.Name,
call.SvcID.Method, call.Dest)
logger.Debugf(
"waiting response for %s.%s to %s",
call.SvcID.Name,
call.SvcID.Method,
call.Dest,
)
var resp Response
if err := s.dec.Decode(&resp); err != nil {
call.doneWithError(err)
call.doneWithError(newClientError(err))
s.stream.Reset()
return
}

defer call.done()
if e := resp.Error; e != "" {
call.setError(errors.New(e))
call.setError(responseError(resp.ErrType, e))
}

// Even on error we sent the reply so it needs to be
// read
if err := s.dec.Decode(call.Reply); err != nil && err != io.EOF {
call.setError(err)
call.setError(newClientError(err))
}
return
}
104 changes: 104 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package rpc

import "errors"

// responseErr is an enum type for providing error type
// information over the wire between rpc server and client.
type responseErr int

const (
// nonRPCErr is an error that hasn't arisen from the gorpc package.
nonRPCErr responseErr = iota
// serverErr is an error that has arisen on the server side.
serverErr
// clientErr is an error that has arisen on the client side.
clientErr
)

// serverError indicates that error originated in server
// specific code.
type serverError struct {
msg string
}

func (s *serverError) Error() string {
return s.msg
}

// newServerError wraps an error in the serverError type.
func newServerError(err error) error {
return &serverError{err.Error()}
}

// clientError indicates that error originated in client
// specific code.
type clientError struct {
msg string
}

func (c *clientError) Error() string {
return c.msg
}

// newClientError wraps an error in the clientError type.
func newClientError(err error) error {
return &clientError{err.Error()}
}

// responseError converts an responseErr and error message string
// into the appropriate error type.
func responseError(errType responseErr, errMsg string) error {
switch errType {
case serverErr:
return &serverError{errMsg}
case clientErr:
return &clientError{errMsg}
default:
return errors.New(errMsg)
}
}

// responseErrorType determines whether an error is of either
// serverError or clientError type and returns the appropriate
// responseErr value.
func responseErrorType(err error) responseErr {
switch err.(type) {
case *serverError:
return serverErr
case *clientError:
return clientErr
default:
return nonRPCErr
}
}

// IsRPCError returns whether an error is either a serverError
// or clientError.
func IsRPCError(err error) bool {
switch err.(type) {
case *serverError, *clientError:
return true
default:
return false
}
}

// IsServerError returns whether an error is serverError.
func IsServerError(err error) bool {
switch err.(type) {
case *serverError:
return true
default:
return false
}
}

// IsClientError returns whether an error is clientError.
func IsClientError(err error) bool {
switch err.(type) {
case *clientError:
return true
default:
return false
}
}
38 changes: 23 additions & 15 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ type ServiceID struct {
type Response struct {
Service ServiceID
Error string // error, if any.
ErrType responseErr
}

// Server is an LibP2P RPC server. It can register services which comply to the
Expand Down Expand Up @@ -128,7 +129,7 @@ func NewServer(h host.Host, p protocol.ID) *Server {
err := s.handle(sWrap)
if err != nil {
logger.Error("error handling RPC:", err)
resp := &Response{ServiceID{}, err.Error()}
resp := &Response{ServiceID{}, err.Error(), responseErrorType(err)}
sendResponse(sWrap, resp, nil)
}
})
Expand All @@ -152,14 +153,14 @@ func (server *Server) handle(s *streamWrap) error {

err := s.dec.Decode(&svcID)
if err != nil {
return err
return newServerError(err)
}

logger.Debugf("RPC ServiceID is %s.%s", svcID.Name, svcID.Method)

service, mtype, err := server.getService(svcID)
if err != nil {
return err
return newServerError(err)
}

// Decode the argument value.
Expand All @@ -172,7 +173,7 @@ func (server *Server) handle(s *streamWrap) error {
}
// argv guaranteed to be a pointer now.
if err = s.dec.Decode(argv.Interface()); err != nil {
return err
return newServerError(err)
}
if argIsValue {
argv = argv.Elem()
Expand Down Expand Up @@ -216,7 +217,7 @@ func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID,
if errInter != nil {
errmsg = errInter.(error).Error()
}
resp := &Response{svcID, errmsg}
resp := &Response{svcID, errmsg, nonRPCErr}

return sendResponse(sWrap, resp, replyv.Interface())
}
Expand Down Expand Up @@ -245,7 +246,7 @@ func (server *Server) Call(call *Call) error {
var argv, replyv reflect.Value
service, mtype, err := server.getService(call.SvcID)
if err != nil {
return err
return newServerError(err)
}

// Use the context value from the call directly
Expand All @@ -257,15 +258,19 @@ func (server *Server) Call(call *Call) error {
if reflect.TypeOf(call.Args).Kind() != reflect.Ptr {
return fmt.Errorf(
"%s.%s is being called with the wrong arg type",
call.SvcID.Name, call.SvcID.Method)
call.SvcID.Name,
call.SvcID.Method,
)
}
argv = reflect.New(mtype.ArgType.Elem())
argv.Elem().Set(reflect.ValueOf(call.Args).Elem())
} else {
if reflect.TypeOf(call.Args).Kind() == reflect.Ptr {
return fmt.Errorf(
"%s.%s is being called with the wrong arg type",
call.SvcID.Name, call.SvcID.Method)
call.SvcID.Name,
call.SvcID.Method,
)
}
argv = reflect.New(mtype.ArgType)
argv.Elem().Set(reflect.ValueOf(call.Args))
Expand All @@ -282,11 +287,14 @@ func (server *Server) Call(call *Call) error {
// Call service and respond
function := mtype.method.Func
// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{
service.rcvr,
ctxv, // context
argv, // argument
replyv}) // reply
returnValues := function.Call(
[]reflect.Value{
service.rcvr,
ctxv, // context
argv, // argument
replyv,
},
) // reply

creplyv := reflect.ValueOf(call.Reply)
creplyv.Elem().Set(replyv.Elem())
Expand All @@ -306,12 +314,12 @@ func (server *Server) getService(id ServiceID) (*service, *methodType, error) {
server.mu.RUnlock()
if service == nil {
err := errors.New("rpc: can't find service " + id.Name)
return nil, nil, err
return nil, nil, newServerError(err)
}
mtype := service.method[id.Method]
if mtype == nil {
err := errors.New("rpc: can't find method " + id.Method)
return nil, nil, err
return nil, nil, newServerError(err)
}
return service, mtype, nil
}
Expand Down

0 comments on commit 962492b

Please sign in to comment.