From 75a33698bb1bc68ea2d739a4e7794e8b2b9abba1 Mon Sep 17 00:00:00 2001 From: Brady Catherman Date: Wed, 30 Nov 2016 13:02:24 -0700 Subject: [PATCH] Make sure that Interface() is not called on private values. --- equal.go | 20 ++++++++++++++++---- equal_test.go | 10 ++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/equal.go b/equal.go index 8189e6f..dd0ab13 100644 --- a/equal.go +++ b/equal.go @@ -147,6 +147,18 @@ func (t *T) isNil(obj interface{}) bool { return v.IsNil() } +// Returns a string representation of a Value that is sanitized based on what +// we are allowed to see. Private fields can not be exposed via a call to +// Interface() and trying will cause a panic, so we must use String() +// in that case. +func stringValue(v reflect.Value) string { + if v.CanInterface() { + return fmt.Sprintf("%#v", v.Interface()) + } else { + return v.String() + } +} + // Deep comparison. This is based on golang 1.2's reflect.Equal functionality. func (t *T) deepEqual( desc string, have, want reflect.Value, ignores []string, @@ -206,13 +218,13 @@ func (t *T) deepEqual( checkNil := func() bool { if want.IsNil() && !have.IsNil() { diffs = append(diffs, fmt.Sprintf("%s: not equal.", desc)) - diffs = append(diffs, fmt.Sprintf(" have: %#v", have.Interface())) + diffs = append(diffs, fmt.Sprintf(" have: %s", stringValue(have))) diffs = append(diffs, " want: nil") return true } else if !want.IsNil() && have.IsNil() { diffs = append(diffs, fmt.Sprintf("%s: not equal.", desc)) diffs = append(diffs, " have: nil") - diffs = append(diffs, fmt.Sprintf(" want: %#v", want.Interface())) + diffs = append(diffs, fmt.Sprintf(" want: %s", stringValue(want))) return true } return false @@ -224,8 +236,8 @@ func (t *T) deepEqual( diffs = append(diffs, fmt.Sprintf( "%s: (len(have): %d, len(want): %d)", desc, have.Len(), want.Len())) - diffs = append(diffs, fmt.Sprintf(" have: %#v", have.Interface())) - diffs = append(diffs, fmt.Sprintf(" want: %#v", want.Interface())) + diffs = append(diffs, fmt.Sprintf(" have: %s", stringValue(have))) + diffs = append(diffs, fmt.Sprintf(" want: %s", stringValue(want))) return true } return false diff --git a/equal_test.go b/equal_test.go index 23b94ef..c0da125 100644 --- a/equal_test.go +++ b/equal_test.go @@ -15,6 +15,7 @@ package testlib import ( + "bytes" "fmt" "math/rand" "os" @@ -259,6 +260,15 @@ func TestT_EqualAndNotEqual(t *testing.T) { []interface{}{sCust1, sCust2}, []interface{}{dCust1}) + // Structures with private, unexported fields. + sBuff1 := bytes.NewBuffer([]byte{1, 2, 3}) + sBuff2 := bytes.NewBuffer([]byte{1, 2, 3}) + dBuff1 := bytes.NewBuffer([]byte{1, 2, 3, 4}) + dBuff2 := bytes.NewBuffer([]byte{1}) + runTest( + []interface{}{sBuff1, sBuff2}, + []interface{}{dBuff1, dBuff2}) + // Structures in a slice. sCustSlice1 := []testEqualCustomStruct{sCust1, sCust2} sCustSlice2 := []testEqualCustomStruct{sCust1, sCust2}