Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mysql8 auth compatibility #781

Merged
merged 11 commits into from
Apr 12, 2023
72 changes: 49 additions & 23 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ func authPluginAllowed(pluginName string) bool {
return false
}

// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
// See:
// - https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
// - https://github.com/alibaba/canal/blob/0ec46991499a22870dde4ae736b2586cbcbfea94/driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/packets/server/HandshakeInitializationPacket.java#L89
// - https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift
// - https://github.com/github/vitess-gh/blob/70ae1a2b3a116ff6411b0f40852d6e71382f6e07/go/mysql/client.go
func (c *Conn) readInitialHandshake() error {
data, err := c.ReadPacket()
if err != nil {
Expand All @@ -40,24 +44,28 @@ func (c *Conn) readInitialHandshake() error {
if data[0] < MinProtocolVersion {
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
}
pos := 1

// skip mysql version
// mysql version end with 0x00
version := data[1 : bytes.IndexByte(data[1:], 0x00)+1]
version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1]
c.serverVersion = string(version)
pos := 1 + len(version)
pos += len(version) + 1 /*trailing zero byte*/

// connection id length is 4
c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4

c.salt = []byte{}
c.salt = append(c.salt, data[pos:pos+8]...)
// first 8 bytes of the plugin provided data (scramble)
c.salt = append(c.salt[:0], data[pos:pos+8]...)
pos += 8

// skip filter
pos += 8 + 1
if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
}
pos++

// capability lower 2 bytes
// The lower 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
// check protocol
if c.capability&CLIENT_PROTOCOL_41 == 0 {
Expand All @@ -69,35 +77,53 @@ func (c *Conn) readInitialHandshake() error {
pos += 2

if len(data) > pos {
// skip server charset
// default server a_protocol_character_set, only the lower 8-bits
// c.charset = data[pos]
pos += 1

c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
// capability flags (upper 2 bytes)

// The upper 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
pos += 2

// auth_data is end with 0x00, min data length is 13 + 8 = 21
// ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
maxAuthDataLen := 21
if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
maxAuthDataLen = int(data[pos])
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
authPluginDataLen := data[pos]
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
}
pos++

// skip reserved (all [00])
pos += 10 + 1
// skip reserved (all [00] ?)
pos += 10

// auth_data is end with 0x00, so we need to trim 0x00
resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
if c.capability&CLIENT_SECURE_CONNECTION != 0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just remind that this is a deprecated flag https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html#ga8be684cc38eeca913698414efec06933 , we can skip this if-check or leave it because no test is broken.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it seems always set.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about merging this PR as is? I added an issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

// Rest of the plugin provided data (scramble)

// skip reset of end pos
pos = resetOfAuthDataEndPos + 1
// $len=MAX(13, length of auth-plugin-data - 8)
rest := int(authPluginDataLen) - 8
if max := 13; rest > max {
atercattus marked this conversation as resolved.
Show resolved Hide resolved
rest = max
}
if data[pos+rest-1] != 0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if data[pos+rest-1] != 0 {
if data[pos+rest] != 0 {

BTW, im not familiar with this part so double check that why there's NULL after scramble? In the MySQL protocol doc I think it's a fixed length string, rather than a NULL-terminated string.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the origin of it:

n - rest of the plugin provided data (at least 12 bytes)
1 - \0 byte, terminating the second part of a scramble

And sha256_password auth implementation follows this rule:

Native authentication sent 20 bytes + '\0' character = 21 bytes.
This plugin must do the same to stay consistent with historical behavior

Also, alibaba channel contains the same description:

Packet规定最后13个byte是剩下的scrumble, 但实际上最后一个字节是0, 不应该包含在scrumble中.

I used a translator for this.

I can remove this check by \0, but it looks like a historic standard convention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's about 20 bytes len I found this:

the first packet must have at least 20 bytes of a scramble.
if a plugin provided less, we pad it to 20 with zeros

Right now I don't see places where it would be more than 20 bytes, but let it be for the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this logic for mysql 8.0 and 5.7. Unfortunately, compatibility with 5.6- was broken since we added json fields in our tests in 2021.

return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
}

authPluginDataPart2 := data[pos : pos+rest-1]
pos += rest

c.salt = append(c.salt, authPluginDataPart2...)
}

if c.capability&CLIENT_PLUGIN_AUTH != 0 {
c.authPluginName = string(data[pos : len(data)-1])
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(c.authPluginName)

if data[pos] != 0 {
return errors.Errorf("expect 0x00 after authPluginName, got %q", rune(data[pos]))
}
// pos++ // ineffectual
}
}

Expand Down
132 changes: 129 additions & 3 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ func (c *Conn) handshake() error {
var err error
if err = c.readInitialHandshake(); err != nil {
c.Close()
return errors.Trace(err)
return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err))
}

if err := c.writeAuthHandshake(); err != nil {
c.Close()

return errors.Trace(err)
return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err))
}

if err := c.handleAuthResult(); err != nil {
c.Close()
return errors.Trace(err)
return errors.Trace(fmt.Errorf("handleAuthResult: %w", err))
}

return nil
Expand Down Expand Up @@ -408,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) {

return c.readResult(false)
}

func (c *Conn) CapabilityString() string {
var caps []string
capability := c.capability
for i := 0; capability != 0; i++ {
field := uint32(1 << i)
if capability&field == 0 {
lance6716 marked this conversation as resolved.
Show resolved Hide resolved
continue
}
capability ^= field

switch field {
case CLIENT_LONG_PASSWORD:
caps = append(caps, "CLIENT_LONG_PASSWORD")
case CLIENT_FOUND_ROWS:
caps = append(caps, "CLIENT_FOUND_ROWS")
case CLIENT_LONG_FLAG:
caps = append(caps, "CLIENT_LONG_FLAG")
case CLIENT_CONNECT_WITH_DB:
caps = append(caps, "CLIENT_CONNECT_WITH_DB")
case CLIENT_NO_SCHEMA:
caps = append(caps, "CLIENT_NO_SCHEMA")
case CLIENT_COMPRESS:
caps = append(caps, "CLIENT_COMPRESS")
case CLIENT_ODBC:
caps = append(caps, "CLIENT_ODBC")
case CLIENT_LOCAL_FILES:
caps = append(caps, "CLIENT_LOCAL_FILES")
case CLIENT_IGNORE_SPACE:
caps = append(caps, "CLIENT_IGNORE_SPACE")
case CLIENT_PROTOCOL_41:
caps = append(caps, "CLIENT_PROTOCOL_41")
case CLIENT_INTERACTIVE:
caps = append(caps, "CLIENT_INTERACTIVE")
case CLIENT_SSL:
caps = append(caps, "CLIENT_SSL")
case CLIENT_IGNORE_SIGPIPE:
caps = append(caps, "CLIENT_IGNORE_SIGPIPE")
case CLIENT_TRANSACTIONS:
caps = append(caps, "CLIENT_TRANSACTIONS")
case CLIENT_RESERVED:
caps = append(caps, "CLIENT_RESERVED")
case CLIENT_SECURE_CONNECTION:
caps = append(caps, "CLIENT_SECURE_CONNECTION")
case CLIENT_MULTI_STATEMENTS:
caps = append(caps, "CLIENT_MULTI_STATEMENTS")
case CLIENT_MULTI_RESULTS:
caps = append(caps, "CLIENT_MULTI_RESULTS")
case CLIENT_PS_MULTI_RESULTS:
caps = append(caps, "CLIENT_PS_MULTI_RESULTS")
case CLIENT_PLUGIN_AUTH:
caps = append(caps, "CLIENT_PLUGIN_AUTH")
case CLIENT_CONNECT_ATTRS:
caps = append(caps, "CLIENT_CONNECT_ATTRS")
case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA")
case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS:
caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS")
case CLIENT_SESSION_TRACK:
caps = append(caps, "CLIENT_SESSION_TRACK")
case CLIENT_DEPRECATE_EOF:
caps = append(caps, "CLIENT_DEPRECATE_EOF")
case CLIENT_OPTIONAL_RESULTSET_METADATA:
caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA")
case CLIENT_ZSTD_COMPRESSION_ALGORITHM:
caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM")
case CLIENT_QUERY_ATTRIBUTES:
caps = append(caps, "CLIENT_QUERY_ATTRIBUTES")
case MULTI_FACTOR_AUTHENTICATION:
caps = append(caps, "MULTI_FACTOR_AUTHENTICATION")
case CLIENT_CAPABILITY_EXTENSION:
caps = append(caps, "CLIENT_CAPABILITY_EXTENSION")
case CLIENT_SSL_VERIFY_SERVER_CERT:
caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT")
case CLIENT_REMEMBER_OPTIONS:
caps = append(caps, "CLIENT_REMEMBER_OPTIONS")
default:
caps = append(caps, fmt.Sprintf("(%d)", field))
}
}

return strings.Join(caps, "|")
}

func (c *Conn) StatusString() string {
var stats []string
status := c.status
for i := 0; status != 0; i++ {
field := uint16(1 << i)
if status&field == 0 {
continue
}
status ^= field

switch field {
case SERVER_STATUS_IN_TRANS:
stats = append(stats, "SERVER_STATUS_IN_TRANS")
case SERVER_STATUS_AUTOCOMMIT:
stats = append(stats, "SERVER_STATUS_AUTOCOMMIT")
case SERVER_MORE_RESULTS_EXISTS:
stats = append(stats, "SERVER_MORE_RESULTS_EXISTS")
case SERVER_STATUS_NO_GOOD_INDEX_USED:
stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED")
case SERVER_STATUS_NO_INDEX_USED:
stats = append(stats, "SERVER_STATUS_NO_INDEX_USED")
case SERVER_STATUS_CURSOR_EXISTS:
stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS")
case SERVER_STATUS_LAST_ROW_SEND:
stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND")
case SERVER_STATUS_DB_DROPPED:
stats = append(stats, "SERVER_STATUS_DB_DROPPED")
case SERVER_STATUS_NO_BACKSLASH_ESCAPED:
stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED")
case SERVER_STATUS_METADATA_CHANGED:
stats = append(stats, "SERVER_STATUS_METADATA_CHANGED")
case SERVER_QUERY_WAS_SLOW:
stats = append(stats, "SERVER_QUERY_WAS_SLOW")
case SERVER_PS_OUT_PARAMS:
stats = append(stats, "SERVER_PS_OUT_PARAMS")
default:
stats = append(stats, fmt.Sprintf("(%d)", field))
}
}

return strings.Join(stats, "|")
}
14 changes: 13 additions & 1 deletion mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ const (
)

const (
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html

CLIENT_LONG_PASSWORD uint32 = 1 << iota
CLIENT_FOUND_ROWS
CLIENT_LONG_FLAG
Expand All @@ -98,6 +100,16 @@ const (
CLIENT_PLUGIN_AUTH
CLIENT_CONNECT_ATTRS
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS
CLIENT_SESSION_TRACK
CLIENT_DEPRECATE_EOF
CLIENT_OPTIONAL_RESULTSET_METADATA
CLIENT_ZSTD_COMPRESSION_ALGORITHM
CLIENT_QUERY_ATTRIBUTES
MULTI_FACTOR_AUTHENTICATION
CLIENT_CAPABILITY_EXTENSION
CLIENT_SSL_VERIFY_SERVER_CERT
CLIENT_REMEMBER_OPTIONS
)

const (
Expand All @@ -119,7 +131,7 @@ const (
MYSQL_TYPE_VARCHAR
MYSQL_TYPE_BIT

//mysql 5.6
// mysql 5.6
MYSQL_TYPE_TIMESTAMP2
MYSQL_TYPE_DATETIME2
MYSQL_TYPE_TIME2
Expand Down