Skip to content

Commit

Permalink
CRDB-28040 : JWKS fetch from jwks_uri
Browse files Browse the repository at this point in the history
This commit adds capability to fetch remote JWKS from issuer's jwks_uri endpoint. This will satisfy the requirement to have an ability to automatically fetch the new JWK when the existing JWK is rotated - without human intervention or custom scripts.

Changes include

1. The existing order of token signature verification first and rest of claims next is modified to get issuer first and then the token signature verification. This change is requied to determine the issuer for which the jwks has to be fetched remotely.

2. Introduction of a new cluster setting called `server.jwt_authentication.jwks_auto_fetch.enabled`

3. Depending on the value of `server.jwt_authentication.jwks_auto_fetch.enabled` use JWKS configured through cluster setting or remotely fetch JWKS from jwks_uri of the issuer

4. Modification to exiting test cases to match the new order of verification steps.

The change is backward compatible and no changes required in existing deployments and JWT Auth usage.
  • Loading branch information
BabuSrithar committed Dec 23, 2023
1 parent c2c2a20 commit 61776d5
Show file tree
Hide file tree
Showing 10 changed files with 479 additions and 84 deletions.
5 changes: 5 additions & 0 deletions pkg/ccl/jwtauthccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ go_library(
"//pkg/settings/cluster",
"//pkg/sql/pgwire",
"//pkg/sql/pgwire/identmap",
"//pkg/util/httputil",
"//pkg/util/log",
"//pkg/util/syncutil",
"//pkg/util/uuid",
Expand All @@ -34,6 +35,7 @@ go_test(
"settings_test.go",
],
args = ["-test.timeout=295s"],
data = glob(["testdata/**"]),
embed = [":jwtauthccl"],
deps = [
"//pkg/base",
Expand All @@ -43,12 +45,15 @@ go_test(
"//pkg/security/username",
"//pkg/server",
"//pkg/sql/pgwire/identmap",
"//pkg/testutils",
"//pkg/testutils/serverutils",
"//pkg/testutils/testcluster",
"//pkg/util/leaktest",
"//pkg/util/log",
"//pkg/util/randutil",
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//oserror",
"@com_github_lestrrat_go_jwx//jwa",
"@com_github_lestrrat_go_jwx//jwk",
"@com_github_lestrrat_go_jwx//jwt",
Expand Down
120 changes: 106 additions & 14 deletions pkg/ccl/jwtauthccl/authentication_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ package jwtauthccl

import (
"context"
"encoding/json"
"fmt"
"io"
"strings"

"github.com/cockroachdb/cockroach/pkg/ccl/utilccl"
"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/server/telemetry"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/identmap"
"github.com/cockroachdb/cockroach/pkg/util/httputil"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
Expand Down Expand Up @@ -61,11 +65,12 @@ type jwtAuthenticator struct {
// jwtAuthenticatorConf contains all the values to configure JWT authentication. These values are copied from
// the matching cluster settings.
type jwtAuthenticatorConf struct {
audience []string
enabled bool
issuers []string
jwks jwk.Set
claim string
audience []string
enabled bool
issuers []string
jwks jwk.Set
claim string
jwksAutoFetchEnabled bool
}

// reloadConfig locks mutex and then refreshes the values in conf from the cluster settings.
Expand All @@ -80,11 +85,12 @@ func (authenticator *jwtAuthenticator) reloadConfigLocked(
ctx context.Context, st *cluster.Settings,
) {
conf := jwtAuthenticatorConf{
audience: mustParseValueOrArray(JWTAuthAudience.Get(&st.SV)),
enabled: JWTAuthEnabled.Get(&st.SV),
issuers: mustParseValueOrArray(JWTAuthIssuers.Get(&st.SV)),
jwks: mustParseJWKS(JWTAuthJWKS.Get(&st.SV)),
claim: JWTAuthClaim.Get(&st.SV),
audience: mustParseValueOrArray(JWTAuthAudience.Get(&st.SV)),
enabled: JWTAuthEnabled.Get(&st.SV),
issuers: mustParseValueOrArray(JWTAuthIssuers.Get(&st.SV)),
jwks: mustParseJWKS(JWTAuthJWKS.Get(&st.SV)),
claim: JWTAuthClaim.Get(&st.SV),
jwksAutoFetchEnabled: JWKSAutoFetchEnabled.Get(&st.SV),
}

if !authenticator.mu.conf.enabled && conf.enabled {
Expand Down Expand Up @@ -121,7 +127,11 @@ func (authenticator *jwtAuthenticator) mapUsername(
// * the issuer field is one of the values in the issuer cluster setting.
// * the cluster has an enterprise license.
func (authenticator *jwtAuthenticator) ValidateJWTLogin(
st *cluster.Settings, user username.SQLUsername, tokenBytes []byte, identMap *identmap.Conf,
ctx context.Context,
st *cluster.Settings,
user username.SQLUsername,
tokenBytes []byte,
identMap *identmap.Conf,
) error {
authenticator.mu.Lock()
defer authenticator.mu.Unlock()
Expand All @@ -132,22 +142,44 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(

telemetry.Inc(beginAuthUseCounter)

parsedToken, err := jwt.Parse(tokenBytes, jwt.WithKeySet(authenticator.mu.conf.jwks), jwt.WithValidate(true), jwt.InferAlgorithmFromKey(true))
// Just parse the token to check the format is valid and issuer is present.
// The token will be parsed again later to actually verify the signature.
unverifiedToken, err := jwt.Parse(tokenBytes)
if err != nil {
return errors.Newf("JWT authentication: invalid token")
}

// Check for issuer match against configured issuers.
issuerUrl := ""
issuerMatch := false
for _, issuer := range authenticator.mu.conf.issuers {
if issuer == parsedToken.Issuer() {
if issuer == unverifiedToken.Issuer() {
issuerMatch = true
issuerUrl = issuer
break
}
}
if !issuerMatch {
return errors.WithDetailf(
errors.Newf("JWT authentication: invalid issuer"),
"token issued by %s", parsedToken.Issuer())
"token issued by %s", unverifiedToken.Issuer())
}

var jwkSet jwk.Set
// If auto-fetch is enabled, fetch the JWKS remotely from the issuer's well known jwks url.
if authenticator.mu.conf.jwksAutoFetchEnabled {
jwkSet, err = remoteFetchJWKS(ctx, issuerUrl)
if err != nil {
return errors.Newf("JWT authentication: unable to validate token")
}
} else {
jwkSet = authenticator.mu.conf.jwks
}

// Now that both the issuer and key-id are matched, parse the token again to validate the signature.
parsedToken, err := jwt.Parse(tokenBytes, jwt.WithKeySet(jwkSet), jwt.WithValidate(true), jwt.InferAlgorithmFromKey(true))
if err != nil {
return errors.Newf("JWT authentication: invalid token")
}

// Extract all requested principals from the token. By default, we take it from the subject unless they specify
Expand Down Expand Up @@ -236,6 +268,63 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
return nil
}

// remoteFetchJWKS fetches the JWKS from the provided URI.
func remoteFetchJWKS(ctx context.Context, issuerUrl string) (jwk.Set, error) {
jwksUrl, err := getJWKSUrl(ctx, issuerUrl)
if err != nil {
return nil, err
}
body, err := getHttpResponse(ctx, jwksUrl)
if err != nil {
return nil, err
}
jwkSet, err := jwk.Parse(body)
if err != nil {
return nil, err
}
return jwkSet, nil
}

// getJWKSUrl returns the JWKS URI from the OpenID configuration endpoint.
func getJWKSUrl(ctx context.Context, issuerUrl string) (string, error) {
type OIDCConfigResponse struct {
JWKSUri string `json:"jwks_uri"`
}
openIdConfigEndpoint := getOpenIdConfigEndpoint(issuerUrl)
body, err := getHttpResponse(ctx, openIdConfigEndpoint)
if err != nil {
return "", err
}
var config OIDCConfigResponse
if err = json.Unmarshal(body, &config); err != nil {
return "", err
}
if config.JWKSUri == "" {
return "", errors.Newf("no JWKS URI found in OpenID configuration")
}
return config.JWKSUri, nil
}

// getOpenIdConfigEndpoint returns the OpenID configuration endpoint by appending standard open-id url.
func getOpenIdConfigEndpoint(issuerUrl string) string {
openIdConfigEndpoint := strings.TrimSuffix(issuerUrl, "/") + "/.well-known/openid-configuration"
return openIdConfigEndpoint
}

var getHttpResponse = func(ctx context.Context, url string) ([]byte, error) {
resp, err := httputil.Get(ctx, url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}

// ConfigureJWTAuth initializes and returns a jwtAuthenticator. It also sets up listeners so
// that the jwtAuthenticator's config is updated when the cluster settings values change.
var ConfigureJWTAuth = func(
Expand All @@ -262,6 +351,9 @@ var ConfigureJWTAuth = func(
JWTAuthClaim.SetOnChange(&st.SV, func(ctx context.Context) {
authenticator.reloadConfig(ambientCtx.AnnotateCtx(ctx), st)
})
JWKSAutoFetchEnabled.SetOnChange(&st.SV, func(ctx context.Context) {
authenticator.reloadConfig(ambientCtx.AnnotateCtx(ctx), st)
})
return &authenticator
}

Expand Down

0 comments on commit 61776d5

Please sign in to comment.