/
mysql.go
138 lines (121 loc) · 3.19 KB
/
mysql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package dsn
import (
"errors"
"net/url"
"strings"
"github.com/ego-component/egorm/manager"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
_ manager.DSNParser = (*MysqlDSNParser)(nil)
)
type MysqlDSNParser struct {
}
func init() {
manager.Register(&MysqlDSNParser{})
}
func (p *MysqlDSNParser) Scheme() string {
return "mysql"
}
func (p *MysqlDSNParser) NamingStrategy() schema.Namer {
return nil
}
func (m *MysqlDSNParser) GetDialector(dsn string) gorm.Dialector {
return mysql.Open(dsn)
}
func (m *MysqlDSNParser) ParseDSN(dsn string) (cfg *manager.DSN, err error) {
// New config with some default values
cfg = new(manager.DSN)
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
parseUsernamePassword(cfg, dsn[:j])
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
if err = parseAddrNet(cfg, dsn[j:i]); err != nil {
return
}
}
// dbname[?param1=value1&...¶mN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.DBName = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
return
}
// username[:password]
func parseUsernamePassword(cfg *manager.DSN, userPassStr string) {
for i := 0; i < len(userPassStr); i++ {
if userPassStr[i] == ':' {
cfg.Password = userPassStr[i+1:]
cfg.User = userPassStr[:i]
break
}
}
}
// [protocol[(address)]]
func parseAddrNet(cfg *manager.DSN, addrNetStr string) error {
for i := 0; i < len(addrNetStr); i++ {
if addrNetStr[i] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if addrNetStr[len(addrNetStr)-1] != ')' {
if strings.ContainsRune(addrNetStr[i+1:], ')') {
return errInvalidDSNUnescaped
}
return errInvalidDSNAddr
}
cfg.Addr = addrNetStr[i+1 : len(addrNetStr)-1]
cfg.Net = addrNetStr[1:i]
break
}
}
return nil
}
// param1=value1&...¶mN=valueN
func parseDSNParams(cfg *manager.DSN, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
value := param[1]
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
return
}