Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue91 #93

Merged
merged 5 commits into from Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Expand Up @@ -4,5 +4,7 @@ go 1.12

require (
github.com/go-pkgz/expirable-cache v0.0.3
github.com/kr/pretty v0.1.0 // indirect
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
)
7 changes: 7 additions & 0 deletions go.sum
Expand Up @@ -2,6 +2,11 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-pkgz/expirable-cache v0.0.3 h1:rTh6qNPp78z0bQE6HDhXBHUwqnV9i09Vm6dksJLXQDc=
github.com/go-pkgz/expirable-cache v0.0.3/go.mod h1:+IauqN00R2FqNRLCLA+X5YljQJrwB179PfiAoMPlTlQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand All @@ -13,5 +18,7 @@ golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omN
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
169 changes: 150 additions & 19 deletions tollbooth.go
Expand Up @@ -43,17 +43,142 @@ func LimitByKeys(lmt *limiter.Limiter, keys []string) *errors.HTTPError {
return nil
}

// ShouldSkipLimiter is a series of filter that decides if request should be limited or not.
func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool {
// ---------------------------------
// Filter by remote ip
// If we are unable to find remoteIP, skip limiter
remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r)
if remoteIP == "" {
return true
}

// ---------------------------------
// Filter by request method
lmtMethods := lmt.GetMethods()
lmtMethodsIsSet := len(lmtMethods) > 0

if lmtMethodsIsSet {
// If request does not contain all of the methods in limiter,
// skip limiter
requestMethodDefinedInLimiter := libstring.StringInSlice(lmtMethods, r.Method)

if !requestMethodDefinedInLimiter {
return true
}
}

// ---------------------------------
// Filter by request headers
lmtHeaders := lmt.GetHeaders()
lmtHeadersIsSet := len(lmtHeaders) > 0

if lmtHeadersIsSet {
// If request does not contain all of the headers in limiter,
// skip limiter
requestHeadersDefinedInLimiter := false

for headerKey := range lmtHeaders {
reqHeaderValue := r.Header.Get(headerKey)
if reqHeaderValue != "" {
requestHeadersDefinedInLimiter = true
break
}
}

if !requestHeadersDefinedInLimiter {
return true
}

// ------------------------------
// If request contains the header key but not the values,
// skip limiter
requestHeadersDefinedInLimiter = false

for headerKey, headerValues := range lmtHeaders {
for _, headerValue := range headerValues {
if r.Header.Get(headerKey) == headerValue {
requestHeadersDefinedInLimiter = true
break
}
}
}

if !requestHeadersDefinedInLimiter {
return true
}
}

// ---------------------------------
// Filter by context values
lmtContextValues := lmt.GetContextValues()
lmtContextValuesIsSet := len(lmtContextValues) > 0

if lmtContextValuesIsSet {
// If request does not contain all of the contexts in limiter,
// skip limiter
requestContextValuesDefinedInLimiter := false

for contextKey := range lmtContextValues {
reqContextValue := fmt.Sprintf("%v", r.Context().Value(contextKey))
if reqContextValue != "" {
requestContextValuesDefinedInLimiter = true
break
}
}

if !requestContextValuesDefinedInLimiter {
return true
}

// ------------------------------
// If request contains the context key but not the values,
// skip limiter
requestContextValuesDefinedInLimiter = false

for contextKey, contextValues := range lmtContextValues {
for _, contextValue := range contextValues {
if r.Header.Get(contextKey) == contextValue {
requestContextValuesDefinedInLimiter = true
break
}
}
}

if !requestContextValuesDefinedInLimiter {
return true
}
}

// ---------------------------------
// Filter by basic auth usernames
lmtBasicAuthUsers := lmt.GetBasicAuthUsers()
lmtBasicAuthUsersIsSet := len(lmtBasicAuthUsers) > 0

if lmtBasicAuthUsersIsSet {
// If request does not contain all of the basic auth users in limiter,
// skip limiter
requestAuthUsernameDefinedInLimiter := false

username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
requestAuthUsernameDefinedInLimiter = true
}

if !requestAuthUsernameDefinedInLimiter {
return true
}
}

return false
}

// BuildKeys generates a slice of keys to rate-limit by given limiter and request structs.
func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r)
path := r.URL.Path
sliceKeys := make([][]string, 0)

// Don't BuildKeys if remoteIP is blank.
if remoteIP == "" {
return sliceKeys
}

lmtMethods := lmt.GetMethods()
lmtHeaders := lmt.GetHeaders()
lmtContextValues := lmt.GetContextValues()
Expand All @@ -63,11 +188,6 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
lmtContextValuesIsSet := len(lmtContextValues) > 0
lmtBasicAuthUsersIsSet := len(lmtBasicAuthUsers) > 0

method := ""
if lmtMethods != nil && libstring.StringInSlice(lmtMethods, r.Method) {
method = r.Method
}

usernameToLimit := ""
if lmtBasicAuthUsersIsSet {
username, _, ok := r.BasicAuth()
Expand Down Expand Up @@ -98,8 +218,6 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
}
}
}
} else {
headerValuesToLimit = append(headerValuesToLimit, []string{"", ""})
}

contextValuesToLimit := [][]string{}
Expand All @@ -111,11 +229,11 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
}

if len(contextValues) == 0 {
// If header values are empty, rate-limit all request containing headerKey.
// If context values are empty, rate-limit all request containing contextKey.
contextValuesToLimit = append(contextValuesToLimit, []string{contextKey, reqContextValue})

} else {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
// If context values are not empty, rate-limit all request with contextKey and contextValues.
for _, contextValue := range contextValues {
if reqContextValue == contextValue {
contextValuesToLimit = append(contextValuesToLimit, []string{contextKey, contextValue})
Expand All @@ -124,16 +242,24 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
}
}
}
} else {
contextValuesToLimit = append(contextValuesToLimit, []string{"", ""})
}

sliceKey := []string{remoteIP, path}

sliceKey = append(sliceKey, lmtMethods...)

for _, header := range headerValuesToLimit {
for _, contextValue := range contextValuesToLimit {
sliceKeys = append(sliceKeys, []string{remoteIP, path, method, header[0], header[1], contextValue[0], contextValue[1], usernameToLimit})
}
sliceKey = append(sliceKey, header[0], header[1])
}

for _, contextValue := range contextValuesToLimit {
sliceKey = append(sliceKey, contextValue[0], contextValue[1])
}

sliceKey = append(sliceKey, usernameToLimit)

sliceKeys = append(sliceKeys, sliceKey)

return sliceKeys
}

Expand All @@ -142,6 +268,11 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
func LimitByRequest(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request) *errors.HTTPError {
setResponseHeaders(lmt, w, r)

shouldSkip := ShouldSkipLimiter(lmt, r)
if shouldSkip {
return nil
}

sliceKeys := BuildKeys(lmt, r)

// Loop sliceKeys and check if one of them has error.
Expand Down
75 changes: 74 additions & 1 deletion tollbooth_bug_report_test.go
Expand Up @@ -107,7 +107,7 @@ limiter.headers: %v`,
// 2nd, 429
response, _ = client.Do(request)
if response.StatusCode != http.StatusTooManyRequests {
t.Fatalf(`Both customer must pass rate limiter.
t.Fatalf(`Both customer must fail rate limiter.
Expected to receive: %v status code. Got: %v`,
http.StatusTooManyRequests, response.StatusCode)
}
Expand All @@ -131,3 +131,76 @@ Expected to receive: %v status code. Got: %v`,
}
}
}

func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) {
requestsPerSecond := float64(1)

lmt := NewLimiter(requestsPerSecond, nil)
lmt.SetMethods([]string{"POST"})

methods := lmt.GetMethods()
if methods[0] != "POST" {
t.Fatalf("Failed to set methods correctly. Expected: POST Got: %v", methods[0])
}

// -------------------------------------------------------------------

handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`hello world`))
}))

// Create GET HTTP request
req, _ := http.NewRequest("GET", "/doesntmatter", nil)
req.RemoteAddr = "127.0.0.1"

// We should never reach the limit because we are sending 10 GET requests and
// we are only limiting POST requests.
for i := 0; i < 10; i++ {
start := time.Now()

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if rr.Code != http.StatusOK {
t.Fatalf("Should be able to handle %v reqs/second. HTTP status: %v. Expected HTTP status: %v. Failed in %v microseconds", requestsPerSecond, rr.Code, http.StatusOK, time.Since(start).Microseconds())
}
}
}

func Test_Issue91_BrokenSetMethod_BlockPost(t *testing.T) {
requestsPerSecond := float64(1)

lmt := NewLimiter(requestsPerSecond, nil)
lmt.SetMethods([]string{"POST"})

limitReachedCounter := 0
lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) {
limitReachedCounter++
})

methods := lmt.GetMethods()
if methods[0] != "POST" {
t.Fatalf("Failed to set methods correctly. Expected: POST Got: %v", methods[0])
}

// -------------------------------------------------------------------

handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`hello world`))
}))

// Create POST HTTP request
req, _ := http.NewRequest("POST", "/blockmeafter2", nil)
req.RemoteAddr = "127.0.0.1"

// We should reach the limit because we are sending 2 POST requests and
// our limiter is 1 POST per second.
for i := 0; i < 2; i++ {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}

if limitReachedCounter == 0 {
t.Fatalf("Should have reached limit. Limit reached counter: %d", limitReachedCounter)
}
}