Skip to content

Commit

Permalink
Change LogLevelSetter to LogErrorFunc
Browse files Browse the repository at this point in the history
LogErrorFunc provides more general interface to handle errors
in the recover middleware.
  • Loading branch information
ant1k9 committed Jan 17, 2022
1 parent d4ad69d commit 4730deb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 30 deletions.
31 changes: 18 additions & 13 deletions middleware/recover.go
Expand Up @@ -9,8 +9,8 @@ import (
)

type (
// LogLevelSetter defines a function to get log level for the recovered value.
LogLevelSetter func(value interface{}) log.Lvl
// LogErrorFunc defines a function for custom logging in the middleware.
LogErrorFunc func(c echo.Context, err error, stack []byte) error

// RecoverConfig defines the config for Recover middleware.
RecoverConfig struct {
Expand All @@ -34,9 +34,9 @@ type (
// Optional. Default value 0 (Print).
LogLevel log.Lvl

// LogLevelSetter defines a function to get log level for the recovered value.
// LogLevelSetter has higher priority than LogLevel when it's set.
LogLevelSetter LogLevelSetter
// LogErrorFunc defines a function for custom logging in the middleware.
// If it's set you don't need to provide LogLevel for config.
LogErrorFunc LogErrorFunc
}
)

Expand All @@ -48,7 +48,7 @@ var (
DisableStackAll: false,
DisablePrintStack: false,
LogLevel: 0,
LogLevelSetter: nil,
LogErrorFunc: nil,
}
)

Expand Down Expand Up @@ -81,15 +81,20 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
if !ok {
err = fmt.Errorf("%v", r)
}
logLevel := config.LogLevel
if config.LogLevelSetter != nil {
logLevel = config.LogLevelSetter(r)
}
stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll)
var stack []byte
var length int

if !config.DisablePrintStack {
stack = make([]byte, config.StackSize)
length = runtime.Stack(stack, !config.DisableStackAll)
stack = stack[:length]
}

if config.LogErrorFunc != nil {
err = config.LogErrorFunc(c, err, stack)
} else if !config.DisablePrintStack {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
switch logLevel {
switch config.LogLevel {
case log.DEBUG:
c.Logger().Debug(msg)
case log.INFO:
Expand Down
54 changes: 37 additions & 17 deletions middleware/recover_test.go
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"bytes"
"errors"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -82,7 +83,7 @@ func TestRecoverWithConfig_LogLevel(t *testing.T) {
}
}

func TestRecoverWithConfig_LogLevelSetter(t *testing.T) {
func TestRecoverWithConfig_LogErrorFunc(t *testing.T) {
e := echo.New()
e.Logger.SetLevel(log.DEBUG)

Expand All @@ -93,24 +94,43 @@ func TestRecoverWithConfig_LogLevelSetter(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

testError := errors.New("test")
config := DefaultRecoverConfig
config.LogLevelSetter = func(value interface{}) log.Lvl {
if s, ok := value.(string); ok {
if s == "test" {
return log.DEBUG
}
config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack)
if errors.Is(err, testError) {
c.Logger().Debug(msg)
} else {
c.Logger().Error(msg)
}
return log.ERROR
return err
}
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic("test")
}))

h(c)

assert.Equal(t, http.StatusInternalServerError, rec.Code)

output := buf.String()
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, `"level":"DEBUG"`)
t.Run("first branch case for LogErrorFunc", func(t *testing.T) {
buf.Reset()
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic(testError)
}))

h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)

output := buf.String()
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, `"level":"DEBUG"`)
})

t.Run("else branch case for LogErrorFunc", func(t *testing.T) {
buf.Reset()
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic("other")
}))

h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)

output := buf.String()
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, `"level":"ERROR"`)
})
}

0 comments on commit 4730deb

Please sign in to comment.