Skip to content

Commit

Permalink
Merge pull request #185 from hanchuanchuan/fix-trans-mix-ddl-dml
Browse files Browse the repository at this point in the history
fix: 修复在事务中DDL和DML混合执行时可能出错的问题 (#182)
  • Loading branch information
hanchuanchuan committed Apr 3, 2020
2 parents ec01295 + c165a35 commit 450a18f
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 23 deletions.
80 changes: 79 additions & 1 deletion session/conn.go
Expand Up @@ -87,7 +87,7 @@ func (s *session) Raw(sqlStr string) (rows *sql.Rows, err error) {
return
}

// Raw 执行sql语句,连接失败时自动重连,自动重置当前数据库
// Exec 执行sql语句,连接失败时自动重连,自动重置当前数据库
func (s *session) Exec(sqlStr string, retry bool) (res sql.Result, err error) {
// 连接断开无效时,自动重试
for i := 0; i < maxBadConnRetries; i++ {
Expand All @@ -114,6 +114,33 @@ func (s *session) Exec(sqlStr string, retry bool) (res sql.Result, err error) {
return
}

// ExecDDL 执行sql语句,连接失败时自动重连,自动重置当前数据库
func (s *session) ExecDDL(sqlStr string, retry bool) (res sql.Result, err error) {
// 连接断开无效时,自动重试
for i := 0; i < maxBadConnRetries; i++ {
res, err = s.ddlDB.DB().Exec(sqlStr)
if err == nil {
return
} else {
log.Errorf("con:%d %v sql:%s", s.sessionVars.ConnectionID, err, sqlStr)
if err == mysqlDriver.ErrInvalidConn {
err1 := s.initConnection()
if err1 != nil {
return res, err1
}
if retry {
s.AppendErrorMessage(mysqlDriver.ErrInvalidConn.Error())
continue
} else {
return
}
}
return
}
}
return
}

// Raw 执行sql语句,连接失败时自动重连,自动重置当前数据库
func (s *session) RawScan(sqlStr string, dest interface{}) (err error) {
// 连接断开无效时,自动重试
Expand Down Expand Up @@ -180,3 +207,54 @@ func (s *session) initConnection() (err error) {
}
return
}

// // SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空)
// func (s *session) SwitchDatabase(db *gorm.DB) error {
// name := s.DBName
// if name == "" {
// name = s.opt.db
// }
// if name == "" {
// return nil
// }

// // log.Infof("SwitchDatabase: %v", name)
// _, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name))
// if err != nil {
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
// if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
// s.AppendErrorMessage(myErr.Message)
// } else {
// s.AppendErrorMessage(err.Error())
// }
// }
// return err
// }

// // GetDatabase 获取当前数据库
// func (s *session) GetDatabase() string {
// log.Debug("GetDatabase")

// var value string
// sql := "select database();"

// rows, err := s.Raw(sql)
// if rows != nil {
// defer rows.Close()
// }

// if err != nil {
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
// if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
// s.AppendErrorMessage(myErr.Message)
// } else {
// s.AppendErrorMessage(err.Error())
// }
// } else {
// for rows.Next() {
// rows.Scan(&value)
// }
// }

// return value
// }
3 changes: 3 additions & 0 deletions session/session.go
Expand Up @@ -169,6 +169,9 @@ type session struct {
db *gorm.DB
backupdb *gorm.DB

// 执行DDL操作的数据库连接. 仅用于事务功能
ddlDB *gorm.DB

DBName string

myRecord *Record
Expand Down
44 changes: 35 additions & 9 deletions session/session_inception.go
Expand Up @@ -21,6 +21,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"database/sql"
"database/sql/driver"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -525,6 +526,9 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
if s.db != nil {
defer s.db.Close()
}
if s.ddlDB != nil {
defer s.ddlDB.Close()
}
if s.backupdb != nil {
defer s.backupdb.Close()
}
Expand Down Expand Up @@ -1655,6 +1659,8 @@ func (s *session) executeAllStatement(ctx context.Context) {
trans = make([]*Record, 0, s.opt.tranBatch)
}

// 用于事务. 判断是否为DML语句
// lastIsDMLTrans := false
for i, record := range s.recordSets.All() {

// 忽略不需要备份的类型
Expand Down Expand Up @@ -1684,11 +1690,13 @@ func (s *session) executeAllStatement(ctx context.Context) {
}
}
}

// lastIsDMLTrans = true
case *ast.UseStmt, *ast.SetStmt:
// 环境命令
// 事务内部和非事务均需要执行
// log.Infof("1111: [%s] [%d] %s,RowsAffected: %d", s.DBName, s.fetchThreadID(), record.Sql, record.AffectedRows)
_, err := s.Exec(record.Sql, true)
_, err := s.ExecDDL(record.Sql, true)
if err != nil {
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
Expand Down Expand Up @@ -1716,7 +1724,14 @@ func (s *session) executeAllStatement(ctx context.Context) {
trans = nil
}

s.executeRemoteCommand(record)
// 如果前端是DML语句,则在执行DDL前切换一次数据库
// log.Infof("lastIsDMLTrans: %v", lastIsDMLTrans)
// if lastIsDMLTrans {
// s.SwitchDatabase(s.ddlDB)
// lastIsDMLTrans = false
// }

s.executeRemoteCommand(record, true)

// trans = append(trans, record)
// s.executeTransaction(trans)
Expand All @@ -1731,7 +1746,7 @@ func (s *session) executeAllStatement(ctx context.Context) {
}
}
} else {
s.executeRemoteCommand(record)
s.executeRemoteCommand(record, false)
}

if s.hasErrorBefore() {
Expand Down Expand Up @@ -1946,7 +1961,7 @@ func (s *session) executeTransaction(records []*Record) int {
return 0
}

func (s *session) executeRemoteCommand(record *Record) int {
func (s *session) executeRemoteCommand(record *Record, isTran bool) int {

s.myRecord = record
record.Stage = StageExec
Expand All @@ -1972,7 +1987,7 @@ func (s *session) executeRemoteCommand(record *Record) int {
*ast.SetStmt,
*ast.DropIndexStmt:

s.executeRemoteStatement(record)
s.executeRemoteStatement(record, isTran)

default:
log.Infof("无匹配类型: %T\n", node)
Expand Down Expand Up @@ -2181,10 +2196,10 @@ func statisticsTableSQL() string {
return buf.String()
}

func (s *session) executeRemoteStatement(record *Record) {
func (s *session) executeRemoteStatement(record *Record, isTran bool) {
log.Debug("executeRemoteStatement")

sql := record.Sql
sqlStmt := record.Sql

start := time.Now()

Expand All @@ -2205,7 +2220,13 @@ func (s *session) executeRemoteStatement(record *Record) {

return
} else {
res, err := s.Exec(sql, false)
var res sql.Result
var err error
if isTran {
res, err = s.ExecDDL(sqlStmt, false)
} else {
res, err = s.Exec(sqlStmt, false)
}

record.ExecTime = fmt.Sprintf("%.3f", time.Since(start).Seconds())
record.ExecTimestamp = time.Now().Unix()
Expand Down Expand Up @@ -2295,7 +2316,7 @@ func (s *session) executeRemoteStatementAndBackup(record *Record) {
return
}

s.executeRemoteStatement(record)
s.executeRemoteStatement(record, false)

if !s.hasError() || record.ExecComplete {
if s.opt.backup {
Expand Down Expand Up @@ -2906,6 +2927,11 @@ func (s *session) parseOptions(sql string) {
return
}

if s.opt.tranBatch > 1 {
s.ddlDB, _ = gorm.Open("mysql", fmt.Sprintf("%s&autocommit=1", addr))
s.ddlDB.LogMode(false)
}

// 禁用日志记录器,不显示任何日志
db.LogMode(false)

Expand Down
26 changes: 13 additions & 13 deletions session/session_inception_common_test.go
Expand Up @@ -491,12 +491,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
if tableName == "" {
sql := "select tablename from `%s`.`%s` where opid_time = ?"
sql = fmt.Sprintf(sql, backupDBName, s.remoteBackupTable)
rows, err := s.db.Raw(sql, opid).Rows()
tableRows, err := s.db.Raw(sql, opid).Rows()
c.Assert(err, IsNil)
for rows.Next() {
rows.Scan(&tableName)
for tableRows.Next() {
tableRows.Scan(&tableName)
}
rows.Close()
tableRows.Close()
}
c.Assert(tableName, Not(Equals), "", Commentf("%v", row))

Expand All @@ -507,10 +507,9 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri

// 如果表改变了,或者超过500行了
if lastTable != currentTable || len(ids) >= 500 {
lastTable = currentTable
if len(ids) > 0 {
sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;"
sql = fmt.Sprintf(sql, currentTable)
sql = fmt.Sprintf(sql, lastTable)
rows, err := s.db.Raw(sql, ids).Rows()
c.Assert(err, IsNil)

Expand All @@ -522,11 +521,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
}
rows.Close()

c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v", sql))
c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v,%v", sql, ids))
result = append(result, result1...)

ids = nil
}
lastTable = currentTable
}

ids = append(ids, opid)
Expand All @@ -536,22 +536,22 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
if len(ids) > 0 {
sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;"
sql = fmt.Sprintf(sql, currentTable)
rows, err := s.db.Raw(sql, ids).Rows()
rollbackRows, err := s.db.Raw(sql, ids).Rows()
c.Assert(err, IsNil)

str := ""
result1 := []string{}
for rows.Next() {
rows.Scan(&str)
for rollbackRows.Next() {
rollbackRows.Scan(&str)
result1 = append(result1, s.trim(str))
}
rows.Close()
rollbackRows.Close()

c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", sql))
c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", rows))
result = append(result, result1...)
}

c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", rows))
c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", result))

// 如果是UPDATE多表操作,此时回滚的SQL可能是无序的
if len(result) > 1 && strings.HasPrefix(result[0], "UPDATE") {
Expand Down
53 changes: 53 additions & 0 deletions session/session_inception_tran_test.go
Expand Up @@ -444,3 +444,56 @@ func (s *testSessionIncTranSuite) TestDelete(c *C) {
c.Assert(backup, Equals, "INSERT INTO `test_inc`.`t1`(`id`,`c1`) VALUES(1,'😁😄🙂👩');", Commentf("%v", res.Rows()))

}

func (s *testSessionIncTranSuite) TestCreateTable(c *C) {
saved := config.GetGlobalConfig().Inc
defer func() {
config.GetGlobalConfig().Inc = saved
}()

var (
res *testkit.Result
// row []interface{}
// backup string
)

res = s.mustRunBackupTran(c, `DROP TABLE IF EXISTS t1,t2;
CREATE TABLE t1 (id int(11) NOT NULL,
c1 int(11) DEFAULT NULL,
c2 int(11) DEFAULT NULL,
PRIMARY KEY (id));
INSERT INTO t1 VALUES (1, 1, 1);
CREATE TABLE t2 (id int(11) NOT NULL,
c1 int(11) DEFAULT NULL,
c2 int(11) DEFAULT NULL,
PRIMARY KEY (id))`)
s.assertRows(c, res.Rows()[2:],
"DROP TABLE `test_inc`.`t1`;",
"DELETE FROM `test_inc`.`t1` WHERE `id`=1;",
"DROP TABLE `test_inc`.`t2`;")

res = s.mustRunBackupTran(c, `DROP TABLE IF EXISTS t1,t2;
create table t1(id int primary key,c1 int);
insert into t1 values(1,1),(2,2);
delete from t1 where id=1;
alter table t1 add column c2 int;
insert into t1 values(3,3,3);
delete from t1 where id>0;
create table t2(id int primary key,c1 int);
insert into t2 values(3,3);`)
s.assertRows(c, res.Rows()[2:],
"DROP TABLE `test_inc`.`t1`;",
"DELETE FROM `test_inc`.`t1` WHERE `id`=1;",
"DELETE FROM `test_inc`.`t1` WHERE `id`=2;",
"INSERT INTO `test_inc`.`t1`(`id`,`c1`) VALUES(1,1);",
"ALTER TABLE `test_inc`.`t1` DROP COLUMN `c2`;",
"DELETE FROM `test_inc`.`t1` WHERE `id`=3;",
"INSERT INTO `test_inc`.`t1`(`id`,`c1`,`c2`) VALUES(2,2,NULL);",
"INSERT INTO `test_inc`.`t1`(`id`,`c1`,`c2`) VALUES(3,3,3);",
"DROP TABLE `test_inc`.`t2`;",
"DELETE FROM `test_inc`.`t2` WHERE `id`=3;")

}

0 comments on commit 450a18f

Please sign in to comment.