Skip to content

Commit

Permalink
Tweak and test for NotFound handler
Browse files Browse the repository at this point in the history
Apply NotFound whenever the pattern doesn't match; don't apply the
Allowed header convenience behavior in that case.
  • Loading branch information
cespare committed Feb 17, 2016
1 parent ac52254 commit b9a994a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
19 changes: 13 additions & 6 deletions mux.go
Expand Up @@ -89,8 +89,14 @@ import (
// convenience, PatternServeMux will add the Allow header for requests that
// match a pattern for a method other than the method requested and set the
// Status to "405 Method Not Allowed".
//
// If the NotFound handler is set, then it is used whenever the pattern doesn't
// match the request path for the current method (and the Allow header is not
// altered).
type PatternServeMux struct {
// NotFound allows you to register a custom not found handler
// NotFound, if set, is used whenever the request doesn't match any
// pattern for its method. NotFound should be set before serving any
// requests.
NotFound http.Handler
handlers map[string][]*patHandler
}
Expand All @@ -113,6 +119,11 @@ func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

if p.NotFound != nil {
p.NotFound.ServeHTTP(w, r)
return
}

allowed := make([]string, 0, len(p.handlers))
for meth, handlers := range p.handlers {
if meth == r.Method {
Expand All @@ -127,11 +138,7 @@ func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

if len(allowed) == 0 {
if p.NotFound != nil {
p.NotFound.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
http.NotFound(w, r)
return
}

Expand Down
16 changes: 16 additions & 0 deletions mux_test.go
Expand Up @@ -200,6 +200,22 @@ func TestTail(t *testing.T) {
}
}

func TestNotFound(t *testing.T) {
p := New()
p.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(123)
})
p.Post("/bar", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

for _, path := range []string{"/foo", "/bar"} {
res := httptest.NewRecorder()
p.ServeHTTP(res, newRequest("GET", path, nil))
if res.Code != 123 {
t.Errorf("for path %q: got code %d; want 123", path, res.Code)
}
}
}

func newRequest(method, urlStr string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, urlStr, body)
if err != nil {
Expand Down

0 comments on commit b9a994a

Please sign in to comment.