diff --git a/client/conn.go b/client/conn.go index dd81e5ebd..d864fb970 100644 --- a/client/conn.go +++ b/client/conn.go @@ -268,10 +268,6 @@ func (c *Conn) SetTLSConfig(config *tls.Config) { } func (c *Conn) UseDB(dbName string) error { - if c.db == dbName { - return nil - } - if err := c.writeCommandStr(mysql.COM_INIT_DB, dbName); err != nil { return errors.Trace(err) } diff --git a/client/conn_test.go b/client/conn_test.go index c82010757..472558800 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -210,3 +210,30 @@ func (s *connTestSuite) TestSetQueryAttributes() { } require.Equal(s.T(), expected, s.c.queryAttributes) } + +func (s *connTestSuite) TestUseDB() { + _, err := s.c.Execute("create database if not exists proxier;") + require.NoError(s.T(), err) + err = s.c.UseDB("proxier") + require.NoError(s.T(), err) + result, err := s.c.Execute("select database();") + require.NoError(s.T(), err) + value, err := result.GetString(0, 0) + require.NoError(s.T(), err) + require.Equal(s.T(), "proxier", value) + _, err = s.c.Execute("drop database proxier;") + require.NoError(s.T(), err) + _, err = s.c.Execute("create database proxier;") + require.NoError(s.T(), err) + err = s.c.UseDB("proxier") + require.NoError(s.T(), err) + result, err = s.c.Execute("select database();") + require.NoError(s.T(), err) + value, err = result.GetString(0, 0) + require.NoError(s.T(), err) + require.Equal(s.T(), "proxier", value) + _, err = s.c.Execute("drop database proxier;") + require.NoError(s.T(), err) + err = s.c.UseDB("test") + require.NoError(s.T(), err) +}