From 73d8fde4a9148cff9db6f60aacf6c9c22dda1fad Mon Sep 17 00:00:00 2001 From: Paul Bellamy Date: Tue, 29 Jul 2014 18:43:44 +0100 Subject: [PATCH] Added *CSRFHandler.ExemptFunc, for matching on more complex rules --- exempt.go | 16 ++++++++++++--- exempt_test.go | 56 +++++++++++++++++++++++++++++++++++--------------- handler.go | 4 +++- 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/exempt.go b/exempt.go index 4f76940..f49a444 100644 --- a/exempt.go +++ b/exempt.go @@ -2,15 +2,21 @@ package nosurf import ( "fmt" + "net/http" pathModule "path" "reflect" "regexp" ) -// Checks if the given path is exempt from CSRF checks. -// The function checks the exact paths first, +// Checks if the given request is exempt from CSRF checks. +// It checks the ExemptFunc first, then the exact paths, // then the globs and finally the regexps. -func (h *CSRFHandler) IsExempt(path string) bool { +func (h *CSRFHandler) IsExempt(r *http.Request) bool { + if h.exemptFunc != nil && h.exemptFunc(r) { + return true + } + + path := r.URL.Path if sContains(h.exemptPaths, path) { return true } @@ -96,3 +102,7 @@ func (h *CSRFHandler) ExemptRegexps(res ...interface{}) { h.ExemptRegexp(v) } } + +func (h *CSRFHandler) ExemptFunc(fn func(r *http.Request) bool) { + h.exemptFunc = fn +} diff --git a/exempt_test.go b/exempt_test.go index 905d04d..9a12254 100644 --- a/exempt_test.go +++ b/exempt_test.go @@ -1,6 +1,7 @@ package nosurf import ( + "net/http" "regexp" "testing" ) @@ -9,15 +10,16 @@ func TestExemptPath(t *testing.T) { // the handler doesn't matter here, let's use nil hand := New(nil) path := "/home" + exempt, _ := http.NewRequest("GET", path, nil) hand.ExemptPath(path) - if !hand.IsExempt(path) { - t.Errorf("%v is not exempt, but it should be", path) + if !hand.IsExempt(exempt) { + t.Errorf("%v is not exempt, but it should be", exempt.URL.Path) } - other := "/faq" + other, _ := http.NewRequest("GET", "/faq", nil) if hand.IsExempt(other) { - t.Errorf("%v is exempt, but it shouldn't be", other) + t.Errorf("%v is exempt, but it shouldn't be", other.URL.Path) } } @@ -27,13 +29,13 @@ func TestExemptPaths(t *testing.T) { hand.ExemptPaths(paths...) for _, v := range paths { - if !hand.IsExempt(v) { + request, _ := http.NewRequest("GET", v, nil) + if !hand.IsExempt(request) { t.Errorf("%v should be exempt, but it isn't", v) } } - other := "/accounts" - + other, _ := http.NewRequest("GET", "/accounts", nil) if hand.IsExempt(other) { t.Errorf("%v is exempt, but it shouldn't be", other) } @@ -45,22 +47,22 @@ func TestExemptGlob(t *testing.T) { hand.ExemptGlob(glob) - test := "/mail" + test, _ := http.NewRequest("GET", "/mail", nil) if !hand.IsExempt(test) { t.Errorf("%v should be exempt, but it isn't.", test) } - test = "/nail" + test, _ = http.NewRequest("GET", "/nail", nil) if !hand.IsExempt(test) { t.Errorf("%v should be exempt, but it isn't.", test) } - test = "/snail" + test, _ = http.NewRequest("GET", "/snail", nil) if hand.IsExempt(test) { t.Errorf("%v should not be exempt, but it is.", test) } - test = "/mail/outbox" + test, _ = http.NewRequest("GET", "/mail/outbox", nil) if hand.IsExempt(test) { t.Errorf("%v should not be exempt, but it is.", test) } @@ -80,13 +82,15 @@ func TestExemptGlobs(t *testing.T) { hand.ExemptGlobs(slice...) for _, v := range matching { - if !hand.IsExempt(v) { + test, _ := http.NewRequest("GET", v, nil) + if !hand.IsExempt(test) { t.Errorf("%v should be exempt, but it isn't.", v) } } for _, v := range nonMatching { - if hand.IsExempt(v) { + test, _ := http.NewRequest("GET", v, nil) + if hand.IsExempt(test) { t.Errorf("%v shouldn't be exempt, but it is", v) } } @@ -191,23 +195,41 @@ func TestExemptRegexpMatching(t *testing.T) { hand.ExemptRegexp(re) // valid - test := "/mail" + test, _ := http.NewRequest("GET", "/mail", nil) if !hand.IsExempt(test) { t.Errorf("%v should be exempt, but it isn't.", test) } - test = "/nail" + test, _ = http.NewRequest("GET", "/nail", nil) if !hand.IsExempt(test) { t.Errorf("%v should be exempt, but it isn't.", test) } - test = "/mail/outbox" + test, _ = http.NewRequest("GET", "/mail/outbox", nil) if hand.IsExempt(test) { t.Errorf("%v shouldn't be exempt, but it is.", test) } - test = "/snail" + test, _ = http.NewRequest("GET", "/snail", nil) if hand.IsExempt(test) { t.Errorf("%v shouldn't be exempt, but it is.", test) } } + +func TestExemptFunc(t *testing.T) { + // the handler doesn't matter here, let's use nil + hand := New(nil) + hand.ExemptFunc(func(r *http.Request) bool { + return r.Method == "GET" + }) + + test, _ := http.NewRequest("GET", "/path", nil) + if !hand.IsExempt(test) { + t.Errorf("%v is not exempt, but it should be", test) + } + + other, _ := http.NewRequest("POST", "/path", nil) + if hand.IsExempt(other) { + t.Errorf("%v is exempt, but it shouldn't be", other) + } +} diff --git a/handler.go b/handler.go index cc184f2..14a26b0 100644 --- a/handler.go +++ b/handler.go @@ -52,6 +52,8 @@ type CSRFHandler struct { exemptRegexps []*regexp.Regexp // ...or a glob (as used by path.Match()). exemptGlobs []string + // ...or a custom matcher function + exemptFunc func(r *http.Request) bool // All of those will be matched against Request.URL.Path, // So they should take the leading slash into account @@ -111,7 +113,7 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if h.IsExempt(r.URL.Path) { + if h.IsExempt(r) { h.handleSuccess(w, r) return }