Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default connection attribute '_server_host' #1506

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Janek Vedock <janekvedock at comcast.net>
Jason Ng <oblitorum at gmail.com>
Jean-Yves Pellé <jy at pelle.link>
Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Expand Down Expand Up @@ -128,6 +129,7 @@ Keybase Inc.
Multiplay Ltd.
Percona LLC
Pivotal Inc.
Shattered Silicon Ltd.
Stripe Inc.
Zendesk Inc.
Dolthub Inc.
19 changes: 9 additions & 10 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ package mysql
import (
"context"
"database/sql/driver"
"fmt"
"net"
"os"
"strconv"
Expand All @@ -23,8 +22,8 @@ type connector struct {
encodedAttributes string // Encoded connection attributes.
}

func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf := make([]byte, 0, 251)
func encodeConnectionAttributes(cfg *Config) string {
connAttrsBuf := make([]byte, 0)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
Expand All @@ -35,9 +34,12 @@ func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
serverHost, _, _ := net.SplitHostPort(cfg.Addr)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
methane marked this conversation as resolved.
Show resolved Hide resolved

// user-defined connection attributes
for _, connAttr := range strings.Split(textAttributes, ",") {
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
k, v, found := strings.Cut(connAttr, ":")
if !found {
continue
Expand All @@ -49,15 +51,12 @@ func encodeConnectionAttributes(textAttributes string) string {
return string(connAttrsBuf)
}

func newConnector(cfg *Config) (*connector, error) {
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
if len(encodedAttributes) > 250 {
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
}
func newConnector(cfg *Config) *connector {
encodedAttributes := encodeConnectionAttributes(cfg)
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}, nil
}
}

// Connect implements driver.Connector interface.
Expand Down
7 changes: 2 additions & 5 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ import (
)

func TestConnectorReturnsTimeout(t *testing.T) {
connector, err := newConnector(&Config{
connector := newConnector(&Config{
Net: "tcp",
Addr: "1.1.1.1:1234",
Timeout: 10 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}

_, err = connector.Connect(context.Background())
_, err := connector.Connect(context.Background())
if err == nil {
t.Fatal("error expected")
}
Expand Down
1 change: 1 addition & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
connAttrServerHost = "_server_host"
)

// MySQL constants documentation:
Expand Down
9 changes: 3 additions & 6 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c, err := newConnector(cfg)
if err != nil {
return nil, err
}
c := newConnector(cfg)
return c.Connect(context.Background())
}

Expand All @@ -108,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
return newConnector(cfg)
return newConnector(cfg), nil
}

// OpenConnector implements driver.DriverContext.
Expand All @@ -117,5 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil {
return nil, err
}
return newConnector(cfg)
return newConnector(cfg), nil
}
64 changes: 40 additions & 24 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -3377,11 +3378,31 @@ func TestConnectionAttributes(t *testing.T) {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)
defaultAttrs := []string{
connAttrClientName,
connAttrOS,
connAttrPlatform,
connAttrPid,
connAttrServerHost,
}
host, _, _ := net.SplitHostPort(addr)
defaultAttrValues := []string{
connAttrClientNameValue,
connAttrOSValue,
connAttrPlatformValue,
strconv.Itoa(os.Getpid()),
host,
}

customAttrs := []string{"attr1", "attr2"}
customAttrValues := []string{"foo", "bar"}

customAttrStrs := make([]string, len(customAttrs))
for i := range customAttrs {
customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i])
}

dsn += fmt.Sprintf("&connectionAttributes=%s", strings.Join(customAttrStrs, ","))

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
Expand All @@ -3394,27 +3415,22 @@ func TestConnectionAttributes(t *testing.T) {

dbt := &DBTest{t, db}

var attrValue string
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()"
rows := dbt.mustQuery(queryString)
defer rows.Close()

rowsMap := make(map[string]string)
for rows.Next() {
var attrName, attrValue string
rows.Scan(&attrName, &attrValue)
rowsMap[attrName] = attrValue
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...)
expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...)
for i := range connAttrs {
if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] {
dbt.Errorf("expected %s, got %s", expectedAttrValues[i], gotValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}
16 changes: 7 additions & 9 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

// 1 byte to store length of all key-values
// NOTE: Actually, this is length encoded integer.
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
// doesn't support buffer size more than 4096 bytes.
// TODO(methane): Rewrite buffer management.
pktLen += 1 + len(mc.connector.encodedAttributes)
// encode length of the connection attributes
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
Expand Down Expand Up @@ -380,8 +379,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pos++

// Connection Attributes
data[pos] = byte(len(mc.connector.encodedAttributes))
pos++
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))

// Send Auth packet
Expand Down
5 changes: 1 addition & 4 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ var _ net.Conn = new(mockConn)

func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector, err := newConnector(NewConfig())
if err != nil {
panic(err)
}
connector := newConnector(NewConfig())
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: connector.cfg,
Expand Down
Loading