forked from pingcap/tidb-tools
/
executable_checker.go
160 lines (145 loc) · 4.57 KB
/
executable_checker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package checker
import (
"context"
"fmt"
"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/atomic"
)
// ExecutableChecker is a part of TiDB to check the sql's executability
type ExecutableChecker struct {
session session.Session
parser *parser.Parser
isClosed *atomic.Bool
}
// NewExecutableChecker creates a new ExecutableChecker
func NewExecutableChecker() (*ExecutableChecker, error) {
logutil.InitZapLogger(&logutil.LogConfig{
Config: log.Config{
Level: "error",
},
})
mockTikv, err := mockstore.NewMockStore()
if err != nil {
return nil, errors.Trace(err)
}
_, err = session.BootstrapSession(mockTikv)
if err != nil {
return nil, errors.Trace(err)
}
session, err := session.CreateSession4Test(mockTikv)
if err != nil {
return nil, errors.Trace(err)
}
return &ExecutableChecker{
session: session,
parser: parser.New(),
isClosed: atomic.NewBool(false),
}, nil
}
// Execute executes the sql to check it's executability
func (ec *ExecutableChecker) Execute(context context.Context, sql string) error {
_, err := ec.session.Execute(context, sql)
if err != nil {
return errors.Trace(err)
}
return nil
}
// IsTableExist returns whether the table with the specified name exists
func (ec *ExecutableChecker) IsTableExist(context *context.Context, tableName string) bool {
_, err := ec.session.Execute(*context,
fmt.Sprintf("select 0 from `%s` limit 1", tableName))
return err == nil
}
// CreateTable creates a new table with the specified sql
func (ec *ExecutableChecker) CreateTable(context context.Context, sql string) error {
err := ec.Execute(context, sql)
if err != nil {
return errors.Trace(err)
}
return nil
}
// DropTable drops the the specified table
func (ec *ExecutableChecker) DropTable(context context.Context, tableName string) error {
err := ec.Execute(context, fmt.Sprintf("drop table if exists `%s`", tableName))
if err != nil {
return errors.Trace(err)
}
return nil
}
// Close closes the ExecutableChecker
func (ec *ExecutableChecker) Close() error {
if !ec.isClosed.CAS(false, true) {
return errors.New("ExecutableChecker is already closed")
}
ec.session.Close()
return nil
}
// Parse parses a query and returns an ast.StmtNode.
func (ec *ExecutableChecker) Parse(sql string) (stmt ast.StmtNode, err error) {
charset, collation := ec.session.GetSessionVars().GetCharsetInfo()
stmt, err = ec.parser.ParseOneStmt(sql, charset, collation)
return
}
// GetTablesNeededExist reports the table name needed to execute ast.StmtNode
// the specified ast.StmtNode must be a DDLNode
func GetTablesNeededExist(stmt ast.StmtNode) ([]string, error) {
switch x := stmt.(type) {
case *ast.TruncateTableStmt:
return []string{x.Table.Name.String()}, nil
case *ast.CreateIndexStmt:
return []string{x.Table.Name.String()}, nil
case *ast.DropTableStmt:
tablesName := make([]string, len(x.Tables))
for i, table := range x.Tables {
tablesName[i] = table.Name.String()
}
return tablesName, nil
case *ast.DropIndexStmt:
return []string{x.Table.Name.String()}, nil
case *ast.AlterTableStmt:
return []string{x.Table.Name.String()}, nil
case *ast.RenameTableStmt:
return []string{x.TableToTables[0].OldTable.Name.String()}, nil
case ast.DDLNode:
return []string{}, nil
default:
return nil, errors.New("stmt is not a DDLNode")
}
}
// GetTablesNeededNonExist reports the table name that conflicts with ast.StmtNode
// the specified ast.StmtNode must be a DDLNode
func GetTablesNeededNonExist(stmt ast.StmtNode) ([]string, error) {
switch x := stmt.(type) {
case *ast.CreateTableStmt:
return []string{x.Table.Name.String()}, nil
case *ast.RenameTableStmt:
return []string{x.TableToTables[0].NewTable.Name.String()}, nil
case ast.DDLNode:
return []string{}, nil
default:
return nil, errors.New("stmt is not a DDLNode")
}
}
// IsDDL reports weather the table DDLNode
func IsDDL(stmt ast.StmtNode) bool {
_, isDDL := stmt.(ast.DDLNode)
return isDDL
}