Skip to content

Commit

Permalink
Merge c7c84f1 into 7785c74
Browse files Browse the repository at this point in the history
  • Loading branch information
julienschmidt committed Sep 30, 2017
2 parents 7785c74 + c7c84f1 commit 20c871b
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 202 deletions.
16 changes: 2 additions & 14 deletions README.md
Expand Up @@ -294,20 +294,6 @@ supposed to happen, setting this on some MySQL providers (such as AWS Aurora)
is safer for failovers.


##### `strict`

```
Type: bool
Valid Values: true, false
Default: false
```

`strict=true` enables a driver-side strict mode in which MySQL warnings are treated as errors. This mode should not be used in production as it may lead to data corruption in certain situations.

A server-side strict mode, which is safe for production use, can be set via the [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html) system variable.

By default MySQL also treats notes as warnings. Use [`sql_notes=false`](http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_sql_notes) to ignore notes.

##### `timeout`

```
Expand All @@ -317,6 +303,7 @@ Default: OS default

Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.


##### `tls`

```
Expand All @@ -327,6 +314,7 @@ Default: false

`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).


##### `writeTimeout`

```
Expand Down
6 changes: 1 addition & 5 deletions benchmark_test.go
Expand Up @@ -48,11 +48,7 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
if w, ok := err.(MySQLWarnings); ok {
b.Logf("warning on %q: %v", query, w)
} else {
b.Fatalf("error on %q: %v", query, err)
}
b.Fatalf("error on %q: %v", query, err)
}
}
return db
Expand Down
1 change: 0 additions & 1 deletion connection.go
Expand Up @@ -40,7 +40,6 @@ type mysqlConn struct {
status statusFlag
sequence uint8
parseTime bool
strict bool

// for context support (Go 1.8+)
watching bool
Expand Down
1 change: 0 additions & 1 deletion driver.go
Expand Up @@ -64,7 +64,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
mc.strict = mc.cfg.Strict

// Connect to Server
if dial, ok := dials[mc.cfg.Net]; ok {
Expand Down
82 changes: 3 additions & 79 deletions driver_test.go
Expand Up @@ -63,7 +63,7 @@ func init() {
addr = env("MYSQL_TEST_ADDR", "localhost:3306")
dbname = env("MYSQL_TEST_DBNAME", "gotest")
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname)
c, err := net.Dial(prot, addr)
if err == nil {
available = true
Expand Down Expand Up @@ -1170,82 +1170,6 @@ func TestFoundRows(t *testing.T) {
})
}

func TestStrict(t *testing.T) {
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
// make sure the MySQL version is recent enough with a separate connection
// before running the test
conn, err := MySQLDriver{}.Open(relaxedDsn)
if conn != nil {
conn.Close()
}
// Error 1231: Variable 'sql_mode' can't be set to the value of
// 'ALLOW_INVALID_DATES' => skip test, MySQL server version is too old
maybeSkip(t, err, 1231)
runTests(t, relaxedDsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")

var queries = [...]struct {
in string
codes []string
}{
{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
}
var err error

var checkWarnings = func(err error, mode string, idx int) {
if err == nil {
dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in)
}

if warnings, ok := err.(MySQLWarnings); ok {
var codes = make([]string, len(warnings))
for i := range warnings {
codes[i] = warnings[i].Code
}
if len(codes) != len(queries[idx].codes) {
dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
}

for i := range warnings {
if codes[i] != queries[idx].codes[i] {
dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
return
}
}

} else {
dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
}
}

// text protocol
for i := range queries {
_, err = dbt.db.Exec(queries[i].in)
checkWarnings(err, "text", i)
}

var stmt *sql.Stmt

// binary protocol
for i := range queries {
stmt, err = dbt.db.Prepare(queries[i].in)
if err != nil {
dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error())
}

_, err = stmt.Exec()
checkWarnings(err, "binary", i)

err = stmt.Close()
if err != nil {
dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error())
}
}
})
}

func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
Expand Down Expand Up @@ -1762,7 +1686,7 @@ func TestCustomDial(t *testing.T) {
return net.Dial(prot, addr)
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -1859,7 +1783,7 @@ func TestUnixSocketAuthFail(t *testing.T) {
}
}
t.Logf("socket: %s", socket)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
db, err := sql.Open("mysql", badDSN)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
Expand Down
16 changes: 1 addition & 15 deletions dsn.go
Expand Up @@ -55,7 +55,6 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
Strict bool // Return warnings as errors
}

// FormatDSN formats the given Config into a DSN string which can be passed to
Expand Down Expand Up @@ -206,15 +205,6 @@ func (cfg *Config) FormatDSN() string {
}
}

if cfg.Strict {
if hasParam {
buf.WriteString("&strict=true")
} else {
hasParam = true
buf.WriteString("?strict=true")
}
}

if cfg.Timeout > 0 {
if hasParam {
buf.WriteString("&timeout=")
Expand Down Expand Up @@ -502,11 +492,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {

// Strict mode
case "strict":
var isBool bool
cfg.Strict, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")

// Dial Timeout
case "timeout":
Expand Down
73 changes: 0 additions & 73 deletions errors.go
Expand Up @@ -9,10 +9,8 @@
package mysql

import (
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
"os"
)
Expand Down Expand Up @@ -65,74 +63,3 @@ type MySQLError struct {
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}

// MySQLWarnings is an error type which represents a group of one or more MySQL
// warnings
type MySQLWarnings []MySQLWarning

func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf(
"%s %s: %s",
warning.Level,
warning.Code,
warning.Message,
)
}
return msg
}

// MySQLWarning is an error type which represents a single MySQL warning.
// Warnings are returned in groups only. See MySQLWarnings
type MySQLWarning struct {
Level string
Code string
Message string
}

func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", nil)
if err != nil {
return
}

var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)

for {
err = rows.Next(values)
switch err {
case nil:
warning := MySQLWarning{}

if raw, ok := values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok := values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok := values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}

warnings = append(warnings, warning)

case io.EOF:
return warnings

default:
rows.Close()
return
}
}
}
14 changes: 0 additions & 14 deletions packets.go
Expand Up @@ -624,14 +624,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
}

// warning count [2 bytes]
if !mc.strict {
return nil
}

pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
return mc.getWarnings()
}
return nil
}

Expand Down Expand Up @@ -843,14 +836,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
// Reserved [8 bit]

// Warning count [16 bit uint]
if !stmt.mc.strict {
return columnCount, nil
}

// Check for warnings count > 0, only available in MySQL > 4.1
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
return columnCount, stmt.mc.getWarnings()
}
return columnCount, nil
}
return 0, err
Expand Down

0 comments on commit 20c871b

Please sign in to comment.