From 98d72897bab37633105da6dce698ce074fd19995 Mon Sep 17 00:00:00 2001 From: Jason Ng Date: Thu, 23 Nov 2023 21:01:24 +0800 Subject: [PATCH] Add default connection attribute '_server_host' (#1506) The `_server_host` connection attribute is supported in MariaDB (Connector/C) https://mariadb.com/kb/en/mysql_optionsv/#connection-attribute-options --- AUTHORS | 2 ++ connector.go | 21 +++++++------- connector_test.go | 7 ++--- const.go | 1 + driver.go | 9 ++---- driver_test.go | 71 ++++++++++++++++++++++++----------------------- packets.go | 16 +++++------ packets_test.go | 5 +--- 8 files changed, 64 insertions(+), 68 deletions(-) diff --git a/AUTHORS b/AUTHORS index c7e15960..2caa7d70 100644 --- a/AUTHORS +++ b/AUTHORS @@ -50,6 +50,7 @@ INADA Naoki Jacek Szwec James Harr Janek Vedock +Jason Ng Jean-Yves Pellé Jeff Hodges Jeffrey Charles @@ -131,6 +132,7 @@ Multiplay Ltd. Percona LLC PingCAP Inc. Pivotal Inc. +Shattered Silicon Ltd. Stripe Inc. Zendesk Inc. Dolthub Inc. diff --git a/connector.go b/connector.go index ba3be71e..3cef7963 100644 --- a/connector.go +++ b/connector.go @@ -11,7 +11,6 @@ package mysql import ( "context" "database/sql/driver" - "fmt" "net" "os" "strconv" @@ -23,8 +22,8 @@ type connector struct { encodedAttributes string // Encoded connection attributes. } -func encodeConnectionAttributes(textAttributes string) string { - connAttrsBuf := make([]byte, 0, 251) +func encodeConnectionAttributes(cfg *Config) string { + connAttrsBuf := make([]byte, 0) // default connection attributes connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) @@ -35,9 +34,14 @@ func encodeConnectionAttributes(textAttributes string) string { connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + serverHost, _, _ := net.SplitHostPort(cfg.Addr) + if serverHost != "" { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost) + } // user-defined connection attributes - for _, connAttr := range strings.Split(textAttributes, ",") { + for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") { k, v, found := strings.Cut(connAttr, ":") if !found { continue @@ -49,15 +53,12 @@ func encodeConnectionAttributes(textAttributes string) string { return string(connAttrsBuf) } -func newConnector(cfg *Config) (*connector, error) { - encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) - if len(encodedAttributes) > 250 { - return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes) - } +func newConnector(cfg *Config) *connector { + encodedAttributes := encodeConnectionAttributes(cfg) return &connector{ cfg: cfg, encodedAttributes: encodedAttributes, - }, nil + } } // Connect implements driver.Connector interface. diff --git a/connector_test.go b/connector_test.go index bedb44ce..82d8c598 100644 --- a/connector_test.go +++ b/connector_test.go @@ -8,16 +8,13 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector, err := newConnector(&Config{ + connector := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, }) - if err != nil { - t.Fatal(err) - } - _, err = connector.Connect(context.Background()) + _, err := connector.Connect(context.Background()) if err == nil { t.Fatal("error expected") } diff --git a/const.go b/const.go index 0f2621a6..22526e03 100644 --- a/const.go +++ b/const.go @@ -26,6 +26,7 @@ const ( connAttrPlatform = "_platform" connAttrPlatformValue = runtime.GOARCH connAttrPid = "_pid" + connAttrServerHost = "_server_host" ) // MySQL constants documentation: diff --git a/driver.go b/driver.go index 45528b92..105316b8 100644 --- a/driver.go +++ b/driver.go @@ -83,10 +83,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c, err := newConnector(cfg) - if err != nil { - return nil, err - } + c := newConnector(cfg) return c.Connect(context.Background()) } @@ -108,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return newConnector(cfg) + return newConnector(cfg), nil } // OpenConnector implements driver.DriverContext. @@ -117,5 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return newConnector(cfg) + return newConnector(cfg), nil } diff --git a/driver_test.go b/driver_test.go index ab780f04..efbff179 100644 --- a/driver_test.go +++ b/driver_test.go @@ -24,6 +24,7 @@ import ( "os" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -3377,12 +3378,30 @@ func TestConnectionAttributes(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - attr1 := "attr1" - value1 := "value1" - attr2 := "fo/o" - value2 := "bo/o" - dsn += "&connectionAttributes=" + url.QueryEscape(fmt.Sprintf("%s:%s,%s:%s", attr1, value1, attr2, value2)) + defaultAttrs := []string{ + connAttrClientName, + connAttrOS, + connAttrPlatform, + connAttrPid, + connAttrServerHost, + } + host, _, _ := net.SplitHostPort(addr) + defaultAttrValues := []string{ + connAttrClientNameValue, + connAttrOSValue, + connAttrPlatformValue, + strconv.Itoa(os.Getpid()), + host, + } + + customAttrs := []string{"attr1", "fo/o"} + customAttrValues := []string{"value1", "bo/o"} + customAttrStrs := make([]string, len(customAttrs)) + for i := range customAttrs { + customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i]) + } + dsn += "&connectionAttributes=" + url.QueryEscape(strings.Join(customAttrStrs, ",")) var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { @@ -3395,40 +3414,24 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} - var attrValue string - queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" - rows := dbt.mustQuery(queryString, connAttrClientName) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != connAttrClientNameValue { - dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) - } - } else { - dbt.Errorf("no data") - } - rows.Close() + queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" + rows := dbt.mustQuery(queryString) + defer rows.Close() - rows = dbt.mustQuery(queryString, attr1) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != value1 { - dbt.Errorf("expected %q, got %q", value1, attrValue) - } - } else { - dbt.Errorf("no data") + rowsMap := make(map[string]string) + for rows.Next() { + var attrName, attrValue string + rows.Scan(&attrName, &attrValue) + rowsMap[attrName] = attrValue } - rows.Close() - rows = dbt.mustQuery(queryString, attr2) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != value2 { - dbt.Errorf("expected %q, got %q", value2, attrValue) + connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...) + expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...) + for i := range connAttrs { + if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] { + dbt.Errorf("expected %q, got %q", expectedAttrValues[i], gotValue) } - } else { - dbt.Errorf("no data") } - rows.Close() } func TestErrorInMultiResult(t *testing.T) { diff --git a/packets.go b/packets.go index 0127232e..49e6bb05 100644 --- a/packets.go +++ b/packets.go @@ -292,15 +292,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } - // 1 byte to store length of all key-values - // NOTE: Actually, this is length encoded integer. - // But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer - // doesn't support buffer size more than 4096 bytes. - // TODO(methane): Rewrite buffer management. - pktLen += 1 + len(mc.connector.encodedAttributes) + // encode length of the connection attributes + var connAttrsLEIBuf [9]byte + connAttrsLen := len(mc.connector.encodedAttributes) + connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) + pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) // Calculate packet length and get buffer with that size - data, err := mc.buf.takeSmallBuffer(pktLen + 4) + data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection mc.cfg.Logger.Print(err) @@ -380,8 +379,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos++ // Connection Attributes - data[pos] = byte(len(mc.connector.encodedAttributes)) - pos++ + pos += copy(data[pos:], connAttrsLEI) pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) // Send Auth packet diff --git a/packets_test.go b/packets_test.go index e86ec584..fa4683ea 100644 --- a/packets_test.go +++ b/packets_test.go @@ -96,10 +96,7 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) - connector, err := newConnector(NewConfig()) - if err != nil { - panic(err) - } + connector := newConnector(NewConfig()) mc := &mysqlConn{ buf: newBuffer(conn), cfg: connector.cfg,