diff --git a/call.go b/call.go index 6c17622..e86fbc1 100644 --- a/call.go +++ b/call.go @@ -73,18 +73,11 @@ func (call *Call) watchContextWithStream(s inet.Stream) { select { case <-call.ctx.Done(): if !call.isFinished() { // context was cancelled not by us - s.Reset() - call.doneWithError(call.ctx.Err()) - } - } -} - -// watch context will wait for a context cancellation -// and close the stream. -func (call *Call) watchContext() { - select { - case <-call.ctx.Done(): - if !call.isFinished() { // context was cancelled not by us + logger.Debug("call context is done before finishing") + // Close() instead of Reset(). This let's the other + // write to the stream without printing errors to + // the console (graceful fail). + s.Close() call.doneWithError(call.ctx.Err()) } } diff --git a/client.go b/client.go index c02b8ac..f244906 100644 --- a/client.go +++ b/client.go @@ -137,10 +137,10 @@ func (c *Client) makeCall(call *Call) { // destination and waiting for a response. func (c *Client) send(call *Call) { logger.Debug("sending remote call") + s, err := c.host.NewStream(call.ctx, call.Dest, c.protocol) if err != nil { - call.Error = err - call.Done <- call + call.doneWithError(err) return } defer s.Close() diff --git a/server.go b/server.go index 696151d..8950758 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ /* Package rpc is heavily inspired by Go standard net/rpc package. It aims to do -the same function, except it uses Libp2p for communication. +the same function, except it uses Libp2p for communication and provides +context support for cancelling operations. A server registers an object, making it visible as a service with the name of the type of the object. After registration, exported methods of the object @@ -12,13 +13,15 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - the method's type is exported. - the method is exported. - - the method has two arguments, both exported (or builtin) types. + - the method has 3 arguments. + - the method's first argument is a context. + - the method's second are third arguments are both exported (or builtin) types. - the method's second argument is a pointer. - the method has return type error. In effect, the method must look schematically like - func (t *T) MethodName(argType T1, replyType *T2) error + func (t *T) MethodName(ctx context.Context, argType T1, replyType *T2) error where T1 and T2 can be marshaled by encoding/gob. @@ -26,17 +29,25 @@ The method's first argument represents the arguments provided by the caller; the second argument represents the result parameters to be returned to the caller. The method's return value, if non-nil, is passed back as a string that the client sees as if created by errors.New. If an error is returned, -the reply parameter will not be sent back to the client. +the reply parameter may not be sent back to the client. In order to use this package, a ready-to-go LibP2P Host must be provided to clients and servers, along with a protocol.ID. rpc will add a stream handler for the given protocol. Hosts must be ready to speak to clients, that is, peers must be part of the peerstore along with keys if secio communication is required. + +Since version 2.0.0, contexts are supported and honored. On the server side, +methods must take a context. A closure or reset of the libp2p stream will +trigger a cancellation of the context received by the functions. +On the client side, the user can optionally provide a context. +Cancelling the client's context will cancel the operation both on the +client and on the server side (by closing the associated stream). */ package rpc import ( + "context" "errors" "fmt" "log" @@ -121,6 +132,7 @@ func NewServer(h host.Host, p protocol.ID) *Server { sendResponse(sWrap, resp, nil) } }) + } return s } @@ -168,15 +180,36 @@ func (server *Server) handle(s *streamWrap) error { replyv = reflect.New(mtype.ReplyType.Elem()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctxv := reflect.ValueOf(ctx) + + // This is a connection watchdog. We do not + // need to read from this stream anymore. + // However we'd like to know if the other side is closed + // (or reset). In that case, we need to cancel our + // context. Note this will also happen at the end + // of a successful operation when we close the stream + // on our side. + go func() { + p := make([]byte, 1) + _, err := s.stream.Read(p) + if err != nil { + cancel() + } + }() + // Call service and respond - return service.svcCall(s, mtype, svcID, argv, replyv) + return service.svcCall(s, mtype, svcID, ctxv, argv, replyv) } // svcCall calls the actual method associated -func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, argv, replyv reflect.Value) error { +func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, ctxv, argv, replyv reflect.Value) error { function := mtype.method.Func + // Invoke the method, providing a new value for the reply. - returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv}) + returnValues := function.Call([]reflect.Value{s.rcvr, ctxv, argv, replyv}) // The return value for the method is an error. errInter := returnValues[0].Interface() errmsg := "" @@ -215,6 +248,9 @@ func (server *Server) Call(call *Call) error { return err } + // Use the context value from the call directly + ctxv := reflect.ValueOf(call.ctx) + // Decode the argument value. argIsValue := false // if true, need to indirect before calling. if mtype.ArgType.Kind() == reflect.Ptr { @@ -248,8 +284,9 @@ func (server *Server) Call(call *Call) error { // Invoke the method, providing a new value for the reply. returnValues := function.Call([]reflect.Value{ service.rcvr, - argv, - replyv}) + ctxv, // context + argv, // argument + replyv}) // reply creplyv := reflect.ValueOf(call.Reply) creplyv.Elem().Set(replyv.Elem()) @@ -379,23 +416,34 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { if method.PkgPath != "" { continue } - // Method needs three ins: receiver, *args, *reply. - if mtype.NumIn() != 3 { + // Method needs four ins: receiver, context.Context, *args, *reply. + if mtype.NumIn() != 4 { if reportErr { log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) } continue } - // First arg need not be a pointer. - argType := mtype.In(1) + + // First argument needs to be a context + ctxType := mtype.In(1) + ctxIntType := reflect.TypeOf((*context.Context)(nil)).Elem() + if !ctxType.Implements(ctxIntType) { + if reportErr { + log.Println(mname, "first argument is not a context.Context:", ctxType) + } + continue + } + + // Second arg need not be a pointer so that's not checked. + argType := mtype.In(2) if !isExportedOrBuiltinType(argType) { if reportErr { log.Println(mname, "argument type not exported:", argType) } continue } - // Second arg must be a pointer. - replyType := mtype.In(2) + // Third arg must be a pointer. + replyType := mtype.In(3) if replyType.Kind() != reflect.Ptr { if reportErr { log.Println("method", mname, "reply type not a pointer:", replyType) diff --git a/server_test.go b/server_test.go index d130032..28deb98 100644 --- a/server_test.go +++ b/server_test.go @@ -30,20 +30,22 @@ type Quotient struct { Quo, Rem int } -type Arith int +type Arith struct { + ctxCancelled bool +} -func (t *Arith) Multiply(args *Args, reply *int) error { +func (t *Arith) Multiply(ctx context.Context, args *Args, reply *int) error { *reply = args.A * args.B return nil } // This uses non pointer args -func (t *Arith) Add(args Args, reply *int) error { +func (t *Arith) Add(ctx context.Context, args Args, reply *int) error { *reply = args.A + args.B return nil } -func (t *Arith) Divide(args *Args, quo *Quotient) error { +func (t *Arith) Divide(ctx context.Context, args *Args, quo *Quotient) error { if args.B == 0 { return errors.New("divide by zero") } @@ -52,13 +54,20 @@ func (t *Arith) Divide(args *Args, quo *Quotient) error { return nil } -func (t *Arith) GimmeError(args *Args, r *int) error { +func (t *Arith) GimmeError(ctx context.Context, args *Args, r *int) error { *r = 42 return errors.New("an error") } -func (t *Arith) Sleep(secs int, res *struct{}) error { - time.Sleep(time.Duration(secs) * time.Second) +func (t *Arith) Sleep(ctx context.Context, secs int, res *struct{}) error { + tim := time.NewTimer(time.Duration(secs) * time.Second) + select { + case <-ctx.Done(): + t.ctxCancelled = true + return ctx.Err() + case <-tim.C: + return nil + } return nil } @@ -234,7 +243,7 @@ func TestErrorResponse(t *testing.T) { } } -func TestCallWithContext(t *testing.T) { +func TestCallWithContextLocal(t *testing.T) { h1, h2 := makeRandomNodes() defer h1.Close() defer h2.Close() @@ -243,8 +252,37 @@ func TestCallWithContext(t *testing.T) { var arith Arith s.Register(&arith) + // Local ctx, cancel := context.WithTimeout(context.Background(), time.Second/2) defer cancel() + err := c.CallWithContext(ctx, h2.ID(), "Arith", "Sleep", 5, &struct{}{}) + if err == nil { + t.Fatal("expected an error") + } + + if !strings.Contains(err.Error(), "context") { + t.Error("expected a context error:", err) + } + + time.Sleep(200 * time.Millisecond) + + if !arith.ctxCancelled { + t.Error("expected ctx cancellation in the function") + } +} + +func TestCallWithContextRemote(t *testing.T) { + h1, h2 := makeRandomNodes() + defer h1.Close() + defer h2.Close() + s := NewServer(h1, "rpc") + c := NewClient(h2, "rpc") + var arith Arith + s.Register(&arith) + + // Local + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() err := c.CallWithContext(ctx, h1.ID(), "Arith", "Sleep", 5, &struct{}{}) if err == nil { t.Fatal("expected an error") @@ -253,6 +291,12 @@ func TestCallWithContext(t *testing.T) { if !strings.Contains(err.Error(), "context") { t.Error("expected a context error:", err) } + + time.Sleep(200 * time.Millisecond) + + if !arith.ctxCancelled { + t.Error("expected ctx cancellation in the function") + } } func TestGoWithContext(t *testing.T) {