diff --git a/conn.go b/conn.go index e050d535..2753ddb7 100644 --- a/conn.go +++ b/conn.go @@ -867,12 +867,20 @@ func (cn *conn) Close() (err error) { return cn.sendSimpleMessage('X') } +func toNamedValue(v []driver.Value) []driver.NamedValue { + v2 := make([]driver.NamedValue, len(v)) + for i := range v { + v2[i] = driver.NamedValue{Value: v[i]} + } + return v2 +} + // Implement the "Queryer" interface func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - return cn.query(query, args) + return cn.query(query, toNamedValue(args)) } -func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { +func (cn *conn) query(query string, args []driver.NamedValue) (_ *rows, err error) { if err := cn.err.get(); err != nil { return nil, err } @@ -921,7 +929,7 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err } if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) + cn.sendBinaryModeQuery(query, toNamedValue(args)) cn.readParseResponse() cn.readBindResponse() @@ -1379,10 +1387,10 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - return st.query(v) + return st.query(toNamedValue(v)) } -func (st *stmt) query(v []driver.Value) (r *rows, err error) { +func (st *stmt) query(v []driver.NamedValue) (r *rows, err error) { if err := st.cn.err.get(); err != nil { return nil, err } @@ -1395,18 +1403,11 @@ func (st *stmt) query(v []driver.Value) (r *rows, err error) { }, nil } -func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - res, _, err = st.cn.readExecuteResponse("simple query") - return res, err +func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { + return st.ExecContext(context.Background(), toNamedValue(v)) } -func (st *stmt) exec(v []driver.Value) { +func (st *stmt) exec(v []driver.NamedValue) { if len(v) >= 65536 { errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) } @@ -1425,10 +1426,10 @@ func (st *stmt) exec(v []driver.Value) { w.int16(0) w.int16(len(v)) for i, x := range v { - if x == nil { + if x.Value == nil { w.int32(-1) } else { - b := encode(&cn.parameterStatus, x, st.paramTyps[i]) + b := encode(&cn.parameterStatus, x.Value, st.paramTyps[i]) w.int32(len(b)) w.bytes(b) } @@ -1684,13 +1685,13 @@ func md5s(s string) string { return fmt.Sprintf("%x", h.Sum(nil)) } -func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { +func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) { // Do one pass over the parameters to see if we're going to send any of // them over in binary. If we are, create a paramFormats array at the // same time. var paramFormats []int for i, x := range args { - _, ok := x.([]byte) + _, ok := x.Value.([]byte) if ok { if paramFormats == nil { paramFormats = make([]int, len(args)) @@ -1709,17 +1710,17 @@ func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { b.int16(len(args)) for _, x := range args { - if x == nil { + if x.Value == nil { b.int32(-1) } else { - datum := binaryEncode(&cn.parameterStatus, x) + datum := binaryEncode(&cn.parameterStatus, x.Value) b.int32(len(datum)) b.bytes(datum) } } } -func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { +func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) { if len(args) >= 65536 { errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) } diff --git a/conn_go18.go b/conn_go18.go index 63d4ca6a..5ec12f63 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -16,12 +16,8 @@ const ( // Implement the "QueryerContext" interface func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } finish := cn.watchCancel(ctx) - r, err := cn.query(query, list) + r, err := cn.query(query, args) if err != nil { if finish != nil { finish() @@ -183,12 +179,8 @@ func (cn *conn) cancel(ctx context.Context) error { // Implement the "StmtQueryContext" interface func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } finish := st.watchCancel(ctx) - r, err := st.query(list) + r, err := st.query(args) if err != nil { if finish != nil { finish() @@ -200,17 +192,19 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri } // Implement the "StmtExecContext" interface -func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - +func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { if finish := st.watchCancel(ctx); finish != nil { defer finish() } - return st.Exec(list) + if err := st.cn.err.get(); err != nil { + return nil, err + } + defer st.cn.errRecover(&err) + + st.exec(args) + res, _, err = st.cn.readExecuteResponse("simple query") + return res, err } // watchCancel is implemented on stmt in order to not mark the parent conn as bad