Skip to content

Commit d7426e2

Browse files
committed
client: Have SetCapability() return an error for unsupported caps
1 parent 0d3c2e3 commit d7426e2

File tree

5 files changed

+35
-30
lines changed

5 files changed

+35
-30
lines changed

client/auth.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ import (
1414

1515
const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD
1616

17+
var optionalCapabilities = []uint32{
18+
mysql.CLIENT_FOUND_ROWS,
19+
mysql.CLIENT_IGNORE_SPACE,
20+
mysql.CLIENT_MULTI_STATEMENTS,
21+
mysql.CLIENT_MULTI_RESULTS,
22+
mysql.CLIENT_PS_MULTI_RESULTS,
23+
mysql.CLIENT_CONNECT_ATTRS,
24+
mysql.CLIENT_COMPRESS,
25+
mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM,
26+
mysql.CLIENT_LOCAL_FILES,
27+
mysql.CLIENT_SESSION_TRACK,
28+
}
29+
1730
// defines the supported auth plugins
1831
var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD, mysql.AUTH_MARIADB_ED25519}
1932

@@ -214,11 +227,9 @@ func (c *Conn) writeAuthHandshake() error {
214227
// Adjust client capability flags on specific client requests
215228
// Only flags that would make any sense setting and aren't handled elsewhere
216229
// in the library are supported here
217-
capability |= c.ccaps&mysql.CLIENT_FOUND_ROWS | c.ccaps&mysql.CLIENT_IGNORE_SPACE |
218-
c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS |
219-
c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS |
220-
c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM |
221-
c.ccaps&mysql.CLIENT_LOCAL_FILES | c.ccaps&mysql.CLIENT_SESSION_TRACK
230+
for _, optionalCap := range optionalCapabilities {
231+
capability |= c.ccaps & optionalCap
232+
}
222233

223234
capability &^= c.clientExplicitOffCaps
224235

client/client_test.go

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ func (s *clientTestSuite) TestConn_Ping() {
9292
func (s *clientTestSuite) TestConn_Compress() {
9393
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
9494
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
95-
conn.SetCapability(mysql.CLIENT_COMPRESS)
96-
return nil
95+
return conn.SetCapability(mysql.CLIENT_COMPRESS)
9796
})
9897
require.NoError(s.T(), err)
9998

@@ -115,38 +114,28 @@ func (s *clientTestSuite) TestConn_NoDeprecateEOF() {
115114

116115
func (s *clientTestSuite) TestConn_SetCapability() {
117116
caps := []uint32{
118-
mysql.CLIENT_LONG_PASSWORD,
119117
mysql.CLIENT_FOUND_ROWS,
120-
mysql.CLIENT_LONG_FLAG,
121-
mysql.CLIENT_CONNECT_WITH_DB,
122-
mysql.CLIENT_NO_SCHEMA,
123-
mysql.CLIENT_COMPRESS,
124-
mysql.CLIENT_ODBC,
125-
mysql.CLIENT_LOCAL_FILES,
126118
mysql.CLIENT_IGNORE_SPACE,
127-
mysql.CLIENT_PROTOCOL_41,
128-
mysql.CLIENT_INTERACTIVE,
129-
mysql.CLIENT_SSL,
130-
mysql.CLIENT_IGNORE_SIGPIPE,
131-
mysql.CLIENT_TRANSACTIONS,
132-
mysql.CLIENT_RESERVED,
133-
mysql.CLIENT_SECURE_CONNECTION,
134119
mysql.CLIENT_MULTI_STATEMENTS,
135120
mysql.CLIENT_MULTI_RESULTS,
136121
mysql.CLIENT_PS_MULTI_RESULTS,
137-
mysql.CLIENT_PLUGIN_AUTH,
138122
mysql.CLIENT_CONNECT_ATTRS,
139-
mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
140-
mysql.CLIENT_DEPRECATE_EOF,
123+
mysql.CLIENT_COMPRESS,
124+
mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM,
125+
mysql.CLIENT_LOCAL_FILES,
141126
}
142127

143128
for _, capI := range caps {
144129
require.False(s.T(), s.c.ccaps&capI > 0)
145-
s.c.SetCapability(capI)
130+
err := s.c.SetCapability(capI)
131+
require.NoError(s.T(), err, "capability: %d", capI)
146132
require.True(s.T(), s.c.ccaps&capI > 0)
147133
s.c.UnsetCapability(capI)
148134
require.False(s.T(), s.c.ccaps&capI > 0)
149135
}
136+
137+
err := s.c.SetCapability(mysql.CLIENT_REMEMBER_OPTIONS + 10)
138+
require.Error(s.T(), err)
150139
}
151140

152141
// NOTE for MySQL 5.5 and 5.6, server side has to config SSL to pass the TLS test, otherwise, it will throw error that

client/conn.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net"
1010
"runtime"
1111
"runtime/debug"
12+
"slices"
1213
"strings"
1314
"time"
1415

@@ -238,9 +239,13 @@ func (c *Conn) Ping() error {
238239
}
239240

240241
// SetCapability marks the specified flag as explicitly enabled by the client.
241-
func (c *Conn) SetCapability(cap uint32) {
242+
func (c *Conn) SetCapability(cap uint32) error {
243+
if !slices.Contains(optionalCapabilities, cap) {
244+
return errors.New("unsupported or unknown capability")
245+
}
242246
c.ccaps |= cap
243247
c.clientExplicitOffCaps &^= cap
248+
return nil
244249
}
245250

246251
// UnsetCapability marks the specified flag as explicitly disabled by the client.

client/conn_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ func (s *connTestSuite) SetupSuite() {
3030
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
3131
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error {
3232
// required for the ExecuteMultiple test
33-
c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
33+
err = c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
3434
c.SetAttributes(map[string]string{"attrtest": "attrvalue"})
35-
return nil
35+
return err
3636
})
3737
require.NoError(s.T(), err)
3838

driver/driver_options.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ func WriteTimeoutOption(c *client.Conn, value string) error {
3838
func CompressOption(c *client.Conn, value string) error {
3939
switch value {
4040
case "zlib":
41-
c.SetCapability(mysql.CLIENT_COMPRESS)
41+
_ = c.SetCapability(mysql.CLIENT_COMPRESS)
4242
case "zstd":
43-
c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)
43+
_ = c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)
4444
case "uncompressed":
4545
c.UnsetCapability(mysql.CLIENT_COMPRESS)
4646
c.UnsetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)

0 commit comments

Comments
 (0)