Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

Commit

Permalink
Add DoAndReturn as a complement to Do that does not ignore its return…
Browse files Browse the repository at this point in the history
… values.

Refactored how all the call actions work to make this easier.
  • Loading branch information
balshetzer committed Nov 9, 2017
1 parent 61503c5 commit 8321731
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 107 deletions.
131 changes: 84 additions & 47 deletions gomock/call.go
Expand Up @@ -25,12 +25,11 @@ import (
type Call struct {
t TestReporter // for triggering test failures on invalid call setup

receiver interface{} // the receiver of the method call
method string // the name of the method
methodType reflect.Type // the type of the method
args []Matcher // the args
rets []interface{} // the return values (if any)
origin string // file and line number of call setup
receiver interface{} // the receiver of the method call
method string // the name of the method
methodType reflect.Type // the type of the method
args []Matcher // the args
origin string // file and line number of call setup

preReqs []*Call // prerequisite calls

Expand All @@ -39,9 +38,10 @@ type Call struct {

numCalls int // actual number made

// Actions
doFunc reflect.Value
setArgs map[int]reflect.Value
// actions are called when this Call is called. Each action gets the args and
// can set the return values by returning a non-nil slice. Actions run in the
// order they are created.
actions []func([]interface{}) []interface{}
}

// AnyTimes allows the expectation to be called 0 or more times
Expand Down Expand Up @@ -70,11 +70,56 @@ func (c *Call) MaxTimes(n int) *Call {
return c
}

// Do declares the action to run when the call is matched.
// DoAndReturn declares the action to run when the call is matched.
// The return values from this function are returned by the mocked function.
// It takes an interface{} argument to support n-arity functions.
func (c *Call) DoAndReturn(f interface{}) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
}
}
vrets := v.Call(vargs)
rets := make([]interface{}, len(vrets))
for i, ret := range vrets {
rets[i] = ret.Interface()
}
return rets
})
return c
}

// Do declares the action to run when the call is matched. The function's
// return values are ignored to retain backward compatibility. To use the
// return values call DoAndReturn.
// It takes an interface{} argument to support n-arity functions.
func (c *Call) Do(f interface{}) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
c.doFunc = reflect.ValueOf(f)
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
}
v.Call(vargs)
}
return nil
})
return c
}

Expand Down Expand Up @@ -113,7 +158,10 @@ func (c *Call) Return(rets ...interface{}) *Call {
}
}

c.rets = rets
c.addAction(func([]interface{}) []interface{} {
return rets
})

return c
}

Expand All @@ -131,9 +179,6 @@ func (c *Call) SetArg(n int, value interface{}) *Call {
h.Helper()
}

if c.setArgs == nil {
c.setArgs = make(map[int]reflect.Value)
}
mt := c.methodType
// TODO: This will break on variadic methods.
// We will need to check those at invocation time.
Expand All @@ -159,7 +204,17 @@ func (c *Call) SetArg(n int, value interface{}) *Call {
c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]",
n, at, c.origin)
}
c.setArgs[n] = reflect.ValueOf(value)

c.addAction(func(args []interface{}) []interface{} {
v := reflect.ValueOf(value)
switch reflect.TypeOf(args[n]).Kind() {
case reflect.Slice:
setSlice(args[n], v)
default:
reflect.ValueOf(args[n]).Elem().Set(v)
}
return nil
})
return c
}

Expand Down Expand Up @@ -296,43 +351,21 @@ func (c *Call) dropPrereqs() (preReqs []*Call) {
return
}

func (c *Call) call(args []interface{}) (rets []interface{}, action func()) {
c.numCalls++

// Actions
if c.doFunc.IsValid() {
doArgs := make([]reflect.Value, len(args))
ft := c.doFunc.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
doArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
doArgs[i] = reflect.Zero(ft.In(i))
}
}
action = func() { c.doFunc.Call(doArgs) }
}
for n, v := range c.setArgs {
switch reflect.TypeOf(args[n]).Kind() {
case reflect.Slice:
setSlice(args[n], v)
default:
reflect.ValueOf(args[n]).Elem().Set(v)
}
}

rets = c.rets
if rets == nil {
func (c *Call) defaultActions() []func([]interface{}) []interface{} {
return []func([]interface{}) []interface{}{func([]interface{}) []interface{} {
// Synthesize the zero value for each of the return args' types.
mt := c.methodType
rets = make([]interface{}, mt.NumOut())
rets := make([]interface{}, mt.NumOut())
for i := 0; i < mt.NumOut(); i++ {
rets[i] = reflect.Zero(mt.Out(i)).Interface()
}
}
return rets
}}
}

return
func (c *Call) call(args []interface{}) []func([]interface{}) []interface{} {
c.numCalls++
return c.actions
}

// InOrder declares that the given calls should occur in order.
Expand All @@ -348,3 +381,7 @@ func setSlice(arg interface{}, v reflect.Value) {
va.Index(i).Set(v.Index(i))
}
}

func (c *Call) addAction(action func([]interface{}) []interface{}) {
c.actions = append(c.actions, action)
}
32 changes: 0 additions & 32 deletions gomock/call_test.go
@@ -1,7 +1,6 @@
package gomock

import (
"reflect"
"testing"
)

Expand Down Expand Up @@ -50,34 +49,3 @@ func TestCall_After(t *testing.T) {
}
})
}

func TestCall_SetArg(t *testing.T) {
t.Run("SetArgSlice", func(t *testing.T) {
c := &Call{
methodType: reflect.TypeOf(func([]byte) {}),
t: &mockTestReporter{},
}
c.SetArg(0, []byte{1, 2, 3})

in := []byte{4, 5, 6}
c.call([]interface{}{in})

if in[0] != 1 || in[1] != 2 || in[2] != 3 {
t.Error("Expected SetArg() to modify input slice argument")
}
})

t.Run("SetArgPointer", func(t *testing.T) {
c := &Call{
methodType: reflect.TypeOf(func(*int) {}),
t: &mockTestReporter{},
}
c.SetArg(0, 42)

in := 43
c.call([]interface{}{&in})
if in != 42 {
t.Error("Expected SetArg() to modify value pointed to by argument")
}
})
}
57 changes: 29 additions & 28 deletions gomock/controller.go
Expand Up @@ -116,8 +116,7 @@ func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...
}
}
ctrl.t.Fatalf("gomock: failed finding method %s on %T", method, receiver)
// In case t.Fatalf does not panic.
panic(fmt.Sprintf("gomock: failed finding method %s on %T", method, receiver))
panic("unreachable")
}

func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
Expand All @@ -140,6 +139,7 @@ func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method st

origin := callerInfo(2)
call := &Call{t: ctrl.t, receiver: receiver, method: method, methodType: methodType, args: margs, origin: origin, minCalls: 1, maxCalls: 1}
call.actions = call.defaultActions()

ctrl.expectedCalls.Add(call)
return call
Expand All @@ -150,36 +150,37 @@ func (ctrl *Controller) Call(receiver interface{}, method string, args ...interf
h.Helper()
}

ctrl.mu.Lock()
defer ctrl.mu.Unlock()
// Nest this code so we can use defer to make sure the lock is released.
actions := func() []func([]interface{}) []interface{} {
ctrl.mu.Lock()
defer ctrl.mu.Unlock()

expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
if err != nil {
origin := callerInfo(2)
ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
}
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
if err != nil {
origin := callerInfo(2)
ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
}

// Two things happen here:
// * the matching call no longer needs to check prerequite calls,
// * and the prerequite calls are no longer expected, so remove them.
preReqCalls := expected.dropPrereqs()
for _, preReqCall := range preReqCalls {
ctrl.expectedCalls.Remove(preReqCall)
}
// Two things happen here:
// * the matching call no longer needs to check prerequite calls,
// * and the prerequite calls are no longer expected, so remove them.
preReqCalls := expected.dropPrereqs()
for _, preReqCall := range preReqCalls {
ctrl.expectedCalls.Remove(preReqCall)
}

rets, action := expected.call(args)
if expected.exhausted() {
ctrl.expectedCalls.Remove(expected)
}
actions := expected.call(args)
if expected.exhausted() {
ctrl.expectedCalls.Remove(expected)
}
return actions
}()

// Don't hold the lock while doing the call's action (if any)
// so that actions may execute concurrently.
// We use the deferred Unlock to capture any panics that happen above;
// here we add a deferred Lock to balance it.
ctrl.mu.Unlock()
defer ctrl.mu.Lock()
if action != nil {
action()
var rets []interface{}
for _, action := range actions {
if r := action(args); r != nil {
rets = r
}
}

return rets
Expand Down

0 comments on commit 8321731

Please sign in to comment.