Skip to content

Commit 558ed11

Browse files
authored
Don't store plaintext passwords (#1040)
1 parent 4dea9fe commit 558ed11

File tree

12 files changed

+355
-183
lines changed

12 files changed

+355
-183
lines changed

client/auth.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
154154
// password hashing
155155
switch c.authPluginName {
156156
case mysql.AUTH_NATIVE_PASSWORD:
157-
return mysql.CalcPassword(authData[:20], []byte(c.password)), false, nil
157+
return mysql.CalcNativePassword(authData[:20], []byte(c.password)), false, nil
158158
case mysql.AUTH_CACHING_SHA2_PASSWORD:
159-
return mysql.CalcCachingSha2Password(authData, c.password), false, nil
159+
return mysql.CalcCachingSha2Password(authData, []byte(c.password)), false, nil
160160
case mysql.AUTH_CLEAR_PASSWORD:
161161
return []byte(c.password), true, nil
162162
case mysql.AUTH_SHA256_PASSWORD:

driver/driver_options_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type mockHandler struct {
3939
}
4040

4141
func TestDriverOptions_SetRetriesOn(t *testing.T) {
42-
srv := CreateMockServer(t)
42+
srv := createMockServer(t)
4343
defer srv.Stop()
4444
var wg sync.WaitGroup
4545
srv.handler.modifier = &wg
@@ -64,7 +64,7 @@ func TestDriverOptions_SetRetriesOn(t *testing.T) {
6464
}
6565

6666
func TestDriverOptions_SetRetriesOff(t *testing.T) {
67-
srv := CreateMockServer(t)
67+
srv := createMockServer(t)
6868
defer srv.Stop()
6969
var wg sync.WaitGroup
7070
srv.handler.modifier = &wg
@@ -114,7 +114,7 @@ func TestDriverOptions_SetCompression(t *testing.T) {
114114
}
115115

116116
func TestDriverOptions_ConnectTimeout(t *testing.T) {
117-
srv := CreateMockServer(t)
117+
srv := createMockServer(t)
118118
defer srv.Stop()
119119

120120
conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?timeout=1s")
@@ -131,7 +131,7 @@ func TestDriverOptions_ConnectTimeout(t *testing.T) {
131131
}
132132

133133
func TestDriverOptions_BufferSize(t *testing.T) {
134-
srv := CreateMockServer(t)
134+
srv := createMockServer(t)
135135
defer srv.Stop()
136136

137137
SetDSNOptions(map[string]DriverOption{
@@ -156,7 +156,7 @@ func TestDriverOptions_BufferSize(t *testing.T) {
156156
}
157157

158158
func TestDriverOptions_ReadTimeout(t *testing.T) {
159-
srv := CreateMockServer(t)
159+
srv := createMockServer(t)
160160
defer srv.Stop()
161161

162162
conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=100ms")
@@ -177,7 +177,7 @@ func TestDriverOptions_ReadTimeout(t *testing.T) {
177177
}
178178

179179
func TestDriverOptions_writeTimeout(t *testing.T) {
180-
srv := CreateMockServer(t)
180+
srv := createMockServer(t)
181181
defer srv.Stop()
182182

183183
// use a writeTimeout that will fail parsing by ParseDuration resulting
@@ -224,7 +224,7 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
224224
return nil
225225
})
226226

227-
srv := CreateMockServer(t)
227+
srv := createMockServer(t)
228228
defer srv.Stop()
229229
conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1s")
230230
defer func() {
@@ -265,9 +265,9 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
265265
require.True(t, math.MaxUint64 == a)
266266
}
267267

268-
func CreateMockServer(t *testing.T) *testServer {
268+
func createMockServer(t *testing.T) *testServer {
269269
inMemProvider := server.NewInMemoryProvider()
270-
inMemProvider.AddUser(*testUser, *testPassword)
270+
require.NoError(t, inMemProvider.AddUser(*testUser, *testPassword))
271271
defaultServer := server.NewDefaultServer()
272272

273273
l, err := net.Listen("tcp", "127.0.0.1:3307")

mysql/mysql_gtid.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,6 @@ func (s IntervalSlice) Normalize() IntervalSlice {
111111
return n
112112
}
113113

114-
func min(a, b int64) int64 {
115-
if a < b {
116-
return a
117-
}
118-
return b
119-
}
120-
121-
func max(a, b int64) int64 {
122-
if a > b {
123-
return a
124-
}
125-
return b
126-
}
127-
128114
func (s *IntervalSlice) InsertInterval(interval Interval) {
129115
var (
130116
count int

mysql/util.go

Lines changed: 163 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import (
99
"crypto/sha1"
1010
"crypto/sha256"
1111
"crypto/sha512"
12+
"crypto/subtle"
1213
"encoding/binary"
14+
"encoding/hex"
1315
"fmt"
1416
"io"
1517
mrand "math/rand"
@@ -29,7 +31,7 @@ func Pstack() string {
2931
return string(buf[0:n])
3032
}
3133

32-
func CalcPassword(scramble, password []byte) []byte {
34+
func CalcNativePassword(scramble, password []byte) []byte {
3335
if len(password) == 0 {
3436
return nil
3537
}
@@ -39,35 +41,100 @@ func CalcPassword(scramble, password []byte) []byte {
3941
crypt.Write(password)
4042
stage1 := crypt.Sum(nil)
4143

42-
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
43-
// inner Hash
44+
// stage2Hash = SHA1(stage1Hash)
4445
crypt.Reset()
4546
crypt.Write(stage1)
46-
hash := crypt.Sum(nil)
47+
stage2 := crypt.Sum(nil)
4748

48-
// outer Hash
49+
// scrambleHash = SHA1(scramble + stage2Hash)
4950
crypt.Reset()
5051
crypt.Write(scramble)
51-
crypt.Write(hash)
52-
scramble = crypt.Sum(nil)
52+
crypt.Write(stage2)
53+
scrambleHash := crypt.Sum(nil)
5354

5455
// token = scrambleHash XOR stage1Hash
55-
for i := range scramble {
56-
scramble[i] ^= stage1[i]
56+
return Xor(scrambleHash, stage1)
57+
}
58+
59+
// Xor modifies hash1 in-place with XOR against hash2
60+
func Xor(hash1 []byte, hash2 []byte) []byte {
61+
l := min(len(hash1), len(hash2))
62+
for i := range l {
63+
hash1[i] ^= hash2[i]
64+
}
65+
return hash1
66+
}
67+
68+
// hash_stage1 = xor(reply, sha1(public_seed, hash_stage2))
69+
func stage1FromReply(scramble []byte, seed []byte, stage2 []byte) []byte {
70+
crypt := sha1.New()
71+
crypt.Write(seed)
72+
crypt.Write(stage2)
73+
seededHash := crypt.Sum(nil)
74+
75+
return Xor(scramble, seededHash)
76+
}
77+
78+
// DecodePasswordHex decodes the standard format used by MySQL
79+
// Password hashes in the 4.1 format always begin with a * character
80+
// see https://dev.mysql.com/doc/mysql-security-excerpt/5.7/en/password-hashing.html
81+
// ref vitess.io/vitess/go/mysql/auth_server.go
82+
func DecodePasswordHex(hexEncodedPassword string) ([]byte, error) {
83+
if hexEncodedPassword[0] == '*' {
84+
hexEncodedPassword = hexEncodedPassword[1:]
85+
}
86+
return hex.DecodeString(hexEncodedPassword)
87+
}
88+
89+
// EncodePasswordHex encodes to the standard format used by MySQL
90+
// adds the optionally leading * to the hashed password
91+
func EncodePasswordHex(passwordHash []byte) string {
92+
hexstr := strings.ToUpper(hex.EncodeToString(passwordHash))
93+
return "*" + hexstr
94+
}
95+
96+
// NativePasswordHash = sha1(sha1(password))
97+
func NativePasswordHash(password []byte) []byte {
98+
if len(password) == 0 {
99+
return nil
57100
}
58-
return scramble
101+
102+
// stage1Hash = SHA1(password)
103+
crypt := sha1.New()
104+
crypt.Write(password)
105+
stage1 := crypt.Sum(nil)
106+
107+
// stage2Hash = SHA1(stage1Hash)
108+
crypt.Reset()
109+
crypt.Write(stage1)
110+
return crypt.Sum(stage1[:0])
111+
}
112+
113+
func CompareNativePassword(reply []byte, stored []byte, seed []byte) bool {
114+
if len(stored) == 0 {
115+
return false
116+
}
117+
118+
// hash_stage1 = xor(reply, sha1(public_seed, hash_stage2))
119+
stage1 := stage1FromReply(reply, seed, stored)
120+
// andidate_hash2 = sha1(hash_stage1)
121+
stage2 := sha1.Sum(stage1)
122+
123+
// check(candidate_hash2 == hash_stage2)
124+
// use ConstantTimeCompare to mitigate timing based attacks
125+
return subtle.ConstantTimeCompare(stage2[:], stored) == 1
59126
}
60127

61128
// CalcCachingSha2Password: Hash password using MySQL 8+ method (SHA256)
62-
func CalcCachingSha2Password(scramble []byte, password string) []byte {
129+
func CalcCachingSha2Password(scramble []byte, password []byte) []byte {
63130
if len(password) == 0 {
64131
return nil
65132
}
66133

67134
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
68135

69136
crypt := sha256.New()
70-
crypt.Write([]byte(password))
137+
crypt.Write(password)
71138
message1 := crypt.Sum(nil)
72139

73140
crypt.Reset()
@@ -79,11 +146,7 @@ func CalcCachingSha2Password(scramble []byte, password string) []byte {
79146
crypt.Write(scramble)
80147
message2 := crypt.Sum(nil)
81148

82-
for i := range message1 {
83-
message1[i] ^= message2[i]
84-
}
85-
86-
return message1
149+
return Xor(message1, message2)
87150
}
88151

89152
// Taken from https://github.com/go-sql-driver/mysql/pull/1518
@@ -135,6 +198,89 @@ func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte,
135198
return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil)
136199
}
137200

201+
const (
202+
SALT_LENGTH = 16
203+
ITERATION_MULTIPLIER = 1000
204+
SHA256_PASSWORD_ITERATIONS = 5
205+
)
206+
207+
// generateUserSalt generate salt of given length for sha256_password hash
208+
func generateUserSalt(length int) ([]byte, error) {
209+
// Generate a random salt of the given length
210+
// Implement this function for your project
211+
salt := make([]byte, length)
212+
_, err := rand.Read(salt)
213+
if err != nil {
214+
return []byte(""), err
215+
}
216+
217+
// Restrict to 7-bit to avoid multi-byte UTF-8
218+
for i := range salt {
219+
salt[i] = salt[i] &^ 128
220+
for salt[i] == 36 || salt[i] == 0 { // '$' or NUL
221+
newval := make([]byte, 1)
222+
_, err := rand.Read(newval)
223+
if err != nil {
224+
return []byte(""), err
225+
}
226+
salt[i] = newval[0] &^ 128
227+
}
228+
}
229+
return salt, nil
230+
}
231+
232+
// hashCrypt256 salt and hash a password the given number of iterations
233+
func hashCrypt256(source, salt string, iterations uint64) (string, error) {
234+
actualIterations := iterations * ITERATION_MULTIPLIER
235+
hashInput := []byte(source + salt)
236+
var hash [32]byte
237+
for i := uint64(0); i < actualIterations; i++ {
238+
hash = sha256.Sum256(hashInput)
239+
hashInput = hash[:]
240+
}
241+
242+
hashHex := hex.EncodeToString(hash[:])
243+
digest := fmt.Sprintf("$%d$%s$%s", iterations, salt, hashHex)
244+
return digest, nil
245+
}
246+
247+
// Check256HashingPassword compares a password to a hash for sha256_password
248+
// rather than trying to recreate just the hash we recreate the full hash
249+
// and use that for comparison
250+
func Check256HashingPassword(pwhash []byte, password string) (bool, error) {
251+
pwHashParts := bytes.Split(pwhash, []byte("$"))
252+
if len(pwHashParts) != 4 {
253+
return false, errors.New("failed to decode hash parts")
254+
}
255+
256+
iterationsPart := pwHashParts[1]
257+
if len(iterationsPart) == 0 {
258+
return false, errors.New("iterations part is empty")
259+
}
260+
261+
iterations, err := strconv.ParseUint(string(iterationsPart), 10, 64)
262+
if err != nil {
263+
return false, errors.New("failed to decode iterations")
264+
}
265+
salt := pwHashParts[2][:SALT_LENGTH]
266+
267+
newHash, err := hashCrypt256(password, string(salt), iterations)
268+
if err != nil {
269+
return false, err
270+
}
271+
272+
return subtle.ConstantTimeCompare(pwhash, []byte(newHash)) == 1, nil
273+
}
274+
275+
// NewSha256PasswordHash creates a new password hash for sha256_password
276+
func NewSha256PasswordHash(pwd string) (string, error) {
277+
salt, err := generateUserSalt(SALT_LENGTH)
278+
if err != nil {
279+
return "", err
280+
}
281+
return hashCrypt256(pwd, string(salt), SHA256_PASSWORD_ITERATIONS)
282+
}
283+
138284
func DecompressMariadbData(data []byte) ([]byte, error) {
139285
// algorithm always 0=zlib
140286
// algorithm := (data[pos] & 0x07) >> 4

0 commit comments

Comments
 (0)