Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ import (

const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD

var optionalCapabilities = []uint32{
mysql.CLIENT_FOUND_ROWS,
mysql.CLIENT_IGNORE_SPACE,
mysql.CLIENT_MULTI_STATEMENTS,
mysql.CLIENT_MULTI_RESULTS,
mysql.CLIENT_PS_MULTI_RESULTS,
mysql.CLIENT_CONNECT_ATTRS,
mysql.CLIENT_COMPRESS,
mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM,
mysql.CLIENT_LOCAL_FILES,
mysql.CLIENT_SESSION_TRACK,
}

// defines the supported auth plugins
var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD, mysql.AUTH_MARIADB_ED25519}

Expand Down Expand Up @@ -214,11 +227,9 @@ func (c *Conn) writeAuthHandshake() error {
// Adjust client capability flags on specific client requests
// Only flags that would make any sense setting and aren't handled elsewhere
// in the library are supported here
capability |= c.ccaps&mysql.CLIENT_FOUND_ROWS | c.ccaps&mysql.CLIENT_IGNORE_SPACE |
c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS |
c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS |
c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM |
c.ccaps&mysql.CLIENT_LOCAL_FILES | c.ccaps&mysql.CLIENT_SESSION_TRACK
for _, optionalCap := range optionalCapabilities {
capability |= c.ccaps & optionalCap
}

capability &^= c.clientExplicitOffCaps

Expand Down
29 changes: 9 additions & 20 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ func (s *clientTestSuite) TestConn_Ping() {
func (s *clientTestSuite) TestConn_Compress() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
conn.SetCapability(mysql.CLIENT_COMPRESS)
return nil
return conn.SetCapability(mysql.CLIENT_COMPRESS)
})
require.NoError(s.T(), err)

Expand All @@ -115,38 +114,28 @@ func (s *clientTestSuite) TestConn_NoDeprecateEOF() {

func (s *clientTestSuite) TestConn_SetCapability() {
caps := []uint32{
mysql.CLIENT_LONG_PASSWORD,
mysql.CLIENT_FOUND_ROWS,
mysql.CLIENT_LONG_FLAG,
mysql.CLIENT_CONNECT_WITH_DB,
mysql.CLIENT_NO_SCHEMA,
mysql.CLIENT_COMPRESS,
mysql.CLIENT_ODBC,
mysql.CLIENT_LOCAL_FILES,
mysql.CLIENT_IGNORE_SPACE,
mysql.CLIENT_PROTOCOL_41,
mysql.CLIENT_INTERACTIVE,
mysql.CLIENT_SSL,
mysql.CLIENT_IGNORE_SIGPIPE,
mysql.CLIENT_TRANSACTIONS,
mysql.CLIENT_RESERVED,
mysql.CLIENT_SECURE_CONNECTION,
mysql.CLIENT_MULTI_STATEMENTS,
mysql.CLIENT_MULTI_RESULTS,
mysql.CLIENT_PS_MULTI_RESULTS,
mysql.CLIENT_PLUGIN_AUTH,
mysql.CLIENT_CONNECT_ATTRS,
mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
mysql.CLIENT_DEPRECATE_EOF,
mysql.CLIENT_COMPRESS,
mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM,
mysql.CLIENT_LOCAL_FILES,
}

for _, capI := range caps {
require.False(s.T(), s.c.ccaps&capI > 0)
s.c.SetCapability(capI)
err := s.c.SetCapability(capI)
require.NoError(s.T(), err, "capability: %d", capI)
require.True(s.T(), s.c.ccaps&capI > 0)
s.c.UnsetCapability(capI)
require.False(s.T(), s.c.ccaps&capI > 0)
}

err := s.c.SetCapability(mysql.CLIENT_REMEMBER_OPTIONS + 10)
require.Error(s.T(), err)
}

// NOTE for MySQL 5.5 and 5.6, server side has to config SSL to pass the TLS test, otherwise, it will throw error that
Expand Down
7 changes: 6 additions & 1 deletion client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net"
"runtime"
"runtime/debug"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -238,9 +239,13 @@ func (c *Conn) Ping() error {
}

// SetCapability marks the specified flag as explicitly enabled by the client.
func (c *Conn) SetCapability(cap uint32) {
func (c *Conn) SetCapability(cap uint32) error {
if !slices.Contains(optionalCapabilities, cap) {
return errors.New("unsupported or unknown capability")
}
c.ccaps |= cap
c.clientExplicitOffCaps &^= cap
return nil
}

// UnsetCapability marks the specified flag as explicitly disabled by the client.
Expand Down
4 changes: 2 additions & 2 deletions client/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ func (s *connTestSuite) SetupSuite() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error {
// required for the ExecuteMultiple test
c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
err = c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
c.SetAttributes(map[string]string{"attrtest": "attrvalue"})
return nil
return err
})
require.NoError(s.T(), err)

Expand Down
4 changes: 2 additions & 2 deletions driver/driver_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ func WriteTimeoutOption(c *client.Conn, value string) error {
func CompressOption(c *client.Conn, value string) error {
switch value {
case "zlib":
c.SetCapability(mysql.CLIENT_COMPRESS)
_ = c.SetCapability(mysql.CLIENT_COMPRESS)
case "zstd":
c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)
_ = c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)
case "uncompressed":
c.UnsetCapability(mysql.CLIENT_COMPRESS)
c.UnsetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)
Expand Down
Loading