diff --git a/packets.go b/packets.go index a263a06e7..b570a3e1c 100644 --- a/packets.go +++ b/packets.go @@ -24,54 +24,56 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { - // Read packet header - data, err := mc.buf.readNext(4) - if err != nil { - errLog.Print(err.Error()) - mc.Close() - return nil, driver.ErrBadConn - } + var payload []byte + for { + // Read packet header + data, err := mc.buf.readNext(4) + if err != nil { + errLog.Print(err.Error()) + mc.Close() + return nil, driver.ErrBadConn + } - // Packet Length [24 bit] - pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + // Packet Length [24 bit] + pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if pktLen < 1 { - errLog.Print(errMalformPkt.Error()) - mc.Close() - return nil, driver.ErrBadConn - } + if pktLen < 1 { + errLog.Print(errMalformPkt.Error()) + mc.Close() + return nil, driver.ErrBadConn + } - // Check Packet Sync [8 bit] - if data[3] != mc.sequence { - if data[3] > mc.sequence { - return nil, errPktSyncMul - } else { - return nil, errPktSync + // Check Packet Sync [8 bit] + if data[3] != mc.sequence { + if data[3] > mc.sequence { + return nil, errPktSyncMul + } else { + return nil, errPktSync + } } - } - mc.sequence++ + mc.sequence++ + + // Read packet body [pktLen bytes] + data, err = mc.buf.readNext(pktLen) + if err != nil { + errLog.Print(err.Error()) + mc.Close() + return nil, driver.ErrBadConn + } + + isLastPacket := (pktLen < maxPacketSize) - // Read packet body [pktLen bytes] - if data, err = mc.buf.readNext(pktLen); err == nil { - if pktLen < maxPacketSize { + // Zero allocations for non-splitting packets + if isLastPacket && payload == nil { return data, nil } - // Make a copy since data becomes invalid with the next read - buf := make([]byte, len(data)) - copy(buf, data) + payload = append(payload, data...) - // More data - data, err = mc.readPacket() - if err == nil { - return append(buf, data...), nil + if isLastPacket { + return payload, nil } } - - // err case - mc.Close() - errLog.Print(err.Error()) - return nil, driver.ErrBadConn } // Write packet buffer 'data'