Skip to content

Commit

Permalink
Make BeforeConnect a functional option
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Mar 6, 2024
1 parent 078d1fc commit 6a4e24e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 25 deletions.
6 changes: 3 additions & 3 deletions connector.go
Expand Up @@ -66,11 +66,11 @@ func newConnector(cfg *Config) *connector {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error

// Invoke BeforeConnect if present, with a copy of the configuration
// Invoke beforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.BeforeConnect != nil {
if c.cfg.beforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.BeforeConnect(ctx, cfg)
err = c.cfg.beforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions driver_test.go
Expand Up @@ -2055,10 +2055,10 @@ func TestBeforeConnect(t *testing.T) {
t.Fatalf("error parsing DSN: %v", err)
}

cfg.BeforeConnect = func(ctx context.Context, c *Config) error {
cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error {
c.DBName = dbname
return nil
}
}))

connector, err := NewConnector(cfg)
if err != nil {
Expand Down
48 changes: 28 additions & 20 deletions dsn.go
Expand Up @@ -37,24 +37,23 @@ var (
type Config struct {
// non boolean fields

User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger

// boolean fields

Expand All @@ -73,8 +72,9 @@ type Config struct {

// unexported fields. new options should be come here

pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
}

// Functional Options Pattern
Expand Down Expand Up @@ -114,6 +114,14 @@ func TimeTruncate(d time.Duration) Option {
}
}

// BeforeConnect sets the function to be invoked before a connection is established.
func BeforeConnect(fn func(context.Context, *Config) error) Option {
return func(cfg *Config) error {
cfg.beforeConnect = fn
return nil
}
}

func (cfg *Config) Clone() *Config {
cp := *cfg
if cp.TLS != nil {
Expand Down

0 comments on commit 6a4e24e

Please sign in to comment.