From cf724b8742be411565e380a35b22814aea525e78 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 29 Sep 2017 23:21:21 +0200 Subject: [PATCH 1/2] Remove strict mode Fixes #556 #602 #635 Closes #609 --- README.md | 16 ++------- benchmark_test.go | 6 +--- connection.go | 1 - driver.go | 1 - driver_test.go | 82 ++--------------------------------------------- dsn.go | 16 +-------- errors.go | 73 ----------------------------------------- packets.go | 14 -------- 8 files changed, 7 insertions(+), 202 deletions(-) diff --git a/README.md b/README.md index 779ada5ba..b5882e6c8 100644 --- a/README.md +++ b/README.md @@ -294,20 +294,6 @@ supposed to happen, setting this on some MySQL providers (such as AWS Aurora) is safer for failovers. -##### `strict` - -``` -Type: bool -Valid Values: true, false -Default: false -``` - -`strict=true` enables a driver-side strict mode in which MySQL warnings are treated as errors. This mode should not be used in production as it may lead to data corruption in certain situations. - -A server-side strict mode, which is safe for production use, can be set via the [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html) system variable. - -By default MySQL also treats notes as warnings. Use [`sql_notes=false`](http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_sql_notes) to ignore notes. - ##### `timeout` ``` @@ -317,6 +303,7 @@ Default: OS default Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + ##### `tls` ``` @@ -327,6 +314,7 @@ Default: false `tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). + ##### `writeTimeout` ``` diff --git a/benchmark_test.go b/benchmark_test.go index 7da833a2a..c1de8672b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -48,11 +48,7 @@ func initDB(b *testing.B, queries ...string) *sql.DB { db := tb.checkDB(sql.Open("mysql", dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { - if w, ok := err.(MySQLWarnings); ok { - b.Logf("warning on %q: %v", query, w) - } else { - b.Fatalf("error on %q: %v", query, err) - } + b.Fatalf("error on %q: %v", query, err) } } return db diff --git a/connection.go b/connection.go index 948a59561..58ae29988 100644 --- a/connection.go +++ b/connection.go @@ -40,7 +40,6 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool - strict bool // for context support (Go 1.8+) watching bool diff --git a/driver.go b/driver.go index c341b6680..d42ce7a3d 100644 --- a/driver.go +++ b/driver.go @@ -64,7 +64,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } mc.parseTime = mc.cfg.ParseTime - mc.strict = mc.cfg.Strict // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { diff --git a/driver_test.go b/driver_test.go index 206e07cc9..bc0386a09 100644 --- a/driver_test.go +++ b/driver_test.go @@ -63,7 +63,7 @@ func init() { addr = env("MYSQL_TEST_ADDR", "localhost:3306") dbname = env("MYSQL_TEST_DBNAME", "gotest") netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname) c, err := net.Dial(prot, addr) if err == nil { available = true @@ -1170,82 +1170,6 @@ func TestFoundRows(t *testing.T) { }) } -func TestStrict(t *testing.T) { - // ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors - relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'" - // make sure the MySQL version is recent enough with a separate connection - // before running the test - conn, err := MySQLDriver{}.Open(relaxedDsn) - if conn != nil { - conn.Close() - } - // Error 1231: Variable 'sql_mode' can't be set to the value of - // 'ALLOW_INVALID_DATES' => skip test, MySQL server version is too old - maybeSkip(t, err, 1231) - runTests(t, relaxedDsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))") - - var queries = [...]struct { - in string - codes []string - }{ - {"DROP TABLE IF EXISTS no_such_table", []string{"1051"}}, - {"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}}, - } - var err error - - var checkWarnings = func(err error, mode string, idx int) { - if err == nil { - dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in) - } - - if warnings, ok := err.(MySQLWarnings); ok { - var codes = make([]string, len(warnings)) - for i := range warnings { - codes[i] = warnings[i].Code - } - if len(codes) != len(queries[idx].codes) { - dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) - } - - for i := range warnings { - if codes[i] != queries[idx].codes[i] { - dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) - return - } - } - - } else { - dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error()) - } - } - - // text protocol - for i := range queries { - _, err = dbt.db.Exec(queries[i].in) - checkWarnings(err, "text", i) - } - - var stmt *sql.Stmt - - // binary protocol - for i := range queries { - stmt, err = dbt.db.Prepare(queries[i].in) - if err != nil { - dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error()) - } - - _, err = stmt.Exec() - checkWarnings(err, "binary", i) - - err = stmt.Close() - if err != nil { - dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error()) - } - } - }) -} - func TestTLS(t *testing.T) { tlsTest := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { @@ -1762,7 +1686,7 @@ func TestCustomDial(t *testing.T) { return net.Dial(prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname)) + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1859,7 +1783,7 @@ func TestUnixSocketAuthFail(t *testing.T) { } } t.Logf("socket: %s", socket) - badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) db, err := sql.Open("mysql", badDSN) if err != nil { t.Fatalf("error connecting: %s", err.Error()) diff --git a/dsn.go b/dsn.go index 432ca43b8..0cee78998 100644 --- a/dsn.go +++ b/dsn.go @@ -55,7 +55,6 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - Strict bool // Return warnings as errors } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -206,15 +205,6 @@ func (cfg *Config) FormatDSN() string { } } - if cfg.Strict { - if hasParam { - buf.WriteString("&strict=true") - } else { - hasParam = true - buf.WriteString("?strict=true") - } - } - if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -502,11 +492,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Strict mode case "strict": - var isBool bool - cfg.Strict, isBool = readBool(value) - if !isBool { - return errors.New("invalid bool value: " + value) - } + return errors.New("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") // Dial Timeout case "timeout": diff --git a/errors.go b/errors.go index d0d0d2e11..760782ff2 100644 --- a/errors.go +++ b/errors.go @@ -9,10 +9,8 @@ package mysql import ( - "database/sql/driver" "errors" "fmt" - "io" "log" "os" ) @@ -65,74 +63,3 @@ type MySQLError struct { func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } - -// MySQLWarnings is an error type which represents a group of one or more MySQL -// warnings -type MySQLWarnings []MySQLWarning - -func (mws MySQLWarnings) Error() string { - var msg string - for i, warning := range mws { - if i > 0 { - msg += "\r\n" - } - msg += fmt.Sprintf( - "%s %s: %s", - warning.Level, - warning.Code, - warning.Message, - ) - } - return msg -} - -// MySQLWarning is an error type which represents a single MySQL warning. -// Warnings are returned in groups only. See MySQLWarnings -type MySQLWarning struct { - Level string - Code string - Message string -} - -func (mc *mysqlConn) getWarnings() (err error) { - rows, err := mc.Query("SHOW WARNINGS", nil) - if err != nil { - return - } - - var warnings = MySQLWarnings{} - var values = make([]driver.Value, 3) - - for { - err = rows.Next(values) - switch err { - case nil: - warning := MySQLWarning{} - - if raw, ok := values[0].([]byte); ok { - warning.Level = string(raw) - } else { - warning.Level = fmt.Sprintf("%s", values[0]) - } - if raw, ok := values[1].([]byte); ok { - warning.Code = string(raw) - } else { - warning.Code = fmt.Sprintf("%s", values[1]) - } - if raw, ok := values[2].([]byte); ok { - warning.Message = string(raw) - } else { - warning.Message = fmt.Sprintf("%s", values[0]) - } - - warnings = append(warnings, warning) - - case io.EOF: - return warnings - - default: - rows.Close() - return - } - } -} diff --git a/packets.go b/packets.go index 79648d572..1887467df 100644 --- a/packets.go +++ b/packets.go @@ -624,14 +624,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { } // warning count [2 bytes] - if !mc.strict { - return nil - } - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } return nil } @@ -843,14 +836,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Reserved [8 bit] // Warning count [16 bit uint] - if !stmt.mc.strict { - return columnCount, nil - } - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } return columnCount, nil } return 0, err From c7c84f17fc8bb9d77acb2fe3e0dd100c67d187c3 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 30 Sep 2017 14:44:03 +0200 Subject: [PATCH 2/2] dsn: panic in case of strict mode --- dsn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsn.go b/dsn.go index 0cee78998..6ce5cc020 100644 --- a/dsn.go +++ b/dsn.go @@ -492,7 +492,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Strict mode case "strict": - return errors.New("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") // Dial Timeout case "timeout":