diff --git a/buf.go b/buf.go index fd966c39..e7ff5777 100644 --- a/buf.go +++ b/buf.go @@ -47,28 +47,44 @@ func (b *readBuf) byte() byte { return b.next(1)[0] } -type writeBuf []byte +type writeBuf struct { + buf []byte + pos int +} func (b *writeBuf) int32(n int) { x := make([]byte, 4) binary.BigEndian.PutUint32(x, uint32(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) int16(n int) { x := make([]byte, 2) binary.BigEndian.PutUint16(x, uint16(n)) - *b = append(*b, x...) + b.buf = append(b.buf, x...) } func (b *writeBuf) string(s string) { - *b = append(*b, (s + "\000")...) + b.buf = append(b.buf, (s + "\000")...) } func (b *writeBuf) byte(c byte) { - *b = append(*b, c) + b.buf = append(b.buf, c) } func (b *writeBuf) bytes(v []byte) { - *b = append(*b, v...) + b.buf = append(b.buf, v...) +} + +func (b *writeBuf) wrap() []byte { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + return b.buf +} + +func (b *writeBuf) next(c byte) { + p := b.buf[b.pos:] + binary.BigEndian.PutUint32(p, uint32(len(p))) + b.pos = len(b.buf) + 1 + b.buf = append(b.buf, c, 0, 0, 0, 0) } diff --git a/conn.go b/conn.go index 1a1f09a1..3071bbd5 100644 --- a/conn.go +++ b/conn.go @@ -106,12 +106,16 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. bad bool + + binary_mode bool } func (c *conn) writeBuf(b byte) *writeBuf { c.scratch[0] = b - w := writeBuf(c.scratch[:5]) - return &w + return &writeBuf{ + buf: c.scratch[:5], + pos: 1, + } } func Open(name string) (_ driver.Conn, err error) { @@ -216,6 +220,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } + // set binary_mode + if binary_mode := o.Get("binary_mode"); binary_mode != "" { + if binary_mode == "on" { + cn.binary_mode = true + } else if binary_mode == "off" { + cn.binary_mode = false + } else { + return nil, err + } + } return cn, err } @@ -476,11 +490,91 @@ func (cn *conn) gname() string { return strconv.FormatInt(int64(cn.namei), 10) } -func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) +func (cn *conn) awaitSynchronizationPoint() { + t, r := cn.recv1() + switch t { + case 'Z': + cn.processReadyForQuery(r) + return + default: + cn.bad = true + errorf("unexpected message %q while waiting for synchronization point") + } +} + +func (cn *conn) readParseResponse() { + t, r := cn.recv1() + switch t { + case '1': + return + case 'E': + err := parseError(r) + cn.awaitSynchronizationPoint() + panic(err) + default: + cn.bad = true + errorf("unexpected Parse response %q", t) + } +} + +func (cn *conn) readBindResponse() { + t, r := cn.recv1() + switch t { + case '2': + return + case 'E': + err := parseError(r) + cn.awaitSynchronizationPoint() + panic(err) + default: + cn.bad = true + errorf("unexpected Bind response %q", t) + } +} + +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, cols []string, rowTyps []oid.Oid) { + t, r := cn.recv1() + switch t { + case 't': + nparams := r.int16() + paramTyps = make([]oid.Oid, nparams) + for i := range paramTyps { + paramTyps[i] = r.oid() + } + case 'E': + err := parseError(r) + cn.awaitSynchronizationPoint() + panic(err) + default: + cn.bad = true + errorf("unexpected Describe response %q", t) + } + + // Cheat a bit since the result should be exactly the same as when + // describing a portal. + cols, rowTyps = cn.readPortalDescribeResponse() + return paramTyps, cols, rowTyps +} + +func (cn *conn) readPortalDescribeResponse() (cols []string, rowTyps []oid.Oid) { + t, r := cn.recv1() + switch t { + case 'T': + return parseMeta(r) + case 'n': + return nil, nil + case 'E': + err := parseError(r) + cn.awaitSynchronizationPoint() + panic(err) + default: + cn.bad = true + errorf("unexpected Describe response %q", t) + } + panic("not reached") +} +func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { for { t, r := cn.recv1() switch t { @@ -488,24 +582,29 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) - // done - return + return res, commandTag, err case 'E': err = parseError(r) case 'T', 'D', 'I': // ignore any results default: cn.bad = true - errorf("unknown response for simple query: %q", t) + errorf("unknown %s response: %q", protocolState, t) } } } + +func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { + b := cn.writeBuf('Q') + b.string(q) + cn.send(b) + return cn.readExecuteResponse("simple query") +} + func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { defer cn.errRecover(&err) - st := &stmt{cn: cn, name: ""} - b := cn.writeBuf('Q') b.string(q) cn.send(b) @@ -522,7 +621,7 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { cn.bad = true errorf("unexpected message %q in simple query execution", t) } - res = &rows{st: st, done: true} + res = &rows{cn: cn, done: true} case 'Z': cn.processReadyForQuery(r) // done @@ -541,8 +640,9 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { case 'T': // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it - res = &rows{st: st} - st.cols, st.rowTyps = parseMeta(r) + rs := &rows{cn: cn} + rs.cols, rs.rowTyps = parseMeta(r) + res = rs // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -560,34 +660,28 @@ func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { b.string(st.name) b.string(q) b.int16(0) - cn.send(b) - b = cn.writeBuf('D') + b.next('D') b.byte('S') b.string(st.name) + + b.next('S') cn.send(b) - cn.send(cn.writeBuf('S')) + cn.readParseResponse() + st.paramTyps, st.cols, st.rowTyps = cn.readStatementDescribeResponse() for { t, r := cn.recv1() switch t { - case '1': - case 't': - nparams := r.int16() - st.paramTyps = make([]oid.Oid, nparams) - - for i := range st.paramTyps { - st.paramTyps[i] = r.oid() - } - case 'T': - st.cols, st.rowTyps = parseMeta(r) - case 'n': - // no data case 'Z': cn.processReadyForQuery(r) return st, err case 'E': + if err != nil { + cn.bad = true + errorf("unexpected ErrorResponse during extended query") + } err = parseError(r) default: cn.bad = true @@ -637,17 +731,121 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err return cn.simpleQuery(query) } - st, err := cn.prepareTo(query, "") - if err != nil { - panic(err) + if cn.binary_mode { + cn.sendBinaryModeQuery(query, args) + + cn.readParseResponse() + cn.readBindResponse() + rows := &rows{cn: cn} + rows.cols, rows.rowTyps = cn.readPortalDescribeResponse() + cn.postExecute() + return rows, nil + } else { + st, err := cn.prepareTo(query, "") + if err != nil { + panic(err) + } + + st.exec(args) + return &rows{ + cn: cn, + cols: st.cols, + rowTyps: st.rowTyps, + }, nil } + panic("not reached") +} - st.exec(args) - return &rows{st: st}, nil +func (cn *conn) postExecute() { + // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores + // any errors from rows.Next, which masks errors that happened during the + // execution of the query. To avoid the problem in common cases, we wait + // here for one more message from the database. If it's not an error the + // query will likely succeed (or perhaps has already, if it's a + // CommandComplete), so we push the message into the conn struct; recv1 + // will return it as the next message for rows.Next or rows.Close. + // However, if it's an error, we wait until ReadyForQuery and then return + // the error to our caller. + var err error + for { + t, r := cn.recv1() + switch t { + case 'E': + err = parseError(r) + cn.awaitSynchronizationPoint() + panic(err) + case 'C', 'D', 'I': + // the query didn't fail, but we can't process this message + cn.saveMessage(t, r) + return + default: + cn.bad = true + errorf("unexpected message during extended query execution: %q", t) + } + } +} + +func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { + if len(args) >= 65536 { + errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) + } + + b := cn.writeBuf('P') + b.byte(0) // unnamed statement + b.string(query) + b.int16(0) + + b.next('B') + b.int16(0) // unnamed portal and statement + + // Do one pass over the parameters to see if we're going to send any of + // them over in binary. If we are, create a paramFormats array at the + // same time. + var paramFormats []int + for i, x := range args { + _, ok := x.([]byte) + if ok { + if paramFormats == nil { + paramFormats = make([]int, len(args)) + } + paramFormats[i] = 1 + } + } + if paramFormats == nil { + b.int16(0) + } else { + b.int16(len(paramFormats)) + for _, x := range paramFormats { + b.int16(x) + } + } + + b.int16(len(args)) + for _, x := range args { + if x == nil { + b.int32(-1) + } else { + datum := binaryEncode(&cn.parameterStatus, x) + b.int32(len(datum)) + b.bytes(datum) + } + } + b.int16(0) + + b.next('D') + b.byte('P') + b.byte(0) // unnamed statement + + b.next('E') + b.byte(0) + b.int32(0) + + b.next('S') + cn.send(b) } // Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) { +func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { if cn.bad { return nil, driver.ErrBadConn } @@ -657,36 +855,47 @@ func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err er // *much* faster than going through prepare/exec if len(args) == 0 { // ignore commandTag, our caller doesn't care - r, _, err := cn.simpleExec(query) - return r, err + res, _, err = cn.simpleExec(query) + return res, err } - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st, err := cn.prepareTo(query, "") - if err != nil { - panic(err) + if cn.binary_mode { + cn.sendBinaryModeQuery(query, args) + + cn.readParseResponse() + cn.readBindResponse() + cn.readPortalDescribeResponse() + cn.postExecute() + res, _, err = cn.readExecuteResponse("Execute") + return res, err + } else { + // Not in binary mode; we need to do a full prepare/describe round-trip + // to the server to know the types of the parameter. We can still use + // the unnamed statement to defer planning until bind time so that + // value-based selectivity estimates can be used. + st, err := cn.prepareTo(query, "") + if err != nil { + panic(err) + } + return st.Exec(args) } + panic("not reached") +} - r, err := st.Exec(args) +func (cn *conn) send(m *writeBuf) { + _, err := cn.c.Write(m.wrap()) if err != nil { panic(err) } - - return r, err } -// Assumes len(*m) is > 5 -func (cn *conn) send(m *writeBuf) { - b := (*m)[1:] - binary.BigEndian.PutUint32(b, uint32(len(b))) - - if (*m)[0] == 0 { - *m = b +func (cn *conn) sendStartupPacket(m *writeBuf) { + // sanity check + if m.buf[0] != 0 { + panic("oops") } - _, err := cn.c.Write(*m) + _, err := cn.c.Write((m.wrap())[1:]) if err != nil { panic(err) } @@ -826,7 +1035,7 @@ func (cn *conn) ssl(o values) { w := cn.writeBuf(0) w.int32(80877103) - cn.send(w) + cn.sendStartupPacket(w) b := cn.scratch[:1] _, err := io.ReadFull(cn.c, b) @@ -963,6 +1172,8 @@ func isDriverSetting(key string) bool { return true case "connect_timeout": return true + case "binary_mode": + return true default: return false @@ -990,7 +1201,7 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.send(w) + cn.sendStartupPacket(w) for { t, r := cn.recv() @@ -1094,7 +1305,11 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { defer st.cn.errRecover(&err) st.exec(v) - return &rows{st: st}, nil + return &rows{ + cn: st.cn, + cols: st.cols, + rowTyps: st.rowTyps, + }, nil } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { @@ -1104,25 +1319,8 @@ func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { defer st.cn.errRecover(&err) st.exec(v) - - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case 'C': - res, _ = st.cn.parseComplete(r.string()) - case 'Z': - st.cn.processReadyForQuery(r) - // done - return - case 'T', 'D', 'I': - // ignore any results - default: - st.cn.bad = true - errorf("unknown exec response: %q", t) - } - } + res, _, err = st.cn.readExecuteResponse("simple query") + return res, err } func (st *stmt) exec(v []driver.Value) { @@ -1134,7 +1332,7 @@ func (st *stmt) exec(v []driver.Value) { } w := st.cn.writeBuf('B') - w.string("") + w.byte(0) // unnamed portal w.string(st.name) w.int16(0) w.int16(len(v)) @@ -1148,69 +1346,16 @@ func (st *stmt) exec(v []driver.Value) { } } w.int16(0) - st.cn.send(w) - w = st.cn.writeBuf('E') - w.string("") + w.next('E') + w.byte(0) w.int32(0) - st.cn.send(w) - - st.cn.send(st.cn.writeBuf('S')) - var err error - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case '2': - if err != nil { - panic(err) - } - goto workaround - case 'Z': - st.cn.processReadyForQuery(r) - if err != nil { - panic(err) - } - return - default: - st.cn.bad = true - errorf("unexpected bind response: %q", t) - } - } + w.next('S') + st.cn.send(w) - // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores - // any errors from rows.Next, which masks errors that happened during the - // execution of the query. To avoid the problem in common cases, we wait - // here for one more message from the database. If it's not an error the - // query will likely succeed (or perhaps has already, if it's a - // CommandComplete), so we push the message into the conn struct; recv1 - // will return it as the next message for rows.Next or rows.Close. - // However, if it's an error, we wait until ReadyForQuery and then return - // the error to our caller. -workaround: - for { - t, r := st.cn.recv1() - switch t { - case 'E': - err = parseError(r) - case 'C', 'D', 'I': - // the query didn't fail, but we can't process this message - st.cn.saveMessage(t, r) - return - case 'Z': - if err == nil { - st.cn.bad = true - errorf("unexpected ReadyForQuery during extended query execution") - } - st.cn.processReadyForQuery(r) - panic(err) - default: - st.cn.bad = true - errorf("unexpected message during query execution: %q", t) - } - } + st.cn.readBindResponse() + st.cn.postExecute() } func (st *stmt) NumInput() int { @@ -1267,7 +1412,9 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } type rows struct { - st *stmt + cn *conn + cols []string + rowTyps []oid.Oid done bool rb readBuf } @@ -1287,7 +1434,7 @@ func (rs *rows) Close() error { } func (rs *rows) Columns() []string { - return rs.st.cols + return rs.cols } func (rs *rows) Next(dest []driver.Value) (err error) { @@ -1295,7 +1442,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { return io.EOF } - conn := rs.st.cn + conn := rs.cn if conn.bad { return driver.ErrBadConn } @@ -1326,7 +1473,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.st.rowTyps[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.rowTyps[i]) } return default: diff --git a/encode.go b/encode.go index 556986a4..e706d31e 100644 --- a/encode.go +++ b/encode.go @@ -14,6 +14,16 @@ import ( "github.com/lib/pq/oid" ) +func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { + switch v := x.(type) { + case []byte: + return v + default: + return encode(parameterStatus, x, oid.T_unknown) + } + panic("not reached") +} + func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { switch v := x.(type) { case int64: diff --git a/notify.go b/notify.go index e3b08d59..1070ed6e 100644 --- a/notify.go +++ b/notify.go @@ -247,8 +247,10 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { // Can't use l.cn.writeBuf here because it uses the scratch buffer which // might get overwritten by listenerConnLoop. - data := writeBuf([]byte("Q\x00\x00\x00\x00")) - b := &data + b := &writeBuf{ + buf: []byte("Q\x00\x00\x00\x00"), + pos: 1, + } b.string(q) l.cn.send(b)