Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client_ed25519 authentication #1518

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Evan Elias <evan at skeema.net>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
Gusted <postmaster at gusted.xyz>
Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Expand Down
47 changes: 47 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ import (
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/pem"
"fmt"
"sync"

"filippo.io/edwards25519"
)

// server pub keys registry
Expand Down Expand Up @@ -225,6 +228,44 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte,
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}

// authEd25519 does ed25519 authentication used by MariaDB.
func authEd25519(scramble []byte, password string) ([]byte, error) {
// Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c
// Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207
h := sha512.Sum512([]byte(password))

s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32])
if err != nil {
return nil, err
}
A := (&edwards25519.Point{}).ScalarBaseMult(s)

mh := sha512.New()
mh.Write(h[32:])
mh.Write(scramble)
messageDigest := mh.Sum(nil)
r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest)
if err != nil {
return nil, err
}

R := (&edwards25519.Point{}).ScalarBaseMult(r)

kh := sha512.New()
kh.Write(R.Bytes())
kh.Write(A.Bytes())
kh.Write(scramble)
hramDigest := kh.Sum(nil)
k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest)
if err != nil {
return nil, err
}

S := k.MultiplyAdd(k, s, r)

return append(R.Bytes(), S.Bytes()...), nil
}

func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil {
Expand Down Expand Up @@ -290,6 +331,12 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, err

case "client_ed25519":
if len(authData) != 32 {
return nil, ErrMalformPkt
}
return authEd25519(authData, mc.cfg.Passwd)

default:
mc.cfg.Logger.Print("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
Expand Down
51 changes: 51 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,3 +1328,54 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) {
t.Errorf("got unexpected data: %v", conn.written)
}
}

// Derived from https://github.com/MariaDB/server/blob/6b2287fff23fbdc362499501c562f01d0d2db52e/plugin/auth_ed25519/ed25519-t.c
func TestEd25519Auth(t *testing.T) {
conn, mc := newRWMockConn(1)
mc.cfg.User = "root"
mc.cfg.Passwd = "foobar"

authData := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
plugin := "client_ed25519"

// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}

// check written auth response
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{
232, 61, 201, 63, 67, 63, 51, 53, 86, 73, 238, 35, 170, 117, 146,
214, 26, 17, 35, 9, 8, 132, 245, 141, 48, 99, 66, 58, 36, 228, 48,
84, 115, 254, 187, 168, 88, 162, 249, 57, 35, 85, 79, 238, 167, 106,
68, 117, 56, 135, 171, 47, 20, 14, 133, 79, 15, 229, 124, 160, 176,
100, 138, 14,
methane marked this conversation as resolved.
Show resolved Hide resolved
}
if writtenAuthRespLen != 64 {
t.Fatalf("expected 64 bytes from client, got %d", writtenAuthRespLen)
}
if !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("auth response did not match expected value:\n%v\n%v", writtenAuthResp, expectedAuthResp)
}
conn.written = nil

// auth response
conn.data = []byte{
7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
}
conn.maxReads = 1

// Handle response to auth packet
if err := mc.handleAuthResult(authData, plugin); err != nil {
t.Errorf("got error: %v", err)
}
methane marked this conversation as resolved.
Show resolved Hide resolved
}
10 changes: 5 additions & 5 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
for _, test := range tests {
t.Run("default", func(t *testing.T) {
dbt := &DBTest{t, db}
defer dbt.db.Exec("DROP TABLE IF EXISTS test")
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
})
if db2 != nil {
t.Run("interpolateParams", func(t *testing.T) {
dbt2 := &DBTest{t, db2}
defer dbt2.db.Exec("DROP TABLE IF EXISTS test")
test(dbt2)
dbt2.db.Exec("DROP TABLE IF EXISTS test")
})
}
}
Expand Down Expand Up @@ -3181,14 +3181,14 @@ func TestRawBytesAreNotModified(t *testing.T) {

rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`)
if err != nil {
t.Fatal(err)
dbt.Fatal(err)
}

var b int
var raw sql.RawBytes
for rows.Next() {
if err := rows.Scan(&b, &raw); err != nil {
t.Fatal(err)
dbt.Fatal(err)
}

before := string(raw)
Expand All @@ -3198,7 +3198,7 @@ func TestRawBytesAreNotModified(t *testing.T) {
after := string(raw)

if before != after {
t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
dbt.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
}
}
rows.Close()
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/go-sql-driver/mysql

go 1.18

require filippo.io/edwards25519 v1.1.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=