Skip to content

Commit

Permalink
Merge pull request #57 from maxatome/sort-reflect-value
Browse files Browse the repository at this point in the history
Sort []reflect.Value
  • Loading branch information
maxatome committed Feb 16, 2019
2 parents 565d2ec + 4825131 commit bb055aa
Show file tree
Hide file tree
Showing 15 changed files with 500 additions and 51 deletions.
2 changes: 1 addition & 1 deletion cmp_funcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,7 @@ func ExampleCmpSmuggle_interface() {
t.Fatal(err)
}

// Do not check the struct itself, but it stringified form
// Do not check the struct itself, but its stringified form
ok := CmpSmuggle(t, gotTime, func(s fmt.Stringer) string {
return s.String()
}, "2018-05-23 12:13:14 +0000 UTC")
Expand Down
19 changes: 12 additions & 7 deletions equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,15 @@ func deepValueEqual(ctx ctxerr.Context, got, expected reflect.Value) (err *ctxer
expectedLen = expected.Len()
)

// Shortcut in boolean context
if ctx.BooleanError && gotLen != expectedLen {
return ctxerr.BooleanError
}

if got.Pointer() == expected.Pointer() {
return
if gotLen != expectedLen {
// Shortcut in boolean context
if ctx.BooleanError {
return ctxerr.BooleanError
}
} else {
if got.Pointer() == expected.Pointer() {
return
}
}

var maxLen int
Expand All @@ -193,6 +195,7 @@ func deepValueEqual(ctx ctxerr.Context, got, expected reflect.Value) (err *ctxer
if gotLen != expectedLen {
res := tdSetResult{
Kind: itemsSetResult,
// do not sort Extra/Mising here
}

if gotLen > expectedLen {
Expand Down Expand Up @@ -295,6 +298,7 @@ func deepValueEqual(ctx ctxerr.Context, got, expected reflect.Value) (err *ctxer
Summary: (tdSetResult{
Kind: keysSetResult,
Missing: notFoundKeys,
Sort: true,
}).Summary(),
})
}
Expand All @@ -308,6 +312,7 @@ func deepValueEqual(ctx ctxerr.Context, got, expected reflect.Value) (err *ctxer
Kind: keysSetResult,
Missing: notFoundKeys,
Extra: make([]reflect.Value, 0, got.Len()-len(foundKeys)),
Sort: true,
}

for _, vkey := range got.MapKeys() {
Expand Down
23 changes: 18 additions & 5 deletions equal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,20 @@ func TestEqualSlice(t *testing.T) {
checkOK(t, []int{1, 2}, []int{1, 2})

// Same pointer
array := [2]int{1, 2}
array := [...]int{2, 1, 4, 3}
checkOK(t, array[:], array[:])
checkOK(t, ([]int)(nil), ([]int)(nil))

// Same pointer, but not same len
checkError(t, array[:2], array[:],
expectedError{
Message: mustBe("comparing slices, from index #2"),
Path: mustBe("DATA"),
// Missing items are not sorted
Summary: mustBe(`Missing 2 items: (4,
3)`),
})

checkError(t, []int{1, 2}, []int{1, 2, 3},
expectedError{
Message: mustBe("comparing slices, from index #2"),
Expand Down Expand Up @@ -366,13 +376,16 @@ func TestEqualMap(t *testing.T) {
Summary: mustMatch(`Missing key:[^"]+"test"`),
})

checkError(t, map[string]int{"foo": 1, "bar": 4, "test+": 12},
map[string]int{"foo": 1, "bar": 4, "test-": 12},
// Extra and missing keys are sorted
checkError(t, map[string]int{"foo": 1, "bar": 4, "test1+": 12, "test2+": 13},
map[string]int{"foo": 1, "bar": 4, "test1-": 12, "test2-": 13},
expectedError{
Message: mustBe("comparing map"),
Path: mustBe("DATA"),
Summary: mustMatch(`Missing key:[^"]+"test-".*
Extra key:[^"]+"test\+"`),
Summary: mustBe(`Missing 2 keys: ("test1-",
"test2-")
Extra 2 keys: ("test1+",
"test2+")`),
})
}

Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2028,7 +2028,7 @@ func ExampleSmuggle_interface() {
t.Fatal(err)
}

// Do not check the struct itself, but it stringified form
// Do not check the struct itself, but its stringified form
ok := CmpDeeply(t, gotTime,
Smuggle(func(s fmt.Stringer) string {
return s.String()
Expand Down
218 changes: 218 additions & 0 deletions helpers/tdutil/sort.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (c) 2019, Maxime Soulé
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

package tdutil

import (
"math"
"reflect"
"sort"

"github.com/maxatome/go-testdeep/internal/visited"
)

// SortableValues is used to allow the sorting of a []reflect.Value
// slice. It is used with the standard sort package:
//
// vals := []reflect.Value{a, b, c, d}
// sort.Sort(SortableValues(vals))
// // vals contents now sorted
//
// Replace sort.Sort by sort.Stable for a stable sort. See sort documentation.
//
// Sorting rules are as follows:
// - nil is always lower
// - different types are sorted by their name
// - false is lesser than true
// - float and int numbers are sorted by their value
// - complex numbers are sorted by their real, then by their imaginary parts
// - strings are sorted by their value
// - map: shorter length is lesser, then sorted by address
// - functions, channels and unsafe pointer are sorted by their address
// - struct: comparison is spread to each field
// - pointer: comparison is spred to the pointed value
// - arrays: comparison is spread to each item
// - slice: comparison is spread to each item, then shorter length is lesser
// - interface: comparison is spread to the value
//
// Cyclic references are correctly handled.
func SortableValues(s []reflect.Value) sort.Interface {
r := &rValues{
Slice: s,
}
if len(s) > 1 {
r.Visited = visited.NewVisited()
}
return r
}

type rValues struct {
Visited visited.Visited
Slice []reflect.Value
}

func (v *rValues) Len() int {
return len(v.Slice)
}

func (v *rValues) Less(i, j int) bool {
return cmp(v.Visited, v.Slice[i], v.Slice[j]) < 0
}

func (v *rValues) Swap(i, j int) {
v.Slice[i], v.Slice[j] = v.Slice[j], v.Slice[i]
}

func cmpRet(less, gt bool) int {
if less {
return -1
}
if gt {
return 1
}
return 0
}

func cmpFloat(a, b float64) int {
if math.IsNaN(a) {
return -1
}
if math.IsNaN(b) {
return 1
}
return cmpRet(a < b, a > b)
}

// cmp returns -1 if a < b, 1 if a > b, 0 if a == b.
func cmp(v visited.Visited, a, b reflect.Value) int {
if !a.IsValid() {
if !b.IsValid() {
return 0
}
return -1
}
if !b.IsValid() {
return 1
}

if at, bt := a.Type(), b.Type(); at != bt {
sat, sbt := at.String(), bt.String()
return cmpRet(sat < sbt, sat > sbt)
}

// Avoid looping forever on cyclic references
if v.Record(a, b) {
return 0
}

switch a.Kind() {
case reflect.Bool:
if a.Bool() {
if b.Bool() {
return 0
}
return 1
}
if b.Bool() {
return -1
}
return 0

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
na, nb := a.Int(), b.Int()
return cmpRet(na < nb, na > nb)

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64, reflect.Uintptr:
na, nb := a.Uint(), b.Uint()
return cmpRet(na < nb, na > nb)

case reflect.Float32, reflect.Float64:
return cmpFloat(a.Float(), b.Float())

case reflect.Complex64, reflect.Complex128:
na, nb := a.Complex(), b.Complex()
fa, fb := real(na), real(nb)
if r := cmpFloat(fa, fb); r != 0 {
return r
}
return cmpFloat(imag(na), imag(nb))

case reflect.String:
sa, sb := a.String(), b.String()
return cmpRet(sa < sb, sa > sb)

case reflect.Array:
for i := 0; i < a.Len(); i++ {
if r := cmp(v, a.Index(i), b.Index(i)); r != 0 {
return r
}
}
return 0

case reflect.Slice:
al, bl := a.Len(), b.Len()
maxl := al
if al > bl {
maxl = bl
}
for i := 0; i < maxl; i++ {
if r := cmp(v, a.Index(i), b.Index(i)); r != 0 {
return r
}
}
return cmpRet(al < bl, al > bl)

case reflect.Interface:
if a.IsNil() {
if b.IsNil() {
return 0
}
return -1
}
if b.IsNil() {
return 1
}
return cmp(v, a.Elem(), b.Elem())

case reflect.Struct:
for i, m := 0, a.NumField(); i < m; i++ {
if r := cmp(v, a.Field(i), b.Field(i)); r != 0 {
return r
}
}
return 0

case reflect.Ptr:
if a.Pointer() == b.Pointer() {
return 0
}
if a.IsNil() {
return -1
}
if b.IsNil() {
return 1
}
return cmp(v, a.Elem(), b.Elem())

case reflect.Map:
// consider shorter maps are before longer ones
al, bl := a.Len(), b.Len()
if r := cmpRet(al < bl, al > bl); r != 0 {
return r
}
// then fallback on pointers comparison. How to say a map is
// before another one otherwise?
fallthrough

case reflect.Func, reflect.Chan, reflect.UnsafePointer:
pa, pb := a.Pointer(), b.Pointer()
return cmpRet(pa < pb, pa > pb)

default:
panic("don't know how to compare " + a.Kind().String())
}
}

0 comments on commit bb055aa

Please sign in to comment.