Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection attributes #1389

Merged
merged 9 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
; TestConcurrent fails if max_connections is too large
max_connections=50
local_infile=1
performance_schema=on
- name: setup database
run: |
mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;'
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ Default: 0

I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.

##### `connectionAttributes`

```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```

[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.

##### System Variables

Expand Down
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type mysqlConn struct {
affectedRows uint64
insertId uint64
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
Expand Down
46 changes: 45 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,54 @@ package mysql
import (
"context"
"database/sql/driver"
"fmt"
"net"
"os"
"strconv"
"strings"
)

type connector struct {
cfg *Config // immutable private copy.
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
}

func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf := make([]byte, 0, 251)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))

// user-defined connection attributes
for _, connAttr := range strings.Split(textAttributes, ",") {
attr := strings.SplitN(connAttr, ":", 2)
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

return string(connAttrsBuf)
}

func newConnector(cfg *Config) (*connector, error) {
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
if len(encodedAttributes) > 250 {
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
}
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}, nil
}

// Connect implements driver.Connector interface.
Expand All @@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime

Expand Down
9 changes: 6 additions & 3 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
)

func TestConnectorReturnsTimeout(t *testing.T) {
connector := &connector{&Config{
connector, err := newConnector(&Config{
Net: "tcp",
Addr: "1.1.1.1:1234",
Timeout: 10 * time.Millisecond,
}}
})
if err != nil {
t.Fatal(err)
}

_, err := connector.Connect(context.Background())
_, err = connector.Connect(context.Background())
if err == nil {
t.Fatal("error expected")
}
Expand Down
12 changes: 12 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,24 @@

package mysql

import "runtime"

const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"

// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
methane marked this conversation as resolved.
Show resolved Hide resolved
connAttrClientNameValue = "Go-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
)

// MySQL constants documentation:
Expand Down
11 changes: 5 additions & 6 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c := &connector{
cfg: cfg,
c, err := newConnector(cfg)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
Expand All @@ -103,7 +104,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
return &connector{cfg: cfg}, nil
return newConnector(cfg)
}

// OpenConnector implements driver.DriverContext.
Expand All @@ -112,7 +113,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
return newConnector(cfg)
}
47 changes: 47 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3214,3 +3214,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
t.Errorf("connection not closed")
}
}

func TestConnectionAttributes(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
}

dbt := &DBTest{t, db}

var attrValue string
methane marked this conversation as resolved.
Show resolved Hide resolved
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}
40 changes: 23 additions & 17 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,24 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down Expand Up @@ -560,6 +561,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}

// Connection attributes
case "connectionAttributes":
cfg.ConnectionAttributes = value

default:
// lazy init
if cfg.Params == nil {
Expand Down
13 changes: 13 additions & 0 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -318,6 +319,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

// 1 byte to store length of all key-values
// NOTE: Actually, this is length encoded integer.
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
// doesn't support buffer size more than 4096 bytes.
// TODO(methane): Rewrite buffer management.
pktLen += 1 + len(mc.connector.encodedAttributes)

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
Expand Down Expand Up @@ -394,6 +402,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++

// Connection Attributes
data[pos] = byte(len(mc.connector.encodedAttributes))
pos++
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))

// Send Auth packet
return mc.writePacket(data[:pos])
}
Expand Down
7 changes: 6 additions & 1 deletion packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ var _ net.Conn = new(mockConn)

func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector, err := newConnector(NewConfig())
if err != nil {
panic(err)
}
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
cfg: connector.cfg,
connector: connector,
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
Expand Down
5 changes: 5 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}

func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}

// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
Expand Down