diff --git a/AUTHORS b/AUTHORS index 6f7041c7a..096105255 100644 --- a/AUTHORS +++ b/AUTHORS @@ -106,6 +106,7 @@ Xuehong Chan Zhenye Xie Zhixin Wen Ziheng Lyu +Brian Hendriks # Organizations @@ -123,3 +124,4 @@ Percona LLC Pivotal Inc. Stripe Inc. Zendesk Inc. +Dolthub Inc. diff --git a/README.md b/README.md index 25de2e5aa..49c71802c 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,11 @@ This has the same effect as an empty DSN string: ``` +If your database name includes a slash, use the [URL encoding](https://en.wikipedia.org/wiki/Percent-encoding) `%2F`: +``` +/dbname%2Fwithslash +``` + Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. #### Password diff --git a/dsn.go b/dsn.go index 4b71aaab0..f18e1c499 100644 --- a/dsn.go +++ b/dsn.go @@ -196,7 +196,8 @@ func (cfg *Config) FormatDSN() string { // /dbname buf.WriteByte('/') - buf.WriteString(cfg.DBName) + dbNameEncoded := url.QueryEscape(cfg.DBName) + buf.WriteString(dbNameEncoded) // [?param1=value1&...¶mN=valueN] hasParam := false @@ -358,7 +359,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) { break } } - cfg.DBName = dsn[i+1 : j] + + dbName := dsn[i+1 : j] + if cfg.DBName, err = url.QueryUnescape(dbName); err != nil { + return nil, fmt.Errorf("invalid dbname '%s': %w", dbName, err) + } break } diff --git a/dsn_test.go b/dsn_test.go index 41a6a29fa..83ea8caa6 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -50,6 +50,9 @@ var testDSNs = []struct { }, { "/dbname", &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/dbname%2Fwithslash", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "@/", &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, @@ -76,17 +79,20 @@ var testDSNs = []struct { func TestDSNParser(t *testing.T) { for i, tst := range testDSNs { - cfg, err := ParseDSN(tst.in) - if err != nil { - t.Error(err.Error()) - } + t.Run(tst.in, func(t *testing.T) { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + return + } - // pointer not static - cfg.TLS = nil + // pointer not static + cfg.TLS = nil - if !reflect.DeepEqual(cfg, tst.out) { - t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) - } + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + }) } } @@ -113,27 +119,29 @@ func TestDSNParserInvalid(t *testing.T) { func TestDSNReformat(t *testing.T) { for i, tst := range testDSNs { - dsn1 := tst.in - cfg1, err := ParseDSN(dsn1) - if err != nil { - t.Error(err.Error()) - continue - } - cfg1.TLS = nil // pointer not static - res1 := fmt.Sprintf("%+v", cfg1) - - dsn2 := cfg1.FormatDSN() - cfg2, err := ParseDSN(dsn2) - if err != nil { - t.Error(err.Error()) - continue - } - cfg2.TLS = nil // pointer not static - res2 := fmt.Sprintf("%+v", cfg2) + t.Run(tst.in, func(t *testing.T) { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + return + } + cfg1.TLS = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) - if res1 != res2 { - t.Errorf("%d. %q does not match %q", i, res2, res1) - } + dsn2 := cfg1.FormatDSN() + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + return + } + cfg2.TLS = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + }) } }