From d24135cb8039bf8c8e9289a443fdab6a4d7ba88c Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Thu, 27 Apr 2023 18:30:30 -0400 Subject: [PATCH] Add support for NamedValueChecker interface Added in 1.9: https://go.dev/doc/go1.9#minor_library_changes. --- array.go | 7 ++--- conn_go19.go | 35 +++++++++++++++++++++++++ conn_go19_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 conn_go19.go create mode 100644 conn_go19_test.go diff --git a/array.go b/array.go index 39c8f7e2..9957c048 100644 --- a/array.go +++ b/array.go @@ -19,10 +19,11 @@ var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // slice of any dimension. // // For example: -// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) // -// var x []sql.NullInt64 -// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) // // Scanning multi-dimensional arrays is not supported. Arrays where the lower // bound is not one (such as `[0:0]={1}') are not supported. diff --git a/conn_go19.go b/conn_go19.go new file mode 100644 index 00000000..e34705e1 --- /dev/null +++ b/conn_go19.go @@ -0,0 +1,35 @@ +//go:build go1.9 +// +build go1.9 + +package pq + +import ( + "database/sql/driver" + "reflect" +) + +var _ driver.NamedValueChecker = (*conn)(nil) + +func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { + if _, ok := nv.Value.(driver.Valuer); ok { + // Ignore Valuer, for backward compatiblity with pq.Array() + return driver.ErrSkip + } + + // Ignoring []byte / []uint8 + if _, ok := nv.Value.([]uint8); ok { + return driver.ErrSkip + } + + v := reflect.ValueOf(nv.Value) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + var err error + nv.Value, err = Array(nv.Value).Value() + return err + } + + return driver.ErrSkip +} diff --git a/conn_go19_test.go b/conn_go19_test.go new file mode 100644 index 00000000..1552261b --- /dev/null +++ b/conn_go19_test.go @@ -0,0 +1,66 @@ +//go:build go1.9 +// +build go1.9 + +package pq + +import ( + "fmt" + "reflect" + "testing" +) + +func TestArrayArg(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tc := range []struct { + in, out interface{} + }{ + { + in: []int{245, 231}, + out: []int64{245, 231}, + }, + { + in: &[]int{245, 231}, + out: []int64{245, 231}, + }, + { + in: []int64{245, 231}, + }, + { + in: &[]int64{245, 231}, + out: []int64{245, 231}, + }, + } { + t.Run(fmt.Sprintf("%#v", tc.in), func(t *testing.T) { + r, err := db.Query("SELECT $1::int[]", tc.in) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected row") + } + + defer func() { + if r.Next() { + t.Fatal("unexpected row") + } + }() + + got := reflect.New(reflect.TypeOf(tc.out)).Elem() + if err := r.Scan(got.Interface()); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(tc.out, got.Interface()) { + t.Errorf("got %v, want %v", got, tc.out) + } + }) + } + +}