diff --git a/AUTHORS b/AUTHORS index 50afa2c85..94a92d5e9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -50,6 +50,7 @@ Jeffrey Charles Jerome Meyer Jiajia Zhong Jian Zhen +Joseph Boudou Joshua Prunier Julien Lefevre Julien Schmidt diff --git a/README.md b/README.md index 0b13154fc..ca638593e 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,14 @@ Default: false ``` `allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). +##### `autoReprepare` + +``` +Type: decimal number +Default: 0 +``` +When `autoReprepare` is greater than zero, the driver will re-prepare statements when error 1615 is received from the database. Some known bugs of MySQL and MariaDB spuriously invalidate prepared statements, resulting in this error being sent. This parameter is meant to workaround these bugs. More precisely, the value of `autoReprepare` indicates how many successive errors 1615 are handled before the execution of the statement fails; hence, it should not be greater than one. + ##### `charset` ``` diff --git a/connection.go b/connection.go index 835f89729..de2bf8793 100644 --- a/connection.go +++ b/connection.go @@ -177,6 +177,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { stmt := &mysqlStmt{ mc: mc, } + if stmt.mc.cfg.AutoReprepare > 0 { + stmt.queryStr = query + } // Read Result columnCount, err := stmt.readPrepareResultPacket() diff --git a/dsn.go b/dsn.go index 93f3548cb..4b46ae45f 100644 --- a/dsn.go +++ b/dsn.go @@ -55,6 +55,7 @@ type Config struct { AllowCleartextPasswords bool // Allows the cleartext client side plugin AllowNativePasswords bool // Allows the native password authentication method AllowOldPasswords bool // Allows the old insecure password method + AutoReprepare int // Automatically reprepare statements when receiving error 1615 CheckConnLiveness bool // Check connections for liveness before using them ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names @@ -212,6 +213,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true") } + if cfg.AutoReprepare > 0 { + writeDSNParam(&buf, &hasParam, "autoReprepare", strconv.Itoa(cfg.AutoReprepare)) + } + if !cfg.CheckConnLiveness { writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false") } @@ -407,6 +412,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Reprepare statement on error 1615 + case "autoReprepare": + cfg.AutoReprepare, err = strconv.Atoi(value) + if err != nil { + return err + } + // Check connections for Liveness before using them case "checkConnLiveness": var isBool bool diff --git a/dsn_test.go b/dsn_test.go index 89815b341..de194be8a 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -44,6 +44,9 @@ var testDSNs = []struct { }, { "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, +}, { + "username:password@protocol(address)/dbname?autoReprepare=2", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, AutoReprepare: 2, CheckConnLiveness: true}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, diff --git a/statement.go b/statement.go index 18a3ae498..1a42dd90d 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,8 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + queryStr string + reprepared int // How many times the statement has been reprepared since last execution. } func (stmt *mysqlStmt) Close() error { @@ -68,8 +70,16 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() if err != nil { + if mysqlErr, ok := err.(*MySQLError); ok && stmt.reprepared < stmt.mc.cfg.AutoReprepare && + mysqlErr.Number == 1615 { + err = stmt.reprepare() + if err == nil { + return stmt.Exec(args) + } + } return nil, err } + stmt.reprepared = 0 if resLen > 0 { // Columns @@ -113,8 +123,16 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() if err != nil { + if mysqlErr, ok := err.(*MySQLError); ok && stmt.reprepared < stmt.mc.cfg.AutoReprepare && + mysqlErr.Number == 1615 { + err = stmt.reprepare() + if err == nil { + return stmt.query(args) + } + } return nil, err } + stmt.reprepared = 0 rows := new(binaryRows) @@ -135,6 +153,40 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +func (stmt *mysqlStmt) reprepare() error { + stmt.reprepared += 1 + + // Close + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + if err != nil { + return err + } + stmt.id = 0 + stmt.paramCount = 0 + + // Send prepare + err = stmt.mc.writeCommandPacketStr(comStmtPrepare, stmt.queryStr) + if err != nil { + return err + } + + // Read Result + columnCount, err := stmt.readPrepareResultPacket() + if err == nil { + if stmt.paramCount > 0 { + if err = stmt.mc.readUntilEOF(); err != nil { + return err + } + } + + if columnCount > 0 { + err = stmt.mc.readUntilEOF() + } + } + + return err +} + var jsonType = reflect.TypeOf(json.RawMessage{}) type converter struct{}