diff --git a/append.go b/append.go index dae7a2fe..32a9695d 100644 --- a/append.go +++ b/append.go @@ -77,30 +77,15 @@ func appendIface(dst []byte, srci interface{}) []byte { dst = append(dst, '\'') return dst case []string: - dst = append(dst, '\'') - dst = appendStringSlice(dst, src, false) - dst = append(dst, '\'') - return dst + return appendStringSlice(dst, src) case []int: - dst = append(dst, '\'') - dst = appendIntSlice(dst, src) - dst = append(dst, '\'') - return dst + return appendIntSlice(dst, src) case []int64: - dst = append(dst, '\'') - dst = appendInt64Slice(dst, src) - dst = append(dst, '\'') - return dst + return appendInt64Slice(dst, src) case []float64: - dst = append(dst, '\'') - dst = appendFloat64Slice(dst, src) - dst = append(dst, '\'') - return dst + return appendFloat64Slice(dst, src) case map[string]string: - dst = append(dst, '\'') - dst = appendStringStringMap(dst, src, false) - dst = append(dst, '\'') - return dst + return appendStringStringMap(dst, src, false) case QueryAppender: return src.AppendQuery(dst) case driver.Valuer: @@ -165,15 +150,15 @@ func appendIfaceRaw(dst []byte, srci interface{}) []byte { case []byte: return appendBytes(dst, src) case []string: - return appendStringSlice(dst, src, true) + return appendStringSliceRaw(dst, src, true) case []int: - return appendIntSlice(dst, src) + return appendIntSliceRaw(dst, src) case []int64: - return appendInt64Slice(dst, src) + return appendInt64SliceRaw(dst, src) case []float64: - return appendFloat64Slice(dst, src) + return appendFloat64SliceRaw(dst, src) case map[string]string: - return appendStringStringMap(dst, src, true) + return appendStringStringMapRaw(dst, src, true) case RawQueryAppender: return src.AppendRawQuery(dst) case driver.Valuer: @@ -295,12 +280,25 @@ func appendBytes(dst []byte, v []byte) []byte { return dst } -func appendStringStringMap(dst []byte, v map[string]string, raw bool) []byte { - if len(v) == 0 { +func appendStringStringMap(dst []byte, m map[string]string, raw bool) []byte { + if m == nil { + return appendNull(dst) + } + dst = append(dst, '\'') + dst = appendStringStringMapRaw(dst, m, false) + dst = append(dst, '\'') + return dst +} + +func appendStringStringMapRaw(dst []byte, m map[string]string, raw bool) []byte { + if m == nil { + return nil + } + if len(m) == 0 { return dst } - for key, value := range v { + for key, value := range m { dst = appendSubstring(dst, key, raw) dst = append(dst, '=', '>') dst = appendSubstring(dst, value, raw) @@ -310,13 +308,26 @@ func appendStringStringMap(dst []byte, v map[string]string, raw bool) []byte { return dst } -func appendStringSlice(dst []byte, v []string, raw bool) []byte { - if len(v) == 0 { +func appendStringSlice(dst []byte, ss []string) []byte { + if ss == nil { + return appendNull(dst) + } + dst = append(dst, '\'') + dst = appendStringSliceRaw(dst, ss, false) + dst = append(dst, '\'') + return dst +} + +func appendStringSliceRaw(dst []byte, ss []string, raw bool) []byte { + if ss == nil { + return nil + } + if len(ss) == 0 { return append(dst, "{}"...) } dst = append(dst, '{') - for _, s := range v { + for _, s := range ss { dst = appendSubstring(dst, s, raw) dst = append(dst, ',') } @@ -324,13 +335,26 @@ func appendStringSlice(dst []byte, v []string, raw bool) []byte { return dst } -func appendIntSlice(dst []byte, v []int) []byte { - if len(v) == 0 { +func appendIntSlice(dst []byte, ints []int) []byte { + if ints == nil { + return appendNull(dst) + } + dst = append(dst, '\'') + dst = appendIntSliceRaw(dst, ints) + dst = append(dst, '\'') + return dst +} + +func appendIntSliceRaw(dst []byte, ints []int) []byte { + if ints == nil { + return nil + } + if len(ints) == 0 { return append(dst, "{}"...) } dst = append(dst, '{') - for _, n := range v { + for _, n := range ints { dst = strconv.AppendInt(dst, int64(n), 10) dst = append(dst, ',') } @@ -338,13 +362,26 @@ func appendIntSlice(dst []byte, v []int) []byte { return dst } -func appendInt64Slice(dst []byte, v []int64) []byte { - if len(v) == 0 { +func appendInt64Slice(dst []byte, ints []int64) []byte { + if ints == nil { + return appendNull(dst) + } + dst = append(dst, '\'') + dst = appendInt64SliceRaw(dst, ints) + dst = append(dst, '\'') + return dst +} + +func appendInt64SliceRaw(dst []byte, ints []int64) []byte { + if ints == nil { + return nil + } + if len(ints) == 0 { return append(dst, "{}"...) } dst = append(dst, "{"...) - for _, n := range v { + for _, n := range ints { dst = strconv.AppendInt(dst, n, 10) dst = append(dst, ',') } @@ -352,13 +389,26 @@ func appendInt64Slice(dst []byte, v []int64) []byte { return dst } -func appendFloat64Slice(dst []byte, v []float64) []byte { - if len(v) == 0 { +func appendFloat64Slice(dst []byte, floats []float64) []byte { + if floats == nil { + return appendNull(dst) + } + dst = append(dst, '\'') + dst = appendFloat64SliceRaw(dst, floats) + dst = append(dst, '\'') + return dst +} + +func appendFloat64SliceRaw(dst []byte, floats []float64) []byte { + if floats == nil { + return nil + } + if len(floats) == 0 { return append(dst, "{}"...) } dst = append(dst, "{"...) - for _, n := range v { + for _, n := range floats { dst = appendFloat(dst, n) dst = append(dst, ',') } diff --git a/decode_value.go b/decode_value.go index b736412a..4d12ab63 100644 --- a/decode_value.go +++ b/decode_value.go @@ -198,7 +198,8 @@ func decodeMapValue(v reflect.Value, b []byte) error { func decodeNullValue(v reflect.Value) error { kind := v.Kind() - if kind == reflect.Interface { + switch kind { + case reflect.Interface: return decodeNullValue(v.Elem()) } if v.CanSet() { diff --git a/types_test.go b/types_test.go index 063e8fb5..620e22f0 100644 --- a/types_test.go +++ b/types_test.go @@ -40,6 +40,7 @@ type JSONRecord2 struct { } type conversionTest struct { + i int src, dst, wanted interface{} pgtype string @@ -64,7 +65,7 @@ func zero(v interface{}) interface{} { } func (test *conversionTest) String() string { - return fmt.Sprintf("src=%#v dst=%#v", test.src, test.dst) + return fmt.Sprintf("#%d src=%#v dst=%#v", test.i, test.src, test.dst) } func (test *conversionTest) Fatalf(t *testing.T, s interface{}, args ...interface{}) { @@ -182,15 +183,35 @@ func TestConversion(t *testing.T) { {src: float64(math.SmallestNonzeroFloat64), dst: new(float64), pgtype: "decimal"}, {src: nil, dst: new([]int), pgtype: "int[]", wantnil: true}, + {src: []int(nil), dst: new([]int), pgtype: "int[]", wantnil: true}, {src: []int{}, dst: new([]int), pgtype: "int[]", wantzero: true}, {src: []int{1, 2, 3}, dst: new([]int), pgtype: "int[]"}, + + {src: nil, dst: new([]int64), pgtype: "bigint[]", wantnil: true}, + {src: []int64(nil), dst: new([]int64), pgtype: "bigint[]", wantnil: true}, {src: []int64{1, 2, 3}, dst: new([]int64), pgtype: "bigint[]"}, + + {src: nil, dst: new([]float64), pgtype: "double precision[]", wantnil: true}, + {src: []float64(nil), dst: new([]float64), pgtype: "double precision[]", wantnil: true}, {src: []float64{1.1, 2.22, 3.333}, dst: new([]float64), pgtype: "double precision[]"}, {src: nil, dst: new([]string), pgtype: "text[]", wantnil: true}, + {src: []string(nil), dst: new([]string), pgtype: "text[]", wantnil: true}, {src: []string{}, dst: new([]string), pgtype: "text[]", wantzero: true}, {src: []string{"foo\n", "bar {}", "'\\\""}, dst: new([]string), pgtype: "text[]"}, + { + src: nil, + dst: new(map[string]string), + pgtype: "hstore", + wantnil: true, + }, + { + src: map[string]string(nil), + dst: new(map[string]string), + pgtype: "hstore", + wantnil: true, + }, { src: map[string]string{"foo\n =>": "bar\n =>", "'\\\"": "'\\\""}, dst: new(map[string]string), @@ -246,7 +267,9 @@ func TestConversion(t *testing.T) { db.Exec("CREATE EXTENSION hstore") defer db.Exec("DROP EXTENSION hstore") - for _, test := range conversionTests { + for i, test := range conversionTests { + test.i = i + var err error if _, ok := test.dst.(pg.ColumnLoader); ok { _, err = db.QueryOne(test.dst, "SELECT (?) AS dst", test.src) @@ -257,7 +280,9 @@ func TestConversion(t *testing.T) { test.Assert(t, err) } - for _, test := range conversionTests { + for i, test := range conversionTests { + test.i = i + if test.pgtype == "" { continue } @@ -280,7 +305,9 @@ func TestConversion(t *testing.T) { } } - for _, test := range conversionTests { + for i, test := range conversionTests { + test.i = i + if _, ok := test.dst.(pg.ColumnLoader); ok { continue } @@ -289,7 +316,9 @@ func TestConversion(t *testing.T) { test.Assert(t, err) } - for _, test := range conversionTests { + for i, test := range conversionTests { + test.i = i + if _, ok := test.dst.(pg.ColumnLoader); ok { continue }