Skip to content

Commit

Permalink
Merge pull request #79 from go-pg/fix/append-nil-slice-map
Browse files Browse the repository at this point in the history
Properly encode nil slices and maps. Fixes #76.
  • Loading branch information
vmihailenco committed Oct 13, 2015
2 parents 81b9866 + 7d800ed commit f02b4d2
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 46 deletions.
130 changes: 90 additions & 40 deletions append.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,15 @@ func appendIface(dst []byte, srci interface{}) []byte {
dst = append(dst, '\'')
return dst
case []string:
dst = append(dst, '\'')
dst = appendStringSlice(dst, src, false)
dst = append(dst, '\'')
return dst
return appendStringSlice(dst, src)
case []int:
dst = append(dst, '\'')
dst = appendIntSlice(dst, src)
dst = append(dst, '\'')
return dst
return appendIntSlice(dst, src)
case []int64:
dst = append(dst, '\'')
dst = appendInt64Slice(dst, src)
dst = append(dst, '\'')
return dst
return appendInt64Slice(dst, src)
case []float64:
dst = append(dst, '\'')
dst = appendFloat64Slice(dst, src)
dst = append(dst, '\'')
return dst
return appendFloat64Slice(dst, src)
case map[string]string:
dst = append(dst, '\'')
dst = appendStringStringMap(dst, src, false)
dst = append(dst, '\'')
return dst
return appendStringStringMap(dst, src, false)
case QueryAppender:
return src.AppendQuery(dst)
case driver.Valuer:
Expand Down Expand Up @@ -165,15 +150,15 @@ func appendIfaceRaw(dst []byte, srci interface{}) []byte {
case []byte:
return appendBytes(dst, src)
case []string:
return appendStringSlice(dst, src, true)
return appendStringSliceRaw(dst, src, true)
case []int:
return appendIntSlice(dst, src)
return appendIntSliceRaw(dst, src)
case []int64:
return appendInt64Slice(dst, src)
return appendInt64SliceRaw(dst, src)
case []float64:
return appendFloat64Slice(dst, src)
return appendFloat64SliceRaw(dst, src)
case map[string]string:
return appendStringStringMap(dst, src, true)
return appendStringStringMapRaw(dst, src, true)
case RawQueryAppender:
return src.AppendRawQuery(dst)
case driver.Valuer:
Expand Down Expand Up @@ -295,12 +280,25 @@ func appendBytes(dst []byte, v []byte) []byte {
return dst
}

func appendStringStringMap(dst []byte, v map[string]string, raw bool) []byte {
if len(v) == 0 {
func appendStringStringMap(dst []byte, m map[string]string, raw bool) []byte {
if m == nil {
return appendNull(dst)
}
dst = append(dst, '\'')
dst = appendStringStringMapRaw(dst, m, false)
dst = append(dst, '\'')
return dst
}

func appendStringStringMapRaw(dst []byte, m map[string]string, raw bool) []byte {
if m == nil {
return nil
}
if len(m) == 0 {
return dst
}

for key, value := range v {
for key, value := range m {
dst = appendSubstring(dst, key, raw)
dst = append(dst, '=', '>')
dst = appendSubstring(dst, value, raw)
Expand All @@ -310,55 +308,107 @@ func appendStringStringMap(dst []byte, v map[string]string, raw bool) []byte {
return dst
}

func appendStringSlice(dst []byte, v []string, raw bool) []byte {
if len(v) == 0 {
func appendStringSlice(dst []byte, ss []string) []byte {
if ss == nil {
return appendNull(dst)
}
dst = append(dst, '\'')
dst = appendStringSliceRaw(dst, ss, false)
dst = append(dst, '\'')
return dst
}

func appendStringSliceRaw(dst []byte, ss []string, raw bool) []byte {
if ss == nil {
return nil
}
if len(ss) == 0 {
return append(dst, "{}"...)
}

dst = append(dst, '{')
for _, s := range v {
for _, s := range ss {
dst = appendSubstring(dst, s, raw)
dst = append(dst, ',')
}
dst[len(dst)-1] = '}' // Replace trailing comma.
return dst
}

func appendIntSlice(dst []byte, v []int) []byte {
if len(v) == 0 {
func appendIntSlice(dst []byte, ints []int) []byte {
if ints == nil {
return appendNull(dst)
}
dst = append(dst, '\'')
dst = appendIntSliceRaw(dst, ints)
dst = append(dst, '\'')
return dst
}

func appendIntSliceRaw(dst []byte, ints []int) []byte {
if ints == nil {
return nil
}
if len(ints) == 0 {
return append(dst, "{}"...)
}

dst = append(dst, '{')
for _, n := range v {
for _, n := range ints {
dst = strconv.AppendInt(dst, int64(n), 10)
dst = append(dst, ',')
}
dst[len(dst)-1] = '}' // Replace trailing comma.
return dst
}

func appendInt64Slice(dst []byte, v []int64) []byte {
if len(v) == 0 {
func appendInt64Slice(dst []byte, ints []int64) []byte {
if ints == nil {
return appendNull(dst)
}
dst = append(dst, '\'')
dst = appendInt64SliceRaw(dst, ints)
dst = append(dst, '\'')
return dst
}

func appendInt64SliceRaw(dst []byte, ints []int64) []byte {
if ints == nil {
return nil
}
if len(ints) == 0 {
return append(dst, "{}"...)
}

dst = append(dst, "{"...)
for _, n := range v {
for _, n := range ints {
dst = strconv.AppendInt(dst, n, 10)
dst = append(dst, ',')
}
dst[len(dst)-1] = '}' // Replace trailing comma.
return dst
}

func appendFloat64Slice(dst []byte, v []float64) []byte {
if len(v) == 0 {
func appendFloat64Slice(dst []byte, floats []float64) []byte {
if floats == nil {
return appendNull(dst)
}
dst = append(dst, '\'')
dst = appendFloat64SliceRaw(dst, floats)
dst = append(dst, '\'')
return dst
}

func appendFloat64SliceRaw(dst []byte, floats []float64) []byte {
if floats == nil {
return nil
}
if len(floats) == 0 {
return append(dst, "{}"...)
}

dst = append(dst, "{"...)
for _, n := range v {
for _, n := range floats {
dst = appendFloat(dst, n)
dst = append(dst, ',')
}
Expand Down
3 changes: 2 additions & 1 deletion decode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ func decodeMapValue(v reflect.Value, b []byte) error {

func decodeNullValue(v reflect.Value) error {
kind := v.Kind()
if kind == reflect.Interface {
switch kind {
case reflect.Interface:
return decodeNullValue(v.Elem())
}
if v.CanSet() {
Expand Down
39 changes: 34 additions & 5 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type JSONRecord2 struct {
}

type conversionTest struct {
i int
src, dst, wanted interface{}
pgtype string

Expand All @@ -64,7 +65,7 @@ func zero(v interface{}) interface{} {
}

func (test *conversionTest) String() string {
return fmt.Sprintf("src=%#v dst=%#v", test.src, test.dst)
return fmt.Sprintf("#%d src=%#v dst=%#v", test.i, test.src, test.dst)
}

func (test *conversionTest) Fatalf(t *testing.T, s interface{}, args ...interface{}) {
Expand Down Expand Up @@ -182,15 +183,35 @@ func TestConversion(t *testing.T) {
{src: float64(math.SmallestNonzeroFloat64), dst: new(float64), pgtype: "decimal"},

{src: nil, dst: new([]int), pgtype: "int[]", wantnil: true},
{src: []int(nil), dst: new([]int), pgtype: "int[]", wantnil: true},
{src: []int{}, dst: new([]int), pgtype: "int[]", wantzero: true},
{src: []int{1, 2, 3}, dst: new([]int), pgtype: "int[]"},

{src: nil, dst: new([]int64), pgtype: "bigint[]", wantnil: true},
{src: []int64(nil), dst: new([]int64), pgtype: "bigint[]", wantnil: true},
{src: []int64{1, 2, 3}, dst: new([]int64), pgtype: "bigint[]"},

{src: nil, dst: new([]float64), pgtype: "double precision[]", wantnil: true},
{src: []float64(nil), dst: new([]float64), pgtype: "double precision[]", wantnil: true},
{src: []float64{1.1, 2.22, 3.333}, dst: new([]float64), pgtype: "double precision[]"},

{src: nil, dst: new([]string), pgtype: "text[]", wantnil: true},
{src: []string(nil), dst: new([]string), pgtype: "text[]", wantnil: true},
{src: []string{}, dst: new([]string), pgtype: "text[]", wantzero: true},
{src: []string{"foo\n", "bar {}", "'\\\""}, dst: new([]string), pgtype: "text[]"},

{
src: nil,
dst: new(map[string]string),
pgtype: "hstore",
wantnil: true,
},
{
src: map[string]string(nil),
dst: new(map[string]string),
pgtype: "hstore",
wantnil: true,
},
{
src: map[string]string{"foo\n =>": "bar\n =>", "'\\\"": "'\\\""},
dst: new(map[string]string),
Expand Down Expand Up @@ -246,7 +267,9 @@ func TestConversion(t *testing.T) {
db.Exec("CREATE EXTENSION hstore")
defer db.Exec("DROP EXTENSION hstore")

for _, test := range conversionTests {
for i, test := range conversionTests {
test.i = i

var err error
if _, ok := test.dst.(pg.ColumnLoader); ok {
_, err = db.QueryOne(test.dst, "SELECT (?) AS dst", test.src)
Expand All @@ -257,7 +280,9 @@ func TestConversion(t *testing.T) {
test.Assert(t, err)
}

for _, test := range conversionTests {
for i, test := range conversionTests {
test.i = i

if test.pgtype == "" {
continue
}
Expand All @@ -280,7 +305,9 @@ func TestConversion(t *testing.T) {
}
}

for _, test := range conversionTests {
for i, test := range conversionTests {
test.i = i

if _, ok := test.dst.(pg.ColumnLoader); ok {
continue
}
Expand All @@ -289,7 +316,9 @@ func TestConversion(t *testing.T) {
test.Assert(t, err)
}

for _, test := range conversionTests {
for i, test := range conversionTests {
test.i = i

if _, ok := test.dst.(pg.ColumnLoader); ok {
continue
}
Expand Down

0 comments on commit f02b4d2

Please sign in to comment.