diff --git a/client/auth.go b/client/auth.go index 2d3c15e27..f41f74f6a 100644 --- a/client/auth.go +++ b/client/auth.go @@ -14,6 +14,19 @@ import ( const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD +var optionalCapabilities = []uint32{ + mysql.CLIENT_FOUND_ROWS, + mysql.CLIENT_IGNORE_SPACE, + mysql.CLIENT_MULTI_STATEMENTS, + mysql.CLIENT_MULTI_RESULTS, + mysql.CLIENT_PS_MULTI_RESULTS, + mysql.CLIENT_CONNECT_ATTRS, + mysql.CLIENT_COMPRESS, + mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM, + mysql.CLIENT_LOCAL_FILES, + mysql.CLIENT_SESSION_TRACK, +} + // defines the supported auth plugins var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD, mysql.AUTH_MARIADB_ED25519} @@ -214,11 +227,9 @@ func (c *Conn) writeAuthHandshake() error { // Adjust client capability flags on specific client requests // Only flags that would make any sense setting and aren't handled elsewhere // in the library are supported here - capability |= c.ccaps&mysql.CLIENT_FOUND_ROWS | c.ccaps&mysql.CLIENT_IGNORE_SPACE | - c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS | - c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS | - c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM | - c.ccaps&mysql.CLIENT_LOCAL_FILES | c.ccaps&mysql.CLIENT_SESSION_TRACK + for _, optionalCap := range optionalCapabilities { + capability |= c.ccaps & optionalCap + } capability &^= c.clientExplicitOffCaps diff --git a/client/client_test.go b/client/client_test.go index 05c8089df..5b4e3b3c2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -92,8 +92,7 @@ func (s *clientTestSuite) TestConn_Ping() { func (s *clientTestSuite) TestConn_Compress() { addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { - conn.SetCapability(mysql.CLIENT_COMPRESS) - return nil + return conn.SetCapability(mysql.CLIENT_COMPRESS) }) require.NoError(s.T(), err) @@ -115,38 +114,28 @@ func (s *clientTestSuite) TestConn_NoDeprecateEOF() { func (s *clientTestSuite) TestConn_SetCapability() { caps := []uint32{ - mysql.CLIENT_LONG_PASSWORD, mysql.CLIENT_FOUND_ROWS, - mysql.CLIENT_LONG_FLAG, - mysql.CLIENT_CONNECT_WITH_DB, - mysql.CLIENT_NO_SCHEMA, - mysql.CLIENT_COMPRESS, - mysql.CLIENT_ODBC, - mysql.CLIENT_LOCAL_FILES, mysql.CLIENT_IGNORE_SPACE, - mysql.CLIENT_PROTOCOL_41, - mysql.CLIENT_INTERACTIVE, - mysql.CLIENT_SSL, - mysql.CLIENT_IGNORE_SIGPIPE, - mysql.CLIENT_TRANSACTIONS, - mysql.CLIENT_RESERVED, - mysql.CLIENT_SECURE_CONNECTION, mysql.CLIENT_MULTI_STATEMENTS, mysql.CLIENT_MULTI_RESULTS, mysql.CLIENT_PS_MULTI_RESULTS, - mysql.CLIENT_PLUGIN_AUTH, mysql.CLIENT_CONNECT_ATTRS, - mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, - mysql.CLIENT_DEPRECATE_EOF, + mysql.CLIENT_COMPRESS, + mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM, + mysql.CLIENT_LOCAL_FILES, } for _, capI := range caps { require.False(s.T(), s.c.ccaps&capI > 0) - s.c.SetCapability(capI) + err := s.c.SetCapability(capI) + require.NoError(s.T(), err, "capability: %d", capI) require.True(s.T(), s.c.ccaps&capI > 0) s.c.UnsetCapability(capI) require.False(s.T(), s.c.ccaps&capI > 0) } + + err := s.c.SetCapability(mysql.CLIENT_REMEMBER_OPTIONS + 10) + require.Error(s.T(), err) } // NOTE for MySQL 5.5 and 5.6, server side has to config SSL to pass the TLS test, otherwise, it will throw error that diff --git a/client/conn.go b/client/conn.go index f32fb9dda..43da7d63e 100644 --- a/client/conn.go +++ b/client/conn.go @@ -9,6 +9,7 @@ import ( "net" "runtime" "runtime/debug" + "slices" "strings" "time" @@ -238,9 +239,13 @@ func (c *Conn) Ping() error { } // SetCapability marks the specified flag as explicitly enabled by the client. -func (c *Conn) SetCapability(cap uint32) { +func (c *Conn) SetCapability(cap uint32) error { + if !slices.Contains(optionalCapabilities, cap) { + return errors.New("unsupported or unknown capability") + } c.ccaps |= cap c.clientExplicitOffCaps &^= cap + return nil } // UnsetCapability marks the specified flag as explicitly disabled by the client. diff --git a/client/conn_test.go b/client/conn_test.go index 472558800..2f7bed9a3 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -30,9 +30,9 @@ func (s *connTestSuite) SetupSuite() { addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error { // required for the ExecuteMultiple test - c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS) + err = c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS) c.SetAttributes(map[string]string{"attrtest": "attrvalue"}) - return nil + return err }) require.NoError(s.T(), err) diff --git a/driver/driver_options.go b/driver/driver_options.go index 1a3ccc9b6..a9dc6fb3e 100644 --- a/driver/driver_options.go +++ b/driver/driver_options.go @@ -38,9 +38,9 @@ func WriteTimeoutOption(c *client.Conn, value string) error { func CompressOption(c *client.Conn, value string) error { switch value { case "zlib": - c.SetCapability(mysql.CLIENT_COMPRESS) + _ = c.SetCapability(mysql.CLIENT_COMPRESS) case "zstd": - c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM) + _ = c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM) case "uncompressed": c.UnsetCapability(mysql.CLIENT_COMPRESS) c.UnsetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)