Skip to content

Commit

Permalink
Fix some race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
hsanjuan committed Mar 7, 2018
1 parent 1408c7e commit 93c39c3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
18 changes: 16 additions & 2 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ type Call struct {
SvcID ServiceID // The name of the service and method to call.
Args interface{} // The argument to the function (*struct).
Reply interface{} // The reply from the function (*struct).
Error error // After completion, the error status.
Done chan *Call // Strobes when call is complete.

errorMu sync.Mutex
Error error // After completion, the error status.

}

func newCall(ctx context.Context, dest peer.ID, svcName, svcMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
Expand Down Expand Up @@ -57,7 +60,10 @@ func (call *Call) done() {
}

func (call *Call) doneWithError(err error) {
call.Error = err
if err != nil {
logger.Error(err)
call.setError(err)
}
call.done()
}

Expand All @@ -82,3 +88,11 @@ func (call *Call) watchContextWithStream(s inet.Stream) {
}
}
}

func (call *Call) setError(err error) {
call.errorMu.Lock()
defer call.errorMu.Unlock()
if call.Error == nil {
call.Error = err
}
}
11 changes: 3 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,11 @@ func (c *Client) makeCall(call *Call) {
if c.server == nil {
err := errors.New(
"Cannot make local calls: server not set")
logger.Error(err)
call.doneWithError(err)
return
}
err := c.server.Call(call)
call.Error = err
if err != nil {
logger.Error(err)
}
call.done()
call.doneWithError(err)
return
}

Expand Down Expand Up @@ -185,13 +180,13 @@ func receiveResponse(s *streamWrap, call *Call) {

defer call.done()
if e := resp.Error; e != "" {
call.Error = errors.New(e)
call.setError(errors.New(e))
}

// Even on error we sent the reply so it needs to be
// read
if err := s.dec.Decode(call.Reply); err != nil && err != io.EOF {
call.Error = err
call.setError(err)
}
return
}
19 changes: 15 additions & 4 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -31,7 +32,15 @@ type Quotient struct {
}

type Arith struct {
ctxCancelled bool
sleepCancelledMu sync.Mutex
sleepCancelled bool
}

// helper to see if we cancelled the context in Sleep()
func (t *Arith) isSleepCancelled() bool {
t.sleepCancelledMu.Lock()
defer t.sleepCancelledMu.Unlock()
return t.sleepCancelled
}

func (t *Arith) Multiply(ctx context.Context, args *Args, reply *int) error {
Expand Down Expand Up @@ -63,7 +72,9 @@ func (t *Arith) Sleep(ctx context.Context, secs int, res *struct{}) error {
tim := time.NewTimer(time.Duration(secs) * time.Second)
select {
case <-ctx.Done():
t.ctxCancelled = true
t.sleepCancelledMu.Lock()
t.sleepCancelled = true
t.sleepCancelledMu.Unlock()
return ctx.Err()
case <-tim.C:
return nil
Expand Down Expand Up @@ -266,7 +277,7 @@ func TestCallContextLocal(t *testing.T) {

time.Sleep(200 * time.Millisecond)

if !arith.ctxCancelled {
if !arith.isSleepCancelled() {
t.Error("expected ctx cancellation in the function")
}
}
Expand Down Expand Up @@ -294,7 +305,7 @@ func TestCallContextRemote(t *testing.T) {

time.Sleep(200 * time.Millisecond)

if !arith.ctxCancelled {
if !arith.isSleepCancelled() {
t.Error("expected ctx cancellation in the function")
}
}
Expand Down

0 comments on commit 93c39c3

Please sign in to comment.