diff --git a/buffer.go b/buffer.go index ca790d32f..6896ca805 100644 --- a/buffer.go +++ b/buffer.go @@ -56,7 +56,6 @@ func (b *buffer) fill(need int) (err error) { } return // err } - return } // returns next N bytes from buffer. @@ -72,3 +71,65 @@ func (b *buffer) readNext(need int) (p []byte, err error) { b.length -= need return } + +// various allocation pools + +var bytesPool = make(chan []byte, 16) + +// may return unzeroed bytes +func getBytes(n int) []byte { + select { + case s := <-bytesPool: + if cap(s) >= n { + return s[:n] + } + default: + } + return make([]byte, n) +} + +func putBytes(s []byte) { + select { + case bytesPool <- s: + default: + } +} + +var fieldPool = make(chan []mysqlField, 16) + +func getMysqlFields(n int) []mysqlField { + select { + case f := <-fieldPool: + if cap(f) >= n { + return f[:n] + } + default: + } + return make([]mysqlField, n) +} + +func putMysqlFields(f []mysqlField) { + select { + case fieldPool <- f: + default: + } +} + +var rowsPool = make(chan *mysqlRows, 16) + +func getMysqlRows() *mysqlRows { + select { + case r := <-rowsPool: + return r + default: + } + return new(mysqlRows) +} + +func putMysqlRows(r *mysqlRows) { + *r = mysqlRows{} // zero it + select { + case rowsPool <- r: + default: + } +} diff --git a/connection.go b/connection.go index a62db429a..d5a7b7c81 100644 --- a/connection.go +++ b/connection.go @@ -190,13 +190,14 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { - rows := &mysqlRows{mc, false, nil, false} + rows := getMysqlRows() + rows.mc = mc if resLen > 0 { // Columns rows.columns, err = mc.readColumns(resLen) } - return rows, err + return &mysqlRowsI{rows}, err } } @@ -217,7 +218,8 @@ func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) { var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { - rows := &mysqlRows{mc, false, nil, false} + rows := getMysqlRows() + rows.mc = mc if resLen > 0 { // Columns diff --git a/errors.go b/errors.go index 20003e086..04c423628 100644 --- a/errors.go +++ b/errors.go @@ -24,6 +24,7 @@ var ( errPktSync = errors.New("Commands out of sync. You can't run this command now") errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?") errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.") + errInvConn = errors.New("Invalid Connection") ) // error type which represents a single MySQL error diff --git a/packets.go b/packets.go index b26caad4a..0ef57a678 100644 --- a/packets.go +++ b/packets.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" "math" @@ -328,18 +329,20 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { }) } -func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { +func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) (err error) { // Reset Packet Sequence mc.sequence = 0 pktLen := 1 + len(arg) - data := make([]byte, pktLen+4) + + // get byte slice from pool + data := getBytes(pktLen + 4) // Add the packet header [24bit length + 1 byte sequence] data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) - //data[3] = mc.sequence + data[3] = 0x00 // Add command byte data[4] = command @@ -348,7 +351,12 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + + // Return byte slice to pool + putBytes(data) + + return } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { @@ -489,7 +497,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { var i, pos, n int var name []byte - columns = make([]mysqlField, count) + columns = getMysqlFields(count) for { data, err = mc.readPacket() @@ -537,7 +545,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { if err != nil { return } - columns[i].name = string(name) + columns[i].name = string(name) // TODO(bradfitz): garbage. intern these. pos += n // Original name [len coded string] @@ -730,111 +738,98 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { - if len(args) != stmt.paramCount { + argsCount := len(args) + if argsCount != stmt.paramCount { return fmt.Errorf( "Arguments count mismatch (Got: %d Has: %d)", - len(args), - stmt.paramCount) + argsCount, + stmt.paramCount, + ) } - // Reset packet-sequence stmt.mc.sequence = 0 - pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1) - paramValues := make([][]byte, stmt.paramCount) - paramTypes := make([]byte, (stmt.paramCount << 1)) - bitMask := uint64(0) - var i int + pktLen := 1 + 4 + 1 + 4 + ((argsCount + 7) >> 3) + 1 + (argsCount << 1) + + // bitmask for params send as long data packet + longDataMask := uint(0) - for i = range args { - // build NULL-bitmap + // 2-PASS packing + + // PASS 1 - get length + for i := range args { if args[i] == nil { - bitMask += 1 << uint(i) - paramTypes[i<<1] = fieldTypeNULL continue } - // cache types and values switch v := args[i].(type) { - case int64: - paramTypes[i<<1] = fieldTypeLongLong - paramValues[i] = uint64ToBytes(uint64(v)) - pktLen += 8 - continue - - case float64: - paramTypes[i<<1] = fieldTypeDouble - paramValues[i] = uint64ToBytes(math.Float64bits(v)) + case int64, float64: pktLen += 8 continue case bool: - paramTypes[i<<1] = fieldTypeTiny pktLen++ - if v { - paramValues[i] = []byte{0x01} - } else { - paramValues[i] = []byte{0x00} - } continue case []byte: - paramTypes[i<<1] = fieldTypeString - if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 { - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(v))), - v..., - ) - pktLen += len(paramValues[i]) - continue - } else { - err := stmt.writeCommandLongData(i, v) - if err == nil { + if n := len(v); n < stmt.mc.maxPacketAllowed-pktLen-(argsCount-(i+1))*64 { + switch { + case n <= 250: + pktLen += n + 1 + continue + + case n <= 0xffff: + pktLen += n + 3 continue + + case n <= 0xffffff: + pktLen += n + 4 + continue + + default: + return errors.New("Invalid length") } - return err + } else { + longDataMask |= 1 << uint(i) } case string: - paramTypes[i<<1] = fieldTypeString - if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 { - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(v))), - []byte(v)..., - ) - pktLen += len(paramValues[i]) - continue - } else { - err := stmt.writeCommandLongData(i, []byte(v)) - if err == nil { + if n := len(v); n < stmt.mc.maxPacketAllowed-pktLen-(argsCount-(i+1))*64 { + switch { + case n <= 250: + pktLen += n + 1 + continue + + case n <= 0xffff: + pktLen += n + 3 + continue + + case n <= 0xffffff: + pktLen += n + 4 continue + + default: + return errors.New("Invalid length") } - return err + } else { + longDataMask |= 1 << uint(i) } case time.Time: - paramTypes[i<<1] = fieldTypeString - - var val []byte - if v.IsZero() { - val = []byte("0000-00-00") + if !v.IsZero() { + pktLen += 1 + 19 } else { - val = []byte(v.Format(timeFormat)) + pktLen += 1 + 10 } - paramValues[i] = append( - lengthEncodedIntegerToBytes(uint64(len(val))), - val..., - ) - pktLen += len(paramValues[i]) - continue - default: return fmt.Errorf("Can't convert type: %T", args[i]) } } - data := make([]byte, pktLen+4) + // get data buffer + data := getBytes(pktLen + 4) + defer putBytes(data) // packet header [4 bytes] data[0] = byte(pktLen) @@ -852,32 +847,112 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[8] = byte(stmt.id >> 24) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] - //data[9] = 0x00 + // 0 must be set explicitly since the slice might not be zeroed + data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] + // 0 must be set explicitly since the slice might not be zeroed data[10] = 0x01 - //data[11] = 0x00 - //data[12] = 0x00 - //data[13] = 0x00 + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 - if stmt.paramCount > 0 { + if argsCount > 0 { // NULL-bitmap [(param_count+7)/8 bytes] - pos := 14 + ((stmt.paramCount + 7) >> 3) - // Convert bitMask to bytes - for i = 14; i < pos; i++ { - data[i] = byte(bitMask >> uint((i-14)<<3)) - } + // deferred + bitMask := uint(0) + pos := 14 + ((argsCount + 7) >> 3) // newParameterBoundFlag 1 [1 byte] data[pos] = 0x01 pos++ - // type of parameters [param_count*2 bytes] - pos += copy(data[pos:], paramTypes) + // type of the parameters [param_count*2 bytes] + paramTypes := data[pos:] + + // values of the parameters [n bytes] + paramValues := data[pos+(argsCount<<1):] + pos = 0 + + // PASS 2 - copy data + for i := range args { + paramTypes[(i<<1)+1] = 0x00 - // values for the parameters [n bytes] - for i = range paramValues { - pos += copy(data[pos:], paramValues[i]) + // build NULL-bitmap + if args[i] == nil { + bitMask |= 1 << uint(i) + paramTypes[i<<1] = fieldTypeNULL + continue + } + + // cache types and values + switch v := args[i].(type) { + case int64: + paramTypes[i<<1] = fieldTypeLongLong + binary.LittleEndian.PutUint64(paramValues[pos:], uint64(v)) + pos += 8 + continue + + case float64: + paramTypes[i<<1] = fieldTypeDouble + binary.LittleEndian.PutUint64(paramValues[pos:], math.Float64bits(v)) + pos += 8 + continue + + case bool: + paramTypes[i<<1] = fieldTypeTiny + if v { + paramValues[pos] = 0x01 + } else { + paramValues[pos] = 0x00 + } + pos++ + continue + + case []byte: + paramTypes[i<<1] = fieldTypeString + if longDataMask&(1<> 3); i < max; i++ { + data[14+i] = byte(bitMask >> uint(i<<3)) } } diff --git a/rows.go b/rows.go index fda75998b..0295a9857 100644 --- a/rows.go +++ b/rows.go @@ -11,7 +11,6 @@ package mysql import ( "database/sql/driver" - "errors" "io" ) @@ -21,14 +20,32 @@ type mysqlField struct { flags fieldFlag } +// mysqlRows is the driver-internal Rows struct that is never given to +// the database/sql package. This struct is 40 bytes on 64-bit +// machines and is recycled. Its size isn't very relevant, since we +// recycle it. +// +// Allocate with newMysqlRows (from buffer.go) and return with +// putMySQLRows. See also: mysqlRowsI. type mysqlRows struct { mc *mysqlConn - binary bool columns []mysqlField + binary bool // Note: packing small bool fields at the end eof bool } +// mysqlRowsI implements driver.Rows. Its wrapped *mysqlRows pointer +// becomes nil and recycled on Close. This struct is kept small (8 +// bytes) to minimize garbage creation. +type mysqlRowsI struct { + *mysqlRows +} + func (rows *mysqlRows) Columns() (columns []string) { + if rows == nil { + errLog.Print("mysqlRows.Columns called with nil receiver") + return nil + } columns = make([]string, len(rows.columns)) for i := range columns { columns[i] = rows.columns[i].name @@ -36,15 +53,28 @@ func (rows *mysqlRows) Columns() (columns []string) { return } -func (rows *mysqlRows) Close() (err error) { +func (ri *mysqlRowsI) Close() (err error) { + if ri.mysqlRows == nil { + errLog.Print("mysqlRows.Close called twice? sql package fail?") + return errInvConn + } + err = ri.mysqlRows.close() + putMysqlRows(ri.mysqlRows) + ri.mysqlRows = nil + return err +} + +func (rows *mysqlRows) close() (err error) { defer func() { rows.mc = nil + putMysqlFields(rows.columns) + rows.columns = nil }() // Remove unread packets from stream if !rows.eof { if rows.mc == nil { - return errors.New("Invalid Connection") + return errInvConn } err = rows.mc.readUntilEOF() @@ -54,12 +84,16 @@ func (rows *mysqlRows) Close() (err error) { } func (rows *mysqlRows) Next(dest []driver.Value) error { + if rows == nil { + errLog.Print("mysqlRows.Next called with nil receiver") + return errInvConn + } if rows.eof { return io.EOF } if rows.mc == nil { - return errors.New("Invalid Connection") + return errInvConn } // Fetch next row from stream diff --git a/statement.go b/statement.go index faa1ad032..e0fdf40e8 100644 --- a/statement.go +++ b/statement.go @@ -79,7 +79,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, err } - rows := &mysqlRows{stmt.mc, true, nil, false} + rows := getMysqlRows() + rows.mc = stmt.mc + rows.binary = true if resLen > 0 { // Columns @@ -89,5 +91,5 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } } - return rows, err + return &mysqlRowsI{rows}, err } diff --git a/utils.go b/utils.go index 097ecd0aa..eea6631c6 100644 --- a/utils.go +++ b/utils.go @@ -340,19 +340,6 @@ func readBool(value string) bool { * Convert from and to bytes * ******************************************************************************/ -func uint64ToBytes(n uint64) []byte { - return []byte{ - byte(n), - byte(n >> 8), - byte(n >> 16), - byte(n >> 24), - byte(n >> 32), - byte(n >> 40), - byte(n >> 48), - byte(n >> 56), - } -} - func uint64ToString(n uint64) []byte { var a [20]byte i := 20 @@ -453,16 +440,24 @@ func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) { return } -func lengthEncodedIntegerToBytes(n uint64) []byte { +// Does NOT make bounds check +func lengthEncodedIntegerToBytes(b []byte, i uint32) int { switch { - case n <= 250: - return []byte{byte(n)} - - case n <= 0xffff: - return []byte{0xfc, byte(n), byte(n >> 8)} - - case n <= 0xffffff: - return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} + case i <= 250: + b[0] = byte(i) + return 1 + + case i <= 0xffff: + b[0] = 0xfc + b[1] = byte(i) + b[2] = byte(i >> 8) + return 3 + + default: //i <= 0xffffff + b[0] = 0xfd + b[1] = byte(i) + b[2] = byte(i >> 8) + b[3] = byte(i >> 16) + return 4 } - return nil }