Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request binary format for bytea row data #359

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,23 +325,23 @@ var testIntBytes = []byte("1234")

func BenchmarkDecodeInt64(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testIntBytes, oid.T_int8)
decode(&parameterStatus{}, testIntBytes, oid.T_int8, formatText)
}
}

var testFloatBytes = []byte("3.14159")

func BenchmarkDecodeFloat64(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testFloatBytes, oid.T_float8)
decode(&parameterStatus{}, testFloatBytes, oid.T_float8, formatText)
}
}

var testBoolBytes = []byte{'t'}

func BenchmarkDecodeBool(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testBoolBytes, oid.T_bool)
decode(&parameterStatus{}, testBoolBytes, oid.T_bool, formatText)
}
}

Expand All @@ -358,7 +358,7 @@ var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07")

func BenchmarkDecodeTimestamptz(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testTimestamptzBytes, oid.T_timestamptz)
decode(&parameterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
}
}

Expand All @@ -371,7 +371,7 @@ func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) {
f := func(wg *sync.WaitGroup, loops int) {
defer wg.Done()
for i := 0; i < loops; i++ {
decode(&parameterStatus{}, testTimestamptzBytes, oid.T_timestamptz)
decode(&parameterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
}
}

Expand Down
29 changes: 23 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{st: st}
st.cols, st.rowTyps = parseMeta(r)
st.cols, st.rowFmts, st.rowTyps = parseMeta(r)

// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
Expand Down Expand Up @@ -585,7 +585,13 @@ func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) {
st.paramTyps[i] = r.oid()
}
case 'T':
st.cols, st.rowTyps = parseMeta(r)
st.cols, st.rowFmts, st.rowTyps = parseMeta(r)

for i, o := range st.rowTyps {
if o == oid.T_bytea {
st.rowFmts[i] = formatBinary
}
}
case 'n':
// no data
case 'Z':
Expand Down Expand Up @@ -1049,10 +1055,16 @@ func (cn *conn) auth(r *readBuf, o values) {
}
}

type format int

const formatText format = 0
const formatBinary format = 1

type stmt struct {
cn *conn
name string
cols []string
rowFmts []format
rowTyps []oid.Oid
paramTyps []oid.Oid
closed bool
Expand Down Expand Up @@ -1151,7 +1163,10 @@ func (st *stmt) exec(v []driver.Value) {
w.bytes(b)
}
}
w.int16(0)
w.int16(len(st.rowFmts))
for _, f := range st.rowFmts {
w.int16(int(f))
}
st.cn.send(w)

w = st.cn.writeBuf('E')
Expand Down Expand Up @@ -1330,7 +1345,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
dest[i] = nil
continue
}
dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.st.rowTyps[i])
dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.st.rowTyps[i], rs.st.rowFmts[i])
}
return
default:
Expand Down Expand Up @@ -1392,15 +1407,17 @@ func (c *conn) processReadyForQuery(r *readBuf) {
c.txnStatus = transactionStatus(r.byte())
}

func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) {
func parseMeta(r *readBuf) (cols []string, rowFmts []format, rowTyps []oid.Oid) {
n := r.int16()
cols = make([]string, n)
rowFmts = make([]format, n)
rowTyps = make([]oid.Oid, n)
for i := range cols {
cols[i] = r.string()
r.next(6)
rowTyps[i] = r.oid()
r.next(8)
r.next(6)
rowFmts[i] = format(r.int16())
}
return
}
Expand Down
5 changes: 4 additions & 1 deletion encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [
panic("not reached")
}

func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
switch typ {
case oid.T_bytea:
if f == formatBinary {
return s
}
return parseBytea(s)
case oid.T_timestamptz:
return parseTs(parameterStatus.currentLocation, string(s))
Expand Down
29 changes: 24 additions & 5 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pq

import (
"bytes"
"database/sql"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -460,7 +461,7 @@ func TestByteaOutputFormats(t *testing.T) {
return
}

testByteaOutputFormat := func(f string) {
testByteaOutputFormat := func(f string, s bool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the name "s" here; it's not at all clear what it means. Maybe "prepare" or "prepared" or "usePrepared" would be more clear?

expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08")
sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')"

Expand All @@ -477,8 +478,18 @@ func TestByteaOutputFormats(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// use Query; QueryRow would hide the actual error
rows, err := txn.Query(sqlQuery)
var rows *sql.Rows
var stmt *sql.Stmt
if s {
stmt, err = txn.Prepare(sqlQuery)
if err != nil {
t.Fatal(err)
}
rows, err = stmt.Query()
} else {
// use Query; QueryRow would hide the actual error
rows, err = txn.Query(sqlQuery)
}
if err != nil {
t.Fatal(err)
}
Expand All @@ -496,13 +507,21 @@ func TestByteaOutputFormats(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if s {
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}
if !bytes.Equal(data, expectedData) {
t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData)
}
}

testByteaOutputFormat("hex")
testByteaOutputFormat("escape")
testByteaOutputFormat("hex", false)
testByteaOutputFormat("escape", false)
testByteaOutputFormat("hex", true)
testByteaOutputFormat("escape", true)
}

func TestAppendEncodedText(t *testing.T) {
Expand Down