Skip to content

Commit

Permalink
Add compatibility with database/sql custom types
Browse files Browse the repository at this point in the history
Support database/sql.Scanner
Support database/sql/driver.Valuer
  • Loading branch information
jackc committed Dec 31, 2015
1 parent 029bd49 commit 9f9a977
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tip

* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
* Go float64 can no longer be encoded to a PostgreSQL float4
* Add ConnPool.Reset method
* []byte skips encoding/decoding
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Pgx supports many additional features beyond what is available through database/
* Maps inet and cidr PostgreSQL types to net.IPNet
* Large object support
* Null mapping to Null* struct or pointer to pointer.
* Supports database/sql.Scanner and database/sql/driver/Valuer interfaces for custom types

## Performance

Expand Down
8 changes: 7 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"crypto/md5"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"errors"
Expand Down Expand Up @@ -851,15 +852,20 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}

wbuf.WriteInt16(int16(len(arguments)))
for i, oid := range ps.ParameterOids {
encode:
if arguments[i] == nil {
wbuf.WriteInt32(-1)
continue
}

encode:
switch arg := arguments[i].(type) {
case Encoder:
err = arg.Encode(wbuf, oid)
case driver.Valuer:
arguments[i], err = arg.Value()
if err == nil {
goto encode
}
case string:
err = encodeText(wbuf, arguments[i])
case []byte:
Expand Down
3 changes: 3 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ Conn.PgTypes.
See example_custom_type_test.go for an example of a custom type for the
PostgreSQL point type.
pgx also includes support for custom types implementing the database/sql.Scanner
and database/sql/driver.Valuer interfaces.
Raw Bytes Mapping
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
Expand Down
35 changes: 35 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgx

import (
"database/sql"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -255,6 +256,40 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(sql.Scanner); ok {
var val interface{}
if 0 <= vr.Len() {
switch vr.Type().DataType {
case BoolOid:
val = decodeBool(vr)
case Int8Oid:
val = int64(decodeInt8(vr))
case Int2Oid:
val = int64(decodeInt2(vr))
case Int4Oid:
val = int64(decodeInt4(vr))
case TextOid, VarcharOid:
val = decodeText(vr)
case OidOid:
val = int64(decodeOid(vr))
case Float4Oid:
val = float64(decodeFloat4(vr))
case Float8Oid:
val = decodeFloat8(vr)
case DateOid:
val = decodeDate(vr)
case TimestampOid:
val = decodeTimestamp(vr)
case TimestampTzOid:
val = decodeTimestampTz(vr)
default:
val = vr.ReadBytes(vr.Len())
}
}
err = s.Scan(val)
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
decodeJson(vr, &d)
} else {
Expand Down
113 changes: 113 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package pgx_test

import (
"bytes"
"database/sql"
"github.com/jackc/pgx"
"strings"
"testing"
"time"

"github.com/shopspring/decimal"
)

func TestConnQueryScan(t *testing.T) {
Expand Down Expand Up @@ -904,3 +907,113 @@ func TestReadingNullByteArrays(t *testing.T) {
t.Errorf("Expected to read 2 rows, read: ", count)
}
}

// Use github.com/shopspring/decimal as real-world database/sql custom type
// to test against.
func TestConnQueryDatabaseSQLScanner(t *testing.T) {
t.Parallel()

conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)

var num decimal.Decimal

err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}

expected, err := decimal.NewFromString("1234.567")
if err != nil {
t.Fatal(err)
}

if !num.Equals(expected) {
t.Errorf("Expected num to be %v, but it was %v", expected, num)
}

ensureConnValid(t, conn)
}

// Use github.com/shopspring/decimal as real-world database/sql custom type
// to test against.
func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
t.Parallel()

conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)

expected, err := decimal.NewFromString("1234.567")
if err != nil {
t.Fatal(err)
}
var num decimal.Decimal

err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}

if !num.Equals(expected) {
t.Errorf("Expected num to be %v, but it was %v", expected, num)
}

ensureConnValid(t, conn)
}

func TestConnQueryDatabaseSQLNullX(t *testing.T) {
t.Parallel()

conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)

type row struct {
boolValid sql.NullBool
boolNull sql.NullBool
int64Valid sql.NullInt64
int64Null sql.NullInt64
float64Valid sql.NullFloat64
float64Null sql.NullFloat64
stringValid sql.NullString
stringNull sql.NullString
}

expected := row{
boolValid: sql.NullBool{Bool: true, Valid: true},
int64Valid: sql.NullInt64{Int64: 123, Valid: true},
float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
stringValid: sql.NullString{String: "pgx", Valid: true},
}

var actual row

err := conn.QueryRow(
"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
expected.boolValid,
expected.boolNull,
expected.int64Valid,
expected.int64Null,
expected.float64Valid,
expected.float64Null,
expected.stringValid,
expected.stringNull,
).Scan(
&actual.boolValid,
&actual.boolNull,
&actual.int64Valid,
&actual.int64Null,
&actual.float64Valid,
&actual.float64Null,
&actual.stringValid,
&actual.stringNull,
)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}

if expected != actual {
t.Errorf("Expected %v, but got %v", expected, actual)
}

ensureConnValid(t, conn)
}

0 comments on commit 9f9a977

Please sign in to comment.