From e5b8b965fc59a1e0ffaeaf14d85a89bf9f1a7aae Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:00:02 +0200 Subject: [PATCH 01/18] Fix accidental octal values --- driver_test.go | 14 +++++++------- utils_test.go | 18 +++++++++--------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/driver_test.go b/driver_test.go index 34b476ed3..6c26c3a28 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1557,7 +1557,7 @@ func TestTimezoneConversion(t *testing.T) { // Insert local time into database (should be converted) usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + reftime := time.Date(2014, 5, 30, 18, 3, 17, 0, time.UTC).In(usCentral) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -2758,12 +2758,12 @@ func TestRowsColumnTypes(t *testing.T) { nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} - nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} - nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} - nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} - nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} - nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} - nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + nt0 := NullTime{Time: time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), Valid: true} + nt1 := NullTime{Time: time.Date(2006, 1, 2, 15, 4, 5, 100000000, time.UTC), Valid: true} + nt2 := NullTime{Time: time.Date(2006, 1, 2, 15, 4, 5, 110000000, time.UTC), Valid: true} + nt6 := NullTime{Time: time.Date(2006, 1, 2, 15, 4, 5, 111111000, time.UTC), Valid: true} + nd1 := NullTime{Time: time.Date(2006, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := NullTime{Time: time.Date(2006, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true} ndNULL := NullTime{Time: time.Time{}, Valid: false} rbNULL := sql.RawBytes(nil) rb0 := sql.RawBytes("0") diff --git a/utils_test.go b/utils_test.go index e3619e7a7..0f634e750 100644 --- a/utils_test.go +++ b/utils_test.go @@ -299,39 +299,39 @@ func TestAppendDateTime(t *testing.T) { str string }{ { - t: time.Date(2020, 05, 30, 0, 0, 0, 0, time.UTC), + t: time.Date(2020, 5, 30, 0, 0, 0, 0, time.UTC), str: "2020-05-30", }, { - t: time.Date(2020, 05, 30, 22, 0, 0, 0, time.UTC), + t: time.Date(2020, 5, 30, 22, 0, 0, 0, time.UTC), str: "2020-05-30 22:00:00", }, { - t: time.Date(2020, 05, 30, 22, 33, 0, 0, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 0, 0, time.UTC), str: "2020-05-30 22:33:00", }, { - t: time.Date(2020, 05, 30, 22, 33, 44, 0, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 44, 0, time.UTC), str: "2020-05-30 22:33:44", }, { - t: time.Date(2020, 05, 30, 22, 33, 44, 550000000, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 44, 550000000, time.UTC), str: "2020-05-30 22:33:44.550000", }, { - t: time.Date(2020, 05, 30, 22, 33, 44, 550000499, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 44, 550000499, time.UTC), str: "2020-05-30 22:33:44.550000", }, { - t: time.Date(2020, 05, 30, 22, 33, 44, 550000500, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 44, 550000500, time.UTC), str: "2020-05-30 22:33:44.550001", }, { - t: time.Date(2020, 05, 30, 22, 33, 44, 550000567, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 44, 550000567, time.UTC), str: "2020-05-30 22:33:44.550001", }, { - t: time.Date(2020, 05, 30, 22, 33, 44, 999999567, time.UTC), + t: time.Date(2020, 5, 30, 22, 33, 44, 999999567, time.UTC), str: "2020-05-30 22:33:45", }, } From 325543c3043617e9df2a5384a37deab71cb27498 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:05:54 +0200 Subject: [PATCH 02/18] Fix for-range scope issues --- benchmark_test.go | 13 ++++++++----- dsn_test.go | 1 + utils_test.go | 3 +++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 3e25a3bf2..d9f5c3d36 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -276,6 +276,7 @@ func BenchmarkQueryContext(b *testing.B) { ) defer db.Close() for _, p := range []int{1, 2, 3, 4} { + p := p b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { benchmarkQueryContext(b, db, p) }) @@ -312,6 +313,7 @@ func BenchmarkExecContext(b *testing.B) { ) defer db.Close() for _, p := range []int{1, 2, 3, 4} { + p := p b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { benchmarkQueryContext(b, db, p) }) @@ -339,15 +341,16 @@ func BenchmarkQueryRawBytes(b *testing.B) { } } - for _, s := range sizes { - b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + for _, size := range sizes { + size := size + b.Run(fmt.Sprintf("size=%v", size), func(b *testing.B) { db.SetMaxIdleConns(0) db.SetMaxIdleConns(1) b.ReportAllocs() b.ResetTimer() for j := 0; j < b.N; j++ { - rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", size) if err != nil { b.Fatal(err) } @@ -358,8 +361,8 @@ func BenchmarkQueryRawBytes(b *testing.B) { if err != nil { b.Fatal(err) } - if len(buf) != s { - b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + if len(buf) != size { + b.Fatalf("size mismatch: expected %v, got %v", size, len(buf)) } nrows++ } diff --git a/dsn_test.go b/dsn_test.go index 89815b341..79994ff71 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -375,6 +375,7 @@ func TestNormalizeTLSConfig(t *testing.T) { defer func() { DeregisterTLSConfig("test_tls_config") }() for _, tc := range tt { + tc := tc t.Run(tc.tlsConfig, func(t *testing.T) { cfg := &Config{ Addr: "myserver:3306", diff --git a/utils_test.go b/utils_test.go index 0f634e750..159bb01a7 100644 --- a/utils_test.go +++ b/utils_test.go @@ -416,7 +416,9 @@ func TestParseDateTime(t *testing.T) { time.UTC, time.FixedZone("test", 8*60*60), } { + loc := loc for _, cc := range cases { + cc := cc t.Run(cc.name+"-"+loc.String(), func(t *testing.T) { var want time.Time if cc.str != sDate0 && cc.str != sDateTime0 { @@ -493,6 +495,7 @@ func TestParseDateTimeFail(t *testing.T) { } for _, cc := range cases { + cc := cc t.Run(cc.name, func(t *testing.T) { got, err := parseDateTime([]byte(cc.str), time.UTC) if err == nil { From 1798d05cae02571358acd8a269af5c37b40311f1 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:08:50 +0200 Subject: [PATCH 03/18] Some simplifications --- auth.go | 8 ++--- buffer.go | 2 +- connection.go | 5 +-- dsn.go | 90 ++++++++++++++++++++++++----------------------- dsn_test.go | 2 +- packets.go | 21 ++++------- statement_test.go | 2 +- utils.go | 8 ++--- 8 files changed, 64 insertions(+), 74 deletions(-) diff --git a/auth.go b/auth.go index fec7040d4..aa383d907 100644 --- a/auth.go +++ b/auth.go @@ -136,7 +136,7 @@ func pwHash(password []byte) (result [2]uint32) { // Hash password using insecure pre 4.1 method func scrambleOldPassword(scramble []byte, password string) []byte { - if len(password) == 0 { + if password == "" { return nil } @@ -162,7 +162,7 @@ func scrambleOldPassword(scramble []byte, password string) []byte { // Hash password using 4.1+ method (SHA1) func scramblePassword(scramble []byte, password string) []byte { - if len(password) == 0 { + if password == "" { return nil } @@ -192,7 +192,7 @@ func scramblePassword(scramble []byte, password string) []byte { // Hash password using MySQL 8+ method (SHA256) func scrambleSHA256Password(scramble []byte, password string) []byte { - if len(password) == 0 { + if password == "" { return nil } @@ -271,7 +271,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { return authResp, nil case "sha256_password": - if len(mc.cfg.Passwd) == 0 { + if mc.cfg.Passwd == "" { return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { diff --git a/buffer.go b/buffer.go index 0774c5c8c..09c9afa61 100644 --- a/buffer.go +++ b/buffer.go @@ -47,7 +47,7 @@ func newBuffer(nc net.Conn) buffer { // this is a delayed flip that simply increases the buffer counter; // the actual flip will be performed the next time we call `buffer.fill` func (b *buffer) flip() { - b.flipcnt += 1 + b.flipcnt++ } // fill reads into the buffer until at least _need_ bytes are in it diff --git a/connection.go b/connection.go index 90aec6439..1a04b832b 100644 --- a/connection.go +++ b/connection.go @@ -80,10 +80,7 @@ func (mc *mysqlConn) handleParams() (err error) { } if cmdSet.Len() > 0 { - err = mc.exec(cmdSet.String()) - if err != nil { - return - } + return mc.exec(cmdSet.String()) } return diff --git a/dsn.go b/dsn.go index 93f3548cb..a3c377684 100644 --- a/dsn.go +++ b/dsn.go @@ -296,62 +296,64 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // Find the last '/' (since the password or the net addr might contain a '/') foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { - if dsn[i] == '/' { - foundSlash = true - var j, k int - - // left part is empty if i <= 0 - if i > 0 { - // [username[:password]@][protocol[(address)]] - // Find the last '@' in dsn[:i] - for j = i; j >= 0; j-- { - if dsn[j] == '@' { - // username[:password] - // Find the first ':' in dsn[:j] - for k = 0; k < j; k++ { - if dsn[k] == ':' { - cfg.Passwd = dsn[k+1 : j] - break - } - } - cfg.User = dsn[:k] - - break - } - } + if dsn[i] != '/' { + continue + } - // [protocol[(address)]] - // Find the first '(' in dsn[j+1:i] - for k = j + 1; k < i; k++ { - if dsn[k] == '(' { - // dsn[i-1] must be == ')' if an address is specified - if dsn[i-1] != ')' { - if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped - } - return nil, errInvalidDSNAddr + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break } - cfg.Addr = dsn[k+1 : i-1] - break } + cfg.User = dsn[:k] + + break } - cfg.Net = dsn[j+1 : k] } - // dbname[?param1=value1&...¶mN=valueN] - // Find the first '?' in dsn[i+1:] - for j = i + 1; j < len(dsn); j++ { - if dsn[j] == '?' { - if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { - return + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr } + cfg.Addr = dsn[k+1 : i-1] break } } - cfg.DBName = dsn[i+1 : j] + cfg.Net = dsn[j+1 : k] + } - break + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } } + cfg.DBName = dsn[i+1 : j] + + break } if !foundSlash && len(dsn) > 0 { diff --git a/dsn_test.go b/dsn_test.go index 79994ff71..fb2be3318 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -372,7 +372,7 @@ func TestNormalizeTLSConfig(t *testing.T) { } RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"}) - defer func() { DeregisterTLSConfig("test_tls_config") }() + defer DeregisterTLSConfig("test_tls_config") for _, tc := range tt { tc := tc diff --git a/packets.go b/packets.go index 6664e5ae5..11aa8b0e0 100644 --- a/packets.go +++ b/packets.go @@ -1031,14 +1031,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - uint64(v), - ) + binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], v) } else { - paramValues = append(paramValues, - uint64ToBytes(uint64(v))..., - ) + paramValues = append(paramValues, uint64ToBytes(v)...) } case float64: @@ -1078,10 +1073,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { uint64(len(v)), ) paramValues = append(paramValues, v...) - } else { - if err := stmt.writeCommandLongData(i, v); err != nil { - return err - } + } else if err = stmt.writeCommandLongData(i, v); err != nil { + return err } continue } @@ -1100,10 +1093,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { uint64(len(v)), ) paramValues = append(paramValues, v...) - } else { - if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { - return err - } + } else if err = stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err } case time.Time: diff --git a/statement_test.go b/statement_test.go index ac6b92de9..2563ece55 100644 --- a/statement_test.go +++ b/statement_test.go @@ -36,7 +36,7 @@ func TestConvertDerivedByteSlice(t *testing.T) { t.Fatal("Byte slice not convertible", err) } - if bytes.Compare(output.([]byte), []byte("value")) != 0 { + if !bytes.Equal(output.([]byte), []byte("value")) { t.Fatalf("Byte slice not converted, got %#v %T", output, output) } } diff --git a/utils.go b/utils.go index b0c6e9ca3..5274dfd5c 100644 --- a/utils.go +++ b/utils.go @@ -56,7 +56,7 @@ var ( // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { - if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { + if _, isBool := readBool(key); isBool || strings.EqualFold(key, "skip-verify") || strings.EqualFold(key, "preferred") { return fmt.Errorf("key '%s' is reserved", key) } @@ -90,7 +90,7 @@ func getTLSConfigClone(key string) (config *tls.Config) { // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value -func readBool(input string) (value bool, valid bool) { +func readBool(input string) (value, valid bool) { switch input { case "1", "true", "TRUE", "True": return true, true @@ -199,7 +199,7 @@ func parseByteYear(b []byte) (int, error) { return 0, err } year += v * n - n = n / 10 + n /= 10 } return year, nil } @@ -414,7 +414,7 @@ func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { // start with the date year := binary.LittleEndian.Uint16(src[:2]) pt := year / 100 - p1 = byte(year - 100*uint16(pt)) + p1 = byte(year - 100*pt) p2, p3 = src[2], src[3] dst = append(dst, digits10[pt], digits01[pt], From f52e7355bd2e64bcf694d345a8df8baab0550217 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:20:11 +0200 Subject: [PATCH 04/18] Add missing err checks --- auth.go | 32 +++++++---- benchmark_test.go | 12 +++- driver_test.go | 142 ++++++++++++++++++++++++++++++++++++++++++++-- infile.go | 11 ++-- 4 files changed, 175 insertions(+), 22 deletions(-) diff --git a/auth.go b/auth.go index aa383d907..85b621654 100644 --- a/auth.go +++ b/auth.go @@ -15,6 +15,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "errors" "sync" ) @@ -168,19 +169,19 @@ func scramblePassword(scramble []byte, password string) []byte { // stage1Hash = SHA1(password) crypt := sha1.New() - crypt.Write([]byte(password)) + _, _ = crypt.Write([]byte(password)) stage1 := crypt.Sum(nil) // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) // inner Hash crypt.Reset() - crypt.Write(stage1) + _, _ = crypt.Write(stage1) hash := crypt.Sum(nil) // outer Hash crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) + _, _ = crypt.Write(scramble) + _, _ = crypt.Write(hash) scramble = crypt.Sum(nil) // token = scrambleHash XOR stage1Hash @@ -199,16 +200,16 @@ func scrambleSHA256Password(scramble []byte, password string) []byte { // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) crypt := sha256.New() - crypt.Write([]byte(password)) + _, _ = crypt.Write([]byte(password)) message1 := crypt.Sum(nil) crypt.Reset() - crypt.Write(message1) + _, _ = crypt.Write(message1) message1Hash := crypt.Sum(nil) crypt.Reset() - crypt.Write(message1Hash) - crypt.Write(scramble) + _, _ = crypt.Write(message1Hash) + _, _ = crypt.Write(scramble) message2 := crypt.Sum(nil) for i := range message1 { @@ -365,7 +366,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return err } data[4] = cachingSha2PasswordRequestPublicKey - mc.writePacket(data) + + err = mc.writePacket(data) + if err != nil { + return err + } // parse public key if data, err = mc.readPacket(); err != nil { @@ -373,11 +378,16 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } block, _ := pem.Decode(data[1:]) - pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + var pkix interface{} + pkix, err = x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err } - pubKey = pkix.(*rsa.PublicKey) + var ok bool + pubKey, ok = pkix.(*rsa.PublicKey) + if !ok { + return errors.New("invalid public key") + } } // send encrypted password diff --git a/benchmark_test.go b/benchmark_test.go index d9f5c3d36..34cdb6f09 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -41,6 +41,10 @@ func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows { return rows } +func (tb *TB) checkRowsErr(rows *sql.Rows) { + tb.check(rows.Err()) +} + func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { tb.check(err) return stmt @@ -165,6 +169,7 @@ func BenchmarkRoundtripTxt(b *testing.B) { rows.Close() b.Fatalf("crashed") } + tb.checkRowsErr(rows) err := rows.Scan(&result) if err != nil { rows.Close() @@ -200,6 +205,7 @@ func BenchmarkRoundtripBin(b *testing.B) { rows.Close() b.Fatalf("crashed") } + tb.checkRowsErr(rows) err := rows.Scan(&result) if err != nil { rows.Close() @@ -357,8 +363,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { nrows := 0 for rows.Next() { var buf sql.RawBytes - err := rows.Scan(&buf) - if err != nil { + if err = rows.Scan(&buf); err != nil { b.Fatal(err) } if len(buf) != size { @@ -366,6 +371,9 @@ func BenchmarkQueryRawBytes(b *testing.B) { } nrows++ } + if err = rows.Err(); rows != nil { + b.Fatal(err) + } rows.Close() if nrows != 100 { b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) diff --git a/driver_test.go b/driver_test.go index 6c26c3a28..32f563127 100644 --- a/driver_test.go +++ b/driver_test.go @@ -218,6 +218,9 @@ func TestEmptyQuery(t *testing.T) { if rows.Next() { dbt.Errorf("next on rows must be false") } + if err := rows.Err(); err != nil { + dbt.Error(err) + } }) } @@ -232,6 +235,9 @@ func TestCRUD(t *testing.T) { if rows.Next() { dbt.Error("unexpected data in empty table") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() // Create Data @@ -266,6 +272,9 @@ func TestCRUD(t *testing.T) { } else { dbt.Error("no data") } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() // Update @@ -292,6 +301,9 @@ func TestCRUD(t *testing.T) { } else { dbt.Error("no data") } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() // Delete @@ -356,6 +368,9 @@ func TestMultiQuery(t *testing.T) { } else { dbt.Error("no data") } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() }) @@ -383,6 +398,9 @@ func TestInt(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") @@ -403,6 +421,9 @@ func TestInt(t *testing.T) { } else { dbt.Errorf("%s ZEROFILL: no data", v) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") @@ -428,6 +449,9 @@ func TestFloat32(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } @@ -452,6 +476,9 @@ func TestFloat64(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } @@ -476,6 +503,9 @@ func TestFloat64Placeholder(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } @@ -503,6 +533,9 @@ func TestString(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") @@ -557,6 +590,9 @@ func TestRawBytes(t *testing.T) { } else { dbt.Errorf("no data") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } }) } @@ -580,6 +616,9 @@ func TestRawMessage(t *testing.T) { } else { dbt.Errorf("no data") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } }) } @@ -608,6 +647,9 @@ func TestValuer(t *testing.T) { } else { dbt.Errorf("Valuer: no data") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") @@ -646,6 +688,9 @@ func TestValuerWithValidation(t *testing.T) { } else { dbt.Errorf("Valuer: no data") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { dbt.Errorf("Failed to check valuer error") @@ -850,11 +895,17 @@ func TestDateTime(t *testing.T) { rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) if err == nil { rows.Scan(µsecsSupported) + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() } rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) if err == nil { rows.Scan(&zeroDateSupported) + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() } for _, setups := range testcases { @@ -901,6 +952,9 @@ func TestTimestampMicros(t *testing.T) { microsecsSupported := false if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { rows.Scan(µsecsSupported) + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() } if !microsecsSupported { @@ -925,6 +979,9 @@ func TestTimestampMicros(t *testing.T) { if !rows.Next() { dbt.Errorf("test contained no selectable values") } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } err = rows.Scan(&res0, &res1, &res6) if err != nil { dbt.Error(err) @@ -1089,6 +1146,9 @@ func TestNULL(t *testing.T) { } else { dbt.Error("no data") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } }) } @@ -1172,6 +1232,9 @@ func TestLongData(t *testing.T) { } else { dbt.Fatalf("LONGBLOB: no data") } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } // Empty table dbt.mustExec("TRUNCATE TABLE test") @@ -1369,6 +1432,9 @@ func TestTLS(t *testing.T) { dbt.Logf("Cipher: %s", *value) } } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } } tlsTestOpt := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { @@ -1508,6 +1574,9 @@ func TestCollation(t *testing.T) { func TestColumnsWithAlias(t *testing.T) { runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { rows := dbt.mustQuery("SELECT 1 AS A") + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } defer rows.Close() cols, _ := rows.Columns() if len(cols) != 1 { @@ -1518,6 +1587,9 @@ func TestColumnsWithAlias(t *testing.T) { } rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } defer rows.Close() cols, _ = rows.Columns() if len(cols) != 1 { @@ -1539,6 +1611,9 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { if !rows.Next() { dbt.Error("expected result, got none") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } var result sql.RawBytes rows.Scan(&result) if expected != string(result) { @@ -1566,6 +1641,9 @@ func TestTimezoneConversion(t *testing.T) { if !rows.Next() { dbt.Fatal("did not get any rows out") } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } var dbTime time.Time err := rows.Scan(&dbTime) @@ -1786,6 +1864,9 @@ func TestPreparedManyCols(t *testing.T) { if err != nil { dbt.Fatal(err) } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() // Create 0byte string which we can't send via STMT_LONG_DATA. @@ -1796,6 +1877,9 @@ func TestPreparedManyCols(t *testing.T) { if err != nil { dbt.Fatal(err) } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } rows.Close() }) } @@ -2061,6 +2145,9 @@ func TestInterruptBySignal(t *testing.T) { dbt.Errorf("expected val to be 42") } } + if err = rows.Err(); rows != nil { + dbt.Fatal(err) + } rows.Close() // binary protocol @@ -2075,6 +2162,9 @@ func TestInterruptBySignal(t *testing.T) { dbt.Errorf("expected val to be 42") } } + if err = rows.Err(); rows != nil { + dbt.Fatal(err) + } rows.Close() }) } @@ -2246,6 +2336,9 @@ func TestMultiResultSet(t *testing.T) { } res1.values = append(res1.values, res[:]) } + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } cols, err := rows.Columns() if err != nil { @@ -2281,6 +2374,9 @@ func TestMultiResultSet(t *testing.T) { } res2.values = append(res2.values, res[:]) } + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } if !reflect.DeepEqual(expected[1], res2) { dbt.Error(desc, "want =", expected[1], "got =", res2) @@ -2301,6 +2397,9 @@ func TestMultiResultSet(t *testing.T) { DO 1; SELECT 0 UNION SELECT 1; SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`) + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } defer rows.Close() checkRows("query: ", rows, dbt) }) @@ -2346,6 +2445,10 @@ func TestMultiResultSet(t *testing.T) { dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) } checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt) + if err := rows.Err(); err != nil { + dbt.Fatal(err) + } + rows.Close() } } }) @@ -2457,7 +2560,14 @@ func TestContextCancelQuery(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + rows, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if rows != nil { + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } + rows.Close() + } + if err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2477,7 +2587,14 @@ func TestContextCancelQuery(t *testing.T) { } // Context is already canceled, so error should come before execution. - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + rows, err = dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)") + if rows != nil { + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } + rows.Close() + } + if err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } @@ -2528,7 +2645,11 @@ func TestContextCancelPrepare(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { ctx, cancel := context.WithCancel(context.Background()) cancel() - if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + rows, err := dbt.db.PrepareContext(ctx, "SELECT 1") + if rows != nil { + rows.Close() + } + if err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } }) @@ -2583,7 +2704,14 @@ func TestContextCancelStmtQuery(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := stmt.QueryContext(ctx); err != context.Canceled { + rows, err := stmt.QueryContext(ctx) + if rows != nil { + if err = rows.Err(); err != nil { + dbt.Fatal(err) + } + rows.Close() + } + if err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2963,6 +3091,9 @@ func TestRowsColumnTypes(t *testing.T) { if i != 3 { t.Errorf("expected 3 rows, got %d", i) } + if err = rows.Err(); rows != nil { + dbt.Fatal(err) + } if err := rows.Close(); err != nil { t.Errorf("error closing rows: %s", err) @@ -3028,6 +3159,9 @@ func TestRawBytesAreNotModified(t *testing.T) { t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } } + if err = rows.Err(); rows != nil { + dbt.Fatal(err) + } rows.Close() }() } diff --git a/infile.go b/infile.go index 60effdfc2..68f5d0166 100644 --- a/infile.go +++ b/infile.go @@ -172,11 +172,12 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { return ioErr } - // read OK packet - if err == nil { - return mc.readResultOK() + if err != nil { + // we already have an error, ignore return values + _, _ = mc.readPacket() + return err } - mc.readPacket() - return err + // read OK packet + return mc.readResultOK() } From cf8b3e552b06f1b79352805c405bc1a65dd740f6 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:21:54 +0200 Subject: [PATCH 05/18] Add missing return names --- benchmark_test.go | 2 +- packets.go | 2 +- rows.go | 2 +- utils.go | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 34cdb6f09..f49a2544e 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -141,7 +141,7 @@ func BenchmarkExec(b *testing.B) { // data, but no db writes var roundtripSample []byte -func initRoundtripBenchmarks() ([]byte, int, int) { +func initRoundtripBenchmarks() (sample []byte, min, max int) { if roundtripSample == nil { roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024)) } diff --git a/packets.go b/packets.go index 11aa8b0e0..9c6d1787f 100644 --- a/packets.go +++ b/packets.go @@ -486,7 +486,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { * Result Packets * ******************************************************************************/ -func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { +func (mc *mysqlConn) readAuthResult() (authData []byte, plugin string, err error) { data, err := mc.readPacket() if err != nil { return nil, "", err diff --git a/rows.go b/rows.go index 888bdb5f0..48c99e90b 100644 --- a/rows.go +++ b/rows.go @@ -71,7 +71,7 @@ func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { return rows.rs.columns[i].flags&flagNotNULL == 0, true } -func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (precision, scale int64, ok bool) { column := rows.rs.columns[i] decimals := int64(column.decimals) diff --git a/utils.go b/utils.go index 5274dfd5c..dcb095f02 100644 --- a/utils.go +++ b/utils.go @@ -533,10 +533,10 @@ func stringToInt(b []byte) int { return val } -// returns the string read as a bytes slice, wheter the value is NULL, +// returns the string read as a bytes slice, whether the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice -func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { +func readLengthEncodedString(b []byte) (str []byte, isNull bool, n int, err error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { @@ -554,7 +554,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // returns the number of bytes skipped and an error, in case the string is // longer than the input slice -func skipLengthEncodedString(b []byte) (int, error) { +func skipLengthEncodedString(b []byte) (n int, err error) { // Get length num, _, n := readLengthEncodedInteger(b) if num < 1 { @@ -571,7 +571,7 @@ func skipLengthEncodedString(b []byte) (int, error) { } // returns the number read, whether the value is NULL and the number of bytes read -func readLengthEncodedInteger(b []byte) (uint64, bool, int) { +func readLengthEncodedInteger(b []byte) (value uint64, isNull bool, n int) { // See issue #349 if len(b) == 0 { return 0, true, 1 From ad5cc3558c6237f2096d748ba70c4acdbb99f513 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:24:06 +0200 Subject: [PATCH 06/18] Fix bare returns --- auth.go | 2 +- connection.go | 5 ++--- dsn.go | 4 ++-- packets_test.go | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/auth.go b/auth.go index 85b621654..36b9df577 100644 --- a/auth.go +++ b/auth.go @@ -132,7 +132,7 @@ func pwHash(password []byte) (result [2]uint32) { result[0] &= 0x7FFFFFFF result[1] &= 0x7FFFFFFF - return + return result } // Hash password using insecure pre 4.1 method diff --git a/connection.go b/connection.go index 1a04b832b..9ed1fcd7b 100644 --- a/connection.go +++ b/connection.go @@ -61,7 +61,7 @@ func (mc *mysqlConn) handleParams() (err error) { } } if err != nil { - return + return err } // Other system vars accumulated in a single SET command @@ -82,8 +82,7 @@ func (mc *mysqlConn) handleParams() (err error) { if cmdSet.Len() > 0 { return mc.exec(cmdSet.String()) } - - return + return err } func (mc *mysqlConn) markBadConn(err error) error { diff --git a/dsn.go b/dsn.go index a3c377684..64a508786 100644 --- a/dsn.go +++ b/dsn.go @@ -363,7 +363,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) { if err = cfg.normalize(); err != nil { return nil, err } - return + return cfg, nil } // parseDSNParams parses the DSN "query string" @@ -551,7 +551,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { } } - return + return err } func ensureHavePort(addr string) string { diff --git a/packets_test.go b/packets_test.go index b61e4dbf7..a33aa98e1 100644 --- a/packets_test.go +++ b/packets_test.go @@ -50,7 +50,7 @@ func (m *mockConn) Read(b []byte) (n int, err error) { n = copy(b, m.data) m.read += n m.data = m.data[n:] - return + return n, err } func (m *mockConn) Write(b []byte) (n int, err error) { if m.closed { @@ -69,7 +69,7 @@ func (m *mockConn) Write(b []byte) (n int, err error) { m.data = m.queuedReplies[0] m.queuedReplies = m.queuedReplies[1:] } - return + return n, err } func (m *mockConn) Close() error { m.closed = true From 5a1a1ff07c2edacb942bc1a8b81d48545f871b40 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:27:35 +0200 Subject: [PATCH 07/18] Move global test env vars to env struct --- benchmark_test.go | 8 +- conncheck_test.go | 2 +- driver_test.go | 224 +++++++++++++++++++++++----------------------- errors_test.go | 2 +- 4 files changed, 118 insertions(+), 118 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index f49a2544e..319ef6e0b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -52,7 +52,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { func initDB(b *testing.B, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open("mysql", env.dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -109,7 +109,7 @@ func BenchmarkExec(b *testing.B) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open("mysql", env.dsn)) db.SetMaxIdleConns(concurrencyLevel) defer db.Close() @@ -154,7 +154,7 @@ func BenchmarkRoundtripTxt(b *testing.B) { sampleString := string(sample) b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open("mysql", env.dsn)) defer db.Close() b.StartTimer() var result string @@ -188,7 +188,7 @@ func BenchmarkRoundtripBin(b *testing.B) { sample, min, max := initRoundtripBenchmarks() b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open("mysql", env.dsn)) defer db.Close() stmt := tb.checkStmt(db.Prepare("SELECT ?")) defer stmt.Close() diff --git a/conncheck_test.go b/conncheck_test.go index 53995517b..9fbf81140 100644 --- a/conncheck_test.go +++ b/conncheck_test.go @@ -16,7 +16,7 @@ import ( ) func TestStaleConnectionChecks(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("SET @@SESSION.wait_timeout = 2") if err := dbt.db.Ping(); err != nil { diff --git a/driver_test.go b/driver_test.go index 32f563127..1535fb1a1 100644 --- a/driver_test.go +++ b/driver_test.go @@ -37,7 +37,7 @@ var ( _ driver.Rows = &textRows{} ) -var ( +var env struct { user string pass string prot string @@ -46,7 +46,7 @@ var ( dsn string netAddr string available bool -) +} var ( tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) @@ -61,22 +61,22 @@ var ( // See https://github.com/go-sql-driver/mysql/wiki/Testing func init() { // get environment variables - env := func(key, defaultValue string) string { + getEnv := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } - user = env("MYSQL_TEST_USER", "root") - pass = env("MYSQL_TEST_PASS", "") - prot = env("MYSQL_TEST_PROT", "tcp") - addr = env("MYSQL_TEST_ADDR", "localhost:3306") - dbname = env("MYSQL_TEST_DBNAME", "gotest") - netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname) - c, err := net.Dial(prot, addr) + env.user = getEnv("MYSQL_TEST_USER", "root") + env.pass = getEnv("MYSQL_TEST_PASS", "") + env.prot = getEnv("MYSQL_TEST_PROT", "tcp") + env.addr = getEnv("MYSQL_TEST_ADDR", "localhost:3306") + env.dbname = getEnv("MYSQL_TEST_DBNAME", "gotest") + env.netAddr = fmt.Sprintf("%s(%s)", env.prot, env.addr) + env.dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", env.user, env.pass, env.netAddr, env.dbname) + c, err := net.Dial(env.prot, env.addr) if err == nil { - available = true + env.available = true c.Close() } } @@ -104,8 +104,8 @@ func (e netErrorMock) Error() string { } func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } dsn += "&multiStatements=true" @@ -126,8 +126,8 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT } func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } db, err := sql.Open("mysql", dsn) @@ -210,7 +210,7 @@ func maybeSkip(t *testing.T, err error, skipErrno uint16) { } func TestEmptyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { // just a comment, no query rows := dbt.mustQuery("--") defer rows.Close() @@ -225,7 +225,7 @@ func TestEmptyQuery(t *testing.T) { } func TestCRUD(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { // Create Table dbt.mustExec("CREATE TABLE test (value BOOL)") @@ -329,7 +329,7 @@ func TestCRUD(t *testing.T) { } func TestMultiQuery(t *testing.T) { - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + runTestsWithMultiStatement(t, env.dsn, func(dbt *DBTest) { // Create Table dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") @@ -377,7 +377,7 @@ func TestMultiQuery(t *testing.T) { } func TestInt(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} in := int64(42) var out int64 @@ -432,7 +432,7 @@ func TestInt(t *testing.T) { } func TestFloat32(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 @@ -459,7 +459,7 @@ func TestFloat32(t *testing.T) { } func TestFloat64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} var expected float64 = 42.23 var out float64 @@ -486,7 +486,7 @@ func TestFloat64(t *testing.T) { } func TestFloat64Placeholder(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} var expected float64 = 42.23 var out float64 @@ -513,7 +513,7 @@ func TestFloat64Placeholder(t *testing.T) { } func TestString(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" var out string @@ -565,7 +565,7 @@ func TestString(t *testing.T) { } func TestRawBytes(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { v1 := []byte("aaa") v2 := []byte("bbb") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) @@ -597,7 +597,7 @@ func TestRawBytes(t *testing.T) { } func TestRawMessage(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { v1 := json.RawMessage("{}") v2 := json.RawMessage("[]") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) @@ -631,7 +631,7 @@ func (tv testValuer) Value() (driver.Value, error) { } func TestValuer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { in := testValuer{"a_value"} var out string var rows *sql.Rows @@ -669,7 +669,7 @@ func (tv testValuerWithValidation) Value() (driver.Value, error) { } func TestValuerWithValidation(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { in := testValuerWithValidation{"a_value"} var out string var rows *sql.Rows @@ -883,8 +883,8 @@ func TestDateTime(t *testing.T) { }}, } dsns := []string{ - dsn + "&parseTime=true", - dsn + "&parseTime=false", + env.dsn + "&parseTime=true", + env.dsn + "&parseTime=false", } for _, testdsn := range dsns { runTests(t, testdsn, func(dbt *DBTest) { @@ -944,7 +944,7 @@ func TestTimestampMicros(t *testing.T) { f0 := format[:19] f1 := format[:21] f6 := format[:26] - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { // check if microseconds are supported. // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width // and not precision. @@ -999,7 +999,7 @@ func TestTimestampMicros(t *testing.T) { } func TestNULL(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { nullStmt, err := dbt.db.Prepare("SELECT NULL") if err != nil { dbt.Fatal(err) @@ -1163,7 +1163,7 @@ func TestUint64(t *testing.T) { shigh = int64(uhigh) stop = ^shigh ) - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) if err != nil { dbt.Fatal(err) @@ -1196,7 +1196,7 @@ func TestUint64(t *testing.T) { } func TestLongData(t *testing.T) { - runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { + runTests(t, env.dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { var maxAllowedPacketSize int err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) if err != nil { @@ -1262,7 +1262,7 @@ func TestLongData(t *testing.T) { } func TestLoadData(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { verifyLoadDataResult := func() { rows, err := dbt.db.Query("SELECT * FROM test") if err != nil { @@ -1363,7 +1363,7 @@ func TestLoadData(t *testing.T) { } func TestFoundRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") @@ -1384,7 +1384,7 @@ func TestFoundRows(t *testing.T) { dbt.Fatalf("Expected 2 affected rows, got %d", count) } }) - runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + runTests(t, env.dsn+"&clientFoundRows=true", func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") @@ -1442,24 +1442,24 @@ func TestTLS(t *testing.T) { } } - runTests(t, dsn+"&tls=preferred", tlsTestOpt) - runTests(t, dsn+"&tls=skip-verify", tlsTestReq) + runTests(t, env.dsn+"&tls=preferred", tlsTestOpt) + runTests(t, env.dsn+"&tls=skip-verify", tlsTestReq) // Verify that registering / using a custom cfg works RegisterTLSConfig("custom-skip-verify", &tls.Config{ InsecureSkipVerify: true, }) - runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) + runTests(t, env.dsn+"&tls=custom-skip-verify", tlsTestReq) } func TestReuseClosedConnection(t *testing.T) { // this test does not use sql.database, it uses the driver directly - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } md := &MySQLDriver{} - conn, err := md.Open(dsn) + conn, err := md.Open(env.dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1489,12 +1489,12 @@ func TestReuseClosedConnection(t *testing.T) { } func TestCharset(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } mustSetCharset := func(charsetParam, expected string) { - runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + runTests(t, env.dsn+"&"+charsetParam, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT @@character_set_connection") defer rows.Close() @@ -1523,7 +1523,7 @@ func TestCharset(t *testing.T) { } func TestFailingCharset(t *testing.T) { - runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + runTests(t, env.dsn+"&charset=none", func(dbt *DBTest) { // run query to really establish connection... _, err := dbt.db.Exec("SELECT 1") if err == nil { @@ -1534,8 +1534,8 @@ func TestFailingCharset(t *testing.T) { } func TestCollation(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } defaultCollation := "utf8mb4_general_ci" @@ -1551,10 +1551,10 @@ func TestCollation(t *testing.T) { for _, collation := range testCollations { var expected, tdsn string if collation != "" { - tdsn = dsn + "&collation=" + collation + tdsn = env.dsn + "&collation=" + collation expected = collation } else { - tdsn = dsn + tdsn = env.dsn expected = defaultCollation } @@ -1572,7 +1572,7 @@ func TestCollation(t *testing.T) { } func TestColumnsWithAlias(t *testing.T) { - runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + runTests(t, env.dsn+"&columnsWithAlias=true", func(dbt *DBTest) { rows := dbt.mustQuery("SELECT 1 AS A") if err := rows.Err(); err != nil { dbt.Fatal(err) @@ -1602,7 +1602,7 @@ func TestColumnsWithAlias(t *testing.T) { } func TestRawBytesResultExceedsBuffer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { // defaultBufSize from buffer.go expected := strings.Repeat("abc", defaultBufSize) @@ -1660,14 +1660,14 @@ func TestTimezoneConversion(t *testing.T) { } for _, tz := range zones { - runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + runTests(t, env.dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) } } // Special cases func TestRowsClose(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { rows, err := dbt.db.Query("SELECT 1") if err != nil { dbt.Fatal(err) @@ -1692,7 +1692,7 @@ func TestRowsClose(t *testing.T) { // dangling statements // http://code.google.com/p/go/issues/detail?id=3865 func TestCloseStmtBeforeRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { stmt, err := dbt.db.Prepare("SELECT 1") if err != nil { dbt.Fatal(err) @@ -1733,7 +1733,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { // It is valid to have multiple Rows for the same Stmt // http://code.google.com/p/go/issues/detail?id=3734 func TestStmtMultiRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") if err != nil { dbt.Fatal(err) @@ -1849,7 +1849,7 @@ func TestStmtMultiRows(t *testing.T) { // * parameters * 64 > max_allowed_packet (issue 734) func TestPreparedManyCols(t *testing.T) { numParams := 65535 - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { query := "SELECT ?" + strings.Repeat(",?", numParams-1) stmt, err := dbt.db.Prepare(query) if err != nil { @@ -1889,7 +1889,7 @@ func TestConcurrent(t *testing.T) { t.Skip("MYSQL_TEST_CONCURRENT env var not set") } - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { var max int err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) if err != nil { @@ -1953,12 +1953,12 @@ func TestConcurrent(t *testing.T) { }) } -func testDialError(t *testing.T, dialErr error, expectErr error) { +func testDialError(t *testing.T, dialErr, expectErr error) { RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { return nil, dialErr }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", env.user, env.pass, env.addr, env.dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1987,17 +1987,17 @@ func TestDialTemporaryNetErr(t *testing.T) { // Tests custom dial functions func TestCustomDial(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } // our custom dial function which justs wraps net.Dial here RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { var d net.Dialer - return d.DialContext(ctx, prot, addr) + return d.DialContext(ctx, env.prot, env.addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", env.user, env.pass, env.addr, env.dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -2029,8 +2029,8 @@ func TestSQLInjection(t *testing.T) { } dsns := []string{ - dsn, - dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + env.dsn, + env.dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", } for _, testdsn := range dsns { runTests(t, testdsn, createTest("1 OR 1=1")) @@ -2059,8 +2059,8 @@ func TestInsertRetrieveEscapedData(t *testing.T) { } dsns := []string{ - dsn, - dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + env.dsn, + env.dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", } for _, testdsn := range dsns { runTests(t, testdsn, testData) @@ -2068,7 +2068,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) { } func TestUnixSocketAuthFail(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { // Save the current logger so we can restore it. oldLogger := errLog @@ -2082,10 +2082,10 @@ func TestUnixSocketAuthFail(t *testing.T) { // Make a new DSN that uses the MySQL socket file and a bad password, which // we can make by simply appending any character to the real password. - badPass := pass + "x" + badPass := env.pass + "x" socket := "" - if prot == "unix" { - socket = addr + if env.prot == "unix" { + socket = env.addr } else { // Get socket file from MySQL. err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) @@ -2094,7 +2094,7 @@ func TestUnixSocketAuthFail(t *testing.T) { } } t.Logf("socket: %s", socket) - badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", env.user, badPass, socket, env.dbname) db, err := sql.Open("mysql", badDSN) if err != nil { t.Fatalf("error connecting: %s", err.Error()) @@ -2116,7 +2116,7 @@ func TestUnixSocketAuthFail(t *testing.T) { // See Issue #422 func TestInterruptBySignal(t *testing.T) { - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + runTestsWithMultiStatement(t, env.dsn, func(dbt *DBTest) { dbt.mustExec(` DROP PROCEDURE IF EXISTS test_signal; CREATE PROCEDURE test_signal(ret INT) @@ -2203,7 +2203,7 @@ func TestColumnsReusesSlice(t *testing.T) { } func TestRejectReadOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { // Create Table dbt.mustExec("CREATE TABLE test (value BOOL)") // Set the session to read-only. We didn't set the `rejectReadOnly` @@ -2222,7 +2222,7 @@ func TestRejectReadOnly(t *testing.T) { }) // Enable the `rejectReadOnly` option. - runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) { + runTests(t, env.dsn+"&rejectReadOnly=true", func(dbt *DBTest) { // Create Table dbt.mustExec("CREATE TABLE test (value BOOL)") // Set the session to read only. Any writes after this should error on @@ -2236,7 +2236,7 @@ func TestRejectReadOnly(t *testing.T) { } func TestPing(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { dbt.fail("Ping", "Ping", err) } @@ -2245,18 +2245,18 @@ func TestPing(t *testing.T) { // See Issue #799 func TestEmptyPassword(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } - dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) + dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", env.user, "", env.netAddr, env.dbname) db, err := sql.Open("mysql", dsn) if err == nil { defer db.Close() err = db.Ping() } - if pass == "" { + if env.pass == "" { if err != nil { t.Fatal(err.Error()) } @@ -2391,7 +2391,7 @@ func TestMultiResultSet(t *testing.T) { } } - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + runTestsWithMultiStatement(t, env.dsn, func(dbt *DBTest) { rows := dbt.mustQuery(`DO 1; SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; DO 1; @@ -2404,7 +2404,7 @@ func TestMultiResultSet(t *testing.T) { checkRows("query: ", rows, dbt) }) - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + runTestsWithMultiStatement(t, env.dsn, func(dbt *DBTest) { queries := []string{ ` DROP PROCEDURE IF EXISTS test_mrss; @@ -2455,7 +2455,7 @@ func TestMultiResultSet(t *testing.T) { } func TestMultiResultSetNoSelect(t *testing.T) { - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + runTestsWithMultiStatement(t, env.dsn, func(dbt *DBTest) { rows := dbt.mustQuery("DO 1; DO 2;") defer rows.Close() @@ -2476,7 +2476,7 @@ func TestMultiResultSetNoSelect(t *testing.T) { // tests if rows are set in a proper state if some results were ignored before // calling rows.NextResultSet. func TestSkipResults(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT 1, 2") defer rows.Close() @@ -2495,7 +2495,7 @@ func TestSkipResults(t *testing.T) { } func TestPingContext(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { ctx, cancel := context.WithCancel(context.Background()) cancel() if err := dbt.db.PingContext(ctx); err != context.Canceled { @@ -2505,7 +2505,7 @@ func TestPingContext(t *testing.T) { } func TestContextCancelExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) @@ -2551,7 +2551,7 @@ func TestContextCancelExec(t *testing.T) { } func TestContextCancelQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) @@ -2609,7 +2609,7 @@ func TestContextCancelQuery(t *testing.T) { } func TestContextCancelQueryRow(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") ctx, cancel := context.WithCancel(context.Background()) @@ -2642,7 +2642,7 @@ func TestContextCancelQueryRow(t *testing.T) { } func TestContextCancelPrepare(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { ctx, cancel := context.WithCancel(context.Background()) cancel() rows, err := dbt.db.PrepareContext(ctx, "SELECT 1") @@ -2656,7 +2656,7 @@ func TestContextCancelPrepare(t *testing.T) { } func TestContextCancelStmtExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") @@ -2691,7 +2691,7 @@ func TestContextCancelStmtExec(t *testing.T) { } func TestContextCancelStmtQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") @@ -2733,7 +2733,7 @@ func TestContextCancelStmtQuery(t *testing.T) { } func TestContextCancelBegin(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) conn, err := dbt.db.Conn(ctx) @@ -2789,7 +2789,7 @@ func TestContextCancelBegin(t *testing.T) { } func TestContextBeginIsolationLevel(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2841,7 +2841,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { } func TestContextBeginReadOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2973,8 +2973,8 @@ func TestRowsColumnTypes(t *testing.T) { values3 = values3[:len(values3)-2] dsns := []string{ - dsn + "&parseTime=true", - dsn + "&parseTime=false", + env.dsn + "&parseTime=true", + env.dsn + "&parseTime=false", } for _, testdsn := range dsns { runTests(t, testdsn, func(dbt *DBTest) { @@ -3103,7 +3103,7 @@ func TestRowsColumnTypes(t *testing.T) { } func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() @@ -3126,7 +3126,7 @@ func TestRawBytesAreNotModified(t *testing.T) { strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), } - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, env.dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") for i := 0; i < insertRows; i++ { dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) @@ -3173,8 +3173,8 @@ var _ driver.DriverContext = &MySQLDriver{} type dialCtxKey struct{} func TestConnectorObeysDialTimeouts(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { @@ -3182,10 +3182,10 @@ func TestConnectorObeysDialTimeouts(t *testing.T) { if !ctx.Value(dialCtxKey{}).(bool) { return nil, fmt.Errorf("test error: query context is not propagated to our dialer") } - return d.DialContext(ctx, prot, addr) + return d.DialContext(ctx, env.prot, env.addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", env.user, env.pass, env.addr, env.dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -3200,16 +3200,16 @@ func TestConnectorObeysDialTimeouts(t *testing.T) { } func configForTests(t *testing.T) *Config { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) + if !env.available { + t.Skipf("MySQL server not running on %s", env.netAddr) } mycnf := NewConfig() - mycnf.User = user - mycnf.Passwd = pass - mycnf.Addr = addr - mycnf.Net = prot - mycnf.DBName = dbname + mycnf.User = env.user + mycnf.Passwd = env.pass + mycnf.Addr = env.addr + mycnf.Net = env.prot + mycnf.DBName = env.dbname return mycnf } @@ -3252,7 +3252,7 @@ func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { func TestConnectorTimeoutsDuringOpen(t *testing.T) { RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { var d net.Dialer - conn, err := d.DialContext(ctx, prot, addr) + conn, err := d.DialContext(ctx, env.prot, env.addr) if err != nil { return nil, err } diff --git a/errors_test.go b/errors_test.go index 96f9126d6..04f7862bf 100644 --- a/errors_test.go +++ b/errors_test.go @@ -36,7 +36,7 @@ func TestErrorsSetLogger(t *testing.T) { } func TestErrorsStrictIgnoreNotes(t *testing.T) { - runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) { + runTests(t, env.dsn+"&sql_notes=false", func(dbt *DBTest) { dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") }) } From 0ea749dd6c457e68d2dc897d09940d4a9f95668f Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:30:01 +0200 Subject: [PATCH 08/18] Fix shadowing issues --- auth.go | 9 ++++++--- conncheck.go | 8 ++++---- connection.go | 8 ++++---- connector.go | 9 +++++---- driver_test.go | 16 ++++++++-------- dsn.go | 6 ++++-- infile.go | 2 +- packets.go | 5 +++-- rows.go | 2 +- statement.go | 2 +- 10 files changed, 37 insertions(+), 30 deletions(-) diff --git a/auth.go b/auth.go index 36b9df577..f0f100333 100644 --- a/auth.go +++ b/auth.go @@ -316,7 +316,8 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, err := mc.auth(authData, plugin) + var authResp []byte + authResp, err = mc.auth(authData, plugin) if err != nil { return err } @@ -361,7 +362,8 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - data, err := mc.buf.takeSmallBuffer(4 + 1) + var data []byte + data, err = mc.buf.takeSmallBuffer(4 + 1) if err != nil { return err } @@ -411,7 +413,8 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return nil // auth successful default: block, _ := pem.Decode(authData) - pub, err := x509.ParsePKIXPublicKey(block.Bytes) + var pub interface{} + pub, err = x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err } diff --git a/conncheck.go b/conncheck.go index 024eb2858..178ade6b8 100644 --- a/conncheck.go +++ b/conncheck.go @@ -33,16 +33,16 @@ func connCheck(conn net.Conn) error { err = rawConn.Read(func(fd uintptr) bool { var buf [1]byte - n, err := syscall.Read(int(fd), buf[:]) + n, readErr := syscall.Read(int(fd), buf[:]) switch { - case n == 0 && err == nil: + case n == 0 && readErr == nil: sysErr = io.EOF case n > 0: sysErr = errUnexpectedRead - case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + case readErr == syscall.EAGAIN, readErr == syscall.EWOULDBLOCK: sysErr = nil default: - sysErr = err + sysErr = readErr } return true }) diff --git a/connection.go b/connection.go index 9ed1fcd7b..759a6bb26 100644 --- a/connection.go +++ b/connection.go @@ -380,7 +380,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) if resLen == 0 { rows.rs.done = true - switch err := rows.NextResultSet(); err { + switch err = rows.NextResultSet(); err { case nil, io.EOF: return rows, nil default: @@ -413,7 +413,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err = mc.readUntilEOF(); err != nil { return nil, err } } @@ -494,7 +494,7 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv return nil, err } - if err := mc.watchCancel(ctx); err != nil { + if err = mc.watchCancel(ctx); err != nil { return nil, err } @@ -547,7 +547,7 @@ func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValu return nil, err } - if err := stmt.mc.watchCancel(ctx); err != nil { + if err = stmt.mc.watchCancel(ctx); err != nil { return nil, err } diff --git a/connector.go b/connector.go index d567b4e4f..49dd022d4 100644 --- a/connector.go +++ b/connector.go @@ -55,7 +55,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { - if err := tc.SetKeepAlive(true); err != nil { + if err = tc.SetKeepAlive(true); err != nil { // Don't send COM_QUIT before handshake. mc.netConn.Close() mc.netConn = nil @@ -65,7 +65,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Call startWatcher for context support (From Go 1.8) mc.startWatcher() - if err := mc.watchCancel(ctx); err != nil { + if err = mc.watchCancel(ctx); err != nil { mc.cleanup() return nil, err } @@ -118,12 +118,13 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") + var maxAP []byte + maxAP, err = mc.getSystemVar("max_allowed_packet") if err != nil { mc.Close() return nil, err } - mc.maxAllowedPacket = stringToInt(maxap) - 1 + mc.maxAllowedPacket = stringToInt(maxAP) - 1 } if mc.maxAllowedPacket < maxPacketSize { mc.maxWriteSize = mc.maxAllowedPacket diff --git a/driver_test.go b/driver_test.go index 1535fb1a1..903f76700 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2139,7 +2139,7 @@ func TestInterruptBySignal(t *testing.T) { dbt.Fatalf("error on text query: %s", err.Error()) } for rows.Next() { - if err := rows.Scan(&val); err != nil { + if err = rows.Scan(&val); err != nil { dbt.Error(err) } else if val != 42 { dbt.Errorf("expected val to be 42") @@ -2156,7 +2156,7 @@ func TestInterruptBySignal(t *testing.T) { dbt.Fatalf("error on binary query: %s", err.Error()) } for rows.Next() { - if err := rows.Scan(&val); err != nil { + if err = rows.Scan(&val); err != nil { dbt.Error(err) } else if val != 42 { dbt.Errorf("expected val to be 42") @@ -2369,7 +2369,7 @@ func TestMultiResultSet(t *testing.T) { for rows.Next() { var res [3]int - if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + if err = rows.Scan(&res[0], &res[1], &res[2]); err != nil { dbt.Fatal(desc, err) } res2.values = append(res2.values, res[:]) @@ -2579,7 +2579,7 @@ func TestContextCancelQuery(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2751,7 +2751,7 @@ func TestContextCancelBegin(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2759,7 +2759,7 @@ func TestContextCancelBegin(t *testing.T) { } // Transaction is canceled, so expect an error. - switch err := tx.Commit(); err { + switch err = tx.Commit(); err { case sql.ErrTxDone: // because the transaction has already been rollbacked. // the database/sql package watches ctx @@ -2815,7 +2815,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { var v int row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") - if err := row.Scan(&v); err != nil { + if err = row.Scan(&v); err != nil { dbt.Fatal(err) } // Because writer transaction wasn't commited yet, it should be available @@ -3145,7 +3145,7 @@ func TestRawBytesAreNotModified(t *testing.T) { var b int var raw sql.RawBytes for rows.Next() { - if err := rows.Scan(&b, &raw); err != nil { + if err = rows.Scan(&b, &raw); err != nil { t.Fatal(err) } diff --git a/dsn.go b/dsn.go index 64a508786..89aa21ee3 100644 --- a/dsn.go +++ b/dsn.go @@ -492,7 +492,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Server public key case "serverPubKey": - name, err := url.QueryUnescape(value) + var name string + name, err = url.QueryUnescape(value) if err != nil { return fmt.Errorf("invalid value for server pub key name: %v", err) } @@ -521,7 +522,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl } else { - name, err := url.QueryUnescape(value) + var name string + name, err = url.QueryUnescape(value) if err != nil { return fmt.Errorf("invalid value for TLS config name: %v", err) } diff --git a/infile.go b/infile.go index 68f5d0166..0a0fa7e9d 100644 --- a/infile.go +++ b/infile.go @@ -149,7 +149,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // send content packets // if packetSize == 0, the Reader contains no data if err == nil && packetSize > 0 { - data := make([]byte, 4+packetSize) + data = make([]byte, 4+packetSize) var n int for err == nil { n, err = rdr.Read(data[4:]) diff --git a/packets.go b/packets.go index 9c6d1787f..b9443c8a8 100644 --- a/packets.go +++ b/packets.go @@ -511,7 +511,7 @@ func (mc *mysqlConn) readAuthResult() (authData []byte, plugin string, err error return nil, "", ErrMalformPkt } plugin := string(data[1:pluginEndIndex]) - authData := data[pluginEndIndex+1:] + authData = data[pluginEndIndex+1:] return authData, plugin, nil default: // Error otherwise @@ -665,7 +665,8 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Table [len coded string] if mc.cfg.ColumnsWithAlias { - tableName, _, n, err := readLengthEncodedString(data[pos:]) + var tableName []byte + tableName, _, n, err = readLengthEncodedString(data[pos:]) if err != nil { return nil, err } diff --git a/rows.go b/rows.go index 48c99e90b..d5ecf5059 100644 --- a/rows.go +++ b/rows.go @@ -107,7 +107,7 @@ func (rows *mysqlRows) Close() (err error) { if mc == nil { return nil } - if err := mc.error(); err != nil { + if err = mc.error(); err != nil { return err } diff --git a/statement.go b/statement.go index 18a3ae498..4bd310b0b 100644 --- a/statement.go +++ b/statement.go @@ -124,7 +124,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { } else { rows.rs.done = true - switch err := rows.NextResultSet(); err { + switch err = rows.NextResultSet(); err { case nil, io.EOF: return rows, nil default: From bb5ce9f036eb5e79e17b8b7e650e511732b9b3e0 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:33:51 +0200 Subject: [PATCH 09/18] Style fixes --- auth.go | 5 +- collations.go | 162 ++++++++++++++++++++++++------------------------- const.go | 4 +- driver_test.go | 19 +++--- dsn.go | 1 - dsn_test.go | 9 +-- errors.go | 2 +- nulltime.go | 2 +- packets.go | 19 ++---- rows.go | 4 +- statement.go | 7 +-- 11 files changed, 110 insertions(+), 124 deletions(-) diff --git a/auth.go b/auth.go index f0f100333..51b3202b0 100644 --- a/auth.go +++ b/auth.go @@ -338,9 +338,8 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } switch plugin { - - // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ case "caching_sha2_password": + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ switch len(authData) { case 0: return nil // auth successful @@ -406,7 +405,6 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { default: return ErrMalformPkt } - case "sha256_password": switch len(authData) { case 0: @@ -426,7 +424,6 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } return mc.readResultOK() } - default: return nil // auth successful } diff --git a/collations.go b/collations.go index 326a9f7fa..5733c7ec6 100644 --- a/collations.go +++ b/collations.go @@ -54,7 +54,7 @@ var collations = map[string]byte{ "armscii8_general_ci": 32, "utf8_general_ci": 33, "cp1250_czech_cs": 34, - //"ucs2_general_ci": 35, + // "ucs2_general_ci": 35, "cp866_general_ci": 36, "keybcs2_general_ci": 37, "macce_general_ci": 38, @@ -73,15 +73,15 @@ var collations = map[string]byte{ "cp1251_general_ci": 51, "cp1251_general_cs": 52, "macroman_bin": 53, - //"utf16_general_ci": 54, - //"utf16_bin": 55, - //"utf16le_general_ci": 56, + // "utf16_general_ci": 54, + // "utf16_bin": 55, + // "utf16le_general_ci": 56, "cp1256_general_ci": 57, "cp1257_bin": 58, "cp1257_general_ci": 59, - //"utf32_general_ci": 60, - //"utf32_bin": 61, - //"utf16le_bin": 62, + // "utf32_general_ci": 60, + // "utf32_bin": 61, + // "utf16le_bin": 62, "binary": 63, "armscii8_bin": 64, "ascii_bin": 65, @@ -109,7 +109,7 @@ var collations = map[string]byte{ "gbk_bin": 87, "sjis_bin": 88, "tis620_bin": 89, - //"ucs2_bin": 90, + // "ucs2_bin": 90, "ujis_bin": 91, "geostd8_general_ci": 92, "geostd8_bin": 93, @@ -119,79 +119,79 @@ var collations = map[string]byte{ "eucjpms_japanese_ci": 97, "eucjpms_bin": 98, "cp1250_polish_ci": 99, - //"utf16_unicode_ci": 101, - //"utf16_icelandic_ci": 102, - //"utf16_latvian_ci": 103, - //"utf16_romanian_ci": 104, - //"utf16_slovenian_ci": 105, - //"utf16_polish_ci": 106, - //"utf16_estonian_ci": 107, - //"utf16_spanish_ci": 108, - //"utf16_swedish_ci": 109, - //"utf16_turkish_ci": 110, - //"utf16_czech_ci": 111, - //"utf16_danish_ci": 112, - //"utf16_lithuanian_ci": 113, - //"utf16_slovak_ci": 114, - //"utf16_spanish2_ci": 115, - //"utf16_roman_ci": 116, - //"utf16_persian_ci": 117, - //"utf16_esperanto_ci": 118, - //"utf16_hungarian_ci": 119, - //"utf16_sinhala_ci": 120, - //"utf16_german2_ci": 121, - //"utf16_croatian_ci": 122, - //"utf16_unicode_520_ci": 123, - //"utf16_vietnamese_ci": 124, - //"ucs2_unicode_ci": 128, - //"ucs2_icelandic_ci": 129, - //"ucs2_latvian_ci": 130, - //"ucs2_romanian_ci": 131, - //"ucs2_slovenian_ci": 132, - //"ucs2_polish_ci": 133, - //"ucs2_estonian_ci": 134, - //"ucs2_spanish_ci": 135, - //"ucs2_swedish_ci": 136, - //"ucs2_turkish_ci": 137, - //"ucs2_czech_ci": 138, - //"ucs2_danish_ci": 139, - //"ucs2_lithuanian_ci": 140, - //"ucs2_slovak_ci": 141, - //"ucs2_spanish2_ci": 142, - //"ucs2_roman_ci": 143, - //"ucs2_persian_ci": 144, - //"ucs2_esperanto_ci": 145, - //"ucs2_hungarian_ci": 146, - //"ucs2_sinhala_ci": 147, - //"ucs2_german2_ci": 148, - //"ucs2_croatian_ci": 149, - //"ucs2_unicode_520_ci": 150, - //"ucs2_vietnamese_ci": 151, - //"ucs2_general_mysql500_ci": 159, - //"utf32_unicode_ci": 160, - //"utf32_icelandic_ci": 161, - //"utf32_latvian_ci": 162, - //"utf32_romanian_ci": 163, - //"utf32_slovenian_ci": 164, - //"utf32_polish_ci": 165, - //"utf32_estonian_ci": 166, - //"utf32_spanish_ci": 167, - //"utf32_swedish_ci": 168, - //"utf32_turkish_ci": 169, - //"utf32_czech_ci": 170, - //"utf32_danish_ci": 171, - //"utf32_lithuanian_ci": 172, - //"utf32_slovak_ci": 173, - //"utf32_spanish2_ci": 174, - //"utf32_roman_ci": 175, - //"utf32_persian_ci": 176, - //"utf32_esperanto_ci": 177, - //"utf32_hungarian_ci": 178, - //"utf32_sinhala_ci": 179, - //"utf32_german2_ci": 180, - //"utf32_croatian_ci": 181, - //"utf32_unicode_520_ci": 182, - //"utf32_vietnamese_ci": 183, + // "utf16_unicode_ci": 101, + // "utf16_icelandic_ci": 102, + // "utf16_latvian_ci": 103, + // "utf16_romanian_ci": 104, + // "utf16_slovenian_ci": 105, + // "utf16_polish_ci": 106, + // "utf16_estonian_ci": 107, + // "utf16_spanish_ci": 108, + // "utf16_swedish_ci": 109, + // "utf16_turkish_ci": 110, + // "utf16_czech_ci": 111, + // "utf16_danish_ci": 112, + // "utf16_lithuanian_ci": 113, + // "utf16_slovak_ci": 114, + // "utf16_spanish2_ci": 115, + // "utf16_roman_ci": 116, + // "utf16_persian_ci": 117, + // "utf16_esperanto_ci": 118, + // "utf16_hungarian_ci": 119, + // "utf16_sinhala_ci": 120, + // "utf16_german2_ci": 121, + // "utf16_croatian_ci": 122, + // "utf16_unicode_520_ci": 123, + // "utf16_vietnamese_ci": 124, + // "ucs2_unicode_ci": 128, + // "ucs2_icelandic_ci": 129, + // "ucs2_latvian_ci": 130, + // "ucs2_romanian_ci": 131, + // "ucs2_slovenian_ci": 132, + // "ucs2_polish_ci": 133, + // "ucs2_estonian_ci": 134, + // "ucs2_spanish_ci": 135, + // "ucs2_swedish_ci": 136, + // "ucs2_turkish_ci": 137, + // "ucs2_czech_ci": 138, + // "ucs2_danish_ci": 139, + // "ucs2_lithuanian_ci": 140, + // "ucs2_slovak_ci": 141, + // "ucs2_spanish2_ci": 142, + // "ucs2_roman_ci": 143, + // "ucs2_persian_ci": 144, + // "ucs2_esperanto_ci": 145, + // "ucs2_hungarian_ci": 146, + // "ucs2_sinhala_ci": 147, + // "ucs2_german2_ci": 148, + // "ucs2_croatian_ci": 149, + // "ucs2_unicode_520_ci": 150, + // "ucs2_vietnamese_ci": 151, + // "ucs2_general_mysql500_ci": 159, + // "utf32_unicode_ci": 160, + // "utf32_icelandic_ci": 161, + // "utf32_latvian_ci": 162, + // "utf32_romanian_ci": 163, + // "utf32_slovenian_ci": 164, + // "utf32_polish_ci": 165, + // "utf32_estonian_ci": 166, + // "utf32_spanish_ci": 167, + // "utf32_swedish_ci": 168, + // "utf32_turkish_ci": 169, + // "utf32_czech_ci": 170, + // "utf32_danish_ci": 171, + // "utf32_lithuanian_ci": 172, + // "utf32_slovak_ci": 173, + // "utf32_spanish2_ci": 174, + // "utf32_roman_ci": 175, + // "utf32_persian_ci": 176, + // "utf32_esperanto_ci": 177, + // "utf32_hungarian_ci": 178, + // "utf32_sinhala_ci": 179, + // "utf32_german2_ci": 180, + // "utf32_croatian_ci": 181, + // "utf32_unicode_520_ci": 182, + // "utf32_vietnamese_ci": 183, "utf8_unicode_ci": 192, "utf8_icelandic_ci": 193, "utf8_latvian_ci": 194, diff --git a/const.go b/const.go index b1e6b85ef..31e790737 100644 --- a/const.go +++ b/const.go @@ -158,11 +158,11 @@ const ( statusNoIndexUsed statusCursorExists statusLastRowSent - statusDbDropped + statusDBDropped statusNoBackslashEscapes statusMetadataChanged statusQueryWasSlow - statusPsOutParams + statusPSOutParams statusInTransReadonly statusSessionStateChanged ) diff --git a/driver_test.go b/driver_test.go index 903f76700..ea7660521 100644 --- a/driver_test.go +++ b/driver_test.go @@ -358,8 +358,8 @@ func TestMultiQuery(t *testing.T) { rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") if rows.Next() { rows.Scan(&out) - if 5 != out { - dbt.Errorf("5 != %d", out) + if out != 5 { + dbt.Errorf("%d != 5", out) } if rows.Next() { @@ -372,7 +372,6 @@ func TestMultiQuery(t *testing.T) { dbt.Fatal(err) } rows.Close() - }) } @@ -661,10 +660,9 @@ type testValuerWithValidation struct { } func (tv testValuerWithValidation) Value() (driver.Value, error) { - if len(tv.value) == 0 { + if tv.value == "" { return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") } - return tv.value, nil } @@ -737,8 +735,9 @@ func (t timeMode) Binary() bool { switch t { case binaryString, binaryTime: return true + default: + return false } - return false } const ( @@ -2018,11 +2017,12 @@ func TestSQLInjection(t *testing.T) { // NULL can't be equal to anything, the idea here is to inject query so it returns row // This test verifies that escapeQuotes and escapeBackslash are working properly err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) - if err == sql.ErrNoRows { + switch err { + case sql.ErrNoRows: return // success, sql injection failed - } else if err == nil { + case nil: dbt.Errorf("sql injection successful with arg: %s", arg) - } else { + default: dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) } } @@ -2385,7 +2385,6 @@ func TestMultiResultSet(t *testing.T) { if rows.NextResultSet() { dbt.Error(desc, "unexpected next result set") } - if err := rows.Err(); err != nil { dbt.Error(desc, err) } diff --git a/dsn.go b/dsn.go index 89aa21ee3..78957b3bf 100644 --- a/dsn.go +++ b/dsn.go @@ -428,7 +428,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Collation case "collation": cfg.Collation = value - break case "columnsWithAlias": var isBool bool diff --git a/dsn_test.go b/dsn_test.go index fb2be3318..1beef81a0 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -99,7 +99,7 @@ func TestDSNParserInvalid(t *testing.T) { "net(addr)//", // unescaped "User:pass@tcp(1.2.3.4:3306)", // no trailing slash "net()/", // unknown default addr - //"/dbname?arg=/some/unescaped/path", + // "/dbname?arg=/some/unescaped/path", } for i, tst := range invalidDSNs { @@ -211,11 +211,12 @@ func TestDSNWithCustomTLS(t *testing.T) { tlsCfg.ServerName = "" cfg, err = ParseDSN(tst) - if err != nil { + switch { + case err != nil: t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + case cfg.tls.ServerName != name: t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) - } else if tlsCfg.ServerName != "" { + case tlsCfg.ServerName != "": t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) } } diff --git a/errors.go b/errors.go index 760782ff2..c992d365a 100644 --- a/errors.go +++ b/errors.go @@ -21,7 +21,7 @@ var ( ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") - ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrNativePassword = errors.New("this user requires mysql native password authentication") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrUnknownPlugin = errors.New("this authentication plugin is not supported") ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") diff --git a/nulltime.go b/nulltime.go index 651723a96..0092abf93 100644 --- a/nulltime.go +++ b/nulltime.go @@ -38,7 +38,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) { } nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) + return fmt.Errorf("can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. diff --git a/packets.go b/packets.go index b9443c8a8..7bd03c595 100644 --- a/packets.go +++ b/packets.go @@ -494,13 +494,10 @@ func (mc *mysqlConn) readAuthResult() (authData []byte, plugin string, err error // packet indicator switch data[0] { - case iOK: return nil, "", mc.handleOkPacket(data) - case iAuthMoreData: return data[1:], "", err - case iEOF: if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest @@ -513,7 +510,6 @@ func (mc *mysqlConn) readAuthResult() (authData []byte, plugin string, err error plugin := string(data[1:pluginEndIndex]) authData = data[pluginEndIndex+1:] return authData, plugin, nil - default: // Error otherwise return nil, "", mc.handleErrorPacket(data) } @@ -538,13 +534,10 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { data, err := mc.readPacket() if err == nil { switch data[0] { - case iOK: return 0, mc.handleOkPacket(data) - case iERR: return 0, mc.handleErrorPacket(data) - case iLocalInFile: return 0, mc.handleInFileRequest(string(data[1:])) } @@ -592,7 +585,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // SQL State [optional: # + 5bytes string] if data[3] == 0x23 { - //sqlstate := string(data[4 : 4+5]) + // sqlstate := string(data[4 : 4+5]) pos = 9 } @@ -723,12 +716,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Decimals [uint8] columns[i].decimals = data[pos] - //pos++ + // pos++ // Default value [len coded binary] - //if pos < len(data) { - // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) - //} + // if pos < len(data) { + // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) + // } } } @@ -789,7 +782,6 @@ func (rows *textRows) readRow(dest []driver.Value) error { continue } } - } else { dest[i] = nil continue @@ -897,7 +889,6 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { continue } return err - } // Reset Packet Sequence diff --git a/rows.go b/rows.go index d5ecf5059..67dc2de70 100644 --- a/rows.go +++ b/rows.go @@ -88,9 +88,9 @@ func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (precision, scale int64, return math.MaxInt64, math.MaxInt64, true } return math.MaxInt64, decimals, true + default: + return 0, 0, false } - - return 0, 0, false } func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { diff --git a/statement.go b/statement.go index 4bd310b0b..06da0ee12 100644 --- a/statement.go +++ b/statement.go @@ -27,7 +27,6 @@ func (stmt *mysqlStmt) Close() error { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. - //errLog.Print(ErrInvalidConn) return driver.ErrBadConn } @@ -171,9 +170,8 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { // indirect pointers if rv.IsNil() { return nil, nil - } else { - return c.ConvertValue(rv.Elem().Interface()) } + return c.ConvertValue(rv.Elem().Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: @@ -193,8 +191,9 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } case reflect.String: return rv.String(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } - return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() From f428407d83e9d9b3b8c7c9bc213e28fafa4aea29 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:34:03 +0200 Subject: [PATCH 10/18] Fix calling of wrong benchmark func --- benchmark_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark_test.go b/benchmark_test.go index 319ef6e0b..fb0bdbf7c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -321,7 +321,7 @@ func BenchmarkExecContext(b *testing.B) { for _, p := range []int{1, 2, 3, 4} { p := p b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { - benchmarkQueryContext(b, db, p) + benchmarkExecContext(b, db, p) }) } } From 18e2ca444fb06f9f7cb4e15d69d465b311894a2f Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:34:17 +0200 Subject: [PATCH 11/18] Fix typos --- driver_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/driver_test.go b/driver_test.go index ea7660521..83075c4a8 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2817,7 +2817,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { if err = row.Scan(&v); err != nil { dbt.Fatal(err) } - // Because writer transaction wasn't commited yet, it should be available + // Because writer transaction wasn't committed yet, it should be available if v != 0 { dbt.Errorf("expected val to be 0, got %d", v) } @@ -2831,7 +2831,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { if err := row.Scan(&v); err != nil { dbt.Fatal(err) } - // Data written by writer transaction is already commited, it should be selectable + // Data written by writer transaction is already committed, it should be selectable if v != 1 { dbt.Errorf("expected val to be 1, got %d", v) } From 8ba16dc253963ec1b16daee324b4ece148d31ec9 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:34:33 +0200 Subject: [PATCH 12/18] Add missing rows/stmt.Close calls --- driver_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/driver_test.go b/driver_test.go index 83075c4a8..09ccdec80 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1267,6 +1267,7 @@ func TestLoadData(t *testing.T) { if err != nil { dbt.Fatal(err.Error()) } + defer rows.Close() i := 0 values := [4]string{ @@ -2617,6 +2618,7 @@ func TestContextCancelQueryRow(t *testing.T) { if err != nil { dbt.Fatalf("%s", err.Error()) } + defer rows.Close() // the first row will be succeed. var v int @@ -2662,6 +2664,7 @@ func TestContextCancelStmtExec(t *testing.T) { if err != nil { dbt.Fatalf("unexpected error: %v", err) } + defer stmt.Close() // Delay execution for just a bit until db.ExecContext has begun. defer time.AfterFunc(250*time.Millisecond, cancel).Stop() @@ -2697,6 +2700,7 @@ func TestContextCancelStmtQuery(t *testing.T) { if err != nil { dbt.Fatalf("unexpected error: %v", err) } + defer stmt.Close() // Delay execution for just a bit until db.ExecContext has begun. defer time.AfterFunc(250*time.Millisecond, cancel).Stop() From 715811558c2c849a3cd611c7c107fc5b3aa5f95a Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 16:45:57 +0200 Subject: [PATCH 13/18] Fix broken rows.Err checks --- driver_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/driver_test.go b/driver_test.go index 09ccdec80..f3cd43500 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2146,7 +2146,7 @@ func TestInterruptBySignal(t *testing.T) { dbt.Errorf("expected val to be 42") } } - if err = rows.Err(); rows != nil { + if err = rows.Err(); err != nil { dbt.Fatal(err) } rows.Close() @@ -2163,7 +2163,7 @@ func TestInterruptBySignal(t *testing.T) { dbt.Errorf("expected val to be 42") } } - if err = rows.Err(); rows != nil { + if err = rows.Err(); err != nil { dbt.Fatal(err) } rows.Close() @@ -3094,7 +3094,7 @@ func TestRowsColumnTypes(t *testing.T) { if i != 3 { t.Errorf("expected 3 rows, got %d", i) } - if err = rows.Err(); rows != nil { + if err = rows.Err(); err != nil { dbt.Fatal(err) } @@ -3162,7 +3162,7 @@ func TestRawBytesAreNotModified(t *testing.T) { t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } } - if err = rows.Err(); rows != nil { + if err = rows.Err(); err != nil { dbt.Fatal(err) } rows.Close() From 2ff6384ec67d6331661ace3bbbacbd6a80956af4 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 17:22:59 +0200 Subject: [PATCH 14/18] Fix context canceled check --- driver_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/driver_test.go b/driver_test.go index f3cd43500..07138c5e5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3162,7 +3162,10 @@ func TestRawBytesAreNotModified(t *testing.T) { t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } } - if err = rows.Err(); err != nil { + if err = rows.Err(); err != context.Canceled { + if err == nil { + t.Fatal("expected 'context canceled' error, but got none") + } dbt.Fatal(err) } rows.Close() From 8a00487587cef79458d4ffd822a43d0a27fbf4bf Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 19:22:36 +0200 Subject: [PATCH 15/18] Fix BenchmarkExec conccurent err reporting --- benchmark_test.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index fb0bdbf7c..bfd94eccf 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -119,9 +119,10 @@ func BenchmarkExec(b *testing.B) { remain := int64(b.N) var wg sync.WaitGroup wg.Add(concurrencyLevel) - defer wg.Wait() - b.StartTimer() + errChan := make(chan error, 1) + + b.StartTimer() for i := 0; i < concurrencyLevel; i++ { go func() { for { @@ -131,11 +132,26 @@ func BenchmarkExec(b *testing.B) { } if _, err := stmt.Exec(); err != nil { - b.Fatal(err.Error()) + // attempt to report error back via errChan without blocking. + // only the first goroutine attempting to report an error will be successful. + select { + case errChan <- err: + default: + } + return } } }() } + wg.Wait() + + // check if an error was reported by a goroutine + select { + case err := <-errChan: + b.Fatal(err) + default: + } + close(errChan) } // data, but no db writes From 885e24c24166de6c6338d0304e4ef5c3d0b3a6cc Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 19:27:43 +0200 Subject: [PATCH 16/18] Fix mysqlConn alignment Reduced struct of size from 264 bytes to 256 bytes --- connection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index 759a6bb26..cf238a9d6 100644 --- a/connection.go +++ b/connection.go @@ -38,11 +38,11 @@ type mysqlConn struct { // for context support (Go 1.8+) watching bool + closed atomicBool // set when conn is closed, before closech is closed + canceled atomicError // set non-nil if conn is canceled watcher chan<- context.Context closech chan struct{} finished chan<- struct{} - canceled atomicError // set non-nil if conn is canceled - closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established From 50079f6846fb7dd5c0095defe4444656a84b992c Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 19:40:52 +0200 Subject: [PATCH 17/18] auth: replace identifier strings with constants --- auth.go | 14 +++++----- auth_test.go | 72 ++++++++++++++++++++++++------------------------- const.go | 10 ++++++- packets.go | 2 +- packets_test.go | 4 +-- 5 files changed, 55 insertions(+), 47 deletions(-) diff --git a/auth.go b/auth.go index 51b3202b0..0b91873a8 100644 --- a/auth.go +++ b/auth.go @@ -240,11 +240,11 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { switch plugin { - case "caching_sha2_password": + case authCachingSHA2: authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) return authResp, nil - case "mysql_old_password": + case authOldPassword: if !mc.cfg.AllowOldPasswords { return nil, ErrOldPassword } @@ -254,7 +254,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) return authResp, nil - case "mysql_clear_password": + case authCleartextPassword: if !mc.cfg.AllowCleartextPasswords { return nil, ErrCleartextPassword } @@ -262,7 +262,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html return append([]byte(mc.cfg.Passwd), 0), nil - case "mysql_native_password": + case authNativePassword: if !mc.cfg.AllowNativePasswords { return nil, ErrNativePassword } @@ -271,7 +271,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { authResp := scramblePassword(authData[:20], mc.cfg.Passwd) return authResp, nil - case "sha256_password": + case authSHA256Password: if mc.cfg.Passwd == "" { return []byte{0}, nil } @@ -338,7 +338,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } switch plugin { - case "caching_sha2_password": + case authCachingSHA2: // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ switch len(authData) { case 0: @@ -405,7 +405,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { default: return ErrMalformPkt } - case "sha256_password": + case authSHA256Password: switch len(authData) { case 0: return nil // auth successful diff --git a/auth_test.go b/auth_test.go index 1920ef39f..33b1c399f 100644 --- a/auth_test.go +++ b/auth_test.go @@ -82,7 +82,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, 22, 41, 84, 32, 123, 43, 118} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -127,7 +127,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, 22, 41, 84, 32, 123, 43, 118} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -169,7 +169,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -225,7 +225,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -277,7 +277,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -333,7 +333,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_clear_password" + plugin := authCleartextPassword // Send Client Authentication Packet _, err := mc.auth(authData, plugin) @@ -350,7 +350,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_clear_password" + plugin := authCleartextPassword // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -393,7 +393,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_clear_password" + plugin := authCleartextPassword // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -436,7 +436,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_native_password" + plugin := authNativePassword // Send Client Authentication Packet _, err := mc.auth(authData, plugin) @@ -452,7 +452,7 @@ func TestAuthFastNativePassword(t *testing.T) { authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_native_password" + plugin := authNativePassword // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -495,7 +495,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_native_password" + plugin := authNativePassword // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -537,7 +537,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" + plugin := authSHA256Password // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -585,7 +585,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" + plugin := authSHA256Password // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -634,7 +634,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" + plugin := authSHA256Password // Send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -667,7 +667,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" + plugin := authSHA256Password // send Client Authentication Packet authResp, err := mc.auth(authData, plugin) @@ -726,7 +726,7 @@ func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -759,7 +759,7 @@ func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -795,7 +795,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -840,7 +840,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -883,7 +883,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -911,7 +911,7 @@ func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword err := mc.handleAuthResult(authData, plugin) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) @@ -933,7 +933,7 @@ func TestAuthSwitchCleartextPassword(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -960,7 +960,7 @@ func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -983,7 +983,7 @@ func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 err := mc.handleAuthResult(authData, plugin) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) @@ -1007,7 +1007,7 @@ func TestAuthSwitchNativePassword(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1037,7 +1037,7 @@ func TestAuthSwitchNativePasswordEmpty(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" + plugin := authCachingSHA2 if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1058,7 +1058,7 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" + plugin := authNativePassword err := mc.handleAuthResult(authData, plugin) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) @@ -1074,7 +1074,7 @@ func TestOldAuthSwitchNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" + plugin := authNativePassword err := mc.handleAuthResult(authData, plugin) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) @@ -1097,7 +1097,7 @@ func TestAuthSwitchOldPassword(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1124,7 +1124,7 @@ func TestOldAuthSwitch(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1151,7 +1151,7 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1178,7 +1178,7 @@ func TestOldAuthSwitchPasswordEmpty(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1207,7 +1207,7 @@ func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1242,7 +1242,7 @@ func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1278,7 +1278,7 @@ func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) @@ -1314,7 +1314,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" + plugin := authNativePassword if err := mc.handleAuthResult(authData, plugin); err != nil { t.Errorf("got error: %v", err) diff --git a/const.go b/const.go index 31e790737..ab0afcc66 100644 --- a/const.go +++ b/const.go @@ -9,13 +9,21 @@ package mysql const ( - defaultAuthPlugin = "mysql_native_password" + defaultAuthPlugin = authNativePassword defaultMaxAllowedPacket = 4 << 20 // 4 MiB minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) +const ( + authNativePassword = "mysql_native_password" + authCleartextPassword = "mysql_clear_password" + authCachingSHA2 = "caching_sha2_password" + authSHA256Password = "sha256_password" + authOldPassword = "mysql_old_password" +) + // MySQL constants documentation: // http://dev.mysql.com/doc/internals/en/client-server-protocol.html diff --git a/packets.go b/packets.go index 7bd03c595..25a8c284f 100644 --- a/packets.go +++ b/packets.go @@ -501,7 +501,7 @@ func (mc *mysqlConn) readAuthResult() (authData []byte, plugin string, err error case iEOF: if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest - return nil, "mysql_old_password", nil + return nil, authOldPassword, nil } pluginEndIndex := bytes.IndexByte(data, 0x00) if pluginEndIndex < 0 { diff --git a/packets_test.go b/packets_test.go index a33aa98e1..c64f033a9 100644 --- a/packets_test.go +++ b/packets_test.go @@ -324,8 +324,8 @@ func TestRegression801(t *testing.T) { t.Fatalf("got error: %v", err) } - if pluginName != "mysql_native_password" { - t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) + if pluginName != authNativePassword { + t.Errorf("expected plugin name '%s', got '%s'", authNativePassword, pluginName) } expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114, From 35f1118c8f0bb91ccce655694d875c7f0f5a5217 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 14 Aug 2020 19:41:28 +0200 Subject: [PATCH 18/18] infile: reword errors --- driver_test.go | 2 +- infile.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/driver_test.go b/driver_test.go index 07138c5e5..b656d2a6c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1356,7 +1356,7 @@ func TestLoadData(t *testing.T) { _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") if err == nil { dbt.Fatal("load non-existent Reader didn't fail") - } else if err.Error() != "Reader 'doesnotexist' is not registered" { + } else if err.Error() != "unknown Reader 'doesnotexist'" { dbt.Fatal(err.Error()) } }) diff --git a/infile.go b/infile.go index 0a0fa7e9d..70c40fca9 100644 --- a/infile.go +++ b/infile.go @@ -116,10 +116,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { defer deferredClose(&err, cl) } } else { - err = fmt.Errorf("Reader '%s' is ", name) + err = fmt.Errorf("nil value for Reader '%s'", name) } } else { - err = fmt.Errorf("Reader '%s' is not registered", name) + err = fmt.Errorf("unknown Reader '%s'", name) } } else { // File name = strings.Trim(name, `"`)