Skip to content

Commit

Permalink
feat: Custom authorization failure handler
Browse files Browse the repository at this point in the history
Added an option to the http.WithHeaderAuthorization middleware to modify
the default response in case of authentication failure.
  • Loading branch information
gkats committed Apr 15, 2024
1 parent 195ed86 commit c5ecf74
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 8 deletions.
27 changes: 25 additions & 2 deletions http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt
if params.Clock == nil {
params.Clock = clerk.NewClock()
}
if params.AuthorizationFailureHandler == nil {
params.AuthorizationFailureHandler = http.HandlerFunc(defaultAuthorizationFailureHandler)
}

authorization := strings.TrimSpace(r.Header.Get("Authorization"))
if authorization == "" {
Expand All @@ -65,14 +68,14 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt
if params.JWK == nil {
params.JWK, err = getJWK(r.Context(), params.JWKSClient, decoded.KeyID, params.Clock)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
params.AuthorizationFailureHandler.ServeHTTP(w, r)
return
}
}
params.Token = token
claims, err := jwt.Verify(r.Context(), &params.VerifyParams)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
params.AuthorizationFailureHandler.ServeHTTP(w, r)
return
}

Expand All @@ -83,6 +86,10 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt
}
}

func defaultAuthorizationFailureHandler(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}

// Retrieve the JSON web key for the provided token from the JWKS set.
// Tries a cached value first, but if there's no value or the entry
// has expired, it will fetch the JWK set from the API and cache the
Expand All @@ -109,6 +116,11 @@ func getJWK(ctx context.Context, jwksClient *jwks.Client, kid string, clock cler

type AuthorizationParams struct {
jwt.VerifyParams
// AuthorizationFailureHandler gets executed when request authorization
// fails. Pass a custom http.Handler to control the http.Response for
// invalid authorization. The default is a Response with an empty body
// and 401 Unauthorized status.
AuthorizationFailureHandler http.Handler
// JWKSClient is the jwks.Client that will be used to fetch the
// JSON Web Key Set. A default client will be used if none is
// provided.
Expand All @@ -119,6 +131,17 @@ type AuthorizationParams struct {
// authorization options.
type AuthorizationOption func(*AuthorizationParams) error

// AuthorizationFailureHandler allows to provide a handler that
// writes the response in case of authorization failures.
// The default behavior is a response with an empty body and 401
// Unauthorized status.
func AuthorizationFailureHandler(h http.Handler) AuthorizationOption {
return func(params *AuthorizationParams) error {
params.AuthorizationFailureHandler = h
return nil
}
}

// AuthorizedParty allows to provide a handler that accepts the
// 'azp' claim.
// The handler can be used to perform validations on the azp claim
Expand Down
97 changes: 91 additions & 6 deletions http/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@ import (
)

func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) {
kid := "kid-" + t.Name()
// Mock the Clerk API server. We expect requests to GET /jwks.
clerkAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/jwks" && r.Method == http.MethodGet {
_, err := w.Write([]byte(
fmt.Sprintf(
`{"keys":[{"use":"sig","kty":"RSA","kid":"%s","alg":"RS256","n":"ypsS9Iq26F71B3lPjT_IMtglDXo8Dko9h5UBmrvkWo6pdH_4zmMjeghozaHY1aQf1dHUBLsov_XvG_t-1yf7tFfO_ImC1JqSQwdSjrXZp3oMNFHwdwAknvtlBg3sBxJ8nM1WaCWaTlb2JhEmczIji15UG6V0M2cAp2VK_brcylQROaJLC2zVa4usGi4AHzAHaRUTv6XB9bGYMvkM-ZniuXgp9dPurisIIWg25DGrTaH-kg8LPaqGwa54eLEnvfAe0ZH_MvA4_bn_u_iDkQ9ZI_CD1vwf0EDnzLgd9ZG1khGsqmXY_4WiLRGsPqZe90HzaBJma9sAxXB4qj_aNnwD5w","e":"AQAB"}]}`,
kid,
),
))
require.NoError(t, err)
return
}
}))
defer clerkAPI.Close()

// Mock the clerk backend
clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: clerkAPI.Client(),
URL: &clerkAPI.URL,
}))

// This is the user's server, guarded by Clerk's middleware.
ts := httptest.NewServer(WithHeaderAuthorization()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := clerk.SessionClaimsFromContext(r.Context())
require.False(t, ok)
Expand All @@ -21,11 +44,6 @@ func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) {
})))
defer ts.Close()

clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: ts.Client(),
URL: &ts.URL,
}))

// Request without Authorization header
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
require.NoError(t, err)
Expand All @@ -38,6 +56,19 @@ func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) {
res, err = ts.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

// Request with unverifiable Bearer token
tokenClaims := map[string]any{
"sid": "sess_123",
}
token, _ := clerktest.GenerateJWT(t, tokenClaims, kid)
req, err = http.NewRequest(http.MethodGet, ts.URL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
require.NoError(t, err)
res, err = ts.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
}

func TestRequireHeaderAuthorization_InvalidAuthorization(t *testing.T) {
Expand Down Expand Up @@ -67,7 +98,7 @@ func TestRequireHeaderAuthorization_InvalidAuthorization(t *testing.T) {
}

func TestWithHeaderAuthorization_Caching(t *testing.T) {
kid := "kid"
kid := "kid-" + t.Name()
clock := clerktest.NewClockAt(time.Now().UTC())

// Mock the Clerk API server. We expect requests to GET /jwks.
Expand Down Expand Up @@ -134,6 +165,60 @@ func TestWithHeaderAuthorization_Caching(t *testing.T) {
require.Equal(t, 2, totalJWKSRequests)
}

func TestWithHeaderAuthorization_CustomFailureHandler(t *testing.T) {
kid := "kid-" + t.Name()
// Mock the Clerk API server. We expect requests to GET /jwks.
clerkAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/jwks" && r.Method == http.MethodGet {
_, err := w.Write([]byte(
fmt.Sprintf(
`{"keys":[{"use":"sig","kty":"RSA","kid":"%s","alg":"RS256","n":"ypsS9Iq26F71B3lPjT_IMtglDXo8Dko9h5UBmrvkWo6pdH_4zmMjeghozaHY1aQf1dHUBLsov_XvG_t-1yf7tFfO_ImC1JqSQwdSjrXZp3oMNFHwdwAknvtlBg3sBxJ8nM1WaCWaTlb2JhEmczIji15UG6V0M2cAp2VK_brcylQROaJLC2zVa4usGi4AHzAHaRUTv6XB9bGYMvkM-ZniuXgp9dPurisIIWg25DGrTaH-kg8LPaqGwa54eLEnvfAe0ZH_MvA4_bn_u_iDkQ9ZI_CD1vwf0EDnzLgd9ZG1khGsqmXY_4WiLRGsPqZe90HzaBJma9sAxXB4qj_aNnwD5w","e":"AQAB"}]}`,
kid,
),
))
require.NoError(t, err)
return
}
}))
defer clerkAPI.Close()

// Define a custom failure handler which returns a custom HTTP
// status code.
customFailureHandler := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
}

// Apply the custom failure handler to the WithHeaderAuthorization
// middleware.
middleware := WithHeaderAuthorization(
AuthorizationFailureHandler(http.HandlerFunc(customFailureHandler)),
)
// This is the user's server, guarded by Clerk's http middleware.
ts := httptest.NewServer(middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := clerk.SessionClaimsFromContext(r.Context())
require.False(t, ok)
_, err := w.Write([]byte("{}"))
require.NoError(t, err)
})))
defer ts.Close()

clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: clerkAPI.Client(),
URL: &clerkAPI.URL,
}))

tokenClaims := map[string]any{
"sid": "sess_123",
}
token, _ := clerktest.GenerateJWT(t, tokenClaims, kid)
// Request with invalid Authorization header
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)

Check failure on line 215 in http/middleware_test.go

View workflow job for this annotation

GitHub Actions / Lint

ineffectual assignment to err (ineffassign)
req.Header.Set("Authorization", "Bearer "+token)
res, err := ts.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusTeapot, res.StatusCode)
}

func TestAuthorizedPartyFunc(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
Expand Down

0 comments on commit c5ecf74

Please sign in to comment.