diff --git a/connection.go b/connection.go index 768794f8e..eff978d93 100644 --- a/connection.go +++ b/connection.go @@ -152,11 +152,11 @@ func (mc *mysqlConn) cleanup() { // Makes cleanup idempotent close(mc.closech) - nc := mc.netConn - if nc == nil { + conn := mc.rawConn + if conn == nil { return } - if err := nc.Close(); err != nil { + if err := conn.Close(); err != nil { mc.log(err) } // This function can be called from multiple goroutines. diff --git a/connector.go b/connector.go index a0ee62839..b67077596 100644 --- a/connector.go +++ b/connector.go @@ -102,10 +102,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { nd := net.Dialer{Timeout: mc.cfg.Timeout} mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) } - if err != nil { return nil, err } + mc.rawConn = mc.netConn // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { diff --git a/driver_test.go b/driver_test.go index 6b52650c2..4fd196d4b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -20,6 +20,7 @@ import ( "io" "log" "math" + mrand "math/rand" "net" "net/url" "os" @@ -3577,3 +3578,35 @@ func runCallCommand(dbt *DBTest, query, name string) { } } } + +func TestIssue1567(t *testing.T) { + // enable TLS. + runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) { + // disable connection pooling. + // data race happens when new connection is created. + dbt.db.SetMaxIdleConns(0) + + // estimate round trip time. + start := time.Now() + if err := dbt.db.PingContext(context.Background()); err != nil { + t.Fatal(err) + } + rtt := time.Since(start) + if rtt <= 0 { + // In some environments, rtt may become 0, so set it to at least 1ms. + rtt = time.Millisecond + } + + count := 1000 + if testing.Short() { + count = 10 + } + + for i := 0; i < count; i++ { + timeout := time.Duration(mrand.Int63n(int64(rtt))) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + dbt.db.PingContext(ctx) + cancel() + } + }) +} diff --git a/packets.go b/packets.go index d727f00fe..90a34728b 100644 --- a/packets.go +++ b/packets.go @@ -351,7 +351,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if err := tlsConn.Handshake(); err != nil { return err } - mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn }