Skip to content

Commit

Permalink
Merge pull request #6 from d4l3k/recursive-structures
Browse files Browse the repository at this point in the history
Added support for recursive data structures
  • Loading branch information
d4l3k committed May 24, 2016
2 parents f96923b + 10334d6 commit 376d1c3
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 45 deletions.
108 changes: 77 additions & 31 deletions messagediff.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"reflect"
"sort"
"strings"
"unsafe"
)

// PrettyDiff does a deep comparison and returns the nicely formated results.
Expand All @@ -26,43 +27,87 @@ func PrettyDiff(a, b interface{}) (string, bool) {

// DeepDiff does a deep comparison and returns the results.
func DeepDiff(a, b interface{}) (*Diff, bool) {
d := newdiff()
return d, diff(a, b, nil, d)
d := newDiff()
return d, d.diff(reflect.ValueOf(a), reflect.ValueOf(b), nil)
}

func newdiff() *Diff {
func newDiff() *Diff {
return &Diff{
Added: make(map[*Path]interface{}),
Removed: make(map[*Path]interface{}),
Modified: make(map[*Path]interface{}),
visited: make(map[visit]bool),
}
}

func diff(a, b interface{}, path Path, d *Diff) bool {
aVal := reflect.ValueOf(a)
bVal := reflect.ValueOf(b)
func (d *Diff) diff(aVal, bVal reflect.Value, path Path) bool {
// Validity checks. Should only trigger if nil is one of the original arguments.
if !aVal.IsValid() && !bVal.IsValid() {
// Both are nil.
return true
}
if !aVal.IsValid() || !bVal.IsValid() {
// One is nil and the other isn't.
d.Modified[&path] = b
if !bVal.IsValid() {
d.Modified[&path] = nil
return false
} else if !aVal.IsValid() {
d.Modified[&path] = bVal.Interface()
return false
}

if aVal.Type() != bVal.Type() {
d.Modified[&path] = b
d.Modified[&path] = bVal.Interface()
return false
}
kind := aVal.Type().Kind()
kind := aVal.Kind()

// Borrowed from the reflect package to handle recursive data structures.
hard := func(k reflect.Kind) bool {
switch k {
case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
return true
}
return false
}

if aVal.CanAddr() && bVal.CanAddr() && hard(kind) {
addr1 := unsafe.Pointer(aVal.UnsafeAddr())
addr2 := unsafe.Pointer(bVal.UnsafeAddr())
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
// Assumes non-moving garbage collector.
addr1, addr2 = addr2, addr1
}

// Short circuit if references are already seen.
typ := aVal.Type()
v := visit{addr1, addr2, typ}
if d.visited[v] {
return true
}

// Remember for later.
d.visited[v] = true
}
// End of borrowed code.

equal := true
switch kind {
case reflect.Array, reflect.Map, reflect.Ptr, reflect.Func, reflect.Chan, reflect.Slice:
if aVal.IsNil() && bVal.IsNil() {
return true
}
if aVal.IsNil() || bVal.IsNil() {
d.Modified[&path] = bVal.Interface()
return false
}
}

switch kind {
case reflect.Array, reflect.Slice:
aLen := aVal.Len()
bLen := bVal.Len()
for i := 0; i < min(aLen, bLen); i++ {
localPath := append(path, SliceIndex(i))
if eq := diff(aVal.Index(i).Interface(), bVal.Index(i).Interface(), localPath, d); !eq {
if eq := d.diff(aVal.Index(i), bVal.Index(i), localPath); !eq {
equal = false
}
}
Expand All @@ -87,7 +132,7 @@ func diff(a, b interface{}, path Path, d *Diff) bool {
if !bI.IsValid() {
d.Removed[&localPath] = aI.Interface()
equal = false
} else if eq := diff(aI.Interface(), bI.Interface(), localPath, d); !eq {
} else if eq := d.diff(aI, bI, localPath); !eq {
equal = false
}
}
Expand All @@ -106,30 +151,19 @@ func diff(a, b interface{}, path Path, d *Diff) bool {
index := []int{i}
field := typ.FieldByIndex(index)
localPath := append(path, StructField(field.Name))
aI := unsafeReflectValue(aVal.FieldByIndex(index)).Interface()
bI := unsafeReflectValue(bVal.FieldByIndex(index)).Interface()
if eq := diff(aI, bI, localPath, d); !eq {
aI := unsafeReflectValue(aVal.FieldByIndex(index))
bI := unsafeReflectValue(bVal.FieldByIndex(index))
if eq := d.diff(aI, bI, localPath); !eq {
equal = false
}
}
case reflect.Ptr:
aVal = aVal.Elem()
bVal = bVal.Elem()
if !aVal.IsValid() && !bVal.IsValid() {
// Both are nil.
equal = true
} else if !aVal.IsValid() || !bVal.IsValid() {
// One is nil and the other isn't.
d.Modified[&path] = b
equal = false
} else {
equal = diff(aVal.Interface(), bVal.Interface(), path, d)
}
equal = d.diff(aVal.Elem(), bVal.Elem(), path)
default:
if reflect.DeepEqual(a, b) {
if reflect.DeepEqual(aVal.Interface(), bVal.Interface()) {
equal = true
} else {
d.Modified[&path] = b
d.Modified[&path] = bVal.Interface()
equal = false
}
}
Expand All @@ -143,9 +177,21 @@ func min(a, b int) int {
return b
}

// During deepValueEqual, must keep track of checks that are
// in progress. The comparison algorithm assumes that all
// checks in progress are true when it reencounters them.
// Visited comparisons are stored in a map indexed by visit.
// This is borrowed from the reflect package.
type visit struct {
a1 unsafe.Pointer
a2 unsafe.Pointer
typ reflect.Type
}

// Diff represents a change in a struct.
type Diff struct {
Added, Removed, Modified map[*Path]interface{}
visited map[visit]bool
}

// Path represents a path to a changed datum.
Expand Down
75 changes: 61 additions & 14 deletions messagediff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,43 @@ type testStruct struct {
C []int
}

type RecursiveStruct struct {
Key int
Child *RecursiveStruct
}

func newRecursiveStruct(key int) *RecursiveStruct {
a := &RecursiveStruct{
Key: key,
}
b := &RecursiveStruct{
Key: key,
Child: a,
}
a.Child = b
return a
}

type testCase struct {
a, b interface{}
diff string
equal bool
}

func checkTestCases(t *testing.T, testData []testCase) {
for i, td := range testData {
diff, equal := PrettyDiff(td.a, td.b)
if diff != td.diff {
t.Errorf("%d. PrettyDiff(%#v, %#v) diff = %#v; not %#v", i, td.a, td.b, diff, td.diff)
}
if equal != td.equal {
t.Errorf("%d. PrettyDiff(%#v, %#v) equal = %#v; not %#v", i, td.a, td.b, equal, td.equal)
}
}
}

func TestPrettyDiff(t *testing.T) {
testData := []struct {
a, b interface{}
diff string
equal bool
}{
testData := []testCase{
{
true,
false,
Expand Down Expand Up @@ -71,11 +102,17 @@ func TestPrettyDiff(t *testing.T) {
true,
},
{
&time.Time{},
&struct{}{},
nil,
"modified: = <nil>\n",
false,
},
{
nil,
&struct{}{},
"modified: = &struct {}{}\n",
false,
},
{
time.Time{},
time.Time{},
Expand All @@ -89,15 +126,25 @@ func TestPrettyDiff(t *testing.T) {
false,
},
}
for i, td := range testData {
diff, equal := PrettyDiff(td.a, td.b)
if diff != td.diff {
t.Errorf("%d. PrettyDiff(%#v, %#v) diff = %#v; not %#v", i, td.a, td.b, diff, td.diff)
}
if equal != td.equal {
t.Errorf("%d. PrettyDiff(%#v, %#v) equal = %#v; not %#v", i, td.a, td.b, equal, td.equal)
}
checkTestCases(t, testData)
}

func TestPrettyDiffRecursive(t *testing.T) {
testData := []testCase{
{
newRecursiveStruct(1),
newRecursiveStruct(1),
"",
true,
},
{
newRecursiveStruct(1),
newRecursiveStruct(2),
"modified: .Child.Key = 2\nmodified: .Key = 2\n",
false,
},
}
checkTestCases(t, testData)
}

func TestPathString(t *testing.T) {
Expand Down

0 comments on commit 376d1c3

Please sign in to comment.