Skip to content

Commit

Permalink
ltree: fix DecodeBinary by correctly checking the 1st byte
Browse files Browse the repository at this point in the history
Start adding tests
  • Loading branch information
AmineChikhaoui authored and jackc committed Nov 5, 2022
1 parent 0f1512e commit fcee893
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
16 changes: 15 additions & 1 deletion ltree.go
Expand Up @@ -2,6 +2,7 @@ package pgtype

import (
"database/sql/driver"
"fmt"
)

type Ltree Text
Expand Down Expand Up @@ -42,7 +43,20 @@ func (dst *Ltree) DecodeText(ci *ConnInfo, src []byte) error {
}

func (dst *Ltree) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*Text)(dst).DecodeBinary(ci, src)
if src == nil {
*dst = Ltree{Status: Null}
return nil
}

// Get Ltree version, only 1 is allowed
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}

ltreeStr := string(src[1:])
*dst = Ltree{String: ltreeStr, Status: Present}
return nil
}

func (Ltree) PreferredParamFormat() int16 {
Expand Down
50 changes: 50 additions & 0 deletions ltree_test.go
@@ -0,0 +1,50 @@
package pgtype_test

import (
"reflect"
"testing"

"github.com/jackc/pgtype"
"github.com/jackc/pgtype/testutil"
)

func TestLtreeTranscode(t *testing.T) {
values := []interface{}{
&pgtype.Ltree{String: "", Status: pgtype.Present},
&pgtype.Ltree{String: "All.foo.one", Status: pgtype.Present},
&pgtype.Ltree{Status: pgtype.Null},
}

testutil.TestSuccessfulTranscodeEqFunc(
t, "ltree", values, func(ai, bi interface{}) bool {
a := ai.(pgtype.Ltree)
b := bi.(pgtype.Ltree)

if a.String != b.String || a.Status != b.Status {
return false
}
return true
},
)

}

func TestLtreeSet(t *testing.T) {
successfulTests := []struct {
src interface{}
result pgtype.Ltree
}{
{src: "All.foo.bar", result: pgtype.Ltree{String: "All.foo.bar", Status: pgtype.Present}},
{src: (*string)(nil), result: pgtype.Ltree{Status: pgtype.Null}},
}
for i, tt := range successfulTests {
var dst pgtype.Ltree
err := dst.Set(tt.src)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(dst, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst)
}
}
}
2 changes: 0 additions & 2 deletions pgtype.go
Expand Up @@ -84,7 +84,6 @@ const (
TstzrangeArrayOID = 3911
Int8rangeOID = 3926
Int8multirangeOID = 4536
LtreeOID = 16407
)

type Status byte
Expand Down Expand Up @@ -328,7 +327,6 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID})
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
ci.RegisterDataType(DataType{Value: &Ltree{}, Name: "ltree", OID: LtreeOID})

registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) {
ci.RegisterDefaultPgType(value, name)
Expand Down

0 comments on commit fcee893

Please sign in to comment.