Skip to content
This repository has been archived by the owner on Dec 23, 2023. It is now read-only.

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
razonyang committed Feb 12, 2020
1 parent d35ace8 commit c0cf2d1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
9 changes: 5 additions & 4 deletions router.go
Expand Up @@ -353,7 +353,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
err := route.handle(ctx)
if err != nil {
r.handleError(ctx, err)
r.HandleError(ctx, err)
}
r.putContext(ctx)
return
Expand Down Expand Up @@ -405,7 +405,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.MethodNotAllowed != nil {
r.MethodNotAllowed.ServeHTTP(w, req)
} else {
r.handleError(ctx, ErrMethodNotAllowed)
r.HandleError(ctx, ErrMethodNotAllowed)
}
return
}
Expand All @@ -415,11 +415,12 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.NotFound != nil {
r.NotFound.ServeHTTP(w, req)
} else {
r.handleError(ctx, ErrNotFound)
r.HandleError(ctx, ErrNotFound)
}
}

func (r *Router) handleError(ctx *Context, err error) {
// HandleError handles error.
func (r *Router) HandleError(ctx *Context, err error) {
if r.ErrorHandler != nil {
r.ErrorHandler.Handle(ctx, err)
return
Expand Down
53 changes: 53 additions & 0 deletions router_test.go
Expand Up @@ -928,3 +928,56 @@ func ExampleRouter_ServeFiles() {
// such as "/favicon.ico".
router.NotFound = http.FileServer(http.Dir("public"))
}

type testErrorHandler struct {
status int
}

func (eh testErrorHandler) Handle(ctx *Context, err error) {
ctx.Error(err.Error(), eh.status)
}

func TestRouter_ErrorHandler(t *testing.T) {
router := NewRouter()
router.ErrorHandler = &testErrorHandler{http.StatusInternalServerError}
router.Get("/error/:msg", func(ctx *Context) error {
return errors.New(ctx.Params.String("msg"))
})

msgs := []string{"foo", "bar"}
for _, msg := range msgs {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/error/"+msg, nil)
router.ServeHTTP(w, req)
if w.Body.String() != fmt.Sprintln(msg) {
t.Errorf("expected error body %q, got %q", fmt.Sprintln(msg), w.Body)
}
if w.Code != http.StatusInternalServerError {
t.Errorf("expected error status code %d, got %d", http.StatusInternalServerError, w.Code)
}
}
}

func TestRouter_HandleError(t *testing.T) {
router := NewRouter()
tests := []struct {
err error
body string
code int
}{
{errors.New("foo"), "foo", http.StatusInternalServerError},
{ErrNotFound, ErrNotFound.Error(), ErrNotFound.Code},
{ErrMethodNotAllowed, ErrMethodNotAllowed.Error(), ErrMethodNotAllowed.Code},
}
for _, test := range tests {
w := httptest.NewRecorder()
ctx := newContext(w, nil)
router.HandleError(ctx, test.err)
if w.Body.String() != fmt.Sprintln(test.body) {
t.Errorf("expected error body %q, got %q", test.body, w.Body)
}
if w.Code != test.code {
t.Errorf("expected error status code %d, got %d", test.code, w.Code)
}
}
}

0 comments on commit c0cf2d1

Please sign in to comment.