This repository has been archived by the owner on Feb 22, 2023. It is now read-only.
forked from go-mysql-org/go-mysql
/
auth.go
139 lines (103 loc) · 2.84 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package client
import (
"bytes"
"encoding/binary"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
)
func (c *Conn) readInitialHandshake() error {
data, err := c.ReadPacket()
if err != nil {
return errors.Trace(err)
}
if data[0] == ERR_HEADER {
return errors.New("read initial handshake error")
}
if data[0] < MinProtocolVersion {
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
}
//skip mysql version
//mysql version end with 0x00
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
//connection id length is 4
c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
pos += 4
c.salt = []byte{}
c.salt = append(c.salt, data[pos:pos+8]...)
//skip filter
pos += 8 + 1
//capability lower 2 bytes
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
pos += 2
if len(data) > pos {
//skip server charset
//c.charset = data[pos]
pos += 1
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
pos += 2
//skip auth data len or [00]
//skip reserved (all [00])
pos += 10 + 1
// The documentation is ambiguous about the length.
// The official Python library uses the fixed length 12
// mysql-proxy also use 12
// which is not documented but seems to work.
c.salt = append(c.salt, data[pos:pos+12]...)
}
return nil
}
func (c *Conn) writeAuthHandshake() error {
// Adjust client capability flags based on server support
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG
capability &= c.capability
//packet length
//capbility 4
//max-packet size 4
//charset 1
//reserved all[0] 23
length := 4 + 4 + 1 + 23
//username
length += len(c.user) + 1
//we only support secure connection
auth := CalcPassword(c.salt, []byte(c.password))
length += 1 + len(auth)
if len(c.db) > 0 {
capability |= CLIENT_CONNECT_WITH_DB
length += len(c.db) + 1
}
c.capability = capability
data := make([]byte, length+4)
//capability [32 bit]
data[4] = byte(capability)
data[5] = byte(capability >> 8)
data[6] = byte(capability >> 16)
data[7] = byte(capability >> 24)
//MaxPacketSize [32 bit] (none)
//data[8] = 0x00
//data[9] = 0x00
//data[10] = 0x00
//data[11] = 0x00
//Charset [1 byte]
//use default collation id 33 here, is utf-8
data[12] = byte(DEFAULT_COLLATION_ID)
//Filler [23 bytes] (all 0x00)
pos := 13 + 23
//User [null terminated string]
if len(c.user) > 0 {
pos += copy(data[pos:], c.user)
}
//data[pos] = 0x00
pos++
// auth [length encoded integer]
data[pos] = byte(len(auth))
pos += 1 + copy(data[pos+1:], auth)
// db [null terminated string]
if len(c.db) > 0 {
pos += copy(data[pos:], c.db)
//data[pos] = 0x00
}
return c.WritePacket(data)
}