Skip to content

Commit

Permalink
Support for sub fiber's error handlers (#1560)
Browse files Browse the repository at this point in the history
- Mounted fiber and its sub apps error handlers are now saved a new
  errorHandlers map in App
- New public App.ErrorHandler method that wraps the logic for which
  error handler to user on any given context
- Error handler match logic based on request path <=> prefix accuracy
- Typo fixes
- Tests
  • Loading branch information
josebalius committed Oct 5, 2021
1 parent 2c6ffb7 commit 587f3ae
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 28 deletions.
102 changes: 76 additions & 26 deletions app.go
Expand Up @@ -109,6 +109,8 @@ type App struct {
getBytes func(s string) (b []byte)
// Converts byte slice to a string
getString func(b []byte) string
// mount prefix -> error handler
errorHandlers map[string]ErrorHandler
}

// Config is a struct holding the server settings.
Expand Down Expand Up @@ -426,9 +428,10 @@ func New(config ...Config) *App {
},
},
// Create config
config: Config{},
getBytes: utils.UnsafeBytes,
getString: utils.UnsafeString,
config: Config{},
getBytes: utils.UnsafeBytes,
getString: utils.UnsafeString,
errorHandlers: make(map[string]ErrorHandler),
}
// Override config if provided
if len(config) > 0 {
Expand Down Expand Up @@ -460,9 +463,11 @@ func New(config ...Config) *App {
if app.config.Immutable {
app.getBytes, app.getString = getBytesImmutable, getStringImmutable
}

if app.config.ErrorHandler == nil {
app.config.ErrorHandler = DefaultErrorHandler
}

if app.config.JSONEncoder == nil {
app.config.JSONEncoder = json.Marshal
}
Expand All @@ -487,7 +492,9 @@ func New(config ...Config) *App {

// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
// compose them as a single service using Mount. The fiber's error handler and
// any of the fiber's sub apps are added to the application's error handlers
// to be invoked on errors that happen within the prefix route.
func (app *App) Mount(prefix string, fiber *App) Router {
stack := fiber.Stack()
for m := range stack {
Expand All @@ -497,6 +504,15 @@ func (app *App) Mount(prefix string, fiber *App) Router {
}
}

// Save the fiber's error handler and its sub apps
prefix = strings.TrimRight(prefix, "/")
if fiber.config.ErrorHandler != nil {
app.errorHandlers[prefix] = fiber.config.ErrorHandler
}
for mountedPrefixes, errHandler := range fiber.errorHandlers {
app.errorHandlers[prefix+mountedPrefixes] = errHandler
}

atomic.AddUint32(&app.handlerCount, fiber.handlerCount)

return app
Expand Down Expand Up @@ -822,7 +838,7 @@ func (app *App) init() *App {
// lock application
app.mutex.Lock()

// Only load templates if an view engine is specified
// Only load templates if a view engine is specified
if app.config.Views != nil {
if err := app.config.Views.Load(); err != nil {
fmt.Printf("views: %v\n", err)
Expand All @@ -833,26 +849,7 @@ func (app *App) init() *App {
app.server = &fasthttp.Server{
Logger: &disableLogger{},
LogAllErrors: false,
ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
err = ErrRequestTimeout
} else {
err = ErrBadRequest
}
if catch := app.config.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
app.ReleaseCtx(c)
},
ErrorHandler: app.serverErrorHandler,
}

// fasthttp server settings
Expand Down Expand Up @@ -880,6 +877,60 @@ func (app *App) init() *App {
return app
}

// ErrorHandler is the application's method in charge of finding the
// appropiate handler for the given request. It searches any mounted
// sub fibers by their prefixes and if it finds a match, it uses that
// error handler. Otherwise it uses the configured error handler for
// the app, which if not set is the DefaultErrorHandler.
func (app *App) ErrorHandler(ctx *Ctx, err error) error {
var (
mountedErrHandler ErrorHandler
mountedPrefixParts int
)

for prefix, errHandler := range app.errorHandlers {
if strings.HasPrefix(ctx.path, prefix) {
parts := len(strings.Split(prefix, "/"))
if mountedPrefixParts <= parts {
mountedErrHandler = errHandler
mountedPrefixParts = parts
}
}
}

if mountedErrHandler != nil {
return mountedErrHandler(ctx, err)
}

return app.config.ErrorHandler(ctx, err)
}

// serverErrorHandler is a wrapper around the application's error handler method
// user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber
// errors before calling the application's error handler method.
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
err = ErrRequestTimeout
} else {
err = ErrBadRequest
}

if catch := app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}

app.ReleaseCtx(c)
}

// startupProcess Is the method which executes all the necessary processes just before the start of the server.
func (app *App) startupProcess() *App {
app.mutex.Lock()
Expand Down Expand Up @@ -961,7 +1012,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
}


scheme := "http"
if tls {
scheme = "https"
Expand Down
77 changes: 77 additions & 0 deletions app_test.go
Expand Up @@ -1439,3 +1439,80 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {

utils.AssertEqual(t, testString, string(body))
}

func Test_App_UseMountedErrorHandler(t *testing.T) {
app := New()

fiber := New(Config{
ErrorHandler: func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom error")
},
})
fiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})

app.Mount("/api", fiber)

resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")

b, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body")
}

func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) {
app := New()

tsf := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom sub sub fiber error")
}
tripleSubFiber := New(Config{
ErrorHandler: tsf,
})
tripleSubFiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})

sf := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom sub fiber error")
}
subfiber := New(Config{
ErrorHandler: sf,
})
subfiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
subfiber.Mount("/third", tripleSubFiber)

f := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom error")
}
fiber := New(Config{
ErrorHandler: f,
})
fiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
fiber.Mount("/sub", subfiber)

app.Mount("/api", fiber)

resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil))
utils.AssertEqual(t, nil, err, "/api/sub req")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")

b, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body")

resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil))
utils.AssertEqual(t, nil, err, "/api/sub/third req")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")

b, err = ioutil.ReadAll(resp2.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom sub sub fiber error", string(b), "Third fiber Response body")
}
2 changes: 1 addition & 1 deletion middleware/logger/logger.go
Expand Up @@ -145,7 +145,7 @@ func New(config ...Config) fiber.Handler {
}
}
// override error handler
errHandler = c.App().Config().ErrorHandler
errHandler = c.App().ErrorHandler
})

var start, stop time.Time
Expand Down
2 changes: 1 addition & 1 deletion router.go
Expand Up @@ -154,7 +154,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) {
// Find match in stack
match, err := app.next(c)
if err != nil {
if catch := c.app.config.ErrorHandler(c, err); catch != nil {
if catch := c.app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
}
Expand Down

0 comments on commit 587f3ae

Please sign in to comment.