Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Eli Pozniansky <elipoz at gmail.com>
Erwan Martin <hello at erwan.io>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Expand Down
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type mysqlConn struct {
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
recvGtids string
cfg *Config
maxAllowedPacket int
maxWriteSize int
Expand Down
7 changes: 7 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,10 @@ const (
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)

const (
sessionTrackSystemVariables = 0
sessionTrackSchema = 1
sessionTrackStateChange = 2
sessionTrackGtids = 3
)
1 change: 1 addition & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
TrackReceivedGtids bool // Track received gtids
}

// NewConfig creates a new Config and sets default values.
Expand Down
54 changes: 47 additions & 7 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientFlags |= clientMultiStatements
}

if mc.cfg.TrackReceivedGtids {
clientFlags |= clientSessionTrack
}

// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
Expand Down Expand Up @@ -610,23 +614,59 @@ func readStatus(b []byte) statusFlag {
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int
var c, n int

// 0x00 [1 byte]
c = 1

// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
mc.affectedRows, _, n = readLengthEncodedInteger(data[c:])
c += n

// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
mc.insertId, _, n = readLengthEncodedInteger(data[c:])
c += n

// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
mc.status = readStatus(data[c : c+2])
c += 2

// warning count [2 bytes]
c += 2

mc.recvGtids = ""

if mc.flags&clientSessionTrack != 0 && mc.status&statusSessionStateChanged != 0 {
// Human readable status information (ignored)
num, _, n := readLengthEncodedInteger(data[c:])
if num < 1 {
return io.EOF
}
c += n + int(num)

// Length of session state changes
num, _, n = readLengthEncodedInteger(data[c:])
if num < 1 {
return io.EOF
}
c += n
for t := 0; t < int(num); {
infoType := data[c]
c += 1
m, _, n := readLengthEncodedInteger(data[c:])
if m < 1 {
return io.EOF
}
c += n

if infoType == sessionTrackGtids {
mc.recvGtids = string(data[c : c+int(m)])
return nil
}
c += int(m)
t += 1 + n + int(m)
}
}

return nil
}
Expand Down
78 changes: 76 additions & 2 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package mysql
import (
"bytes"
"errors"
"io"
"net"
"testing"
"time"
Expand Down Expand Up @@ -179,7 +180,7 @@ func TestReadPacketSplit(t *testing.T) {
data[4] = 0x11
data[maxPacketSize+3] = 0x22

// 2nd packet has payload length 0 and squence id 1
// 2nd packet has payload length 0 and sequence id 1
// 00 00 00 01
data[pkt2ofs+3] = 0x01

Expand Down Expand Up @@ -211,7 +212,7 @@ func TestReadPacketSplit(t *testing.T) {
data[pkt2ofs+4] = 0x33
data[pkt2ofs+maxPacketSize+3] = 0x44

// 3rd packet has payload length 0 and squence id 2
// 3rd packet has payload length 0 and sequence id 2
// 00 00 00 02
data[pkt3ofs+3] = 0x02

Expand Down Expand Up @@ -334,3 +335,76 @@ func TestRegression801(t *testing.T) {
t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
}
}

func TestReadOkPacketWithTrackReceivedGtids(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
flags: clientSessionTrack,
}

data := make([]byte, maxPacketSize)
conn.data = data

// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
data[0] = 0x00
data[1] = 0x42 // affected rows
data[2] = 0x17 // insert id
data[3] = 0x00 // first byte of status
data[4] = byte(statusSessionStateChanged >> 8) // second byte of status
data[5] = 0x00 // warning count
data[6] = 0x00 // warning count
data[7] = 0x01 // Human readable status information length
data[8] = 0x00 // Human readable status information string
data[9] = 0x0A // Length of session_state_changes
data[10] = 0x02 // SESSION_TRACK_STATE_CHANGE == 0x02
data[11] = 0x02 // length
data[12] = 0x58 // 'X'
data[13] = 0x58 // 'X'
data[14] = 0x03 // SESSION_TRACK_GTIDS == 0x03
data[15] = 0x04 // GTIDs length
data[16] = 0x47 // 'G'
data[17] = 0x54 // 'T'
data[18] = 0x49 // 'I'
data[19] = 0x44 // 'D'

// Error 1
saved := data[7]
data[7] = 0x00
conn.data = data
err := mc.handleOkPacket(data)
if err != io.EOF {
t.Fatalf("got error: %v", err)
}
data[7] = saved

// Error 2
saved = data[9]
data[9] = 0x00
conn.data = data
err = mc.handleOkPacket(data)
if err != io.EOF {
t.Fatalf("got error: %v", err)
}
data[9] = saved

// Error 3
saved = data[11]
data[11] = 0x00
conn.data = data
err = mc.handleOkPacket(data)
if err != io.EOF {
t.Fatalf("got error: %v", err)
}
data[11] = saved

// Success
err = mc.handleOkPacket(data)
if err != nil {
t.Fatalf("got error: %v", err)
}

if mc.recvGtids != "GTID" {
t.Fatalf("could not parse GTIDs from session tracking. got: %v", mc.recvGtids)
}
}