Skip to content

Commit

Permalink
Add *Conn.SetLogLevel
Browse files Browse the repository at this point in the history
Allow changing log level after connection is established. Because
log level and loggers can be set independently, it is now possible
to have a log level above none when there is a nil logger. This
means all log statements need to check for nil logger and an
appropriate log level. This check has been factored out into
*Conn.shouldLog.
  • Loading branch information
jackc committed Feb 13, 2016
1 parent cffae7f commit 0f7bf19
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 36 deletions.
49 changes: 32 additions & 17 deletions conn.go
Expand Up @@ -106,6 +106,7 @@ var ErrNotificationTimeout = errors.New("notification timeout")
var ErrDeadConn = errors.New("conn is dead")
var ErrTLSRefused = errors.New("server refused TLS connection")
var ErrConnBusy = errors.New("conn is busy")
var ErrInvalidLogLevel = errors.New("invalid log level")

type ProtocolError string

Expand All @@ -128,26 +129,23 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.logLevel = LogLevelDebug
}
c.logger = c.config.Logger
if c.logger == nil {
c.logLevel = LogLevelNone
}
c.mr.log = c.log
c.mr.logLevel = &c.logLevel
c.mr.shouldLog = c.shouldLog

if c.config.User == "" {
user, err := user.Current()
if err != nil {
return nil, err
}
c.config.User = user.Username
if c.logLevel >= LogLevelDebug {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", "User", c.config.User)
}
}

if c.config.Port == 0 {
c.config.Port = 5432
if c.logLevel >= LogLevelDebug {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port)
}
}
Expand Down Expand Up @@ -180,12 +178,12 @@ func Connect(config ConnConfig) (c *Conn, err error) {
}

func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
}
c.conn, err = c.config.Dial(network, address)
if err != nil {
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err))
}
return err
Expand All @@ -194,7 +192,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
if c != nil && err != nil {
c.conn.Close()
c.alive = false
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, err.Error())
}
}
Expand All @@ -207,11 +205,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.lastActivityTime = time.Now()

if tlsConfig != nil {
if c.logLevel >= LogLevelDebug {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Starting TLS handshake")
}
if err := c.startTLS(tlsConfig); err != nil {
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err))
}
return err
Expand Down Expand Up @@ -262,7 +260,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
case readyForQuery:
c.rxReadyForQuery(r)
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Connection established")
}

Expand Down Expand Up @@ -338,7 +336,7 @@ func (c *Conn) Close() (err error) {
_, err = c.conn.Write(wbuf.buf)

c.die(errors.New("Closed"))
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Closed connection")
}
return err
Expand Down Expand Up @@ -548,7 +546,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
}
}

if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
Expand Down Expand Up @@ -975,12 +973,12 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag

defer func() {
if err == nil {
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
}
} else {
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
}
}
Expand Down Expand Up @@ -1055,7 +1053,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {

c.lastActivityTime = time.Now()

if c.logLevel >= LogLevelTrace {
if c.shouldLog(LogLevelTrace) {
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
}

Expand Down Expand Up @@ -1252,6 +1250,10 @@ func (c *Conn) unlock() error {
return nil
}

func (c *Conn) shouldLog(lvl int) bool {
return c.logger != nil && c.logLevel >= lvl
}

func (c *Conn) log(lvl int, msg string, ctx ...interface{}) {
if c.Pid != 0 {
ctx = append(ctx, "pid", c.Pid)
Expand All @@ -1277,3 +1279,16 @@ func (c *Conn) SetLogger(logger Logger) Logger {
c.logger = logger
return oldLogger
}

// SetLogLevel replaces the current log level and returns the previous log
// level.
func (c *Conn) SetLogLevel(lvl int) (int, error) {
oldLvl := c.logLevel

if lvl < LogLevelNone || lvl > LogLevelTrace {
return oldLvl, ErrInvalidLogLevel
}

c.logLevel = lvl
return lvl, nil
}
79 changes: 74 additions & 5 deletions conn_test.go
Expand Up @@ -1345,12 +1345,28 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
}
}

type testLogger struct{}
type testLog struct {
lvl int
msg string
ctx []interface{}
}

func (l *testLogger) Debug(msg string, ctx ...interface{}) {}
func (l *testLogger) Info(msg string, ctx ...interface{}) {}
func (l *testLogger) Warn(msg string, ctx ...interface{}) {}
func (l *testLogger) Error(msg string, ctx ...interface{}) {}
type testLogger struct {
logs []testLog
}

func (l *testLogger) Debug(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx})
}
func (l *testLogger) Info(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx})
}
func (l *testLogger) Warn(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx})
}
func (l *testLogger) Error(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx})
}

func TestSetLogger(t *testing.T) {
t.Parallel()
Expand All @@ -1364,10 +1380,63 @@ func TestSetLogger(t *testing.T) {
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
}

if err := conn.Listen("foo"); err != nil {
t.Fatal(err)
}

if len(l1.logs) == 0 {
t.Fatal("Expected new logger l1 to be called, but it wasn't")
}

l2 := &testLogger{}
oldLogger = conn.SetLogger(l2)
if oldLogger != l1 {
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger)
}

if err := conn.Listen("bar"); err != nil {
t.Fatal(err)
}

if len(l2.logs) == 0 {
t.Fatal("Expected new logger l2 to be called, but it wasn't")
}
}

func TestSetLogLevel(t *testing.T) {
t.Parallel()

conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)

logger := &testLogger{}
conn.SetLogger(logger)

if _, err := conn.SetLogLevel(0); err != pgx.ErrInvalidLogLevel {
t.Fatal("SetLogLevel with invalid level did not return error")
}

if _, err := conn.SetLogLevel(pgx.LogLevelNone); err != nil {
t.Fatal(err)
}

if err := conn.Listen("foo"); err != nil {
t.Fatal(err)
}

if len(logger.logs) != 0 {
t.Fatalf("Expected logger not to be called, but it was: %v", logger.logs)
}

if _, err := conn.SetLogLevel(pgx.LogLevelTrace); err != nil {
t.Fatal(err)
}

if err := conn.Listen("bar"); err != nil {
t.Fatal(err)
}

if len(logger.logs) == 0 {
t.Fatal("Expected logger to be called, but it wasn't")
}
}
20 changes: 10 additions & 10 deletions msg_reader.go
Expand Up @@ -15,7 +15,7 @@ type msgReader struct {
msgBytesRemaining int32
err error
log func(lvl int, msg string, ctx ...interface{})
logLevel *int
shouldLog func(lvl int) bool
}

// Err returns any error that the msgReader has experienced
Expand All @@ -25,7 +25,7 @@ func (r *msgReader) Err() error {

// fatal tells r that a Fatal error has occurred
func (r *msgReader) fatal(err error) {
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
}
r.err = err
Expand All @@ -38,7 +38,7 @@ func (r *msgReader) rxMsg() (byte, error) {
}

if r.msgBytesRemaining > 0 {
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}

Expand Down Expand Up @@ -68,7 +68,7 @@ func (r *msgReader) readByte() byte {
return 0
}

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
}

Expand All @@ -95,7 +95,7 @@ func (r *msgReader) readInt16() int16 {

n := int16(binary.BigEndian.Uint16(b))

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}

Expand All @@ -122,7 +122,7 @@ func (r *msgReader) readInt32() int32 {

n := int32(binary.BigEndian.Uint32(b))

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}

Expand All @@ -149,7 +149,7 @@ func (r *msgReader) readInt64() int64 {

n := int64(binary.BigEndian.Uint64(b))

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}

Expand Down Expand Up @@ -180,7 +180,7 @@ func (r *msgReader) readCString() string {

s := string(b[0 : len(b)-1])

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}

Expand Down Expand Up @@ -214,7 +214,7 @@ func (r *msgReader) readString(count int32) string {

s := string(b)

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}

Expand All @@ -241,7 +241,7 @@ func (r *msgReader) readBytes(count int32) []byte {
return nil
}

if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
}

Expand Down
8 changes: 4 additions & 4 deletions query.go
Expand Up @@ -52,7 +52,7 @@ type Rows struct {
sql string
args []interface{}
log func(lvl int, msg string, ctx ...interface{})
logLevel *int
shouldLog func(lvl int) bool
unlockConn bool
}

Expand All @@ -78,11 +78,11 @@ func (rows *Rows) close() {
rows.closed = true

if rows.err == nil {
if *rows.logLevel >= LogLevelInfo {
if rows.shouldLog(LogLevelInfo) {
endTime := time.Now()
rows.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount)
}
} else if *rows.logLevel >= LogLevelError {
} else if rows.shouldLog(LogLevelError) {
rows.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args))
}
}
Expand Down Expand Up @@ -474,7 +474,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
// from Query and handle it in *Rows.
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
c.lastActivityTime = time.Now()
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, logLevel: &c.logLevel}
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, shouldLog: c.shouldLog}

if err := c.lock(); err != nil {
rows.abort(err)
Expand Down

0 comments on commit 0f7bf19

Please sign in to comment.