Skip to content

Commit

Permalink
New DSN parser
Browse files Browse the repository at this point in the history
+ Set right default addr for net=unix

Go 1.2RC1
BenchmarkParseDSN_new     200000             10545 ns/op            4039
B/op      42 allocs/op
BenchmarkParseDSN_old      10000            233313 ns/op            7588
B/op      91 allocs/op

Go 1.1
BenchmarkParseDSN_new     200000              7940 ns/op            4204
B/op      42 allocs/op
BenchmarkParseDSN_old      10000            264115 ns/op            8083
B/op      91 allocs/op
  • Loading branch information
julienschmidt committed Oct 16, 2013
1 parent 0792be4 commit dc02949
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 107 deletions.
232 changes: 141 additions & 91 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,24 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"os"
"regexp"
"strings"
"time"
)

var (
errLog *log.Logger // Error Logger
dsnPattern *regexp.Regexp // Data Source Name Parser
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs

errInvalidDSN = errors.New("Invalid DSN")
)

func init() {
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)

dsnPattern = regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]

tlsConfigRegister = make(map[string]*tls.Config)
}

Expand Down Expand Up @@ -79,96 +73,69 @@ func DeregisterTLSConfig(key string) {

func parseDSN(dsn string) (cfg *config, err error) {
cfg = new(config)
cfg.params = make(map[string]string)

matches := dsnPattern.FindStringSubmatch(dsn)
names := dsnPattern.SubexpNames()

for i, match := range matches {
switch names[i] {
case "user":
cfg.user = match
case "passwd":
cfg.passwd = match
case "net":
cfg.net = match
case "addr":
cfg.addr = match
case "dbname":
cfg.dbname = match
case "params":
for _, v := range strings.Split(match, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}

// cfg params
switch value := param[1]; param[0] {

// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}
// TODO: use strings.IndexByte when we can depend on Go 1.2

// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/'
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
var j int

// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
var k int
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
break
}
}
cfg.user = dsn[:k]

// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an adress is specified
if dsn[i-1] != ')' {
return nil, errInvalidDSN
}
cfg.addr = dsn[k+1 : i-1]
break
}
}
cfg.net = dsn[j+1 : k]

// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}

// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
break
}
}

// Time Location
case "loc":
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
}
// non-empty left part must contain an '@'
if j < 0 {
return nil, errInvalidDSN
}
}

// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}

// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
return
}
}

default:
cfg.params[param[0]] = value
break
}
}
cfg.dbname = dsn[i+1 : j]

break
}
}

Expand All @@ -179,7 +146,15 @@ func parseDSN(dsn string) (cfg *config, err error) {

// Set default adress if empty
if cfg.addr == "" {
cfg.addr = "127.0.0.1:3306"
switch cfg.net {
case "tcp":
cfg.addr = "127.0.0.1:3306"
case "unix":
cfg.addr = "/tmp/mysql.sock"
default:
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
}

}

// Set default location if not set
Expand All @@ -190,6 +165,81 @@ func parseDSN(dsn string) (cfg *config, err error) {
return
}

func parseDSNParams(cfg *config, params string) (err error) {
cfg.params = make(map[string]string)

for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}

// cfg params
switch value := param[1]; param[0] {

// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Time Location
case "loc":
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
}

// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
return
}

// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
return fmt.Errorf("Invalid value / unknown config name: %s", value)
}
}

default:
cfg.params[param[0]] = value
}
}

return
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
Expand Down
60 changes: 44 additions & 16 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,26 @@ import (
"time"
)

func TestDSNParser(t *testing.T) {
var testDSNs = []struct {
in string
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
}
var testDSNs = []struct {
in string
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"@unix/", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
}

func TestDSNParser(t *testing.T) {
var cfg *config
var err error
var res string
Expand All @@ -51,6 +54,31 @@ func TestDSNParser(t *testing.T) {
}
}

func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"asdf/dbname",
//"/dbname?arg=/some/unescaped/path",
}

for i, tst := range invalidDSNs {
if _, err := parseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}

func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()

for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := parseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}

func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
Expand Down

0 comments on commit dc02949

Please sign in to comment.