Skip to content

Commit

Permalink
database/sql: guard against panics in driver.Conn
Browse files Browse the repository at this point in the history
The existing implementation may deadlock if the driver.Conn
implementation panics, so use defers to ensure mutexes are unlocked
even if the driver panics.

For golang#13677, but there is more to do.
  • Loading branch information
jbowens committed Aug 29, 2016
1 parent 5a6f973 commit a9c7a27
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 69 deletions.
23 changes: 16 additions & 7 deletions src/database/sql/fakedb_test.go
Expand Up @@ -33,8 +33,8 @@ var _ = log.Printf
// INSERT|<tablename>|col=val,col2=val2,col3=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
//
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
// Any of these can be preceded by PANIC|<type>.<method>| to cause
// the named method on fakeStmt or fakeConn to panic.
//
// When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only.
Expand Down Expand Up @@ -347,6 +347,9 @@ func checkSubsetTypes(args []driver.Value) error {
}

func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if strings.HasPrefix(query, "PANIC|fakeConn.Exec") {
panic("fakeConn.Exec")
}
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
Expand All @@ -359,6 +362,9 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error
}

func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if strings.HasPrefix(query, "PANIC|fakeConn.Query") {
panic("fakeConn.Query")
}
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
Expand Down Expand Up @@ -483,6 +489,9 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
var hookPrepareBadConn func() bool

func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
if strings.HasPrefix(query, "PANIC|fakeConn.Prepare") {
panic("fakeConn.Prepare")
}
c.numPrepare++
if c.db == nil {
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
Expand Down Expand Up @@ -527,7 +536,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
}

func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
if s.panic == "ColumnConverter" {
if s.panic == "fakeStmt.ColumnConverter" {
panic(s.panic)
}
if len(s.placeholderConverter) == 0 {
Expand All @@ -537,7 +546,7 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
}

func (s *fakeStmt) Close() error {
if s.panic == "Close" {
if s.panic == "fakeStmt.Close" {
panic(s.panic)
}
if s.c == nil {
Expand All @@ -559,7 +568,7 @@ var errClosed = errors.New("fakedb: statement has been closed")
var hookExecBadConn func() bool

func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
if s.panic == "Exec" {
if s.panic == "fakeStmt.Exec" {
panic(s.panic)
}
if s.closed {
Expand Down Expand Up @@ -646,7 +655,7 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
var hookQueryBadConn func() bool

func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
if s.panic == "Query" {
if s.panic == "fakeStmt.Query" {
panic(s.panic)
}
if s.closed {
Expand Down Expand Up @@ -731,7 +740,7 @@ rows:
}

func (s *fakeStmt) NumInput() int {
if s.panic == "NumInput" {
if s.panic == "fakeStmt.NumInput" {
panic(s.panic)
}
return s.placeholders
Expand Down
92 changes: 42 additions & 50 deletions src/database/sql/sql.go
Expand Up @@ -297,7 +297,9 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}

func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
func (dc *driverConn) prepare(query string) (driver.Stmt, error) {
dc.Lock()
defer dc.Unlock()
si, err := dc.ci.Prepare(query)
if err == nil {
// Track each driverConn's open statements, so we can close them
Expand Down Expand Up @@ -983,9 +985,7 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
if err != nil {
return nil, err
}
dc.Lock()
si, err := dc.prepareLocked(query)
dc.Unlock()
si, err := dc.prepare(query)
if err != nil {
db.putConn(dc, err)
return nil, err
Expand Down Expand Up @@ -1018,6 +1018,32 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
return res, err
}

func driverExec(dc *driverConn, query string, args []interface{}) (res Result, err error) {
execer, ok := dc.ci.(driver.Execer)
if !ok {
return res, driver.ErrSkip
}

dargs, err := driverArgs(nil, args)
if err != nil {
return nil, err
}

dc.Lock()
defer dc.Unlock()
resi, err := execer.Exec(query, dargs)
if err != nil {
return res, err
}
return driverResult{dc, resi}, nil
}

func driverPrepare(dc *driverConn, query string) (si driver.Stmt, err error) {
dc.Lock()
defer dc.Unlock()
return dc.ci.Prepare(query)
}

func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
dc, err := db.conn(strategy)
if err != nil {
Expand All @@ -1027,25 +1053,12 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
db.putConn(dc, err)
}()

if execer, ok := dc.ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
return nil, err
}
dc.Lock()
resi, err := execer.Exec(query, dargs)
dc.Unlock()
if err != driver.ErrSkip {
if err != nil {
return nil, err
}
return driverResult{dc, resi}, nil
}
res, err = driverExec(dc, query, args)
if err != driver.ErrSkip {
return res, err
}

dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
si, err := driverPrepare(dc, query)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1107,9 +1120,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
}
}

dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
si, err := driverPrepare(dc, query)
if err != nil {
releaseConn(err)
return nil, err
Expand Down Expand Up @@ -1299,9 +1310,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return nil, err
}

dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
si, err := driverPrepare(dc, query)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1346,9 +1355,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if err != nil {
return &Stmt{stickyErr: err}
}
dc.Lock()
si, err := dc.ci.Prepare(stmt.query)
dc.Unlock()
si, err := driverPrepare(dc, stmt.query)
txs := &Stmt{
db: tx.db,
tx: tx,
Expand All @@ -1373,25 +1380,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
return nil, err
}

if execer, ok := dc.ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
return nil, err
}
dc.Lock()
resi, err := execer.Exec(query, dargs)
dc.Unlock()
if err == nil {
return driverResult{dc, resi}, nil
}
if err != driver.ErrSkip {
return nil, err
}
res, err := driverExec(dc, query, args)
if err != driver.ErrSkip {
return res, err
}

dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
si, err := driverPrepare(dc, query)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1578,9 +1572,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
s.mu.Unlock()

// No luck; we need to prepare the statement on this connection
dc.Lock()
si, err = dc.prepareLocked(s.query)
dc.Unlock()
si, err = dc.prepare(s.query)
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
Expand Down
39 changes: 27 additions & 12 deletions src/database/sql/sql_test.go
Expand Up @@ -84,28 +84,43 @@ func TestDriverPanic(t *testing.T) {
f()
}

expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") })
// db.Exec - panics in driver.Stmt
expectPanic("Exec fakeStmt.Exec", func() { db.Exec("PANIC|fakeStmt.Exec|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") })
expectPanic("Exec fakeStmt.NumInput", func() { db.Exec("PANIC|fakeStmt.NumInput|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Exec fakeStmt.Close", func() { db.Exec("PANIC|fakeStmt.Close|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
exec(t, db, "PANIC|fakeStmt.Query|WIPE") // should run successfully: Exec does not call Query
exec(t, db, "WIPE") // check not deadlocked

// db.Exec - panics in driver.Conn
expectPanic("Exec fakeConn.Exec", func() { db.Exec("PANIC|fakeConn.Exec|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Exec fakeConn.Prepare", func() { db.Exec("PANIC|fakeConn.Prepare|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
exec(t, db, "PANIC|fakeConn.Query|WIPE") // should run successfully: Exec does not call Query
exec(t, db, "WIPE") // check not deadlocked

exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")

expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") })
expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") })
expectPanic("Query Close", func() {
rows, err := db.Query("PANIC|Close|SELECT|people|age,name|")
// db.Query - panics in driver.Stmt
expectPanic("Query fakeStmt.Query", func() { db.Query("PANIC|fakeStmt.Query|SELECT|people|age,name|") })
expectPanic("Query fakeStmt.NumInput", func() { db.Query("PANIC|fakeStmt.NumInput|SELECT|people|age,name|") })
expectPanic("Query fakeStmt.Close", func() {
rows, err := db.Query("PANIC|fakeStmt.Close|SELECT|people|age,name|")
if err != nil {
t.Fatal(err)
}
rows.Close()
})
db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Query fakeConn.Query", func() { db.Query("PANIC|fakeConn.Query|SELECT|people|age,name|") })
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Query fakeConn.Prepare", func() { db.Exec("PANIC|fakeConn.Prepare|WIPE") })
exec(t, db, "WIPE") // check not deadlocked

db.Query("PANIC|fakeStmt.Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec
exec(t, db, "WIPE") // check not deadlocked
}

func exec(t testing.TB, db *DB, query string, args ...interface{}) {
Expand Down

0 comments on commit a9c7a27

Please sign in to comment.