Skip to content

Commit

Permalink
Merge pull request #120 from hanchuanchuan/update-real-rowcount-test
Browse files Browse the repository at this point in the history
update: 优化DML转select逻辑,并完善相应测试用例
  • Loading branch information
hanchuanchuan committed Nov 20, 2019
2 parents 3378920 + cf92e3c commit a4de898
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 26 deletions.
29 changes: 15 additions & 14 deletions session/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ import (

// Rewrite 用于重写SQL
type Rewrite struct {
SQL string
NewSQL string
Stmt sqlparser.Statement
SQL string
Stmt sqlparser.Statement
}

// NewRewrite 返回一个*Rewrite对象,如果SQL无法被正常解析,将错误输出到日志中,返回一个nil
Expand All @@ -42,30 +41,28 @@ func NewRewrite(sql string) (*Rewrite, error) {
}

// Rewrite 入口函数
func (rw *Rewrite) Rewrite() (*Rewrite, error) {
func (rw *Rewrite) Rewrite() error {
return rw.RewriteDML2Select()
}

// RewriteDML2Select dml2select: DML 转成 SELECT,兼容低版本的 EXPLAIN
func (rw *Rewrite) RewriteDML2Select() (*Rewrite, error) {
func (rw *Rewrite) RewriteDML2Select() error {
if rw.Stmt == nil {
return rw, nil
return nil
}

switch stmt := rw.Stmt.(type) {
case *sqlparser.Select:
rw.NewSQL = rw.SQL
return rw, nil
return nil
case *sqlparser.Delete: // Multi DELETE not support yet.
rw.NewSQL = delete2Select(stmt)
rw.SQL = delete2Select(stmt)
case *sqlparser.Insert:
rw.NewSQL = insert2Select(stmt)
rw.SQL = insert2Select(stmt)
case *sqlparser.Update: // Multi UPDATE not support yet.
rw.NewSQL = update2Select(stmt)
rw.SQL = update2Select(stmt)
}
var err error
rw.Stmt, err = sqlparser.Parse(rw.NewSQL)
return rw, err
rw.Stmt, err = sqlparser.Parse(rw.SQL)
return err
}

// delete2Select 将 Delete 语句改写成 Select
Expand Down Expand Up @@ -107,6 +104,10 @@ func insert2Select(stmt *sqlparser.Insert) string {
return "select 1 from DUAL"
}

func (rw *Rewrite) TestSelect2Count() string {
return rw.select2Count()
}

// select2Count : SELECT 转成 COUNT语句
func (rw *Rewrite) select2Count() string {
if rw.Stmt == nil {
Expand Down
17 changes: 5 additions & 12 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -6488,20 +6488,13 @@ func (s *session) explainOrAnalyzeSql(sql string) {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
rw, err = rw.RewriteDML2Select()
err = rw.RewriteDML2Select()
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
stmt, err := NewRewrite(rw.NewSQL)
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
sql = stmt.select2Count()
// log.Info(sql)
s.getRealRowCount(sql, sqlId)
}
sql = rw.select2Count()
s.getRealRowCount(sql, sqlId)
}
}
return
Expand All @@ -6512,12 +6505,12 @@ func (s *session) explainOrAnalyzeSql(sql string) {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
rw, err = rw.RewriteDML2Select()
err = rw.RewriteDML2Select()
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
sql = rw.NewSQL
sql = rw.SQL
if sql == "" {
return
}
Expand Down
95 changes: 95 additions & 0 deletions session/session_inception_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2564,3 +2564,98 @@ func (s *testSessionIncSuite) TestMergeAlterTable(c *C) {
s.testErrorCode(c, sql,
session.NewErr(session.ER_ALTER_TABLE_ONCE, "t1"))
}

func (s *testSessionIncSuite) TestNewRewrite(c *C) {
var (
newSql string
rw *session.Rewrite
)

sqls := []struct {
sql string
selectSql string
countSql string
}{
{
"insert into t2 select * from t1 where id >0;",
"select * from t1 where id > 0",
"select count(*) from t1 where id > 0",
},
{
"insert into t2 select * from t1 where id >0 limit 10;",
"select * from t1 where id > 0 limit 10",
"SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t",
},
{
"insert into t2 select * from t1 where id >0 order by c1 desc limit 10;",
"select * from t1 where id > 0 order by c1 desc limit 10",
"SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t",
},
{
"insert into t2 select distinct id from t1 where id >0 order by c1 desc limit 10;",
"select distinct id from t1 where id > 0 order by c1 desc limit 10",
"SELECT COUNT(1) FROM (select distinct id from t1 where id > 0 order by c1 desc limit 10)t",
},
{
"insert into t2 select c1,count(1) as cnt from t1 where id >0 group by c1 limit 10;",
"select c1, count(1) as cnt from t1 where id > 0 group by c1 limit 10",
"SELECT COUNT(1) FROM (select c1, count(1) as cnt from t1 where id > 0 group by c1 limit 10)t",
},

{
"delete from t1 where id >0;",
"select * from t1 where id > 0",
"select count(*) from t1 where id > 0",
},
{
"delete from t1 where id >0 limit 10;",
"select * from t1 where id > 0 limit 10",
"SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t",
},
{
"delete from t1 where id >0 order by c1 desc limit 10;",
"select * from t1 where id > 0 order by c1 desc limit 10",
"SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t",
},

{
"update t1 set c1=1 where id >0;",
"select * from t1 where id > 0",
"select count(*) from t1 where id > 0",
},
{
"update t1 set c1=1 where id >0 limit 10;",
"select * from t1 where id > 0 limit 10",
"SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t",
},
{
"update t1 set c1=1 where id >0 order by c1 desc limit 10;",
"select * from t1 where id > 0 order by c1 desc limit 10",
"SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t",
},
{
"update t1 inner join t2 on t1.id=t2.id2 set t1.c1=t2.c1 where c11=1;",
"select * from t1 join t2 on t1.id = t2.id2 where c11 = 1",
"select count(*) from t1 join t2 on t1.id = t2.id2 where c11 = 1",
},
{
"update t1,t2 set t1.c1=t2.c1 where t1.id=t2.id2 and c11=1;",
"select * from t1, t2 where t1.id = t2.id2 and c11 = 1",
"select count(*) from t1, t2 where t1.id = t2.id2 and c11 = 1",
},
{
"update t1,t2 set t1.c1=t2.c1 where t1.id=t2.id2 and c11=1 limit 10;",
"select * from t1, t2 where t1.id = t2.id2 and c11 = 1 limit 10",
"SELECT COUNT(1) FROM (select * from t1, t2 where t1.id = t2.id2 and c11 = 1 limit 10)t",
},
}

for _, row := range sqls {
rw, _ = session.NewRewrite(row.sql)
rw.RewriteDML2Select()
c.Assert(rw.SQL, Equals, row.selectSql)

newSql = rw.TestSelect2Count()
c.Assert(newSql, Equals, row.countSql)
}
}

0 comments on commit a4de898

Please sign in to comment.