Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support tlcp connection #3

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type mysqlConn struct {
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
extensionFlag clientExtensionFlag
status statusFlag
sequence uint8
parseTime bool
Expand Down
12 changes: 11 additions & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,17 @@ const (
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
clientTLCP
clientOptionalResultSetMetadata
clientZSTDCompressionAlgorithm
clientQueryAttributes
multiFactorAuthentication
clientCapabilityExtension // use CLIENT_CAPABILITY_EXTENSION support TLCP
)

type clientExtensionFlag uint8

const (
clientTLCP clientExtensionFlag = 1 << iota
)

const (
Expand Down
14 changes: 9 additions & 5 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"net"
"net/url"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -149,17 +150,16 @@ func (cfg *Config) normalize() error {
}

if cfg.TLCP == nil {
pool := loadTLCPCert(cfg.TLCPCaPath)
switch cfg.TLCPConfig {
case "false", "":
// don't set anything
case "true":
// todo: CA cert
pool := loadTLCPCert(cfg.TLCPCaPath)
cfg.TLCP = &tlcp.Config{InsecureSkipVerify: false, RootCAs: pool}
case "skip-verify":
cfg.TLCP = &tlcp.Config{InsecureSkipVerify: true}
cfg.TLCP = &tlcp.Config{InsecureSkipVerify: true, RootCAs: pool}
case "preferred":
cfg.TLCP = &tlcp.Config{InsecureSkipVerify: true}
cfg.TLCP = &tlcp.Config{InsecureSkipVerify: true, RootCAs: pool}
cfg.AllowFallbackToPlaintext = true
default:
return errors.New("invalid value / unknown tlcp config " + cfg.TLSConfig)
Expand Down Expand Up @@ -614,7 +614,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {

// TLCP-CA-Path
case "tlcpCaPath":
cfg.TLCPCaPath = append(cfg.TLCPCaPath, strings.ToLower(value))
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
cfg.TLCPCaPath = append(cfg.TLCPCaPath, filepath.Join(homeDir, value))

// I/O write Timeout
case "writeTimeout":
Expand Down
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrNoTLCP = errors.New("TLCP requested but server does not support TLCP")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
Expand Down
23 changes: 21 additions & 2 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
}
}
pos += 2
// TLCP capability flags is in the higher 2 bytes
higher2Bytes := clientFlag(binary.LittleEndian.Uint16(data[pos+3 : pos+5]))
if higher2Bytes&(clientCapabilityExtension>>16) != 0 {
mc.flags |= clientCapabilityExtension
mc.extensionFlag |= clientExtensionFlag(data[39])
}
if mc.extensionFlag&clientTLCP == 0 && mc.cfg.TLCP != nil {
if mc.cfg.AllowFallbackToPlaintext {
mc.cfg.TLCP = nil
} else {
return nil, "", ErrNoTLCP
}
}

if len(data) > pos {
// character set [1 byte]
Expand Down Expand Up @@ -297,8 +310,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// To enable TLS / SSL or TLCP
if mc.cfg.TLS != nil {
clientFlags |= clientSSL
} else if mc.cfg.TLCP != nil {
clientFlags |= clientTLCP
}
var tlcpFlag clientExtensionFlag = 0
if mc.cfg.TLCP != nil {
clientFlags |= clientCapabilityExtension
tlcpFlag |= clientTLCP
}

if mc.cfg.MultiStatements {
Expand Down Expand Up @@ -370,6 +386,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0
}

// TLCP flag in data[13]
data[13] = byte(tlcpFlag)

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.TLS != nil {
Expand Down