-
-
Notifications
You must be signed in to change notification settings - Fork 546
/
rewrite.go
146 lines (131 loc) · 3.74 KB
/
rewrite.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*
* Copyright 2018 Xiaomi, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package session
import (
"fmt"
"vitess.io/vitess/go/vt/sqlparser"
)
// Rewrite 用于重写SQL
type Rewrite struct {
SQL string
NewSQL string
Stmt sqlparser.Statement
}
// NewRewrite 返回一个*Rewrite对象,如果SQL无法被正常解析,将错误输出到日志中,返回一个nil
func NewRewrite(sql string) (*Rewrite, error) {
stmt, err := sqlparser.Parse(sql)
if err != nil {
return nil, err
}
return &Rewrite{
SQL: sql,
Stmt: stmt,
}, err
}
// Rewrite 入口函数
func (rw *Rewrite) Rewrite() (*Rewrite, error) {
return rw.RewriteDML2Select()
}
// RewriteDML2Select dml2select: DML 转成 SELECT,兼容低版本的 EXPLAIN
func (rw *Rewrite) RewriteDML2Select() (*Rewrite, error) {
if rw.Stmt == nil {
return rw, nil
}
switch stmt := rw.Stmt.(type) {
case *sqlparser.Select:
rw.NewSQL = rw.SQL
return rw, nil
case *sqlparser.Delete: // Multi DELETE not support yet.
rw.NewSQL = delete2Select(stmt)
case *sqlparser.Insert:
rw.NewSQL = insert2Select(stmt)
case *sqlparser.Update: // Multi UPDATE not support yet.
rw.NewSQL = update2Select(stmt)
}
var err error
rw.Stmt, err = sqlparser.Parse(rw.NewSQL)
return rw, err
}
// delete2Select 将 Delete 语句改写成 Select
func delete2Select(stmt *sqlparser.Delete) string {
newSQL := &sqlparser.Select{
SelectExprs: []sqlparser.SelectExpr{
new(sqlparser.StarExpr),
},
From: stmt.TableExprs,
Where: stmt.Where,
OrderBy: stmt.OrderBy,
}
return sqlparser.String(newSQL)
}
// update2Select 将 Update 语句改写成 Select
func update2Select(stmt *sqlparser.Update) string {
newSQL := &sqlparser.Select{
SelectExprs: []sqlparser.SelectExpr{
new(sqlparser.StarExpr),
},
From: stmt.TableExprs,
Where: stmt.Where,
OrderBy: stmt.OrderBy,
Limit: stmt.Limit,
}
return sqlparser.String(newSQL)
}
// insert2Select 将 Insert 语句改写成 Select
func insert2Select(stmt *sqlparser.Insert) string {
switch row := stmt.Rows.(type) {
// 如果insert包含子查询,只需要explain该子树
case *sqlparser.Select, *sqlparser.Union, *sqlparser.ParenSelect:
return sqlparser.String(row)
}
return "select 1 from DUAL"
}
// select2Count : SELECT 转成 COUNT语句
func (rw *Rewrite) select2Count() string {
if rw.Stmt == nil {
return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
}
// log.Infof("%#v", rw.Stmt)
switch stmt := rw.Stmt.(type) {
case *sqlparser.Select:
if stmt.Distinct != "" || stmt.GroupBy != nil || stmt.Having != nil {
return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
}
newSQL := &sqlparser.Select{
SelectExprs: []sqlparser.SelectExpr{
&sqlparser.AliasedExpr{
Expr: &sqlparser.FuncExpr{
Name: sqlparser.NewColIdent("count"),
Exprs: []sqlparser.SelectExpr{
new(sqlparser.StarExpr),
},
},
},
},
Distinct: stmt.Distinct,
From: stmt.From,
Where: stmt.Where,
GroupBy: stmt.GroupBy,
Having: stmt.Having,
Limit: stmt.Limit,
}
return sqlparser.String(newSQL)
// case *sqlparser.Union, *sqlparser.ParenSelect:
// return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
default:
return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
}
}