Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DetectCycles to detect cycles in compared structs #78

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions cmp/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ type state struct {
// These fields, once set by processOption, will not change.
exporters map[reflect.Type]bool // Set of structs with unexported field visibility
opts Options // List of all fundamental and filter options

// detectCycles is set by the DetectCycles option, and it indicates to
// enable the cycles check when a pointer is compared.
detectCycles bool
}

func newState(opts []Option) *state {
Expand Down Expand Up @@ -156,6 +160,8 @@ func (s *state) processOption(opt Option) {
panic("difference reporter already registered")
}
s.reporter = opt
case detectCycles:
s.detectCycles = true
default:
panic(fmt.Sprintf("unknown option %T", opt))
}
Expand All @@ -180,8 +186,6 @@ func (s *state) statelessCompare(vx, vy reflect.Value) diff.Result {
}

func (s *state) compareAny(vx, vy reflect.Value) {
// TODO: Support cyclic data structures.

// Rule 0: Differing types are never equal.
if !vx.IsValid() || !vy.IsValid() {
s.report(vx.IsValid() == vy.IsValid(), vx, vy)
Expand Down Expand Up @@ -239,8 +243,21 @@ func (s *state) compareAny(vx, vy reflect.Value) {
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
return
}
s.curPath.push(&indirect{pathStep{t.Elem()}})
s.curPath.push(&indirect{
pathStep: pathStep{t.Elem()},
xAddr: vx.Elem().UnsafeAddr(),
yAddr: vy.Elem().UnsafeAddr(),
})
defer s.curPath.pop()

// If detectCycles is enabled, look in the path stack and search for
// pointer cycle. If one is find, compare the cycles.
if s.detectCycles {
if cyclesEqual, ok := compareCycles(s.curPath); ok {
s.report(cyclesEqual, vx, vy)
return
}
}
s.compareAny(vx.Elem(), vy.Elem())
return
case reflect.Interface:
Expand Down Expand Up @@ -552,3 +569,53 @@ func makeAddressable(v reflect.Value) reflect.Value {
vc.Set(v)
return vc
}

// compareCycles calculates the cyclic pointer chain length by going
// over the path stack.
// It calculates both a cyclic chain of x values and for y values, and
// compare them.
// It returns:
// equals if detected cycles are equal.
// ok as true if cycles were detected.
func compareCycles(p Path) (equals bool, ok bool) {
if len(p) == 0 {
return false, false
}

// Check the current path step for the address values.
// If it is not a cycleStep, there is no point to check the
// chains lengths
curStep, ok := p[len(p)-1].(*indirect)
if !ok || (curStep.xAddr == 0 && curStep.yAddr == 0) {
return false, false
}

// Find the next occurrence in the chain path of either xAddr or yAddr
// of the curStep.
var xLen, yLen, length int
for i := len(p) - 2; i > 0; i-- {
cs, ok := p[i].(*indirect)
if !ok || (curStep.xAddr == 0 && curStep.yAddr == 0) {
continue
}

// since this step is a pointer step, increase the chain length
length += 1

// Check for the same address. We want to return the smallest cycle,
// so we update only if the current length value is 0
if xLen == 0 && cs.xAddr == curStep.xAddr {
xLen = length
}
if yLen == 0 && cs.yAddr == curStep.yAddr {
yLen = length
}

// If we found lengths for both x and y, we can return, there is no
// need to go over the whole stack.
if xLen != 0 && yLen != 0 {
break
}
}
return xLen == yLen, yLen != 0 && xLen != 0
}
93 changes: 93 additions & 0 deletions cmp/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (

var now = time.Now()

var reAddress = regexp.MustCompile(`\(0x[0-9a-f]+\)`)

func intPtr(n int) *int { return &n }

type test struct {
Expand All @@ -44,6 +46,7 @@ func TestDiff(t *testing.T) {
tests = append(tests, transformerTests()...)
tests = append(tests, embeddedTests()...)
tests = append(tests, methodTests()...)
tests = append(tests, detectCyclesTest()...)
tests = append(tests, project1Tests()...)
tests = append(tests, project2Tests()...)
tests = append(tests, project3Tests()...)
Expand All @@ -63,12 +66,21 @@ func TestDiff(t *testing.T) {
}
}
}()

// Add to all test the DetectCycles options: see if it breaks any test, and see that
// it works for the detectCycleTest
tt.opts = append(tt.opts, cmp.DetectCycles())

gotDiff = cmp.Diff(tt.x, tt.y, tt.opts...)
}()
if tt.wantPanic == "" {
if gotPanic != "" {
t.Fatalf("unexpected panic message: %s", gotPanic)
}

// change all addresses in the diff to be 0x00, so they could be expected
gotDiff = reAddress.ReplaceAllString(gotDiff, "(0x00)")

if got, want := strings.TrimSpace(gotDiff), strings.TrimSpace(tt.wantDiff); got != want {
t.Fatalf("difference message:\ngot:\n%s\n\nwant:\n%s", got, want)
}
Expand Down Expand Up @@ -1966,6 +1978,87 @@ func project4Tests() []test {
}}
}

func detectCyclesTest() []test {
const label = "DetectCycles/"

type node struct {
Value string
Next *node
}

var a = node{Value: "a"}
a.Next = &a

var anotherA = node{Value: "a"}
anotherA.Next = &anotherA

var b = node{Value: "b"}
b.Next = &b

// a cyclic link list in length 2
var len21, len22 node
len21.Next = &len22
len22.Next = &len21

// a cyclic link list in length 3
var len31, len32, len33 node
len31.Next = &len32
len32.Next = &len33
len33.Next = &len31

var insideA1, insideA2, insideA3 node
insideA1.Next = &insideA2
insideA2.Next = &insideA3
insideA3.Next = &insideA1
insideA2.Value = "a"

var insideB1, insideB2, insideB3 node
insideB1.Next = &insideB2
insideB2.Next = &insideB3
insideB3.Next = &insideB1
insideB2.Value = "b"

return []test{{
label: label + "simple cycle/different",
x: a,
y: b,
wantDiff: `
{cmp_test.node}.Value:
-: "a"
+: "b"
{cmp_test.node}.Next.Value:
-: "a"
+: "b"
`,
}, {
label: label + "simple cycle/equal",
x: a,
y: anotherA,
}, {
label: label + "simple cycle/equal identity",
x: a,
y: a,
}, {
label: label + "different size cycles",
x: len21,
y: len31,
wantDiff: `
*{cmp_test.node}.Next.Next.Next.Next:
-: &cmp_test.node{Next: &cmp_test.node{Next: (*cmp_test.node)(0x00)}}
+: &cmp_test.node{Next: &cmp_test.node{Next: &cmp_test.node{Next: (*cmp_test.node)(0x00)}}}
`,
}, {
label: label + "value inside an equal cycle is different",
x: insideA1,
y: insideB1,
wantDiff: `
{cmp_test.node}.Next.Value:
-: "a"
+: "b"
`,
}}
}

// TODO: Delete this hack when we drop Go1.6 support.
func tRunParallel(t *testing.T, name string, f func(t *testing.T)) {
type runner interface {
Expand Down
14 changes: 14 additions & 0 deletions cmp/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,17 @@ func getFuncName(p uintptr) string {
}
return name
}

// DetectCycles is an option that prevents infinite searching in cyclic data structure.
// It iterates the whole path stack every time a pointer is tested to check for
// cycles thus it reduces the runtime performance.
// It finds cycles by adding information to the path stack any time a pointer is tested.
func DetectCycles() Option { return detectCycles{} }

type detectCycles struct{}

func (c detectCycles) filter(s *state, vx, vy reflect.Value, t reflect.Type) applicableOption {
return nil
}

func (c detectCycles) String() string { return "DetectCycles()" }
3 changes: 3 additions & 0 deletions cmp/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ type (
}
indirect struct {
pathStep
// xAddr and yAddr are set if the kind of the current step is Ptr.
// They contain the addresses pointing by vx an vy.
xAddr, yAddr uintptr
}
transform struct {
pathStep
Expand Down