Skip to content

Commit

Permalink
Add contexts on the server side
Browse files Browse the repository at this point in the history
This makes server calls to RPC methods context-aware.

In the case of regular RPC-via-libp2p calls, a context is
created and associated to the incoming stream. If the stream
is closed or reset, the context will be cancelled. It is
up to the method's implemented to decide what to do with that.

In the case of RPC-shortcut for local calls, the context optionally
provided by the client is used directly.
  • Loading branch information
hsanjuan committed Mar 6, 2018
1 parent 484fb58 commit e82c1a9
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 37 deletions.
17 changes: 5 additions & 12 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,11 @@ func (call *Call) watchContextWithStream(s inet.Stream) {
select {
case <-call.ctx.Done():
if !call.isFinished() { // context was cancelled not by us
s.Reset()
call.doneWithError(call.ctx.Err())
}
}
}

// watch context will wait for a context cancellation
// and close the stream.
func (call *Call) watchContext() {
select {
case <-call.ctx.Done():
if !call.isFinished() { // context was cancelled not by us
logger.Debug("call context is done before finishing")
// Close() instead of Reset(). This let's the other
// write to the stream without printing errors to
// the console (graceful fail).
s.Close()
call.doneWithError(call.ctx.Err())
}
}
Expand Down
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ func (c *Client) makeCall(call *Call) {
// destination and waiting for a response.
func (c *Client) send(call *Call) {
logger.Debug("sending remote call")

s, err := c.host.NewStream(call.ctx, call.Dest, c.protocol)
if err != nil {
call.Error = err
call.Done <- call
call.doneWithError(err)
return
}
defer s.Close()
Expand Down
78 changes: 63 additions & 15 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
Package rpc is heavily inspired by Go standard net/rpc package. It aims to do
the same function, except it uses Libp2p for communication.
the same function, except it uses Libp2p for communication and provides
context support for cancelling operations.
A server registers an object, making it visible as a service with the name of
the type of the object. After registration, exported methods of the object
Expand All @@ -12,31 +13,41 @@ Only methods that satisfy these criteria will be made available for remote
access; other methods will be ignored:
- the method's type is exported.
- the method is exported.
- the method has two arguments, both exported (or builtin) types.
- the method has 3 arguments.
- the method's first argument is a context.
- the method's second are third arguments are both exported (or builtin) types.
- the method's second argument is a pointer.
- the method has return type error.
In effect, the method must look schematically like
func (t *T) MethodName(argType T1, replyType *T2) error
func (t *T) MethodName(ctx context.Context, argType T1, replyType *T2) error
where T1 and T2 can be marshaled by encoding/gob.
The method's first argument represents the arguments provided by the caller;
the second argument represents the result parameters to be returned to the
caller. The method's return value, if non-nil, is passed back as a string
that the client sees as if created by errors.New. If an error is returned,
the reply parameter will not be sent back to the client.
the reply parameter may not be sent back to the client.
In order to use this package, a ready-to-go LibP2P Host must be provided
to clients and servers, along with a protocol.ID. rpc will add a stream
handler for the given protocol. Hosts must be ready to speak to clients,
that is, peers must be part of the peerstore along with keys if secio
communication is required.
Since version 2.0.0, contexts are supported and honored. On the server side,
methods must take a context. A closure or reset of the libp2p stream will
trigger a cancellation of the context received by the functions.
On the client side, the user can optionally provide a context.
Cancelling the client's context will cancel the operation both on the
client and on the server side (by closing the associated stream).
*/
package rpc

import (
"context"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -121,6 +132,7 @@ func NewServer(h host.Host, p protocol.ID) *Server {
sendResponse(sWrap, resp, nil)
}
})

}
return s
}
Expand Down Expand Up @@ -168,15 +180,36 @@ func (server *Server) handle(s *streamWrap) error {

replyv = reflect.New(mtype.ReplyType.Elem())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ctxv := reflect.ValueOf(ctx)

// This is a connection watchdog. We do not
// need to read from this stream anymore.
// However we'd like to know if the other side is closed
// (or reset). In that case, we need to cancel our
// context. Note this will also happen at the end
// of a successful operation when we close the stream
// on our side.
go func() {
p := make([]byte, 1)
_, err := s.stream.Read(p)
if err != nil {
cancel()
}
}()

// Call service and respond
return service.svcCall(s, mtype, svcID, argv, replyv)
return service.svcCall(s, mtype, svcID, ctxv, argv, replyv)
}

// svcCall calls the actual method associated
func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, argv, replyv reflect.Value) error {
func (s *service) svcCall(sWrap *streamWrap, mtype *methodType, svcID ServiceID, ctxv, argv, replyv reflect.Value) error {
function := mtype.method.Func

// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
returnValues := function.Call([]reflect.Value{s.rcvr, ctxv, argv, replyv})
// The return value for the method is an error.
errInter := returnValues[0].Interface()
errmsg := ""
Expand Down Expand Up @@ -215,6 +248,9 @@ func (server *Server) Call(call *Call) error {
return err
}

// Use the context value from the call directly
ctxv := reflect.ValueOf(call.ctx)

// Decode the argument value.
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Ptr {
Expand Down Expand Up @@ -248,8 +284,9 @@ func (server *Server) Call(call *Call) error {
// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{
service.rcvr,
argv,
replyv})
ctxv, // context
argv, // argument
replyv}) // reply

creplyv := reflect.ValueOf(call.Reply)
creplyv.Elem().Set(replyv.Elem())
Expand Down Expand Up @@ -379,23 +416,34 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
if method.PkgPath != "" {
continue
}
// Method needs three ins: receiver, *args, *reply.
if mtype.NumIn() != 3 {
// Method needs four ins: receiver, context.Context, *args, *reply.
if mtype.NumIn() != 4 {
if reportErr {
log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
}
continue
}
// First arg need not be a pointer.
argType := mtype.In(1)

// First argument needs to be a context
ctxType := mtype.In(1)
ctxIntType := reflect.TypeOf((*context.Context)(nil)).Elem()
if !ctxType.Implements(ctxIntType) {
if reportErr {
log.Println(mname, "first argument is not a context.Context:", ctxType)
}
continue
}

// Second arg need not be a pointer so that's not checked.
argType := mtype.In(2)
if !isExportedOrBuiltinType(argType) {
if reportErr {
log.Println(mname, "argument type not exported:", argType)
}
continue
}
// Second arg must be a pointer.
replyType := mtype.In(2)
// Third arg must be a pointer.
replyType := mtype.In(3)
if replyType.Kind() != reflect.Ptr {
if reportErr {
log.Println("method", mname, "reply type not a pointer:", replyType)
Expand Down
60 changes: 52 additions & 8 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,22 @@ type Quotient struct {
Quo, Rem int
}

type Arith int
type Arith struct {
ctxCancelled bool
}

func (t *Arith) Multiply(args *Args, reply *int) error {
func (t *Arith) Multiply(ctx context.Context, args *Args, reply *int) error {
*reply = args.A * args.B
return nil
}

// This uses non pointer args
func (t *Arith) Add(args Args, reply *int) error {
func (t *Arith) Add(ctx context.Context, args Args, reply *int) error {
*reply = args.A + args.B
return nil
}

func (t *Arith) Divide(args *Args, quo *Quotient) error {
func (t *Arith) Divide(ctx context.Context, args *Args, quo *Quotient) error {
if args.B == 0 {
return errors.New("divide by zero")
}
Expand All @@ -52,13 +54,20 @@ func (t *Arith) Divide(args *Args, quo *Quotient) error {
return nil
}

func (t *Arith) GimmeError(args *Args, r *int) error {
func (t *Arith) GimmeError(ctx context.Context, args *Args, r *int) error {
*r = 42
return errors.New("an error")
}

func (t *Arith) Sleep(secs int, res *struct{}) error {
time.Sleep(time.Duration(secs) * time.Second)
func (t *Arith) Sleep(ctx context.Context, secs int, res *struct{}) error {
tim := time.NewTimer(time.Duration(secs) * time.Second)
select {
case <-ctx.Done():
t.ctxCancelled = true
return ctx.Err()
case <-tim.C:
return nil
}
return nil
}

Expand Down Expand Up @@ -234,7 +243,7 @@ func TestErrorResponse(t *testing.T) {
}
}

func TestCallWithContext(t *testing.T) {
func TestCallWithContextLocal(t *testing.T) {
h1, h2 := makeRandomNodes()
defer h1.Close()
defer h2.Close()
Expand All @@ -243,8 +252,37 @@ func TestCallWithContext(t *testing.T) {
var arith Arith
s.Register(&arith)

// Local
ctx, cancel := context.WithTimeout(context.Background(), time.Second/2)
defer cancel()
err := c.CallWithContext(ctx, h2.ID(), "Arith", "Sleep", 5, &struct{}{})
if err == nil {
t.Fatal("expected an error")
}

if !strings.Contains(err.Error(), "context") {
t.Error("expected a context error:", err)
}

time.Sleep(200 * time.Millisecond)

if !arith.ctxCancelled {
t.Error("expected ctx cancellation in the function")
}
}

func TestCallWithContextRemote(t *testing.T) {
h1, h2 := makeRandomNodes()
defer h1.Close()
defer h2.Close()
s := NewServer(h1, "rpc")
c := NewClient(h2, "rpc")
var arith Arith
s.Register(&arith)

// Local
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := c.CallWithContext(ctx, h1.ID(), "Arith", "Sleep", 5, &struct{}{})
if err == nil {
t.Fatal("expected an error")
Expand All @@ -253,6 +291,12 @@ func TestCallWithContext(t *testing.T) {
if !strings.Contains(err.Error(), "context") {
t.Error("expected a context error:", err)
}

time.Sleep(200 * time.Millisecond)

if !arith.ctxCancelled {
t.Error("expected ctx cancellation in the function")
}
}

func TestGoWithContext(t *testing.T) {
Expand Down

0 comments on commit e82c1a9

Please sign in to comment.