Skip to content

Commit

Permalink
database/sql: don't close a driver.Conn until its Stmts are closed
Browse files Browse the repository at this point in the history
Fixes #5046

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/8016044
  • Loading branch information
bradfitz committed Mar 25, 2013
1 parent f20f8b8 commit 209f6b1
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 20 deletions.
21 changes: 20 additions & 1 deletion src/pkg/database/sql/fakedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,26 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
return c.currTx, nil
}

func (c *fakeConn) Close() error {
var hookPostCloseConn struct {
sync.Mutex
fn func(*fakeConn, error)
}

func setHookpostCloseConn(fn func(*fakeConn, error)) {
hookPostCloseConn.Lock()
defer hookPostCloseConn.Unlock()
hookPostCloseConn.fn = fn
}

func (c *fakeConn) Close() (err error) {
defer func() {
hookPostCloseConn.Lock()
fn := hookPostCloseConn.fn
hookPostCloseConn.Unlock()
if fn != nil {
fn(c, err)
}
}()
if c.currTx != nil {
return errors.New("can't close fakeConn; in a Transaction")
}
Expand Down
82 changes: 63 additions & 19 deletions src/pkg/database/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,42 @@ type DB struct {
// interfaces returned via that Conn, such as calls on Tx, Stmt,
// Result, Rows)
type driverConn struct {
sync.Mutex
ci driver.Conn
db *DB

sync.Mutex // guards following
ci driver.Conn
closed bool
}

// the dc.db's Mutex is held.
func (dc *driverConn) closeDBLocked() error {
dc.Lock()
if dc.closed {
dc.Unlock()
return errors.New("sql: duplicate driverConn close")
}
dc.closed = true
dc.Unlock() // not defer; removeDep finalClose calls may need to lock
return dc.db.removeDepLocked(dc, dc)()
}

func (dc *driverConn) Close() error {
dc.Lock()
if dc.closed {
dc.Unlock()
return errors.New("sql: duplicate driverConn close")
}
dc.closed = true
dc.Unlock() // not defer; removeDep finalClose calls may need to lock
return dc.db.removeDep(dc, dc)
}

func (dc *driverConn) finalClose() error {
dc.Lock()
err := dc.ci.Close()
dc.ci = nil
dc.Unlock()
return err
}

// driverStmt associates a driver.Stmt with the
Expand Down Expand Up @@ -238,6 +272,10 @@ func (db *DB) addDep(x finalCloser, dep interface{}) {
//println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
db.mu.Lock()
defer db.mu.Unlock()
db.addDepLocked(x, dep)
}

func (db *DB) addDepLocked(x finalCloser, dep interface{}) {
if db.dep == nil {
db.dep = make(map[finalCloser]depSet)
}
Expand All @@ -254,10 +292,16 @@ func (db *DB) addDep(x finalCloser, dep interface{}) {
// If x no longer has any dependencies, its finalClose method will be
// called and its error value will be returned.
func (db *DB) removeDep(x finalCloser, dep interface{}) error {
db.mu.Lock()
fn := db.removeDepLocked(x, dep)
db.mu.Unlock()
return fn()
}

func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
//println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
done := false

db.mu.Lock()
xdep := db.dep[x]
if xdep != nil {
delete(xdep, dep)
Expand All @@ -266,13 +310,14 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error {
done = true
}
}
db.mu.Unlock()

if !done {
return nil
return func() error { return nil }
}
return func() error {
//println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
return x.finalClose()
}
//println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
return x.finalClose()
}

// Open opens a database specified by its database driver name and a
Expand Down Expand Up @@ -320,9 +365,7 @@ func (db *DB) Close() error {
defer db.mu.Unlock()
var err error
for _, dc := range db.freeConn {
dc.Lock()
err1 := dc.ci.Close()
dc.Unlock()
err1 := dc.closeDBLocked()
if err1 != nil {
err = err1
}
Expand Down Expand Up @@ -365,11 +408,7 @@ func (db *DB) SetMaxIdleConns(n int) {
dc := db.freeConn[nfree-1]
db.freeConn[nfree-1] = nil
db.freeConn = db.freeConn[:nfree-1]
go func() {
dc.Lock()
dc.ci.Close()
dc.Unlock()
}()
go dc.Close()
}
}

Expand All @@ -393,8 +432,12 @@ func (db *DB) conn() (*driverConn, error) {
if err != nil {
return nil, err
}
dc := &driverConn{ci: ci}
dc := &driverConn{
db: db,
ci: ci,
}
db.mu.Lock()
db.addDepLocked(dc, dc)
db.outConn[dc] = true
db.mu.Unlock()
return dc, nil
Expand Down Expand Up @@ -484,9 +527,7 @@ func (db *DB) putConn(dc *driverConn, err error) {
// statements which are still active?
db.mu.Unlock()

dc.Lock()
dc.ci.Close()
dc.Unlock()
dc.Close()
}

// Prepare creates a prepared statement for later queries or executions.
Expand Down Expand Up @@ -528,6 +569,7 @@ func (db *DB) prepare(query string) (*Stmt, error) {
css: []connStmt{{dc, si}},
}
db.addDep(stmt, stmt)
db.addDep(dc, stmt)
db.putConn(dc, nil)
return stmt, nil
}
Expand Down Expand Up @@ -1031,6 +1073,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
if err != nil {
return nil, nil, nil, err
}
s.db.addDep(dc, s)
s.mu.Lock()
cs = connStmt{dc, si}
s.css = append(s.css, cs)
Expand Down Expand Up @@ -1149,6 +1192,7 @@ func (s *Stmt) Close() error {
func (s *Stmt) finalClose() error {
for _, v := range s.css {
s.db.noteUnusedDriverStatement(v.dc, v.si)
s.db.removeDep(v.dc, s)
}
s.css = nil
return nil
Expand Down
54 changes: 54 additions & 0 deletions src/pkg/database/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ func closeDB(t *testing.T, db *DB) {
fmt.Printf("Panic: %v\n", e)
panic(e)
}
defer setHookpostCloseConn(nil)
setHookpostCloseConn(func(_ *fakeConn, err error) {
if err != nil {
t.Errorf("Error closing fakeConn: %v", err)
}
})
err := db.Close()
if err != nil {
t.Fatalf("error closing DB: %v", err)
Expand Down Expand Up @@ -790,3 +796,51 @@ func TestMaxIdleConns(t *testing.T) {
t.Errorf("freeConns = %d; want 0", got)
}
}

// golang.org/issue/5046
func TestCloseConnBeforeStmts(t *testing.T) {
defer setHookpostCloseConn(nil)
setHookpostCloseConn(func(_ *fakeConn, err error) {
if err != nil {
t.Errorf("Error closing fakeConn: %v", err)
}
})

db := newTestDB(t, "people")

stmt, err := db.Prepare("SELECT|people|name|")
if err != nil {
t.Fatal(err)
}

if len(db.freeConn) != 1 {
t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
}
dc := db.freeConn[0]
if dc.closed {
t.Errorf("conn shouldn't be closed")
}

err = db.Close()
if err != nil {
t.Errorf("db Close = %v", err)
}
if !dc.closed {
t.Errorf("after db.Close, driverConn should be closed")
}
if dc.ci == nil {
t.Errorf("after db.Close, driverConn should still have its Conn interface")
}

err = stmt.Close()
if err != nil {
t.Errorf("Stmt close = %v", err)
}

if !dc.closed {
t.Errorf("conn should be closed")
}
if dc.ci != nil {
t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
}
}

0 comments on commit 209f6b1

Please sign in to comment.