diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 5b238bfc5cbcf..0db8552fdb294 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -33,8 +33,8 @@ var _ = log.Printf // INSERT||col=val,col2=val2,col3=? // SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? // -// Any of these can be preceded by PANIC||, to cause the -// named method on fakeStmt to panic. +// Any of these can be preceded by PANIC|.| to cause +// the named method on fakeStmt or fakeConn to panic. // // When opening a fakeDriver's database, it starts empty with no // tables. All tables and data are stored in memory only. @@ -347,6 +347,9 @@ func checkSubsetTypes(args []driver.Value) error { } func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if strings.HasPrefix(query, "PANIC|fakeConn.Exec") { + panic("fakeConn.Exec") + } // This is an optional interface, but it's implemented here // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -359,6 +362,9 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error } func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { + if strings.HasPrefix(query, "PANIC|fakeConn.Query") { + panic("fakeConn.Query") + } // This is an optional interface, but it's implemented here // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -483,6 +489,9 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e var hookPrepareBadConn func() bool func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + if strings.HasPrefix(query, "PANIC|fakeConn.Prepare") { + panic("fakeConn.Prepare") + } c.numPrepare++ if c.db == nil { panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) @@ -527,7 +536,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { } func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { - if s.panic == "ColumnConverter" { + if s.panic == "fakeStmt.ColumnConverter" { panic(s.panic) } if len(s.placeholderConverter) == 0 { @@ -537,7 +546,7 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { } func (s *fakeStmt) Close() error { - if s.panic == "Close" { + if s.panic == "fakeStmt.Close" { panic(s.panic) } if s.c == nil { @@ -559,7 +568,7 @@ var errClosed = errors.New("fakedb: statement has been closed") var hookExecBadConn func() bool func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { - if s.panic == "Exec" { + if s.panic == "fakeStmt.Exec" { panic(s.panic) } if s.closed { @@ -646,7 +655,7 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result var hookQueryBadConn func() bool func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { - if s.panic == "Query" { + if s.panic == "fakeStmt.Query" { panic(s.panic) } if s.closed { @@ -731,7 +740,7 @@ rows: } func (s *fakeStmt) NumInput() int { - if s.panic == "NumInput" { + if s.panic == "fakeStmt.NumInput" { panic(s.panic) } return s.placeholders diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 09de1c34e826b..ef5d96a221790 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -297,7 +297,9 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { +func (dc *driverConn) prepare(query string) (driver.Stmt, error) { + dc.Lock() + defer dc.Unlock() si, err := dc.ci.Prepare(query) if err == nil { // Track each driverConn's open statements, so we can close them @@ -983,9 +985,7 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { if err != nil { return nil, err } - dc.Lock() - si, err := dc.prepareLocked(query) - dc.Unlock() + si, err := dc.prepare(query) if err != nil { db.putConn(dc, err) return nil, err @@ -1018,6 +1018,32 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { return res, err } +func driverExec(dc *driverConn, query string, args []interface{}) (res Result, err error) { + execer, ok := dc.ci.(driver.Execer) + if !ok { + return res, driver.ErrSkip + } + + dargs, err := driverArgs(nil, args) + if err != nil { + return nil, err + } + + dc.Lock() + defer dc.Unlock() + resi, err := execer.Exec(query, dargs) + if err != nil { + return res, err + } + return driverResult{dc, resi}, nil +} + +func driverPrepare(dc *driverConn, query string) (si driver.Stmt, err error) { + dc.Lock() + defer dc.Unlock() + return dc.ci.Prepare(query) +} + func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { dc, err := db.conn(strategy) if err != nil { @@ -1027,25 +1053,12 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) db.putConn(dc, err) }() - if execer, ok := dc.ci.(driver.Execer); ok { - dargs, err := driverArgs(nil, args) - if err != nil { - return nil, err - } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() - if err != driver.ErrSkip { - if err != nil { - return nil, err - } - return driverResult{dc, resi}, nil - } + res, err = driverExec(dc, query, args) + if err != driver.ErrSkip { + return res, err } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + si, err := driverPrepare(dc, query) if err != nil { return nil, err } @@ -1107,9 +1120,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + si, err := driverPrepare(dc, query) if err != nil { releaseConn(err) return nil, err @@ -1299,9 +1310,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { return nil, err } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + si, err := driverPrepare(dc, query) if err != nil { return nil, err } @@ -1346,9 +1355,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if err != nil { return &Stmt{stickyErr: err} } - dc.Lock() - si, err := dc.ci.Prepare(stmt.query) - dc.Unlock() + si, err := driverPrepare(dc, stmt.query) txs := &Stmt{ db: tx.db, tx: tx, @@ -1373,25 +1380,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { return nil, err } - if execer, ok := dc.ci.(driver.Execer); ok { - dargs, err := driverArgs(nil, args) - if err != nil { - return nil, err - } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() - if err == nil { - return driverResult{dc, resi}, nil - } - if err != driver.ErrSkip { - return nil, err - } + res, err := driverExec(dc, query, args) + if err != driver.ErrSkip { + return res, err } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + si, err := driverPrepare(dc, query) if err != nil { return nil, err } @@ -1578,9 +1572,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St s.mu.Unlock() // No luck; we need to prepare the statement on this connection - dc.Lock() - si, err = dc.prepareLocked(s.query) - dc.Unlock() + si, err = dc.prepare(s.query) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 08df0c7666a35..72838ae1b5471 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -84,28 +84,43 @@ func TestDriverPanic(t *testing.T) { f() } - expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") }) + // db.Exec - panics in driver.Stmt + expectPanic("Exec fakeStmt.Exec", func() { db.Exec("PANIC|fakeStmt.Exec|WIPE") }) exec(t, db, "WIPE") // check not deadlocked - expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") }) + expectPanic("Exec fakeStmt.NumInput", func() { db.Exec("PANIC|fakeStmt.NumInput|WIPE") }) exec(t, db, "WIPE") // check not deadlocked - expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") }) - exec(t, db, "WIPE") // check not deadlocked - exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query - exec(t, db, "WIPE") // check not deadlocked + expectPanic("Exec fakeStmt.Close", func() { db.Exec("PANIC|fakeStmt.Close|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + exec(t, db, "PANIC|fakeStmt.Query|WIPE") // should run successfully: Exec does not call Query + exec(t, db, "WIPE") // check not deadlocked + + // db.Exec - panics in driver.Conn + expectPanic("Exec fakeConn.Exec", func() { db.Exec("PANIC|fakeConn.Exec|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + expectPanic("Exec fakeConn.Prepare", func() { db.Exec("PANIC|fakeConn.Prepare|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + exec(t, db, "PANIC|fakeConn.Query|WIPE") // should run successfully: Exec does not call Query + exec(t, db, "WIPE") // check not deadlocked exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") - expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") }) - expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") }) - expectPanic("Query Close", func() { - rows, err := db.Query("PANIC|Close|SELECT|people|age,name|") + // db.Query - panics in driver.Stmt + expectPanic("Query fakeStmt.Query", func() { db.Query("PANIC|fakeStmt.Query|SELECT|people|age,name|") }) + expectPanic("Query fakeStmt.NumInput", func() { db.Query("PANIC|fakeStmt.NumInput|SELECT|people|age,name|") }) + expectPanic("Query fakeStmt.Close", func() { + rows, err := db.Query("PANIC|fakeStmt.Close|SELECT|people|age,name|") if err != nil { t.Fatal(err) } rows.Close() }) - db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec - exec(t, db, "WIPE") // check not deadlocked + expectPanic("Query fakeConn.Query", func() { db.Query("PANIC|fakeConn.Query|SELECT|people|age,name|") }) + exec(t, db, "WIPE") // check not deadlocked + expectPanic("Query fakeConn.Prepare", func() { db.Exec("PANIC|fakeConn.Prepare|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + + db.Query("PANIC|fakeStmt.Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec + exec(t, db, "WIPE") // check not deadlocked } func exec(t testing.TB, db *DB, query string, args ...interface{}) {