diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index f9e8caafe..88591cea7 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,6 +1,7 @@ package middleware import ( + "bytes" "encoding/base64" "net/http" "strconv" @@ -15,7 +16,8 @@ type ( // Skipper defines a function to skip middleware. Skipper Skipper - // Validator is a function to validate BasicAuth credentials. + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic + // auth headers this function would be called once for each header until first valid result is returned // Required. Validator BasicAuthValidator @@ -71,30 +73,36 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + for i, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) + continue } - - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } + if i >= headerCountLimit { // guard against attacker maliciously sending huge amount of invalid headers + break + } + } + + if lastError != nil { + return lastError } realm := defaultRealm diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 20e769214..9a9ecd7fd 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -2,6 +2,7 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" @@ -11,11 +12,139 @@ import ( "github.com/stretchr/testify/assert" ) +func TestBasicAuthWithConfig(t *testing.T) { + validatorFunc := func(u, p string, c echo.Context) (bool, error) { + if u == "joe" && p == "secret" { + return true, nil + } + if u == "error" { + return false, errors.New(p) + } + return false, nil + } + defaultConfig := BasicAuthConfig{Validator: validatorFunc} + + // we can not add OK value here because ranging over map returns random order. We just try to trigger break + tooManyAuths := make([]string, 0) + for i := 0; i < extractorLimit+2; i++ { + tooManyAuths = append(tooManyAuths, basic+" "+base64.StdEncoding.EncodeToString([]byte("nope:nope"))) + } + + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, from multiple auth headers one is ok", + givenConfig: defaultConfig, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), // different type + basic + " NOT_BASE64", // invalid basic auth + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), // OK + }, + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, too many invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: tooManyAuths, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + mw := BasicAuthWithConfig(tc.givenConfig) + + h := mw(func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err := h(e.NewContext(req, res)) + + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} + func TestBasicAuth(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) f := func(u, p string, c echo.Context) (bool, error) { if u == "joe" && p == "secret" { return true, nil @@ -26,50 +155,11 @@ func TestBasicAuth(t *testing.T) { return c.String(http.StatusOK, "test") }) - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) + auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, auth) assert.NoError(t, h(c)) - - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) - - // Invalid base64 string - auth = basic + " invalidString" - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code) - - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) } diff --git a/middleware/middleware.go b/middleware/middleware.go index 664f71f45..943183f02 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -9,6 +9,12 @@ import ( "github.com/labstack/echo/v4" ) +const ( + // headerCountLimit is arbitrary number to limit number of headers processed. this limits possible resource exhaustion + // attack vector + headerCountLimit = 20 +) + type ( // Skipper defines a function to skip middleware. Returning true skips processing // the middleware.