diff --git a/types_test.go b/types_test.go index 0b8cbd1d..063e8fb5 100644 --- a/types_test.go +++ b/types_test.go @@ -84,8 +84,21 @@ func (test *conversionTest) Assert(t *testing.T, err error) { test.Fatalf(t, "error is %s, wanted nil", err) } + // test.dst is a pointer to the value. + dst := reflect.ValueOf(test.dst).Elem().Interface() + + if test.wantnil { + dstValue := reflect.ValueOf(dst) + if dstValue.IsNil() { + return + } + test.Fatalf(t, "got %#v, wanted nil", dst) + return + } + + // Remove any intermediate pointers to compare values, not pointers. + dst = deref(dst) src := deref(test.src) - dst := deref(test.dst) if test.wantzero { dstValue := reflect.ValueOf(dst) @@ -105,17 +118,6 @@ func (test *conversionTest) Assert(t *testing.T, err error) { return } - if test.wantnil { - if dst == nil { - return - } - if reflect.ValueOf(dst).IsNil() { - return - } - test.Fatalf(t, "got %#v, wanted nil", dst) - return - } - if dstTime, ok := dst.(time.Time); ok { srcTime := src.(time.Time) if dstTime.Unix() != srcTime.Unix() { @@ -179,14 +181,15 @@ func TestConversion(t *testing.T) { {src: float64(math.MaxFloat64), dst: new(float64), pgtype: "decimal"}, {src: float64(math.SmallestNonzeroFloat64), dst: new(float64), pgtype: "decimal"}, - {src: []int{}, dst: new([]int), pgtype: "int[]"}, + {src: nil, dst: new([]int), pgtype: "int[]", wantnil: true}, + {src: []int{}, dst: new([]int), pgtype: "int[]", wantzero: true}, {src: []int{1, 2, 3}, dst: new([]int), pgtype: "int[]"}, {src: []int64{1, 2, 3}, dst: new([]int64), pgtype: "bigint[]"}, {src: []float64{1.1, 2.22, 3.333}, dst: new([]float64), pgtype: "double precision[]"}, - {src: []string{"foo\n", "bar {}", "'\\\""}, dst: new([]string), pgtype: "text[]"}, - {src: []string{}, dst: new([]string), pgtype: "text[]", wantzero: true}, {src: nil, dst: new([]string), pgtype: "text[]", wantnil: true}, + {src: []string{}, dst: new([]string), pgtype: "text[]", wantzero: true}, + {src: []string{"foo\n", "bar {}", "'\\\""}, dst: new([]string), pgtype: "text[]"}, { src: map[string]string{"foo\n =>": "bar\n =>", "'\\\"": "'\\\""}, @@ -210,8 +213,8 @@ func TestConversion(t *testing.T) { {src: &sql.NullFloat64{Valid: true}, dst: &sql.NullFloat64{}, pgtype: "decimal"}, {src: &sql.NullFloat64{Valid: true, Float64: math.MaxFloat64}, dst: &sql.NullFloat64{}, pgtype: "decimal"}, - {src: customStrSlice{}, dst: &customStrSlice{}, wantzero: true}, {src: nil, dst: &customStrSlice{}, wantnil: true}, + {src: customStrSlice{}, dst: &customStrSlice{}, wantzero: true}, {src: customStrSlice{"one", "two"}, dst: &customStrSlice{}}, {src: time.Time{}, dst: &time.Time{}, pgtype: "timestamp"},