Skip to content

Commit

Permalink
fix race condition in statements
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaifullin committed Aug 8, 2017
1 parent b0e2191 commit 45da150
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 58 deletions.
2 changes: 1 addition & 1 deletion clickhouse_test.go
Expand Up @@ -49,7 +49,7 @@ type chSuite struct {
func (s *chSuite) SetupSuite() {
dsn := os.Getenv("TEST_CLICKHOUSE_DSN")
if len(dsn) == 0 {
dsn = "http://localhost:8123/test"
dsn = "http://localhost:8123/default"
}
conn, err := sql.Open("clickhouse", dsn)
s.Require().NoError(err)
Expand Down
33 changes: 6 additions & 27 deletions conn.go
Expand Up @@ -88,22 +88,20 @@ func (c *conn) Commit() (err error) {
if c.txCtx == nil {
return sql.ErrTxDone
}
ctx := c.txCtx
stmts := c.stmts
c.txCtx = nil
c.stmts = stmts[:0]

if len(stmts) == 0 {
return nil
}
for _, stmt := range stmts {
c.log("commit statement: ", stmt.prefix, stmt.pattern)
if err == nil {
if err = stmt.commit(c.txCtx); err != nil {
break
}
if err = stmt.commit(ctx); err != nil {
break
}
stmt.clean()
}
c.txCtx = nil
return
}

Expand All @@ -117,13 +115,10 @@ func (c *conn) Rollback() error {
c.stmts = stmts[:0]

if len(stmts) == 0 {
// there is no statements, so nothing to rollback
return sql.ErrTxDone
}
for _, stmt := range stmts {
c.log("discard statement: ", stmt.prefix, stmt.pattern)
stmt.clean()
}
c.txCtx = nil
// the statements will be closed by sql.Tx
return nil
}

Expand Down Expand Up @@ -218,19 +213,3 @@ func (c *conn) prepare(query string) (*stmt, error) {
}
return s, nil
}

func (c *conn) closeStmt(s *stmt) {
c.log("close statement: ", s.prefix, s.pattern)
if len(c.stmts) == 0 {
return
}
newstmts := make([]*stmt, len(c.stmts))
j := 0
for _, st := range c.stmts {
if st != s {
newstmts[j] = st
j++
}
}
c.stmts = newstmts[:j]
}
1 change: 1 addition & 0 deletions conn_go18_test.go
Expand Up @@ -40,6 +40,7 @@ func (s *connSuite) TestPing() {
func (s *connSuite) TestColumnTypes() {
rows, err := s.conn.Query("SELECT * FROM data LIMIT 1")
s.Require().NoError(err)
defer rows.Close()
types, err := rows.ColumnTypes()
s.Require().NoError(err)
expected := []string{
Expand Down
3 changes: 3 additions & 0 deletions conn_test.go
Expand Up @@ -42,6 +42,7 @@ func (s *connSuite) TestQuery() {
[][]interface{}{{int64(-3), int64(1)}, {int64(-2), int64(1)}, {int64(-1), int64(1)}, {int64(0), int64(3)}},
},
}

for _, tc := range testCases {
rows, err := s.conn.Query(tc.query, tc.args...)
if !s.NoError(err) {
Expand All @@ -56,6 +57,7 @@ func (s *connSuite) TestQuery() {
s.Equal(tc.expected, v)
}
}
s.NoError(rows.Close())
}
}

Expand Down Expand Up @@ -102,6 +104,7 @@ func (s *connSuite) TestExec() {
if s.NoError(err) {
s.Equal([][]interface{}{tc.args}, v)
}
s.NoError(rows.Close())
}
}

Expand Down
1 change: 0 additions & 1 deletion errors.go
Expand Up @@ -14,7 +14,6 @@ var (
ErrMalformed = errors.New("clickhouse: response is malformed")
ErrNoLastInsertID = errors.New("no LastInsertId available")
ErrNoRowsAffected = errors.New("no RowsAffected available")
ErrNoNil = errors.New("nil value is not supported")
)

var errorRe = regexp.MustCompile(`Code: (\d+),.+DB::Exception: (.+),.*`)
Expand Down
60 changes: 31 additions & 29 deletions stmt.go
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql/driver"
"regexp"
"strings"
"sync/atomic"
)

var (
Expand All @@ -14,6 +15,7 @@ var (

type stmt struct {
c *conn
closed int32
prefix string
pattern string
index []int
Expand Down Expand Up @@ -48,10 +50,8 @@ func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {

// Close closes the statement.
func (s *stmt) Close() error {
if s.c != nil {
// make close idempotent
s.c.closeStmt(s)
s.clean()
if atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
s.c = nil
}
return nil
}
Expand All @@ -66,6 +66,7 @@ func (s *stmt) query(ctx context.Context, args []driver.Value) (driver.Rows, err
if err != nil {
return nil, err
}
// sql.Stmt already checks that statements is not closed
return s.c.query(ctx, s.prefix+q, nil)
}

Expand All @@ -78,36 +79,37 @@ func (s *stmt) exec(ctx context.Context, args []driver.Value) (driver.Result, er
if err != nil {
return nil, err
}
// sql.Stmt already checks that statements is not closed
return s.c.exec(ctx, s.prefix+q, nil)
}

func (s *stmt) commit(ctx context.Context) error {
if s.c == nil {
// statement has been closed
return nil
}
if len(s.args) == 0 {
return nil
}
buf := bytes.NewBufferString(s.prefix)
var (
p string
err error
)
for i, args := range s.args {
if i > 0 {
buf.WriteString(", ")
if atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
// statement is not usable after commit
// this code will not run if statement has been closed
args := s.args
con := s.c
s.args = nil
s.c = nil
if len(args) == 0 {
return nil
}
if p, err = interpolateParams(s.pattern, args); err != nil {
return err
buf := bytes.NewBufferString(s.prefix)
var (
p string
err error
)
for i, arg := range args {
if i > 0 {
buf.WriteString(", ")
}
if p, err = interpolateParams(s.pattern, arg); err != nil {
return err
}
buf.WriteString(p)
}
buf.WriteString(p)
_, err = con.exec(ctx, buf.String(), nil)
return err
}
_, err = s.c.exec(ctx, buf.String(), nil)
s.args = s.args[0:0]
return err
}

func (s *stmt) clean() {
s.c = nil
return nil
}
8 changes: 8 additions & 0 deletions stmt_test.go
Expand Up @@ -59,6 +59,7 @@ func (s *stmtSuite) TestQuery() {
s.Equal([][]interface{}{expected}, v)
}
}
s.NoError(rows.Close())
}
s.NoError(st.Close())
_, err = st.Query(tc.args[0]...)
Expand Down Expand Up @@ -103,6 +104,7 @@ func (s *stmtSuite) TestExec() {
if s.NoError(err) {
s.Equal([][]interface{}{args}, v)
}
s.NoError(rows.Close())
}
s.NoError(st.Close())
_, err = st.Exec(tc.args[0]...)
Expand All @@ -121,12 +123,14 @@ func (s *stmtSuite) TestExecMulti() {
st.Exec(22)
rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=21")
s.False(rows.Next())
s.NoError(rows.Close())
require.NoError(tx.Commit())
s.NoError(st.Close())
rows, err = s.conn.Query("SELECT i64 FROM data WHERE i64>20")
require.NoError(err)
expected := [][]interface{}{{int64(21)}, {int64(22)}}
v, err := scanValues(rows, expected[0])
s.NoError(rows.Close())
require.NoError(err)
s.Equal(expected, v)
}
Expand All @@ -141,11 +145,13 @@ func (s *stmtSuite) TestExecMultiRollback() {
st.Exec(32)
rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=31")
s.False(rows.Next())
s.NoError(rows.Close())
require.NoError(tx.Rollback())
s.NoError(st.Close())
rows, err = s.conn.Query("SELECT i64 FROM data WHERE i64>30")
require.NoError(err)
s.False(rows.Next())
s.NoError(rows.Close())
}

func (s *stmtSuite) TestExecMultiInterrupt() {
Expand All @@ -160,12 +166,14 @@ func (s *stmtSuite) TestExecMultiInterrupt() {
st.Exec(32)
rows, err := s.conn.Query("SELECT i64 FROM data WHERE i64=31")
s.False(rows.Next())
s.NoError(rows.Close())
require.NoError(st.Close())
require.NoError(tx.Commit())
require.NoError(st2.Close())
rows, err = s.conn.Query("SELECT i64 FROM data WHERE i64>30")
require.NoError(err)
s.False(rows.Next())
s.NoError(rows.Close())
}

func TestStmt(t *testing.T) {
Expand Down

0 comments on commit 45da150

Please sign in to comment.