diff --git a/td_all.go b/td_all.go index 2ab5242e..f2235563 100644 --- a/td_all.go +++ b/td_all.go @@ -62,56 +62,5 @@ func (a *tdAll) Match(ctx ctxerr.Context, got reflect.Value) (err *ctxerr.Error) } func (a *tdAll) TypeBehind() reflect.Type { - var ( - lastIfType, lastType, curType reflect.Type - severalIfTypes bool - ) - - // - for _, item := range a.items { - if !item.IsValid() { - return nil // no need to go further - } - - if item.Type().Implements(testDeeper) { - curType = item.Interface().(TestDeep).TypeBehind() - - // Ignore unknown TypeBehind - if curType == nil { - continue - } - - // Ignore interface pointers too (see Isa), but keep them in - // mind in case we encounter always the same interface pointer - if curType.Kind() == reflect.Ptr && - curType.Elem().Kind() == reflect.Interface { - if lastIfType == nil { - lastIfType = curType - } else if lastIfType != curType { - severalIfTypes = true - } - continue - } - } else { - curType = item.Type() - } - - if lastType != curType { - if lastType != nil { - return nil - } - lastType = curType - } - } - - // Only one type found - if lastType != nil { - return lastType - } - - // Only one interface type found - if lastIfType != nil && !severalIfTypes { - return lastIfType - } - return nil + return a.uniqTypeBehind() } diff --git a/td_any.go b/td_any.go index ff6a9d0a..34efca92 100644 --- a/td_any.go +++ b/td_any.go @@ -22,6 +22,10 @@ var _ TestDeep = &tdAny{} // Any operator compares data against several expected values. During // a match, at least one of them has to match to succeed. +// +// TypeBehind method can return a non-nil reflect.Type if all items +// known non-interface types are equal, or if only interface types +// are found (mostly issued from Isa()) and they are equal. func Any(expectedValues ...interface{}) TestDeep { return &tdAny{ tdList: newList(expectedValues...), @@ -44,3 +48,7 @@ func (a *tdAny) Match(ctx ctxerr.Context, got reflect.Value) *ctxerr.Error { Expected: a, }) } + +func (a *tdAny) TypeBehind() reflect.Type { + return a.uniqTypeBehind() +} diff --git a/td_any_test.go b/td_any_test.go index 0befacf1..1d1e85e7 100644 --- a/td_any_test.go +++ b/td_any_test.go @@ -7,6 +7,7 @@ package testdeep_test import ( + "fmt" "testing" "github.com/maxatome/go-testdeep" @@ -48,5 +49,37 @@ func TestAny(t *testing.T) { } func TestAnyTypeBehind(t *testing.T) { - equalTypes(t, testdeep.Any(6), nil) + equalTypes(t, testdeep.Any(6, nil), nil) + equalTypes(t, testdeep.Any(6, "toto"), nil) + + equalTypes(t, testdeep.Any(6, testdeep.Zero(), 7, 8), 26) + + // Always the same non-interface type (even if we encounter several + // interface types) + equalTypes(t, + testdeep.Any( + testdeep.Empty(), + 5, + testdeep.Isa((*error)(nil)), // interface type (in fact pointer to ...) + testdeep.Any(6, 7), + testdeep.Isa((*fmt.Stringer)(nil)), // interface type + 8), + 42) + + // Only one interface type + equalTypes(t, + testdeep.Any( + testdeep.Isa((*error)(nil)), + testdeep.Isa((*error)(nil)), + testdeep.Isa((*error)(nil)), + ), + (*error)(nil)) + + // Several interface types, cannot be sure + equalTypes(t, + testdeep.Any( + testdeep.Isa((*error)(nil)), + testdeep.Isa((*fmt.Stringer)(nil)), + ), + nil) } diff --git a/td_list.go b/td_list.go index cd8d456b..f7056736 100644 --- a/td_list.go +++ b/td_list.go @@ -32,3 +32,58 @@ func (l *tdList) String() string { return util.SliceToBuffer(bytes.NewBufferString(l.GetLocation().Func), l.items). String() } + +func (l *tdList) uniqTypeBehind() reflect.Type { + var ( + lastIfType, lastType, curType reflect.Type + severalIfTypes bool + ) + + // + for _, item := range l.items { + if !item.IsValid() { + return nil // no need to go further + } + + if item.Type().Implements(testDeeper) { + curType = item.Interface().(TestDeep).TypeBehind() + + // Ignore unknown TypeBehind + if curType == nil { + continue + } + + // Ignore interface pointers too (see Isa), but keep them in + // mind in case we encounter always the same interface pointer + if curType.Kind() == reflect.Ptr && + curType.Elem().Kind() == reflect.Interface { + if lastIfType == nil { + lastIfType = curType + } else if lastIfType != curType { + severalIfTypes = true + } + continue + } + } else { + curType = item.Type() + } + + if lastType != curType { + if lastType != nil { + return nil + } + lastType = curType + } + } + + // Only one type found + if lastType != nil { + return lastType + } + + // Only one interface type found + if lastIfType != nil && !severalIfTypes { + return lastIfType + } + return nil +}