Skip to content

Commit

Permalink
Make logger configurable per connection
Browse files Browse the repository at this point in the history
  • Loading branch information
frozenbonito committed Apr 7, 2023
1 parent d83ecdc commit 4987352
Show file tree
Hide file tree
Showing 14 changed files with 232 additions and 32 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Tan Jinhua <312841925 at qq.com>
Tetsuro Aoki <t.aoki1130 at gmail.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,16 @@ Note that this sets the location for time.Time values but does not change MySQL'

Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.

##### `logging`

```
Type: bool / string
Valid Values: true, false, <name>
Default: true
```

`logging=false` disables logging. You can use a custom logger after registering it with [`mysql.RegisterLogger`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterLogger).

##### `maxAllowedPacket`
```
Type: decimal number
Expand Down
2 changes: 1 addition & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
return enc, err

default:
errLog.Print("unknown auth plugin:", plugin)
mc.errLog().Print("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
}
}
Expand Down
23 changes: 15 additions & 8 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {

func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
mc.errLog().Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
Expand Down Expand Up @@ -147,7 +147,7 @@ func (mc *mysqlConn) cleanup() {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
mc.errLog().Print(err)
}
}

Expand All @@ -163,14 +163,14 @@ func (mc *mysqlConn) error() error {

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
mc.errLog().Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
errLog.Print(err)
mc.errLog().Print(err)
return nil, driver.ErrBadConn
}

Expand Down Expand Up @@ -204,7 +204,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(err)
mc.errLog().Print(err)
return "", ErrInvalidConn
}
buf = buf[:0]
Expand Down Expand Up @@ -296,7 +296,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
mc.errLog().Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
Expand Down Expand Up @@ -357,7 +357,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro

func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
mc.errLog().Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
Expand Down Expand Up @@ -451,7 +451,7 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
mc.errLog().Print(ErrInvalidConn)
return driver.ErrBadConn
}

Expand Down Expand Up @@ -648,3 +648,10 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load()
}

func (mc *mysqlConn) errLog() Logger {
if mc.cfg.Logger != nil {
return mc.cfg.Logger
}
return defaultLogger
}
1 change: 1 addition & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func TestPingErrInvalidConn(t *testing.T) {
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := ms.Ping(context.Background())
Expand Down
9 changes: 8 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
c.errLog().Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin)
if err != nil {
Expand Down Expand Up @@ -144,3 +144,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
func (c *connector) Driver() driver.Driver {
return &MySQLDriver{}
}

func (c *connector) errLog() Logger {
if c.cfg.Logger != nil {
return c.cfg.Logger
}
return defaultLogger
}
2 changes: 1 addition & 1 deletion driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1995,7 +1995,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
func TestUnixSocketAuthFail(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Save the current logger so we can restore it.
oldLogger := errLog
oldLogger := defaultLogger

// Set a new logger so we can capture its output.
buffer := bytes.NewBuffer(make([]byte, 0, 64))
Expand Down
37 changes: 37 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type Config struct {
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
LoggingConfig string // Logging configuration
Logger Logger // Logger

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down Expand Up @@ -153,6 +155,20 @@ func (cfg *Config) normalize() error {
}
}

if cfg.Logger == nil {
switch cfg.LoggingConfig {
case "true", "":
// use default logger
case "false":
cfg.Logger = defaultNopLogger
default:
cfg.Logger = getLogger(cfg.LoggingConfig)
if cfg.Logger == nil {
return errors.New("invalid value / unknown logger name: " + cfg.LoggingConfig)
}
}
}

return nil
}

Expand Down Expand Up @@ -282,6 +298,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
}

if len(cfg.LoggingConfig) > 0 {
writeDSNParam(&buf, &hasParam, "logging", url.QueryEscape(cfg.LoggingConfig))
}

// other params
if cfg.Params != nil {
var params []string
Expand Down Expand Up @@ -554,6 +574,23 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}

case "logging":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.LoggingConfig = "true"
} else {
cfg.LoggingConfig = "false"
}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for logger name: %v", err)
}
cfg.LoggingConfig = name
}

default:
// lazy init
if cfg.Params == nil {
Expand Down
89 changes: 89 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ package mysql
import (
"crypto/tls"
"fmt"
"log"
"net/url"
"os"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -268,6 +270,93 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
}
}

func TestDSNWithCustomLogger(t *testing.T) {
baseDSN := "User:password@tcp(localhost:5555)/dbname?logging="

t.Run("custom logger is registered", func(tt *testing.T) {
logger := log.New(os.Stderr, "", log.LstdFlags)

RegisterLogger("testKey", logger)
defer DeregisterLogger("testKey")

tst := baseDSN + "testKey"

cfg, err := ParseDSN(tst)
if err != nil {
tt.Fatal(err.Error())
}

if cfg.LoggingConfig != "testKey" {
tt.Errorf("unexpected cfg.LoggingConfig value: %q", cfg.LoggingConfig)
}
if cfg.Logger != logger {
tt.Error("logger pointer doesn't match")
}
})

t.Run("custom logger is missing", func(tt *testing.T) {
tst := baseDSN + "invalid_name"

cfg, err := ParseDSN(tst)
if err == nil {
tt.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
})
}

func TestDSNLoggingConfig(t *testing.T) {
t.Run("logging=true", func(tt *testing.T) {
dsn := "User:password@tcp(localhost:5555)/dbname?logging=true"

cfg, err := ParseDSN(dsn)
if err != nil {
tt.Fatal(err.Error())
}

if cfg.LoggingConfig != "true" {
tt.Errorf("unexpected cfg.LoggingConfig value: %q", cfg.LoggingConfig)
}
if cfg.Logger != nil {
tt.Error("cfg.Logger should be nil")
}
})

t.Run("logging=false", func(tt *testing.T) {
dsn := "User:password@tcp(localhost:5555)/dbname?logging=false"

cfg, err := ParseDSN(dsn)
if err != nil {
tt.Fatal(err.Error())
}

if cfg.LoggingConfig != "false" {
tt.Errorf("unexpected cfg.LoggingConfig value: %q", cfg.LoggingConfig)
}
if cfg.Logger != defaultNopLogger {
tt.Error("logger pointer doesn't match")
}
})
}

func TestDSNWithCustomLoggerQueryEscape(t *testing.T) {
const name = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?logging=" + url.QueryEscape(name)

logger := log.New(os.Stderr, "", log.LstdFlags)

RegisterLogger(name, logger)
defer DeregisterTLSConfig(name)

cfg, err := ParseDSN(dsn)
if err != nil {
t.Fatal(err.Error())
}

if cfg.Logger != logger {
t.Error("logger pointer doesn't match")
}
}

func TestDSNUnsafeCollation(t *testing.T) {
_, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
if err != errInvalidDSNUnsafeCollation {
Expand Down
Loading

0 comments on commit 4987352

Please sign in to comment.