Skip to content

Commit

Permalink
Fix encode driver.Valuer on nil-able non-pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed May 18, 2024
1 parent fec45c8 commit 13beb38
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 20 deletions.
5 changes: 5 additions & 0 deletions extended_query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ func (eqb *ExtendedQueryBuilder) oidAndArgForQueryExecModeExec(m *pgtype.Map, ar
if err != nil {
return 0, nil, err
}

if v == nil {
return 0, nil, nil
}

if dt, ok := m.TypeForValue(v); ok {
return dt.OID, v, nil
}
Expand Down
20 changes: 9 additions & 11 deletions internal/anynil/anynil.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ import (
// var valuerReflectType = reflect.TypeFor[driver.Valuer]()
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()

// Is returns true if value is any type of nil except a pointer that directly implements driver.Valuer. e.g. nil,
// []byte(nil), and a *T where T implements driver.Valuer get normalized to nil but a *T where *T implements
// driver.Valuer does not.
// Is returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
// driver.Valuer if it is only implemented by T.
func Is(value any) bool {
if value == nil {
return true
Expand All @@ -30,14 +29,13 @@ func Is(value any) bool {
return false
}

if kind == reflect.Ptr {
if _, ok := value.(driver.Valuer); ok {
// The pointer will be considered to implement driver.Valuer even if it is actually implemented on the value.
// But we only want to consider it nil if it is implemented on the pointer. So check if what the pointer points
// to implements driver.Valuer.
if !refVal.Type().Elem().Implements(valuerReflectType) {
return false
}
if _, ok := value.(driver.Valuer); ok {
if kind == reflect.Ptr {
// The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on T
// to see if it is not implemented on *T.
return refVal.Type().Elem().Implements(valuerReflectType)
} else {
return false
}
}

Expand Down
6 changes: 3 additions & 3 deletions pgtype/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ Encoding Typed Nils
pgtype normalizes typed nils (e.g. []byte(nil)) into nil. nil is always encoded is the SQL NULL value without going
through the Codec system. This means that Codecs and other encoding logic does not have to handle nil or *T(nil).
However, database/sql compatibility requires Value to be called on a pointer that implements driver.Valuer. See
However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore,
driver.Valuer values are not normalized to nil unless it is a *T(nil) where driver.Valuer is implemented on T. See
https://github.com/golang/go/issues/8415 and
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. Therefore, pointers that implement
driver.Valuer are not normalized to nil.
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870.
Child Records
Expand Down
119 changes: 113 additions & 6 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1173,12 +1173,12 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
ensureConnValid(t, conn)
}

type nilAsEmptyJSONObject struct {
type nilPointerAsEmptyJSONObject struct {
ID string
Name string
}

func (v *nilAsEmptyJSONObject) Value() (driver.Value, error) {
func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) {
if v == nil {
return "{}", nil
}
Expand All @@ -1187,15 +1187,15 @@ func (v *nilAsEmptyJSONObject) Value() (driver.Value, error) {
}

// https://github.com/jackc/pgx/issues/1566
func TestConnQueryDatabaseSQLDriverValuerCalledOnPointerImplementers(t *testing.T) {
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) {
t.Parallel()

conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)

mustExec(t, conn, "create temporary table t(v json not null)")

var v *nilAsEmptyJSONObject
var v *nilPointerAsEmptyJSONObject
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())
Expand All @@ -1208,12 +1208,119 @@ func TestConnQueryDatabaseSQLDriverValuerCalledOnPointerImplementers(t *testing.
_, err = conn.Exec(context.Background(), `delete from t`)
require.NoError(t, err)

v = &nilAsEmptyJSONObject{ID: "1", Name: "foo"}
v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"}
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())

var v2 *nilAsEmptyJSONObject
var v2 *nilPointerAsEmptyJSONObject
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
require.NoError(t, err)
require.Equal(t, v, v2)

ensureConnValid(t, conn)
}

type nilSliceAsEmptySlice []byte

func (j nilSliceAsEmptySlice) Value() (driver.Value, error) {
if len(j) == 0 {
return []byte("[]"), nil
}

return []byte(j), nil
}

func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error {
*j = bytes.Clone(data)
return nil
}

// https://github.com/jackc/pgx/issues/1860
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) {
t.Parallel()

conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)

mustExec(t, conn, "create temporary table t(v json not null)")

var v nilSliceAsEmptySlice
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())

var s string
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
require.NoError(t, err)
require.Equal(t, "[]", s)

_, err = conn.Exec(context.Background(), `delete from t`)
require.NoError(t, err)

v = nilSliceAsEmptySlice(`{"name": "foo"}`)
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())

var v2 nilSliceAsEmptySlice
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
require.NoError(t, err)
require.Equal(t, v, v2)

ensureConnValid(t, conn)
}

type nilMapAsEmptyObject map[string]any

func (j nilMapAsEmptyObject) Value() (driver.Value, error) {
if j == nil {
return []byte("{}"), nil
}

return json.Marshal(j)
}

func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error {
var m map[string]any
err := json.Unmarshal(data, &m)
if err != nil {
return err
}

*j = m

return nil
}

// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751
func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) {
t.Parallel()

conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)

mustExec(t, conn, "create temporary table t(v json not null)")

var v nilMapAsEmptyObject
commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())

var s string
err = conn.QueryRow(context.Background(), "select v from t").Scan(&s)
require.NoError(t, err)
require.Equal(t, "{}", s)

_, err = conn.Exec(context.Background(), `delete from t`)
require.NoError(t, err)

v = nilMapAsEmptyObject{"name": "foo"}
commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v)
require.NoError(t, err)
require.Equal(t, "INSERT 0 1", commandTag.String())

var v2 nilMapAsEmptyObject
err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2)
require.NoError(t, err)
require.Equal(t, v, v2)
Expand Down

0 comments on commit 13beb38

Please sign in to comment.