Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-23.1.17-rc: release-23.1: CRDB-28040 : JWKS fetch from jwks_uri #120063

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/ccl/jwtauthccl/BUILD.bazel
Expand Up @@ -16,6 +16,7 @@ go_library(
"//pkg/settings/cluster",
"//pkg/sql/pgwire",
"//pkg/sql/pgwire/identmap",
"//pkg/util/httputil",
"//pkg/util/log",
"//pkg/util/syncutil",
"//pkg/util/uuid",
Expand All @@ -34,6 +35,7 @@ go_test(
"settings_test.go",
],
args = ["-test.timeout=295s"],
data = glob(["testdata/**"]),
embed = [":jwtauthccl"],
deps = [
"//pkg/base",
Expand All @@ -43,12 +45,15 @@ go_test(
"//pkg/security/username",
"//pkg/server",
"//pkg/sql/pgwire/identmap",
"//pkg/testutils",
"//pkg/testutils/serverutils",
"//pkg/testutils/testcluster",
"//pkg/util/leaktest",
"//pkg/util/log",
"//pkg/util/randutil",
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//oserror",
"@com_github_lestrrat_go_jwx//jwa",
"@com_github_lestrrat_go_jwx//jwk",
"@com_github_lestrrat_go_jwx//jwt",
Expand Down
120 changes: 106 additions & 14 deletions pkg/ccl/jwtauthccl/authentication_jwt.go
Expand Up @@ -10,14 +10,18 @@ package jwtauthccl

import (
"context"
"encoding/json"
"fmt"
"io"
"strings"

"github.com/cockroachdb/cockroach/pkg/ccl/utilccl"
"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/server/telemetry"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/identmap"
"github.com/cockroachdb/cockroach/pkg/util/httputil"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
Expand Down Expand Up @@ -61,11 +65,12 @@ type jwtAuthenticator struct {
// jwtAuthenticatorConf contains all the values to configure JWT authentication. These values are copied from
// the matching cluster settings.
type jwtAuthenticatorConf struct {
audience []string
enabled bool
issuers []string
jwks jwk.Set
claim string
audience []string
enabled bool
issuers []string
jwks jwk.Set
claim string
jwksAutoFetchEnabled bool
}

// reloadConfig locks mutex and then refreshes the values in conf from the cluster settings.
Expand All @@ -80,11 +85,12 @@ func (authenticator *jwtAuthenticator) reloadConfigLocked(
ctx context.Context, st *cluster.Settings,
) {
conf := jwtAuthenticatorConf{
audience: mustParseValueOrArray(JWTAuthAudience.Get(&st.SV)),
enabled: JWTAuthEnabled.Get(&st.SV),
issuers: mustParseValueOrArray(JWTAuthIssuers.Get(&st.SV)),
jwks: mustParseJWKS(JWTAuthJWKS.Get(&st.SV)),
claim: JWTAuthClaim.Get(&st.SV),
audience: mustParseValueOrArray(JWTAuthAudience.Get(&st.SV)),
enabled: JWTAuthEnabled.Get(&st.SV),
issuers: mustParseValueOrArray(JWTAuthIssuers.Get(&st.SV)),
jwks: mustParseJWKS(JWTAuthJWKS.Get(&st.SV)),
claim: JWTAuthClaim.Get(&st.SV),
jwksAutoFetchEnabled: JWKSAutoFetchEnabled.Get(&st.SV),
}

if !authenticator.mu.conf.enabled && conf.enabled {
Expand Down Expand Up @@ -121,7 +127,11 @@ func (authenticator *jwtAuthenticator) mapUsername(
// * the issuer field is one of the values in the issuer cluster setting.
// * the cluster has an enterprise license.
func (authenticator *jwtAuthenticator) ValidateJWTLogin(
st *cluster.Settings, user username.SQLUsername, tokenBytes []byte, identMap *identmap.Conf,
ctx context.Context,
st *cluster.Settings,
user username.SQLUsername,
tokenBytes []byte,
identMap *identmap.Conf,
) error {
authenticator.mu.Lock()
defer authenticator.mu.Unlock()
Expand All @@ -132,22 +142,44 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(

telemetry.Inc(beginAuthUseCounter)

parsedToken, err := jwt.Parse(tokenBytes, jwt.WithKeySet(authenticator.mu.conf.jwks), jwt.WithValidate(true), jwt.InferAlgorithmFromKey(true))
// Just parse the token to check the format is valid and issuer is present.
// The token will be parsed again later to actually verify the signature.
unverifiedToken, err := jwt.Parse(tokenBytes)
if err != nil {
return errors.Newf("JWT authentication: invalid token")
}

// Check for issuer match against configured issuers.
issuerUrl := ""
issuerMatch := false
for _, issuer := range authenticator.mu.conf.issuers {
if issuer == parsedToken.Issuer() {
if issuer == unverifiedToken.Issuer() {
issuerMatch = true
issuerUrl = issuer
break
}
}
if !issuerMatch {
return errors.WithDetailf(
errors.Newf("JWT authentication: invalid issuer"),
"token issued by %s", parsedToken.Issuer())
"token issued by %s", unverifiedToken.Issuer())
}

var jwkSet jwk.Set
// If auto-fetch is enabled, fetch the JWKS remotely from the issuer's well known jwks url.
if authenticator.mu.conf.jwksAutoFetchEnabled {
jwkSet, err = remoteFetchJWKS(ctx, issuerUrl)
if err != nil {
return errors.Newf("JWT authentication: unable to validate token")
}
} else {
jwkSet = authenticator.mu.conf.jwks
}

// Now that both the issuer and key-id are matched, parse the token again to validate the signature.
parsedToken, err := jwt.Parse(tokenBytes, jwt.WithKeySet(jwkSet), jwt.WithValidate(true), jwt.InferAlgorithmFromKey(true))
if err != nil {
return errors.Newf("JWT authentication: invalid token")
}

// Extract all requested principals from the token. By default, we take it from the subject unless they specify
Expand Down Expand Up @@ -236,6 +268,63 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
return nil
}

// remoteFetchJWKS fetches the JWKS from the provided URI.
func remoteFetchJWKS(ctx context.Context, issuerUrl string) (jwk.Set, error) {
jwksUrl, err := getJWKSUrl(ctx, issuerUrl)
if err != nil {
return nil, err
}
body, err := getHttpResponse(ctx, jwksUrl)
if err != nil {
return nil, err
}
jwkSet, err := jwk.Parse(body)
if err != nil {
return nil, err
}
return jwkSet, nil
}

// getJWKSUrl returns the JWKS URI from the OpenID configuration endpoint.
func getJWKSUrl(ctx context.Context, issuerUrl string) (string, error) {
type OIDCConfigResponse struct {
JWKSUri string `json:"jwks_uri"`
}
openIdConfigEndpoint := getOpenIdConfigEndpoint(issuerUrl)
body, err := getHttpResponse(ctx, openIdConfigEndpoint)
if err != nil {
return "", err
}
var config OIDCConfigResponse
if err = json.Unmarshal(body, &config); err != nil {
return "", err
}
if config.JWKSUri == "" {
return "", errors.Newf("no JWKS URI found in OpenID configuration")
}
return config.JWKSUri, nil
}

// getOpenIdConfigEndpoint returns the OpenID configuration endpoint by appending standard open-id url.
func getOpenIdConfigEndpoint(issuerUrl string) string {
openIdConfigEndpoint := strings.TrimSuffix(issuerUrl, "/") + "/.well-known/openid-configuration"
return openIdConfigEndpoint
}

var getHttpResponse = func(ctx context.Context, url string) ([]byte, error) {
resp, err := httputil.Get(ctx, url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}

// ConfigureJWTAuth initializes and returns a jwtAuthenticator. It also sets up listeners so
// that the jwtAuthenticator's config is updated when the cluster settings values change.
var ConfigureJWTAuth = func(
Expand All @@ -262,6 +351,9 @@ var ConfigureJWTAuth = func(
JWTAuthClaim.SetOnChange(&st.SV, func(ctx context.Context) {
authenticator.reloadConfig(ambientCtx.AnnotateCtx(ctx), st)
})
JWKSAutoFetchEnabled.SetOnChange(&st.SV, func(ctx context.Context) {
authenticator.reloadConfig(ambientCtx.AnnotateCtx(ctx), st)
})
return &authenticator
}

Expand Down