Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BeforeConnect callback to configuration object #1469

Merged
merged 6 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Microsoft Corp.
Multiplay Ltd.
Percona LLC
PingCAP Inc.
Expand Down
12 changes: 11 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,22 @@ func newConnector(cfg *Config) *connector {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error

// Invoke beforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.beforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.beforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
}

// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
cfg: cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime
Expand Down
34 changes: 34 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,40 @@ func TestCustomDial(t *testing.T) {
}
}

func TestBeforeConnect(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

// dbname is set in the BeforeConnect handle
cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_"))
if err != nil {
t.Fatalf("error parsing DSN: %v", err)
}

cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error {
c.DBName = dbname
return nil
}))

connector, err := NewConnector(cfg)
if err != nil {
t.Fatalf("error creating connector: %v", err)
}

db := sql.OpenDB(connector)
defer db.Close()

var connectedDb string
err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb)
if err != nil {
t.Fatalf("error executing query: %v", err)
}
if connectedDb != dbname {
t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb)
}
}

func TestSQLInjection(t *testing.T) {
createTest := func(arg string) func(dbt *DBTest) {
return func(dbt *DBTest) {
Expand Down
14 changes: 12 additions & 2 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -71,8 +72,9 @@ type Config struct {

// unexported fields. new options should be come here

pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
Comment on lines +75 to +77
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The beforeConnect field is added to the Config struct to support the BeforeConnect callback functionality. Ensure that this field is correctly exposed and accessible from other parts of the code, such as in connector.go, where it is directly accessed. Consider making it public or providing a getter method if necessary to maintain encapsulation.

}

// Functional Options Pattern
Expand Down Expand Up @@ -112,6 +114,14 @@ func TimeTruncate(d time.Duration) Option {
}
}

// BeforeConnect sets the function to be invoked before a connection is established.
func BeforeConnect(fn func(context.Context, *Config) error) Option {
return func(cfg *Config) error {
cfg.beforeConnect = fn
return nil
}
}

func (cfg *Config) Clone() *Config {
cp := *cfg
if cp.TLS != nil {
Expand Down
Loading