From b9a994a8f6a5945800427435c904865e76dd6d42 Mon Sep 17 00:00:00 2001 From: Caleb Spare Date: Wed, 17 Feb 2016 00:48:21 -0800 Subject: [PATCH] Tweak and test for NotFound handler Apply NotFound whenever the pattern doesn't match; don't apply the Allowed header convenience behavior in that case. --- mux.go | 19 +++++++++++++------ mux_test.go | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/mux.go b/mux.go index c5e0edc..035f9c2 100644 --- a/mux.go +++ b/mux.go @@ -89,8 +89,14 @@ import ( // convenience, PatternServeMux will add the Allow header for requests that // match a pattern for a method other than the method requested and set the // Status to "405 Method Not Allowed". +// +// If the NotFound handler is set, then it is used whenever the pattern doesn't +// match the request path for the current method (and the Allow header is not +// altered). type PatternServeMux struct { - // NotFound allows you to register a custom not found handler + // NotFound, if set, is used whenever the request doesn't match any + // pattern for its method. NotFound should be set before serving any + // requests. NotFound http.Handler handlers map[string][]*patHandler } @@ -113,6 +119,11 @@ func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + if p.NotFound != nil { + p.NotFound.ServeHTTP(w, r) + return + } + allowed := make([]string, 0, len(p.handlers)) for meth, handlers := range p.handlers { if meth == r.Method { @@ -127,11 +138,7 @@ func (p *PatternServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if len(allowed) == 0 { - if p.NotFound != nil { - p.NotFound.ServeHTTP(w, r) - } else { - http.NotFound(w, r) - } + http.NotFound(w, r) return } diff --git a/mux_test.go b/mux_test.go index a795c4a..eeb5a96 100644 --- a/mux_test.go +++ b/mux_test.go @@ -200,6 +200,22 @@ func TestTail(t *testing.T) { } } +func TestNotFound(t *testing.T) { + p := New() + p.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(123) + }) + p.Post("/bar", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + for _, path := range []string{"/foo", "/bar"} { + res := httptest.NewRecorder() + p.ServeHTTP(res, newRequest("GET", path, nil)) + if res.Code != 123 { + t.Errorf("for path %q: got code %d; want 123", path, res.Code) + } + } +} + func newRequest(method, urlStr string, body io.Reader) *http.Request { req, err := http.NewRequest(method, urlStr, body) if err != nil {