Skip to content

Commit

Permalink
add all known capability flags and fix readInitialHandshake (mysql8 c…
Browse files Browse the repository at this point in the history
…ompatibility)
  • Loading branch information
atercattus committed Apr 5, 2023
1 parent e320144 commit 136935a
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 27 deletions.
64 changes: 42 additions & 22 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func authPluginAllowed(pluginName string) bool {
return false
}

// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
// See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
func (c *Conn) readInitialHandshake() error {
data, err := c.ReadPacket()
if err != nil {
Expand All @@ -40,24 +40,28 @@ func (c *Conn) readInitialHandshake() error {
if data[0] < MinProtocolVersion {
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
}
pos := 1

// skip mysql version
// mysql version end with 0x00
version := data[1 : bytes.IndexByte(data[1:], 0x00)+1]
version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1]
c.serverVersion = string(version)
pos := 1 + len(version)
pos += len(version) + 1 /*trailing zero byte*/

// connection id length is 4
c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4

c.salt = []byte{}
c.salt = append(c.salt, data[pos:pos+8]...)
// first 8 bytes of the plugin provided data (scramble)
c.salt = append(c.salt[:0], data[pos:pos+8]...)
pos += 8

// skip filter
pos += 8 + 1
if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
}
pos++

// capability lower 2 bytes
// The lower 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
// check protocol
if c.capability&CLIENT_PROTOCOL_41 == 0 {
Expand All @@ -69,35 +73,51 @@ func (c *Conn) readInitialHandshake() error {
pos += 2

if len(data) > pos {
// skip server charset
// default server a_protocol_character_set, only the lower 8-bits
// c.charset = data[pos]
pos += 1

c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
// capability flags (upper 2 bytes)

// The upper 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
pos += 2

// auth_data is end with 0x00, min data length is 13 + 8 = 21
// ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
maxAuthDataLen := 21
if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
maxAuthDataLen = int(data[pos])
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
authPluginDataLen := data[pos]
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
}
pos++

// skip reserved (all [00])
pos += 10 + 1
pos += 6

// auth_data is end with 0x00, so we need to trim 0x00
resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
// https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift
if c.capability&CLIENT_LONG_PASSWORD != 0 {
// skip reserved (all [00])
pos += 4
} else {
// unknown
pos += 4
}

if rest := int(authPluginDataLen) - 8; rest > 0 {
authPluginDataPart2 := data[pos : pos+rest]
pos += rest

// skip reset of end pos
pos = resetOfAuthDataEndPos + 1
c.salt = append(c.salt, authPluginDataPart2...)
}

if c.capability&CLIENT_PLUGIN_AUTH != 0 {
c.authPluginName = string(data[pos : len(data)-1])
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(c.authPluginName)

if data[pos] != 0 {
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
}
// pos++ // ineffectual
}
}

Expand Down
139 changes: 135 additions & 4 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"strings"
"time"

"github.com/pingcap/errors"

. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/go-mysql-org/go-mysql/utils"
"github.com/pingcap/errors"
)

type Conn struct {
Expand Down Expand Up @@ -118,18 +119,18 @@ func (c *Conn) handshake() error {
var err error
if err = c.readInitialHandshake(); err != nil {
c.Close()
return errors.Trace(err)
return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err))
}

if err := c.writeAuthHandshake(); err != nil {
c.Close()

return errors.Trace(err)
return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err))
}

if err := c.handleAuthResult(); err != nil {
c.Close()
return errors.Trace(err)
return errors.Trace(fmt.Errorf("handleAuthResult: %w", err))
}

return nil
Expand Down Expand Up @@ -198,6 +199,10 @@ func (c *Conn) GetServerVersion() string {
return c.serverVersion
}

func (c *Conn) CompareServerVersion(v string) (int, error) {
return CompareServerVersions(c.serverVersion, v)
}

func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
if len(args) == 0 {
return c.exec(command)
Expand Down Expand Up @@ -403,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) {

return c.readResult(false)
}

func (c *Conn) CapabilityString() string {
var caps []string
capability := c.capability
for i := 0; capability != 0; i++ {
field := uint32(1 << i)
if capability&field == 0 {
continue
}
capability ^= field

switch field {
case CLIENT_LONG_PASSWORD:
caps = append(caps, "CLIENT_LONG_PASSWORD")
case CLIENT_FOUND_ROWS:
caps = append(caps, "CLIENT_FOUND_ROWS")
case CLIENT_LONG_FLAG:
caps = append(caps, "CLIENT_LONG_FLAG")
case CLIENT_CONNECT_WITH_DB:
caps = append(caps, "CLIENT_CONNECT_WITH_DB")
case CLIENT_NO_SCHEMA:
caps = append(caps, "CLIENT_NO_SCHEMA")
case CLIENT_COMPRESS:
caps = append(caps, "CLIENT_COMPRESS")
case CLIENT_ODBC:
caps = append(caps, "CLIENT_ODBC")
case CLIENT_LOCAL_FILES:
caps = append(caps, "CLIENT_LOCAL_FILES")
case CLIENT_IGNORE_SPACE:
caps = append(caps, "CLIENT_IGNORE_SPACE")
case CLIENT_PROTOCOL_41:
caps = append(caps, "CLIENT_PROTOCOL_41")
case CLIENT_INTERACTIVE:
caps = append(caps, "CLIENT_INTERACTIVE")
case CLIENT_SSL:
caps = append(caps, "CLIENT_SSL")
case CLIENT_IGNORE_SIGPIPE:
caps = append(caps, "CLIENT_IGNORE_SIGPIPE")
case CLIENT_TRANSACTIONS:
caps = append(caps, "CLIENT_TRANSACTIONS")
case CLIENT_RESERVED:
caps = append(caps, "CLIENT_RESERVED")
case CLIENT_SECURE_CONNECTION:
caps = append(caps, "CLIENT_SECURE_CONNECTION")
case CLIENT_MULTI_STATEMENTS:
caps = append(caps, "CLIENT_MULTI_STATEMENTS")
case CLIENT_MULTI_RESULTS:
caps = append(caps, "CLIENT_MULTI_RESULTS")
case CLIENT_PS_MULTI_RESULTS:
caps = append(caps, "CLIENT_PS_MULTI_RESULTS")
case CLIENT_PLUGIN_AUTH:
caps = append(caps, "CLIENT_PLUGIN_AUTH")
case CLIENT_CONNECT_ATTRS:
caps = append(caps, "CLIENT_CONNECT_ATTRS")
case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA")
case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS:
caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS")
case CLIENT_SESSION_TRACK:
caps = append(caps, "CLIENT_SESSION_TRACK")
case CLIENT_DEPRECATE_EOF:
caps = append(caps, "CLIENT_DEPRECATE_EOF")
case CLIENT_OPTIONAL_RESULTSET_METADATA:
caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA")
case CLIENT_ZSTD_COMPRESSION_ALGORITHM:
caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM")
case CLIENT_QUERY_ATTRIBUTES:
caps = append(caps, "CLIENT_QUERY_ATTRIBUTES")
case MULTI_FACTOR_AUTHENTICATION:
caps = append(caps, "MULTI_FACTOR_AUTHENTICATION")
case CLIENT_CAPABILITY_EXTENSION:
caps = append(caps, "CLIENT_CAPABILITY_EXTENSION")
case CLIENT_SSL_VERIFY_SERVER_CERT:
caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT")
case CLIENT_REMEMBER_OPTIONS:
caps = append(caps, "CLIENT_REMEMBER_OPTIONS")
default:
caps = append(caps, fmt.Sprintf("(%d)", field))
}
}

return strings.Join(caps, "|")
}

func (c *Conn) StatusString() string {
var stats []string
status := c.status
for i := 0; status != 0; i++ {
field := uint16(1 << i)
if status&field == 0 {
continue
}
status ^= field

switch field {
case SERVER_STATUS_IN_TRANS:
stats = append(stats, "SERVER_STATUS_IN_TRANS")
case SERVER_STATUS_AUTOCOMMIT:
stats = append(stats, "SERVER_STATUS_AUTOCOMMIT")
case SERVER_MORE_RESULTS_EXISTS:
stats = append(stats, "SERVER_MORE_RESULTS_EXISTS")
case SERVER_STATUS_NO_GOOD_INDEX_USED:
stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED")
case SERVER_STATUS_NO_INDEX_USED:
stats = append(stats, "SERVER_STATUS_NO_INDEX_USED")
case SERVER_STATUS_CURSOR_EXISTS:
stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS")
case SERVER_STATUS_LAST_ROW_SEND:
stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND")
case SERVER_STATUS_DB_DROPPED:
stats = append(stats, "SERVER_STATUS_DB_DROPPED")
case SERVER_STATUS_NO_BACKSLASH_ESCAPED:
stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED")
case SERVER_STATUS_METADATA_CHANGED:
stats = append(stats, "SERVER_STATUS_METADATA_CHANGED")
case SERVER_QUERY_WAS_SLOW:
stats = append(stats, "SERVER_QUERY_WAS_SLOW")
case SERVER_PS_OUT_PARAMS:
stats = append(stats, "SERVER_PS_OUT_PARAMS")
default:
stats = append(stats, fmt.Sprintf("(%d)", field))
}
}

return strings.Join(stats, "|")
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.16
require (
github.com/BurntSushi/toml v0.3.1
github.com/DataDog/zstd v1.5.2
github.com/Masterminds/semver v1.5.0
github.com/go-sql-driver/mysql v1.6.0
github.com/google/uuid v1.3.0
github.com/jmoiron/sqlx v1.3.3
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8=
github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw=
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM=
Expand Down
14 changes: 13 additions & 1 deletion mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ const (
)

const (
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html

CLIENT_LONG_PASSWORD uint32 = 1 << iota
CLIENT_FOUND_ROWS
CLIENT_LONG_FLAG
Expand All @@ -98,6 +100,16 @@ const (
CLIENT_PLUGIN_AUTH
CLIENT_CONNECT_ATTRS
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS
CLIENT_SESSION_TRACK
CLIENT_DEPRECATE_EOF
CLIENT_OPTIONAL_RESULTSET_METADATA
CLIENT_ZSTD_COMPRESSION_ALGORITHM
CLIENT_QUERY_ATTRIBUTES
MULTI_FACTOR_AUTHENTICATION
CLIENT_CAPABILITY_EXTENSION
CLIENT_SSL_VERIFY_SERVER_CERT
CLIENT_REMEMBER_OPTIONS
)

const (
Expand All @@ -119,7 +131,7 @@ const (
MYSQL_TYPE_VARCHAR
MYSQL_TYPE_BIT

//mysql 5.6
// mysql 5.6
MYSQL_TYPE_TIMESTAMP2
MYSQL_TYPE_DATETIME2
MYSQL_TYPE_TIME2
Expand Down
18 changes: 18 additions & 0 deletions mysql/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

"github.com/Masterminds/semver"
"github.com/pingcap/errors"
"github.com/siddontang/go/hack"
)
Expand Down Expand Up @@ -379,6 +380,23 @@ func ErrorEqual(err1, err2 error) bool {
return e1.Error() == e2.Error()
}

func CompareServerVersions(a, b string) (int, error) {
var (
aVer, bVer *semver.Version
err error
)

if aVer, err = semver.NewVersion(a); err != nil {
return 0, fmt.Errorf("cannot parse %q as semver: %w", a, err)
}

if bVer, err = semver.NewVersion(b); err != nil {
return 0, fmt.Errorf("cannot parse %q as semver: %w", b, err)
}

return aVer.Compare(bVer), nil
}

var encodeRef = map[byte]byte{
'\x00': '0',
'\'': '\'',
Expand Down

0 comments on commit 136935a

Please sign in to comment.