Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to change (or disable) the default driver name for registration #1499

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmark_test.go
Expand Up @@ -48,7 +48,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {

func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("error on %q: %v", query, err)
Expand Down Expand Up @@ -105,7 +105,7 @@ func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()

Expand Down Expand Up @@ -151,7 +151,7 @@ func BenchmarkRoundtripTxt(b *testing.B) {
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
defer db.Close()
b.StartTimer()
var result string
Expand Down Expand Up @@ -184,7 +184,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
Expand Down
8 changes: 7 additions & 1 deletion driver.go
Expand Up @@ -90,8 +90,14 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return c.Connect(context.Background())
}

// This variable can be replaced with -ldflags like below:
// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"
var driverName = "mysql"

func init() {
sql.Register("mysql", &MySQLDriver{})
if driverName != "" {
sql.Register(driverName, &MySQLDriver{})
}
}

// NewConnector returns new driver.Connector.
Expand Down
28 changes: 19 additions & 9 deletions driver_test.go
Expand Up @@ -31,6 +31,16 @@ import (
"time"
)

// This variable can be replaced with -ldflags like below:
// go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom"
var driverNameTest string

func init() {
if driverNameTest == "" {
driverNameTest = driverName
}
}

// Ensure that all the driver interfaces are implemented
var (
_ driver.Rows = &binaryRows{}
Expand Down Expand Up @@ -111,7 +121,7 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT
dsn += "&multiStatements=true"
var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
db, err = sql.Open(driverNameTest, dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand All @@ -130,7 +140,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
t.Skipf("MySQL server not running on %s", netAddr)
}

db, err := sql.Open("mysql", dsn)
db, err := sql.Open(driverNameTest, dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand All @@ -141,7 +151,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
dsn2 := dsn + "&interpolateParams=true"
var db2 *sql.DB
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
db2, err = sql.Open("mysql", dsn2)
db2, err = sql.Open(driverNameTest, dsn2)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -1917,7 +1927,7 @@ func testDialError(t *testing.T, dialErr error, expectErr error) {
return nil, dialErr
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -1956,7 +1966,7 @@ func TestCustomDial(t *testing.T) {
return d.DialContext(ctx, prot, addr)
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -2054,7 +2064,7 @@ func TestUnixSocketAuthFail(t *testing.T) {
}
t.Logf("socket: %s", socket)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
db, err := sql.Open("mysql", badDSN)
db, err := sql.Open(driverNameTest, badDSN)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -2243,7 +2253,7 @@ func TestEmptyPassword(t *testing.T) {
}

dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname)
db, err := sql.Open("mysql", dsn)
db, err := sql.Open(driverNameTest, dsn)
if err == nil {
defer db.Close()
err = db.Ping()
Expand Down Expand Up @@ -3210,7 +3220,7 @@ func TestConnectorObeysDialTimeouts(t *testing.T) {
return d.DialContext(ctx, prot, addr)
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -3375,7 +3385,7 @@ func TestConnectionAttributes(t *testing.T) {

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
db, err = sql.Open(driverNameTest, dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down