Skip to content

Commit

Permalink
Fix query cancellation collateralizing future queries using the same …
Browse files Browse the repository at this point in the history
…connection
  • Loading branch information
kahuang committed Nov 23, 2020
1 parent 083382b commit f0b8e15
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 56 deletions.
89 changes: 53 additions & 36 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"time"
"unicode"

Expand Down Expand Up @@ -136,7 +137,7 @@ type conn struct {

// If true, this connection is bad and all public-facing functions should
// return ErrBadConn.
bad bool
bad *atomic.Value

// If set, this connection should never use the binary format when
// receiving query results from prepared statements. Only provided for
Expand Down Expand Up @@ -294,9 +295,12 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) {

o := c.opts

bad := &atomic.Value{}
bad.Store(false)
cn = &conn{
opts: o,
dialer: c.dialer,
bad: bad,
}
err = cn.handleDriverSettings(o)
if err != nil {
Expand Down Expand Up @@ -501,9 +505,22 @@ func (cn *conn) isInTransaction() bool {
cn.txnStatus == txnStatusInFailedTransaction
}

func (cn *conn) setBad() {
if cn.bad != nil {
cn.bad.Store(true)
}
}

func (cn *conn) getBad() bool {
if cn.bad != nil {
return cn.bad.Load().(bool)
}
return false
}

func (cn *conn) checkIsInTransaction(intxn bool) {
if cn.isInTransaction() != intxn {
cn.bad = true
cn.setBad()
errorf("unexpected transaction status %v", cn.txnStatus)
}
}
Expand All @@ -513,7 +530,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) {
}

func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand All @@ -524,11 +541,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
return nil, err
}
if commandTag != "BEGIN" {
cn.bad = true
cn.setBad()
return nil, fmt.Errorf("unexpected command tag %s", commandTag)
}
if cn.txnStatus != txnStatusIdleInTransaction {
cn.bad = true
cn.setBad()
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
}
return cn, nil
Expand All @@ -542,7 +559,7 @@ func (cn *conn) closeTxn() {

func (cn *conn) Commit() (err error) {
defer cn.closeTxn()
if cn.bad {
if cn.getBad() {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand All @@ -564,12 +581,12 @@ func (cn *conn) Commit() (err error) {
_, commandTag, err := cn.simpleExec("COMMIT")
if err != nil {
if cn.isInTransaction() {
cn.bad = true
cn.setBad()
}
return err
}
if commandTag != "COMMIT" {
cn.bad = true
cn.setBad()
return fmt.Errorf("unexpected command tag %s", commandTag)
}
cn.checkIsInTransaction(false)
Expand All @@ -578,7 +595,7 @@ func (cn *conn) Commit() (err error) {

func (cn *conn) Rollback() (err error) {
defer cn.closeTxn()
if cn.bad {
if cn.getBad() {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand All @@ -590,7 +607,7 @@ func (cn *conn) rollback() (err error) {
_, commandTag, err := cn.simpleExec("ROLLBACK")
if err != nil {
if cn.isInTransaction() {
cn.bad = true
cn.setBad()
}
return err
}
Expand Down Expand Up @@ -630,7 +647,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
case 'T', 'D':
// ignore any results
default:
cn.bad = true
cn.setBad()
errorf("unknown response for simple query: %q", t)
}
}
Expand All @@ -652,7 +669,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
cn.setBad()
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
Expand All @@ -676,7 +693,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
err = parseError(r)
case 'D':
if res == nil {
cn.bad = true
cn.setBad()
errorf("unexpected DataRow in simple query execution")
}
// the query didn't fail; kick off to Next
Expand All @@ -691,7 +708,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
default:
cn.bad = true
cn.setBad()
errorf("unknown response for simple query: %q", t)
}
}
Expand Down Expand Up @@ -784,7 +801,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
}

func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand Down Expand Up @@ -823,7 +840,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
}

func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
if cn.inCopy {
Expand Down Expand Up @@ -857,7 +874,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {

// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand Down Expand Up @@ -918,7 +935,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) {
// the message yourself.
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
if cn.saveMessageType != 0 {
cn.bad = true
cn.setBad()
errorf("unexpected saveMessageType %d", cn.saveMessageType)
}
cn.saveMessageType = typ
Expand Down Expand Up @@ -1288,7 +1305,7 @@ func (st *stmt) Close() (err error) {
if st.closed {
return nil
}
if st.cn.bad {
if st.cn.getBad() {
return driver.ErrBadConn
}
defer st.cn.errRecover(&err)
Expand All @@ -1302,14 +1319,14 @@ func (st *stmt) Close() (err error) {

t, _ := st.cn.recv1()
if t != '3' {
st.cn.bad = true
st.cn.setBad()
errorf("unexpected close response: %q", t)
}
st.closed = true

t, r := st.cn.recv1()
if t != 'Z' {
st.cn.bad = true
st.cn.setBad()
errorf("expected ready for query, but got: %q", t)
}
st.cn.processReadyForQuery(r)
Expand All @@ -1318,7 +1335,7 @@ func (st *stmt) Close() (err error) {
}

func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
if st.cn.bad {
if st.cn.getBad() {
return nil, driver.ErrBadConn
}
defer st.cn.errRecover(&err)
Expand All @@ -1331,7 +1348,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
}

func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
if st.cn.bad {
if st.cn.getBad() {
return nil, driver.ErrBadConn
}
defer st.cn.errRecover(&err)
Expand Down Expand Up @@ -1418,7 +1435,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
parts := strings.Split(commandTag, " ")
if len(parts) != 3 {
cn.bad = true
cn.setBad()
errorf("unexpected INSERT command tag %s", commandTag)
}
affectedRows = &parts[len(parts)-1]
Expand All @@ -1430,7 +1447,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
}
n, err := strconv.ParseInt(*affectedRows, 10, 64)
if err != nil {
cn.bad = true
cn.setBad()
errorf("could not parse commandTag: %s", err)
}
return driver.RowsAffected(n), commandTag
Expand Down Expand Up @@ -1497,7 +1514,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}

conn := rs.cn
if conn.bad {
if conn.getBad() {
return driver.ErrBadConn
}
defer conn.errRecover(&err)
Expand All @@ -1522,7 +1539,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
case 'D':
n := rs.rb.int16()
if err != nil {
conn.bad = true
conn.setBad()
errorf("unexpected DataRow after error %s", err)
}
if n < len(dest) {
Expand Down Expand Up @@ -1717,7 +1734,7 @@ func (cn *conn) readReadyForQuery() {
cn.processReadyForQuery(r)
return
default:
cn.bad = true
cn.setBad()
errorf("unexpected message %q; expected ReadyForQuery", t)
}
}
Expand All @@ -1737,7 +1754,7 @@ func (cn *conn) readParseResponse() {
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Parse response %q", t)
}
}
Expand All @@ -1762,7 +1779,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Describe statement response %q", t)
}
}
Expand All @@ -1780,7 +1797,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader {
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Describe response %q", t)
}
panic("not reached")
Expand All @@ -1796,7 +1813,7 @@ func (cn *conn) readBindResponse() {
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Bind response %q", t)
}
}
Expand All @@ -1823,7 +1840,7 @@ func (cn *conn) postExecuteWorkaround() {
cn.saveMessage(t, r)
return
default:
cn.bad = true
cn.setBad()
errorf("unexpected message during extended query execution: %q", t)
}
}
Expand All @@ -1836,7 +1853,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
switch t {
case 'C':
if err != nil {
cn.bad = true
cn.setBad()
errorf("unexpected CommandComplete after error %s", err)
}
res, commandTag = cn.parseComplete(r.string())
Expand All @@ -1850,15 +1867,15 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
err = parseError(r)
case 'T', 'D', 'I':
if err != nil {
cn.bad = true
cn.setBad()
errorf("unexpected %q after error %s", t, err)
}
if t == 'I' {
res = emptyRows
}
// ignore any results
default:
cn.bad = true
cn.setBad()
errorf("unknown %s response: %q", protocolState, t)
}
}
Expand Down
Loading

0 comments on commit f0b8e15

Please sign in to comment.