Navigation Menu

Skip to content

Commit

Permalink
Added *CSRFHandler.ExemptFunc, for matching on more complex rules
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbellamy committed Jul 29, 2014
1 parent ec154df commit 73d8fde
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 21 deletions.
16 changes: 13 additions & 3 deletions exempt.go
Expand Up @@ -2,15 +2,21 @@ package nosurf

import (
"fmt"
"net/http"
pathModule "path"
"reflect"
"regexp"
)

// Checks if the given path is exempt from CSRF checks.
// The function checks the exact paths first,
// Checks if the given request is exempt from CSRF checks.
// It checks the ExemptFunc first, then the exact paths,
// then the globs and finally the regexps.
func (h *CSRFHandler) IsExempt(path string) bool {
func (h *CSRFHandler) IsExempt(r *http.Request) bool {
if h.exemptFunc != nil && h.exemptFunc(r) {
return true
}

path := r.URL.Path
if sContains(h.exemptPaths, path) {
return true
}
Expand Down Expand Up @@ -96,3 +102,7 @@ func (h *CSRFHandler) ExemptRegexps(res ...interface{}) {
h.ExemptRegexp(v)
}
}

func (h *CSRFHandler) ExemptFunc(fn func(r *http.Request) bool) {
h.exemptFunc = fn
}
56 changes: 39 additions & 17 deletions exempt_test.go
@@ -1,6 +1,7 @@
package nosurf

import (
"net/http"
"regexp"
"testing"
)
Expand All @@ -9,15 +10,16 @@ func TestExemptPath(t *testing.T) {
// the handler doesn't matter here, let's use nil
hand := New(nil)
path := "/home"
exempt, _ := http.NewRequest("GET", path, nil)

hand.ExemptPath(path)
if !hand.IsExempt(path) {
t.Errorf("%v is not exempt, but it should be", path)
if !hand.IsExempt(exempt) {
t.Errorf("%v is not exempt, but it should be", exempt.URL.Path)
}

other := "/faq"
other, _ := http.NewRequest("GET", "/faq", nil)
if hand.IsExempt(other) {
t.Errorf("%v is exempt, but it shouldn't be", other)
t.Errorf("%v is exempt, but it shouldn't be", other.URL.Path)
}
}

Expand All @@ -27,13 +29,13 @@ func TestExemptPaths(t *testing.T) {
hand.ExemptPaths(paths...)

for _, v := range paths {
if !hand.IsExempt(v) {
request, _ := http.NewRequest("GET", v, nil)
if !hand.IsExempt(request) {
t.Errorf("%v should be exempt, but it isn't", v)
}
}

other := "/accounts"

other, _ := http.NewRequest("GET", "/accounts", nil)
if hand.IsExempt(other) {
t.Errorf("%v is exempt, but it shouldn't be", other)
}
Expand All @@ -45,22 +47,22 @@ func TestExemptGlob(t *testing.T) {

hand.ExemptGlob(glob)

test := "/mail"
test, _ := http.NewRequest("GET", "/mail", nil)
if !hand.IsExempt(test) {
t.Errorf("%v should be exempt, but it isn't.", test)
}

test = "/nail"
test, _ = http.NewRequest("GET", "/nail", nil)
if !hand.IsExempt(test) {
t.Errorf("%v should be exempt, but it isn't.", test)
}

test = "/snail"
test, _ = http.NewRequest("GET", "/snail", nil)
if hand.IsExempt(test) {
t.Errorf("%v should not be exempt, but it is.", test)
}

test = "/mail/outbox"
test, _ = http.NewRequest("GET", "/mail/outbox", nil)
if hand.IsExempt(test) {
t.Errorf("%v should not be exempt, but it is.", test)
}
Expand All @@ -80,13 +82,15 @@ func TestExemptGlobs(t *testing.T) {
hand.ExemptGlobs(slice...)

for _, v := range matching {
if !hand.IsExempt(v) {
test, _ := http.NewRequest("GET", v, nil)
if !hand.IsExempt(test) {
t.Errorf("%v should be exempt, but it isn't.", v)
}
}

for _, v := range nonMatching {
if hand.IsExempt(v) {
test, _ := http.NewRequest("GET", v, nil)
if hand.IsExempt(test) {
t.Errorf("%v shouldn't be exempt, but it is", v)
}
}
Expand Down Expand Up @@ -191,23 +195,41 @@ func TestExemptRegexpMatching(t *testing.T) {
hand.ExemptRegexp(re)

// valid
test := "/mail"
test, _ := http.NewRequest("GET", "/mail", nil)
if !hand.IsExempt(test) {
t.Errorf("%v should be exempt, but it isn't.", test)
}

test = "/nail"
test, _ = http.NewRequest("GET", "/nail", nil)
if !hand.IsExempt(test) {
t.Errorf("%v should be exempt, but it isn't.", test)
}

test = "/mail/outbox"
test, _ = http.NewRequest("GET", "/mail/outbox", nil)
if hand.IsExempt(test) {
t.Errorf("%v shouldn't be exempt, but it is.", test)
}

test = "/snail"
test, _ = http.NewRequest("GET", "/snail", nil)
if hand.IsExempt(test) {
t.Errorf("%v shouldn't be exempt, but it is.", test)
}
}

func TestExemptFunc(t *testing.T) {
// the handler doesn't matter here, let's use nil
hand := New(nil)
hand.ExemptFunc(func(r *http.Request) bool {
return r.Method == "GET"
})

test, _ := http.NewRequest("GET", "/path", nil)
if !hand.IsExempt(test) {
t.Errorf("%v is not exempt, but it should be", test)
}

other, _ := http.NewRequest("POST", "/path", nil)
if hand.IsExempt(other) {
t.Errorf("%v is exempt, but it shouldn't be", other)
}
}
4 changes: 3 additions & 1 deletion handler.go
Expand Up @@ -52,6 +52,8 @@ type CSRFHandler struct {
exemptRegexps []*regexp.Regexp
// ...or a glob (as used by path.Match()).
exemptGlobs []string
// ...or a custom matcher function
exemptFunc func(r *http.Request) bool

// All of those will be matched against Request.URL.Path,
// So they should take the leading slash into account
Expand Down Expand Up @@ -111,7 +113,7 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

if h.IsExempt(r.URL.Path) {
if h.IsExempt(r) {
h.handleSuccess(w, r)
return
}
Expand Down

0 comments on commit 73d8fde

Please sign in to comment.