Skip to content

Commit

Permalink
feat: Custom authorization failure handler (#283)
Browse files Browse the repository at this point in the history
Added an option to the http.WithHeaderAuthorization middleware to modify
the default response in case of authentication failure.
  • Loading branch information
gkats committed Apr 16, 2024
1 parent 195ed86 commit 993c45a
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 8 deletions.
27 changes: 25 additions & 2 deletions http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt
if params.Clock == nil {
params.Clock = clerk.NewClock()
}
if params.AuthorizationFailureHandler == nil {
params.AuthorizationFailureHandler = http.HandlerFunc(defaultAuthorizationFailureHandler)
}

authorization := strings.TrimSpace(r.Header.Get("Authorization"))
if authorization == "" {
Expand All @@ -65,14 +68,14 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt
if params.JWK == nil {
params.JWK, err = getJWK(r.Context(), params.JWKSClient, decoded.KeyID, params.Clock)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
params.AuthorizationFailureHandler.ServeHTTP(w, r)
return
}
}
params.Token = token
claims, err := jwt.Verify(r.Context(), &params.VerifyParams)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
params.AuthorizationFailureHandler.ServeHTTP(w, r)
return
}

Expand All @@ -83,6 +86,10 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt
}
}

func defaultAuthorizationFailureHandler(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}

// Retrieve the JSON web key for the provided token from the JWKS set.
// Tries a cached value first, but if there's no value or the entry
// has expired, it will fetch the JWK set from the API and cache the
Expand All @@ -109,6 +116,11 @@ func getJWK(ctx context.Context, jwksClient *jwks.Client, kid string, clock cler

type AuthorizationParams struct {
jwt.VerifyParams
// AuthorizationFailureHandler gets executed when request authorization
// fails. Pass a custom http.Handler to control the http.Response for
// invalid authorization. The default is a Response with an empty body
// and 401 Unauthorized status.
AuthorizationFailureHandler http.Handler
// JWKSClient is the jwks.Client that will be used to fetch the
// JSON Web Key Set. A default client will be used if none is
// provided.
Expand All @@ -119,6 +131,17 @@ type AuthorizationParams struct {
// authorization options.
type AuthorizationOption func(*AuthorizationParams) error

// AuthorizationFailureHandler allows to provide a handler that
// writes the response in case of authorization failures.
// The default behavior is a response with an empty body and 401
// Unauthorized status.
func AuthorizationFailureHandler(h http.Handler) AuthorizationOption {
return func(params *AuthorizationParams) error {
params.AuthorizationFailureHandler = h
return nil
}
}

// AuthorizedParty allows to provide a handler that accepts the
// 'azp' claim.
// The handler can be used to perform validations on the azp claim
Expand Down
98 changes: 92 additions & 6 deletions http/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@ import (
)

func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) {
kid := "kid-" + t.Name()
// Mock the Clerk API server. We expect requests to GET /jwks.
clerkAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/jwks" && r.Method == http.MethodGet {
_, err := w.Write([]byte(
fmt.Sprintf(
`{"keys":[{"use":"sig","kty":"RSA","kid":"%s","alg":"RS256","n":"ypsS9Iq26F71B3lPjT_IMtglDXo8Dko9h5UBmrvkWo6pdH_4zmMjeghozaHY1aQf1dHUBLsov_XvG_t-1yf7tFfO_ImC1JqSQwdSjrXZp3oMNFHwdwAknvtlBg3sBxJ8nM1WaCWaTlb2JhEmczIji15UG6V0M2cAp2VK_brcylQROaJLC2zVa4usGi4AHzAHaRUTv6XB9bGYMvkM-ZniuXgp9dPurisIIWg25DGrTaH-kg8LPaqGwa54eLEnvfAe0ZH_MvA4_bn_u_iDkQ9ZI_CD1vwf0EDnzLgd9ZG1khGsqmXY_4WiLRGsPqZe90HzaBJma9sAxXB4qj_aNnwD5w","e":"AQAB"}]}`,
kid,
),
))
require.NoError(t, err)
return
}
}))
defer clerkAPI.Close()

// Mock the clerk backend
clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: clerkAPI.Client(),
URL: &clerkAPI.URL,
}))

// This is the user's server, guarded by Clerk's middleware.
ts := httptest.NewServer(WithHeaderAuthorization()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := clerk.SessionClaimsFromContext(r.Context())
require.False(t, ok)
Expand All @@ -21,11 +44,6 @@ func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) {
})))
defer ts.Close()

clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: ts.Client(),
URL: &ts.URL,
}))

// Request without Authorization header
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
require.NoError(t, err)
Expand All @@ -38,6 +56,19 @@ func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) {
res, err = ts.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

// Request with unverifiable Bearer token
tokenClaims := map[string]any{
"sid": "sess_123",
}
token, _ := clerktest.GenerateJWT(t, tokenClaims, kid)
req, err = http.NewRequest(http.MethodGet, ts.URL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
require.NoError(t, err)
res, err = ts.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
}

func TestRequireHeaderAuthorization_InvalidAuthorization(t *testing.T) {
Expand Down Expand Up @@ -67,7 +98,7 @@ func TestRequireHeaderAuthorization_InvalidAuthorization(t *testing.T) {
}

func TestWithHeaderAuthorization_Caching(t *testing.T) {
kid := "kid"
kid := "kid-" + t.Name()
clock := clerktest.NewClockAt(time.Now().UTC())

// Mock the Clerk API server. We expect requests to GET /jwks.
Expand Down Expand Up @@ -134,6 +165,61 @@ func TestWithHeaderAuthorization_Caching(t *testing.T) {
require.Equal(t, 2, totalJWKSRequests)
}

func TestWithHeaderAuthorization_CustomFailureHandler(t *testing.T) {
kid := "kid-" + t.Name()
// Mock the Clerk API server. We expect requests to GET /jwks.
clerkAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/jwks" && r.Method == http.MethodGet {
_, err := w.Write([]byte(
fmt.Sprintf(
`{"keys":[{"use":"sig","kty":"RSA","kid":"%s","alg":"RS256","n":"ypsS9Iq26F71B3lPjT_IMtglDXo8Dko9h5UBmrvkWo6pdH_4zmMjeghozaHY1aQf1dHUBLsov_XvG_t-1yf7tFfO_ImC1JqSQwdSjrXZp3oMNFHwdwAknvtlBg3sBxJ8nM1WaCWaTlb2JhEmczIji15UG6V0M2cAp2VK_brcylQROaJLC2zVa4usGi4AHzAHaRUTv6XB9bGYMvkM-ZniuXgp9dPurisIIWg25DGrTaH-kg8LPaqGwa54eLEnvfAe0ZH_MvA4_bn_u_iDkQ9ZI_CD1vwf0EDnzLgd9ZG1khGsqmXY_4WiLRGsPqZe90HzaBJma9sAxXB4qj_aNnwD5w","e":"AQAB"}]}`,
kid,
),
))
require.NoError(t, err)
return
}
}))
defer clerkAPI.Close()

// Define a custom failure handler which returns a custom HTTP
// status code.
customFailureHandler := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
}

// Apply the custom failure handler to the WithHeaderAuthorization
// middleware.
middleware := WithHeaderAuthorization(
AuthorizationFailureHandler(http.HandlerFunc(customFailureHandler)),
)
// This is the user's server, guarded by Clerk's http middleware.
ts := httptest.NewServer(middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := clerk.SessionClaimsFromContext(r.Context())
require.False(t, ok)
_, err := w.Write([]byte("{}"))
require.NoError(t, err)
})))
defer ts.Close()

clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: clerkAPI.Client(),
URL: &clerkAPI.URL,
}))

tokenClaims := map[string]any{
"sid": "sess_123",
}
token, _ := clerktest.GenerateJWT(t, tokenClaims, kid)
// Request with invalid Authorization header
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
res, err := ts.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusTeapot, res.StatusCode)
}

func TestAuthorizedPartyFunc(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
Expand Down

0 comments on commit 993c45a

Please sign in to comment.