Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support caching sha2 password #794

Merged
merged 4 commits into from
May 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Egor Smolyakov <egorsmkv at gmail.com>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
Expand Down
6 changes: 6 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ const (
statusInTransReadonly
statusSessionStateChanged
)

const (
cachingSha2PasswordRequestPublicKey = 2
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)
28 changes: 24 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.writeTimeout = mc.cfg.WriteTimeout

// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
cipher, pluginName, err := mc.readInitPacket()
if err != nil {
mc.cleanup()
return nil, err
}

// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
if err = mc.writeAuthPacket(cipher, pluginName); err != nil {
mc.cleanup()
return nil, err
}

// Handle response to auth packet, switch methods if possible
if err = handleAuthResult(mc, cipher); err != nil {
if err = handleAuthResult(mc, cipher, pluginName); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
Expand Down Expand Up @@ -153,7 +153,27 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return mc, nil
}

func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
func handleAuthResult(mc *mysqlConn, oldCipher []byte, pluginName string) error {

// handle caching_sha2_password
if pluginName == "caching_sha2_password" {
auth, err := mc.readCachingSha2PasswordAuthResult()
if err != nil {
return err
}
if auth == cachingSha2PasswordPerformFullAuthentication {
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
if err = mc.writeClearAuthPacket(); err != nil {
return err
}
} else {
if err = mc.writePublicKeyAuthPacket(oldCipher); err != nil {
return err
}
}
}
}

// Read Result Packet
cipher, err := mc.readResultOK()
if err == nil {
Expand Down
4 changes: 2 additions & 2 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,7 @@ func TestSQLInjection(t *testing.T) {

dsns := []string{
dsn,
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a required change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO_AUTO_CREATE_USER SQL mode seems removed at MySQL 8.0
https://dev.mysql.com/doc/refman/8.0/en/mysql-nutshell.html

My linux box show that message

--- FAIL: TestSQLInjection (0.92s)
        driver_test.go:161: error on exec CREATE TABLE test (v INTEGER): Error 1231: Variable 'sql_mode' can't be set to the value of 'NO_AUTO_CREATE_USER'

dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
}
for _, testdsn := range dsns {
runTests(t, testdsn, createTest("1 OR 1=1"))
Expand Down Expand Up @@ -1872,7 +1872,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {

dsns := []string{
dsn,
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
}
for _, testdsn := range dsns {
runTests(t, testdsn, testData)
Expand Down
84 changes: 72 additions & 12 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ package mysql

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"database/sql/driver"
"encoding/binary"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {

// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
data, err := mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, driver.ErrBadConn
return nil, "", driver.ErrBadConn
}
return nil, err
return nil, "", err
}

if data[0] == iERR {
return nil, mc.handleErrorPacket(data)
return nil, "", mc.handleErrorPacket(data)
}

// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, fmt.Errorf(
return nil, "", fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
Expand All @@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, ErrOldProtocol
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, ErrNoTLS
return nil, "", ErrNoTLS
}
pos += 2

pluginName := ""
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
Expand All @@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// The official Python library uses the fixed length 12
// which seems to work but technically could have a hidden bug.
cipher = append(cipher, data[pos:pos+12]...)
pos += 13
pluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])

// TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
Expand All @@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], cipher)
return b[:], nil
return b[:], pluginName, nil
}

// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], cipher)
return b[:], nil
return b[:], pluginName, nil
}

// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
return fmt.Errorf("unknown authentication plugin name '%s'", pluginName)
}

// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
Expand All @@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
var scrambleBuff []byte
switch pluginName {
case "mysql_native_password":
scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd))
case "caching_sha2_password":
scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd))
}

pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1

Expand Down Expand Up @@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// Assume native client during response
pos += copy(data[pos:], "mysql_native_password")
pos += copy(data[pos:], pluginName)
data[pos] = 0x00

// Send Auth packet
Expand Down Expand Up @@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
return mc.writePacket(data)
}

// Caching sha2 authentication. Public key request and send encrypted password
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error {
// request public key
data := mc.buf.takeSmallBuffer(4 + 1)
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)

data, err := mc.readPacket()
if err != nil {
return err
}

block, _ := pem.Decode(data[1:])
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}

plain := make([]byte, len(mc.cfg.Passwd)+1)
copy(plain, mc.cfg.Passwd)
for i := range plain {
j := i % len(cipher)
plain[i] ^= cipher[j]
}
sha1 := sha1.New()
enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
data = mc.buf.takeSmallBuffer(4 + len(enc))
copy(data[4:], enc)
return mc.writePacket(data)
}

/******************************************************************************
* Command Packets *
******************************************************************************/
Expand Down Expand Up @@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
return nil, err
}

func (mc *mysqlConn) readCachingSha2PasswordAuthResult() (int, error) {
data, err := mc.readPacket()
if err == nil {
if data[0] != 1 {
return 0, ErrMalformPkt
}
}
return int(data[1]), err
}

// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
Expand Down
29 changes: 29 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
Expand Down Expand Up @@ -211,6 +212,34 @@ func scrambleOldPassword(scramble, password []byte) []byte {
return out[:]
}

// Encrypt password using 8.0 default method
func scrambleCachingSha2Password(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}

// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))

crypt := sha256.New()
crypt.Write(password)
message1 := crypt.Sum(nil)

crypt.Reset()
crypt.Write(message1)
message1Hash := crypt.Sum(nil)

crypt.Reset()
crypt.Write(message1Hash)
crypt.Write(scramble)
message2 := crypt.Sum(nil)

for i := range message1 {
message1[i] ^= message2[i]
}

return message1
}

/******************************************************************************
* Time related utils *
******************************************************************************/
Expand Down
18 changes: 18 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ func TestOldPass(t *testing.T) {
}
}

func TestCachingSha2Pass(t *testing.T) {
scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
vectors := []struct {
pass string
out string
}{
{"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
{"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
}
for _, tuple := range vectors {
ours := scrambleCachingSha2Password(scramble, []byte(tuple.pass))
if tuple.out != fmt.Sprintf("%x", ours) {
t.Errorf("Failed caching sha2 password %q", tuple.pass)
}
}

}

func TestFormatBinaryDateTime(t *testing.T) {
rawDate := [11]byte{}
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
Expand Down