diff --git a/connection.go b/connection.go index a7da9e7e2..67cea1fcb 100644 --- a/connection.go +++ b/connection.go @@ -27,6 +27,7 @@ type mysqlConn struct { affectedRows uint64 insertId uint64 cfg *Config + connector *connector maxAllowedPacket int maxWriteSize int writeTimeout time.Duration diff --git a/connector.go b/connector.go index a5c988e13..6acf3dd50 100644 --- a/connector.go +++ b/connector.go @@ -11,11 +11,54 @@ package mysql import ( "context" "database/sql/driver" + "fmt" "net" + "os" + "strconv" + "strings" ) type connector struct { - cfg *Config // immutable private copy. + cfg *Config // immutable private copy. + encodedAttributes string // Encoded connection attributes. +} + +func encodeConnectionAttributes(textAttributes string) string { + connAttrsBuf := make([]byte, 0, 251) + + // default connection attributes + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + + // user-defined connection attributes + for _, connAttr := range strings.Split(textAttributes, ",") { + attr := strings.SplitN(connAttr, ":", 2) + if len(attr) != 2 { + continue + } + for _, v := range attr { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) + } + } + + return string(connAttrsBuf) +} + +func newConnector(cfg *Config) (*connector, error) { + encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) + if len(encodedAttributes) > 250 { + return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes) + } + return &connector{ + cfg: cfg, + encodedAttributes: encodedAttributes, + }, nil } // Connect implements driver.Connector interface. @@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), cfg: c.cfg, + connector: c, } mc.parseTime = mc.cfg.ParseTime diff --git a/connector_test.go b/connector_test.go index 976903c5b..6dd983867 100644 --- a/connector_test.go +++ b/connector_test.go @@ -8,11 +8,14 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector := &connector{&Config{ + connector, err := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, - }} + }) + if err != nil { + t.Fatal(err) + } _, err := connector.Connect(context.Background()) if err == nil { diff --git a/driver.go b/driver.go index 8b0c3ec0a..c19e04207 100644 --- a/driver.go +++ b/driver.go @@ -85,8 +85,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c := &connector{ - cfg: cfg, + c, err := newConnector(cfg) + if err != nil { + return nil, err } return c.Connect(context.Background()) } @@ -103,7 +104,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return &connector{cfg: cfg}, nil + return newConnector(cfg) } // OpenConnector implements driver.DriverContext. @@ -112,7 +113,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return &connector{ - cfg: cfg, - }, nil + return newConnector(cfg) } diff --git a/packets.go b/packets.go index 200431c5b..3475dfe83 100644 --- a/packets.go +++ b/packets.go @@ -18,9 +18,6 @@ import ( "fmt" "io" "math" - "os" - "strconv" - "strings" "time" ) @@ -322,31 +319,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } - connAttrsBuf := make([]byte, 0, 100) - - // default connection attributes - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) - - // user-defined connection attributes - for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") { - attr := strings.Split(connAttr, ":") - if len(attr) != 2 { - continue - } - for _, v := range attr { - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) - } - } - // 1 byte to store length of all key-values - pktLen += len(connAttrsBuf) + 1 + // NOTE: Actually, this is length encoded integer. + // But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer + // doesn't support buffer size more than 4096 bytes. + // TODO(methane): Rewrite buffer management. + pktLen += 1 + len(mc.connector.encodedAttributes) // Calculate packet length and get buffer with that size data, err := mc.buf.takeSmallBuffer(pktLen + 4)