diff --git a/session/rewrite.go b/session/rewrite.go index d4703f2d..288bec19 100644 --- a/session/rewrite.go +++ b/session/rewrite.go @@ -23,9 +23,8 @@ import ( // Rewrite 用于重写SQL type Rewrite struct { - SQL string - NewSQL string - Stmt sqlparser.Statement + SQL string + Stmt sqlparser.Statement } // NewRewrite 返回一个*Rewrite对象,如果SQL无法被正常解析,将错误输出到日志中,返回一个nil @@ -42,30 +41,28 @@ func NewRewrite(sql string) (*Rewrite, error) { } // Rewrite 入口函数 -func (rw *Rewrite) Rewrite() (*Rewrite, error) { +func (rw *Rewrite) Rewrite() error { return rw.RewriteDML2Select() } // RewriteDML2Select dml2select: DML 转成 SELECT,兼容低版本的 EXPLAIN -func (rw *Rewrite) RewriteDML2Select() (*Rewrite, error) { +func (rw *Rewrite) RewriteDML2Select() error { if rw.Stmt == nil { - return rw, nil + return nil } - switch stmt := rw.Stmt.(type) { case *sqlparser.Select: - rw.NewSQL = rw.SQL - return rw, nil + return nil case *sqlparser.Delete: // Multi DELETE not support yet. - rw.NewSQL = delete2Select(stmt) + rw.SQL = delete2Select(stmt) case *sqlparser.Insert: - rw.NewSQL = insert2Select(stmt) + rw.SQL = insert2Select(stmt) case *sqlparser.Update: // Multi UPDATE not support yet. - rw.NewSQL = update2Select(stmt) + rw.SQL = update2Select(stmt) } var err error - rw.Stmt, err = sqlparser.Parse(rw.NewSQL) - return rw, err + rw.Stmt, err = sqlparser.Parse(rw.SQL) + return err } // delete2Select 将 Delete 语句改写成 Select @@ -107,6 +104,10 @@ func insert2Select(stmt *sqlparser.Insert) string { return "select 1 from DUAL" } +func (rw *Rewrite) TestSelect2Count() string { + return rw.select2Count() +} + // select2Count : SELECT 转成 COUNT语句 func (rw *Rewrite) select2Count() string { if rw.Stmt == nil { diff --git a/session/session_inception.go b/session/session_inception.go index ba87e701..1ea39476 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -6488,20 +6488,13 @@ func (s *session) explainOrAnalyzeSql(sql string) { log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) s.AppendErrorMessage(err.Error()) } else { - rw, err = rw.RewriteDML2Select() + err = rw.RewriteDML2Select() if err != nil { log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) s.AppendErrorMessage(err.Error()) } else { - stmt, err := NewRewrite(rw.NewSQL) - if err != nil { - log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) - s.AppendErrorMessage(err.Error()) - } else { - sql = stmt.select2Count() - // log.Info(sql) - s.getRealRowCount(sql, sqlId) - } + sql = rw.select2Count() + s.getRealRowCount(sql, sqlId) } } return @@ -6512,12 +6505,12 @@ func (s *session) explainOrAnalyzeSql(sql string) { log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) s.AppendErrorMessage(err.Error()) } else { - rw, err = rw.RewriteDML2Select() + err = rw.RewriteDML2Select() if err != nil { log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) s.AppendErrorMessage(err.Error()) } else { - sql = rw.NewSQL + sql = rw.SQL if sql == "" { return } diff --git a/session/session_inception_test.go b/session/session_inception_test.go index f8264171..e8e99a76 100644 --- a/session/session_inception_test.go +++ b/session/session_inception_test.go @@ -2564,3 +2564,98 @@ func (s *testSessionIncSuite) TestMergeAlterTable(c *C) { s.testErrorCode(c, sql, session.NewErr(session.ER_ALTER_TABLE_ONCE, "t1")) } + +func (s *testSessionIncSuite) TestNewRewrite(c *C) { + var ( + newSql string + rw *session.Rewrite + ) + + sqls := []struct { + sql string + selectSql string + countSql string + }{ + { + "insert into t2 select * from t1 where id >0;", + "select * from t1 where id > 0", + "select count(*) from t1 where id > 0", + }, + { + "insert into t2 select * from t1 where id >0 limit 10;", + "select * from t1 where id > 0 limit 10", + "SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t", + }, + { + "insert into t2 select * from t1 where id >0 order by c1 desc limit 10;", + "select * from t1 where id > 0 order by c1 desc limit 10", + "SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t", + }, + { + "insert into t2 select distinct id from t1 where id >0 order by c1 desc limit 10;", + "select distinct id from t1 where id > 0 order by c1 desc limit 10", + "SELECT COUNT(1) FROM (select distinct id from t1 where id > 0 order by c1 desc limit 10)t", + }, + { + "insert into t2 select c1,count(1) as cnt from t1 where id >0 group by c1 limit 10;", + "select c1, count(1) as cnt from t1 where id > 0 group by c1 limit 10", + "SELECT COUNT(1) FROM (select c1, count(1) as cnt from t1 where id > 0 group by c1 limit 10)t", + }, + + { + "delete from t1 where id >0;", + "select * from t1 where id > 0", + "select count(*) from t1 where id > 0", + }, + { + "delete from t1 where id >0 limit 10;", + "select * from t1 where id > 0 limit 10", + "SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t", + }, + { + "delete from t1 where id >0 order by c1 desc limit 10;", + "select * from t1 where id > 0 order by c1 desc limit 10", + "SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t", + }, + + { + "update t1 set c1=1 where id >0;", + "select * from t1 where id > 0", + "select count(*) from t1 where id > 0", + }, + { + "update t1 set c1=1 where id >0 limit 10;", + "select * from t1 where id > 0 limit 10", + "SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t", + }, + { + "update t1 set c1=1 where id >0 order by c1 desc limit 10;", + "select * from t1 where id > 0 order by c1 desc limit 10", + "SELECT COUNT(1) FROM (select * from t1 where id > 0 limit 10)t", + }, + { + "update t1 inner join t2 on t1.id=t2.id2 set t1.c1=t2.c1 where c11=1;", + "select * from t1 join t2 on t1.id = t2.id2 where c11 = 1", + "select count(*) from t1 join t2 on t1.id = t2.id2 where c11 = 1", + }, + { + "update t1,t2 set t1.c1=t2.c1 where t1.id=t2.id2 and c11=1;", + "select * from t1, t2 where t1.id = t2.id2 and c11 = 1", + "select count(*) from t1, t2 where t1.id = t2.id2 and c11 = 1", + }, + { + "update t1,t2 set t1.c1=t2.c1 where t1.id=t2.id2 and c11=1 limit 10;", + "select * from t1, t2 where t1.id = t2.id2 and c11 = 1 limit 10", + "SELECT COUNT(1) FROM (select * from t1, t2 where t1.id = t2.id2 and c11 = 1 limit 10)t", + }, + } + + for _, row := range sqls { + rw, _ = session.NewRewrite(row.sql) + rw.RewriteDML2Select() + c.Assert(rw.SQL, Equals, row.selectSql) + + newSql = rw.TestSelect2Count() + c.Assert(newSql, Equals, row.countSql) + } +}