Skip to content
29 changes: 27 additions & 2 deletions auth/api/iam/jar.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ import (
cryptoNuts "github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"net/url"
"time"
)

// jarMaxValidity bounds the lifetime of a JWT Authorization Request (RFC9101).
// Authorization requests are short-lived by nature (an end-user is interacting
// with the wallet/AS in real time), so a 5 minute window is generous and limits
// the replay window if a request is intercepted.
const jarMaxValidity = 5 * time.Minute

// requestObjectModifier is a function that modifies the Claims/params of an unsigned or signed (JWT) OAuth2 request
type requestObjectModifier func(claims map[string]string)

Expand All @@ -42,6 +49,12 @@ type jarRequest struct {
RequestURIMethod string `json:"request_uri_method"`
}

// jarProfile defines JWT validation rules for JWT Authorization Requests (JAR).
var jarProfile = &cryptoNuts.JWTProfile{
RequiredClaims: []string{jwt.IssuerKey, jwt.IssuedAtKey, jwt.ExpirationKey},
MaxValidity: jarMaxValidity,
}

var _ JAR = &jar{}

type jar struct {
Expand Down Expand Up @@ -92,6 +105,9 @@ func createJarRequest(client did.DID, clientID string, audience string, modifier
for k, v := range params {
oauthParams[k] = v
}
// iat/exp are intentionally NOT set here: the jarRequest is persisted (authzRequestObjectStore)
// and Sign may be called minutes later on a different request. Setting timestamps here would
// eat into the validity window. They are added in Sign() to reflect actual signing time.
return jarRequest{
Claims: oauthParams,
Client: clientID,
Expand All @@ -109,7 +125,16 @@ func (j jar) Sign(ctx context.Context, claims oauthParameters) (string, error) {
if err != nil {
return "", err
}
return j.jwtSigner.SignJWT(ctx, claims, nil, keyId)
// Copy claims so we don't mutate the caller's map (the jarRequest may still be in-memory in tests
// or the store). iat/exp reflect signing time to give the counterparty the full validity window.
signClaims := make(map[string]interface{}, len(claims)+2)
for k, v := range claims {
signClaims[k] = v
}
now := time.Now()
signClaims[jwt.IssuedAtKey] = now.Unix()
signClaims[jwt.ExpirationKey] = now.Add(jarMaxValidity).Unix()
return j.jwtSigner.SignJWT(ctx, signClaims, nil, keyId)
}

func (j jar) Parse(ctx context.Context, ownMetadata oauth.AuthorizationServerMetadata, q url.Values) (oauthParameters, error) {
Expand Down Expand Up @@ -154,7 +179,7 @@ func (j jar) validate(ctx context.Context, rawToken string, clientId string) (oa
signerKid = kid
publicKey, err = j.keyResolver.ResolveKeyByID(kid, nil, resolver.AssertionMethod)
return publicKey, err
}, jwt.WithValidate(true))
}, jarProfile, nil)
if err != nil {
return nil, oauth.OAuth2Error{Code: oauth.InvalidRequestObject, Description: "request signature validation failed", InternalError: err}
}
Expand Down
25 changes: 24 additions & 1 deletion auth/api/iam/jar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
"github.com/nuts-foundation/nuts-node/test"
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jws"
Expand Down Expand Up @@ -55,6 +56,9 @@ func TestJar_Create(t *testing.T) {
assert.Equal(t, verifierDID.String(), req.Claims[jwt.IssuerKey])
assert.Equal(t, holderClientID, req.Claims[jwt.AudienceKey])
assert.Equal(t, "works", req.Claims["requestObjectModifier"])
// iat/exp are intentionally not set here; they're added in Sign() to reflect signing time.
assert.Nil(t, req.Claims[jwt.IssuedAtKey])
assert.Nil(t, req.Claims[jwt.ExpirationKey])
})
t.Run("request_uri_method=post", func(t *testing.T) {
modifier := func(claims map[string]string) {
Expand All @@ -77,12 +81,25 @@ func TestJar_Sign(t *testing.T) {
t.Run("ok", func(t *testing.T) {
ctx := newJarTestCtx(t)
ctx.keyResolver.EXPECT().ResolveKey(clientDID, nil, resolver.AssertionMethod).Return(keyID, nil, nil)
ctx.jwtSigner.EXPECT().SignJWT(context.Background(), claims, nil, keyID).Return("valid token", nil)
ctx.jwtSigner.EXPECT().SignJWT(context.Background(), gomock.Any(), nil, keyID).
DoAndReturn(func(_ context.Context, signed map[string]interface{}, _ map[string]interface{}, _ string) (string, error) {
// Sign augments the caller's claims with iat/exp reflecting signing time.
assert.Equal(t, clientID, signed[oauth.ClientIDParam])
assert.Equal(t, clientDID.String(), signed[jwt.IssuerKey])
iat := signed[jwt.IssuedAtKey].(int64)
exp := signed[jwt.ExpirationKey].(int64)
assert.InDelta(t, time.Now().Unix(), iat, 2)
assert.Equal(t, int64(jarMaxValidity/time.Second), exp-iat)
return "valid token", nil
})

token, err := ctx.jar.Sign(context.Background(), claims)

require.NoError(t, err)
assert.Equal(t, "valid token", token)
// Caller's map is untouched.
assert.Nil(t, claims[jwt.IssuedAtKey])
assert.Nil(t, claims[jwt.ExpirationKey])
})
t.Run("error - failed to sign JWT", func(t *testing.T) {
ctx := newJarTestCtx(t)
Expand Down Expand Up @@ -114,9 +131,12 @@ func TestJar_Parse(t *testing.T) {
jwkSet := jwk.NewSet()
_ = jwkSet.AddKey(jwkKey)

now := time.Now()
bytes, err := createSignedRequestObject(t, kid, privateKey, oauthParameters{
jwt.IssuerKey: holderDID.String(),
oauth.ClientIDParam: holderClientID,
jwt.IssuedAtKey: now.Unix(),
jwt.ExpirationKey: now.Add(jarMaxValidity).Unix(),
})
require.NoError(t, err)
token := string(bytes)
Expand Down Expand Up @@ -273,9 +293,12 @@ func TestJar_Parse(t *testing.T) {
})
t.Run("error - client_id does not match signer", func(t *testing.T) {
ctx := newJarTestCtx(t)
now := time.Now()
bytes, err := createSignedRequestObject(t, kid, privateKey, oauthParameters{
jwt.IssuerKey: verifierDID.String(),
oauth.ClientIDParam: verifierDID.String(),
jwt.IssuedAtKey: now.Unix(),
jwt.ExpirationKey: now.Add(jarMaxValidity).Unix(),
})
require.NoError(t, err)
ctx.keyResolver.EXPECT().ResolveKeyByID(kid, nil, resolver.AssertionMethod).Return(privateKey.Public(), nil)
Expand Down
66 changes: 22 additions & 44 deletions auth/services/oauth/authz_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ func NewAuthorizationServer(
// BearerTokenMaxValidity is the number of seconds that a bearer token is valid
const BearerTokenMaxValidity = 5

// v1AccessTokenProfile defines JWT validation rules for v1 access tokens.
// The "service" claim holds the RFC003 purposeOfUse (see buildAccessToken).
var v1AccessTokenProfile = &nutsCrypto.JWTProfile{
Typ: "at+jwt",
RequiredClaims: []string{jwt.ExpirationKey, jwt.IssuedAtKey, jwt.IssuerKey, jwt.SubjectKey, "service"},
MaxValidity: secureAccessTokenLifeSpan,
Comment thread
stevenvegt marked this conversation as resolved.
Validators: []nutsCrypto.JWTValidator{nutsCrypto.IssuerKidValidator},
}

// v1BearerTokenProfile defines JWT validation rules for v1 JWT bearer tokens (RFC003 §5.2.1).
// aud is set by the builder (claimsFromRequest) to the authorization server endpoint and must
// be present; the actual endpoint comparison still runs in validateAudience post-parse.
var v1BearerTokenProfile = &nutsCrypto.JWTProfile{
Comment thread
stevenvegt marked this conversation as resolved.
RequiredClaims: []string{jwt.ExpirationKey, jwt.IssuedAtKey, jwt.IssuerKey, jwt.SubjectKey, jwt.AudienceKey},
MaxValidity: BearerTokenMaxValidity * time.Second,
}

// Configure the service
func (s *authzServer) Configure(clockSkewInMilliseconds int, secureMode bool) error {
s.clockSkew = time.Duration(clockSkewInMilliseconds) * time.Millisecond
Expand Down Expand Up @@ -232,16 +249,12 @@ func (s *authzServer) validateAccessTokenRequest(ctx context.Context, bearerToke

// extract the JwtBearerToken, validates according to RFC003 §5.2.1.1
// also check if used algorithms are according to spec (ES*** and PS***)
// and checks basic validity. Set jwtBearerTokenClaims in validationContext
// and checks basic validity including max validity (RFC003 §5.2.1.4).
// Set jwtBearerTokenClaims in validationContext
if err := s.parseAndValidateJwtBearerToken(validationCtx); err != nil {
return validationCtx, fmt.Errorf("jwt bearer token validation failed: %w", err)
}

// check the maximum validity, according to RFC003 §5.2.1.4
if validationCtx.jwtBearerToken.Expiration().Sub(validationCtx.jwtBearerToken.IssuedAt()).Seconds() > BearerTokenMaxValidity {
return validationCtx, errors.New("JWT validity too long")
}

// check the requester against the registry, according to RFC003 §5.2.1.3
// checks signing certificate and sets vendor, requesterName in validationContext
if err := s.validateIssuer(validationCtx); err != nil {
Expand Down Expand Up @@ -476,36 +489,21 @@ func (s *authzServer) validateAuthorizationCredentials(context *validationContex

// parseAndValidateJwtBearerToken validates the jwt signature and returns the containing claims
func (s *authzServer) parseAndValidateJwtBearerToken(context *validationContext) error {
var kidHdr string
token, err := nutsCrypto.ParseJWT(context.rawJwtBearerToken, func(kid string) (crypto.PublicKey, error) {
kidHdr = kid
context.kid = kid
return s.keyResolver.ResolveKeyByID(kid, nil, resolver.NutsSigningKeyType)
}, jwt.WithAcceptableSkew(s.clockSkew))
}, v1BearerTokenProfile.WithClockSkew(s.clockSkew), nil)
if err != nil {
return err
}

// this should be ok since it has already succeeded before
context.jwtBearerToken = token
context.kid = kidHdr
return nil
}

// IntrospectAccessToken fills the fields in NutsAccessToken from the given Jwt Access Token
func (s *authzServer) IntrospectAccessToken(ctx context.Context, accessToken string) (*services.NutsAccessToken, error) {
// Validate typ header before full parsing to reject non-access-token JWTs early
headers, err := nutsCrypto.ExtractProtectedHeaders(accessToken)
if err != nil {
return nil, fmt.Errorf("invalid access token headers: %w", err)
}
typ, _ := headers["typ"].(string)
if typ != "at+jwt" {
return nil, fmt.Errorf("invalid access token typ header (expected 'at+jwt', got '%s')", typ)
}

var kidHdr string
token, err := nutsCrypto.ParseJWT(accessToken, func(kid string) (crypto.PublicKey, error) {
kidHdr = kid
exists, err := s.privateKeyStore.Exists(ctx, kid)
if err != nil {
return nil, fmt.Errorf("could not check if JWT signing key exists: %w", err)
Expand All @@ -514,7 +512,7 @@ func (s *authzServer) IntrospectAccessToken(ctx context.Context, accessToken str
return nil, fmt.Errorf("JWT signing key not present on this node (kid=%s)", kid)
}
return s.keyResolver.ResolveKeyByID(kid, nil, resolver.NutsSigningKeyType)
}, jwt.WithAcceptableSkew(s.clockSkew))
}, v1AccessTokenProfile.WithMaxValidity(s.accessTokenLifeSpan).WithClockSkew(s.clockSkew), nil)
if err != nil {
return nil, err
}
Expand All @@ -530,26 +528,6 @@ func (s *authzServer) IntrospectAccessToken(ctx context.Context, accessToken str
result.IssuedAt = token.IssuedAt().Unix()
result.Expiration = token.Expiration().Unix()

// Validate required claims
if result.Issuer == "" {
return nil, errors.New("missing required 'iss' claim in access token")
}
if result.Subject == "" {
return nil, errors.New("missing required 'sub' claim in access token")
}
if result.Service == "" {
return nil, errors.New("missing required 'service' claim in access token")
}

// Validate issuer-to-kid binding: the DID in the kid header must match the iss claim
kidDID, err := resolver.GetDIDFromURL(kidHdr)
if err != nil {
return nil, fmt.Errorf("invalid kid header in access token: %w", err)
}
if kidDID.String() != result.Issuer {
return nil, fmt.Errorf("access token issuer (%s) does not match signing key DID (%s)", result.Issuer, kidDID.String())
}

return result, nil
}

Expand Down
Loading
Loading