forked from pingcap/tidb-tools
/
variable.go
128 lines (110 loc) · 3.27 KB
/
variable.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package dbutil
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
)
// ShowVersion queries variable 'version' and returns its value.
func ShowVersion(ctx context.Context, db *sql.DB) (value string, err error) {
return ShowMySQLVariable(ctx, db, "version")
}
// ShowLogBin queries variable 'log_bin' and returns its value.
func ShowLogBin(ctx context.Context, db *sql.DB) (value string, err error) {
return ShowMySQLVariable(ctx, db, "log_bin")
}
// ShowBinlogFormat queries variable 'binlog_format' and returns its value.
func ShowBinlogFormat(ctx context.Context, db *sql.DB) (value string, err error) {
return ShowMySQLVariable(ctx, db, "binlog_format")
}
// ShowBinlogRowImage queries variable 'binlog_row_image' and returns its values.
func ShowBinlogRowImage(ctx context.Context, db *sql.DB) (value string, err error) {
return ShowMySQLVariable(ctx, db, "binlog_row_image")
}
// ShowServerID queries variable 'server_id' and returns its value.
func ShowServerID(ctx context.Context, db *sql.DB) (serverID uint64, err error) {
value, err := ShowMySQLVariable(ctx, db, "server_id")
if err != nil {
return 0, errors.Trace(err)
}
serverID, err = strconv.ParseUint(value, 10, 64)
return serverID, errors.Annotatef(err, "parse server_id %s failed", value)
}
// ShowMySQLVariable queries MySQL variable and returns its value.
func ShowMySQLVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) {
query := fmt.Sprintf("SHOW GLOBAL VARIABLES LIKE '%s';", variable)
err = db.QueryRowContext(ctx, query).Scan(&variable, &value)
if err != nil {
return "", errors.Trace(err)
}
return value, nil
}
// ShowGrants queries privileges for a mysql user.
// For mysql 8.0, if user has granted roles, ShowGrants also extract privilege from roles.
func ShowGrants(ctx context.Context, db *sql.DB, user, host string) ([]string, error) {
if host == "" {
host = "%"
}
var query string
if user == "" {
// for currrent user.
query = "SHOW GRANTS"
} else {
query = fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host)
}
readGrantsFunc := func() ([]string, error) {
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, errors.Trace(err)
}
defer rows.Close()
grants := make([]string, 0, 8)
for rows.Next() {
var grant string
err = rows.Scan(&grant)
if err != nil {
return nil, errors.Trace(err)
}
grants = append(grants, grant)
}
if rows.Err() != nil {
return nil, errors.Trace(err)
}
return grants, nil
}
grants, err := readGrantsFunc()
if err != nil {
return nil, errors.Trace(err)
}
// for mysql 8.0, we should collect granted roles
var roles []*auth.RoleIdentity
p := parser.New()
for _, grant := range grants {
node, err := p.ParseOneStmt(grant, "", "")
if err != nil {
return nil, err
}
if grantRoleStmt, ok := node.(*ast.GrantRoleStmt); ok {
roles = append(roles, grantRoleStmt.Roles...)
}
}
if len(roles) == 0 {
return grants, nil
}
var s strings.Builder
s.WriteString(query)
s.WriteString(" USING ")
for i, role := range roles {
if i > 0 {
s.WriteString(", ")
}
s.WriteString(role.String())
}
query = s.String()
return readGrantsFunc()
}