Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Improve buffer handling (#890)
* Eliminate redundant size test in takeBuffer.
* Change buffer takeXXX functions to return an error to make it explicit that they can fail.
* Add missing error check in handleAuthResult.
* Add buffer.store(..) method which can be used by external buffer consumers to update the raw buffer.
* Fix some typos and unnecessary UTF-8 characters in comments.
* Improve buffer function docs.
* Add comments to explain some non-obvious behavior around buffer handling.
  • Loading branch information
stevenh authored and methane committed Nov 16, 2018
1 parent 369b5d6 commit 6be42e0
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 49 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Expand Up @@ -73,6 +73,7 @@ Shuode Li <elemount at qq.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Xiangyu Hu <xiangyu.hu at outlook.com>
Expand All @@ -90,3 +91,4 @@ Keybase Inc.
Percona LLC
Pivotal Inc.
Stripe Inc.
Multiplay Ltd.
8 changes: 5 additions & 3 deletions auth.go
Expand Up @@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data := mc.buf.takeSmallBuffer(4 + 1)
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)

// parse public key
data, err := mc.readPacket()
if err != nil {
if data, err = mc.readPacket(); err != nil {
return err
}

Expand Down
49 changes: 31 additions & 18 deletions buffer.go
Expand Up @@ -22,17 +22,17 @@ const defaultBufSize = 4096
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
}

// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
buf: make([]byte, defaultBufSize),
nc: nc,
}
}
Expand Down Expand Up @@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
return b.buf[offset:b.idx], nil
}

// returns a buffer with the requested size.
// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}

// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
if length <= cap(b.buf) {
return b.buf[:length], nil
}

if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
return b.buf, nil
}
return make([]byte, length)

// buffer is larger than we want to store.
return make([]byte, length), nil
}

// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
return b.buf[:length]
return b.buf[:length], nil
}

// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
return b.buf, nil
}

// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return nil
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return b.buf
return nil
}
6 changes: 3 additions & 3 deletions connection.go
Expand Up @@ -182,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}

buf := mc.buf.takeCompleteBuffer()
if buf == nil {
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return "", ErrInvalidConn
}
buf = buf[:0]
Expand Down
2 changes: 1 addition & 1 deletion driver.go
Expand Up @@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {

// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
// the DSN string is formatted
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error

Expand Down
54 changes: 30 additions & 24 deletions packets.go
Expand Up @@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
mc.sequence++

// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)1 bytes long
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
Expand Down Expand Up @@ -286,10 +286,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}

// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -367,10 +367,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data := mc.buf.takeSmallBuffer(pktLen)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -387,10 +387,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0

data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -406,10 +406,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.sequence = 0

pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -427,10 +427,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0

data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -883,7 +883,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc

// Determine threshould dynamically to avoid packet size shortage.
// Determine threshold dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
Expand All @@ -893,15 +893,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc.sequence = 0

var data []byte
var err error

if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if data == nil {
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -927,7 +929,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
pos := minPktLen

var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
Expand All @@ -936,10 +938,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
// No need to clean nullMask as make ensures that.
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := 0; i < maskLen; i++ {
for i := range nullMask {
nullMask[i] = 0
}
pos += maskLen
Expand Down Expand Up @@ -1076,7 +1079,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
if err = mc.buf.store(data); err != nil {
errLog.Print(err)
return errBadConnNoWrite
}
}

pos += len(paramValues)
Expand Down

1 comment on commit 6be42e0

@Synaxis
Copy link

@Synaxis Synaxis commented on 6be42e0 Dec 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this apply to a single query case ? ( specific for mysql)
or this code can be used in another case ?
it's beautifull

Please sign in to comment.