diff --git a/echo.go b/echo.go index 2d66ec5b4..a2110833e 100644 --- a/echo.go +++ b/echo.go @@ -133,7 +133,8 @@ const ( Location = "Location" Upgrade = "Upgrade" Vary = "Vary" - + Basic = "Basic" + WWWAuthenticate = "WWW-Authenticate" //----------- // Protocols //----------- diff --git a/middleware/auth.go b/middleware/auth.go index 1dc5c4ea9..ac2a71a6c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -11,14 +11,9 @@ type ( BasicValidateFunc func(string, string) bool ) -const ( - Basic = "Basic" -) - // BasicAuth returns an HTTP basic authentication middleware. // // For valid credentials it calls the next handler. -// For invalid Authorization header it sends "404 - Bad Request" response. // For invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { return func(c *echo.Context) error { @@ -28,10 +23,10 @@ func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { } auth := c.Request().Header.Get(echo.Authorization) - l := len(Basic) - he := echo.NewHTTPError(http.StatusBadRequest) + l := len(echo.Basic) + he := echo.NewHTTPError(http.StatusUnauthorized) - if len(auth) > l+1 && auth[:l] == Basic { + if len(auth) > l+1 && auth[:l] == echo.Basic { b, err := base64.StdEncoding.DecodeString(auth[l+1:]) if err == nil { cred := string(b) @@ -41,11 +36,12 @@ func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { if fn(cred[:i], cred[i+1:]) { return nil } - he.SetCode(http.StatusUnauthorized) } } } } + + c.Response().Header().Add(echo.WWWAuthenticate, echo.Basic) return he } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index c953d9278..7c864d5aa 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -23,7 +23,7 @@ func TestBasicAuth(t *testing.T) { ba := BasicAuth(fn) // Valid credentials - auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) + auth := echo.Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.Authorization, auth) assert.NoError(t, ba(c)) @@ -32,7 +32,7 @@ func TestBasicAuth(t *testing.T) { //--------------------- // Incorrect password - auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) + auth = echo.Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) req.Header.Set(echo.Authorization, auth) he := ba(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code()) @@ -40,13 +40,13 @@ func TestBasicAuth(t *testing.T) { // Empty Authorization header req.Header.Set(echo.Authorization, "") he = ba(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code()) + assert.Equal(t, http.StatusUnauthorized, he.Code()) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) req.Header.Set(echo.Authorization, auth) he = ba(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code()) + assert.Equal(t, http.StatusUnauthorized, he.Code()) // WebSocket c.Request().Header.Set(echo.Upgrade, echo.WebSocket)