From 7fc38ac71ca846c74943ad2989853b9f431b2397 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 24 Jan 2022 10:01:41 -0600 Subject: [PATCH] Update nil signing key handling - bypass setting ExpireAt if signing key is nil in rotate - return err if singing key is nil in signPayload --- vault/identity_store_oidc.go | 53 ++++---- vault/identity_store_oidc_test.go | 206 +++++++++++++++++++++--------- 2 files changed, 175 insertions(+), 84 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index c2f93f2ccc406..3a3e6b8749304 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1042,6 +1042,9 @@ func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logge } func (k *namedKey) signPayload(payload []byte) (string, error) { + if k.SigningKey == nil { + return "", fmt.Errorf("signing key is nil") + } signingKey := jose.SigningKey{Key: k.SigningKey, Algorithm: jose.SignatureAlgorithm(k.Algorithm)} signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) if err != nil { @@ -1491,12 +1494,22 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req // namedKey.rotate(overrides) performs a key rotation on a namedKey. // verification_ttl can be overridden with an overrideVerificationTTL value >= 0 func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error { - if k.SigningKey == nil { - logger.Debug("nil signing key detected on rotation") - err := k.generateAndSetKey(ctx, logger, s) - if err != nil { - return err + verificationTTL := k.VerificationTTL + if overrideVerificationTTL >= 0 { + verificationTTL = overrideVerificationTTL + } + + now := time.Now() + if k.SigningKey != nil { + // set the previous public key's expiry time + for _, key := range k.KeyRing { + if key.KeyID == k.SigningKey.KeyID { + key.ExpireAt = now.Add(verificationTTL) + break + } } + } else { + logger.Debug("nil signing key detected on rotation") } if k.NextSigningKey == nil { @@ -1509,20 +1522,6 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St } } - verificationTTL := k.VerificationTTL - if overrideVerificationTTL >= 0 { - verificationTTL = overrideVerificationTTL - } - - now := time.Now() - // set the previous public key's expiry time - for _, key := range k.KeyRing { - if key.KeyID == k.SigningKey.KeyID { - key.ExpireAt = now.Add(verificationTTL) - break - } - } - // do the rotation k.SigningKey = k.NextSigningKey k.NextRotation = now.Add(k.RotationPeriod) @@ -1714,21 +1713,21 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor return now, err } - namedKeys, err := s.List(ctx, namedKeyConfigPath) + keyNames, err := s.List(ctx, namedKeyConfigPath) if err != nil { return now, err } usedKeys := make([]string, 0) - for _, k := range namedKeys { - entry, err := s.Get(ctx, namedKeyConfigPath+k) + for _, keyName := range keyNames { + entry, err := s.Get(ctx, namedKeyConfigPath+keyName) if err != nil { return now, err } if entry == nil { - i.Logger().Warn("could not find key to update", "key", k) + i.Logger().Warn("could not find key to update", "key", keyName) continue } @@ -1741,14 +1740,14 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor keyRing := key.KeyRing var keyringUpdated bool - for i := 0; i < len(keyRing); i++ { - k := keyRing[i] + for j := 0; j < len(keyRing); j++ { + k := keyRing[j] if !k.ExpireAt.IsZero() && k.ExpireAt.Before(now) { - keyRing[i] = keyRing[len(keyRing)-1] + keyRing[j] = keyRing[len(keyRing)-1] keyRing = keyRing[:len(keyRing)-1] keyringUpdated = true - i-- + j-- continue } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 35f9c389d9ee1..1813a36841ebd 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -2,8 +2,6 @@ package vault import ( "context" - "crypto/rand" - "crypto/rsa" "encoding/json" "strconv" "strings" @@ -11,7 +9,7 @@ import ( "time" "github.com/go-test/deep" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" @@ -893,6 +891,79 @@ func TestOIDC_SignIDToken(t *testing.T) { } } +// TestOIDC_SignIDToken_NilSigningKey +func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + // storage := &logical.InmemStorage{} + + // Create and load an entity, an entity is required to generate an ID token + testEntity := &identity.Entity{ + Name: "test-entity-name", + ID: "test-entity-id", + BucketKey: "test-entity-bucket-key", + } + + txn := c.identityStore.db.Txn(true) + defer txn.Abort() + err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) + if err != nil { + t.Fatal(err) + } + txn.Commit() + + // Create a test key "test-key" with a nil SigningKey + namedKey := &namedKey{ + name: "test-key", + AllowedClientIDs: []string{"*"}, + Algorithm: "RS256", + VerificationTTL: 60 * time.Second, + RotationPeriod: 60 * time.Second, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), + } + s := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc") + if err := namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), s); err != nil { + t.Fatalf("failed to set next signing key") + } + // Store namedKey + entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+namedKey.name, namedKey) + if err := s.Put(ctx, entry); err != nil { + t.Fatalf("writing to in mem storage failed") + } + + // Create a test role "test-role" -- expect no warning + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "key": "test-key", + "ttl": "1m", + }, + Storage: s, + }) + expectSuccess(t, resp, err) + if resp != nil { + t.Fatalf("was expecting a nil response but instead got: %#v", resp) + } + + // Generate a token against the role "test-role" -- should fail + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/token/test-role", + Operation: logical.ReadOperation, + Storage: s, + EntityID: "test-entity-id", + }) + expectError(t, resp, err) + // validate error message + expectedStrings := map[string]interface{}{ + "error signing OIDC token: signing key is nil": true, + } + expectStrings(t, []string{err.Error()}, expectedStrings) +} + // TestOIDC_PeriodicFunc tests timing logic for running key // rotations and expiration actions. func TestOIDC_PeriodicFunc(t *testing.T) { @@ -900,84 +971,105 @@ func TestOIDC_PeriodicFunc(t *testing.T) { c, _, _ := TestCoreUnsealed(t) ctx := namespace.RootContext(nil) - // Prepare a dummy signing key - key, _ := rsa.GenerateKey(rand.Reader, 2048) - id, _ := uuid.GenerateUUID() - jwk := &jose.JSONWebKey{ - Key: key, - KeyID: id, - Algorithm: "RS256", - Use: "sig", - } - cyclePeriod := 2 * time.Second testSets := []struct { - namedKey *namedKey - testCases []struct { - cycle int - numKeys int - numPublicKeys int - } + namedKey *namedKey + expectedKeyCount int + setSigningKey bool + setNextSigningKey bool + cycle []int }{ { - // don't set NextSigningKey to ensure its non-existence can be handled - &namedKey{ + namedKey: &namedKey{ name: "test-key", Algorithm: "RS256", VerificationTTL: 1 * cyclePeriod, RotationPeriod: 1 * cyclePeriod, KeyRing: nil, - SigningKey: jwk, + SigningKey: nil, + NextSigningKey: nil, NextRotation: time.Now(), }, - []struct { - cycle int - numKeys int - numPublicKeys int - }{ - {1, 2, 2}, - {2, 2, 2}, - {3, 2, 2}, - {4, 2, 2}, - {5, 2, 2}, - {6, 2, 2}, - {7, 2, 2}, - }, + expectedKeyCount: 3, + setSigningKey: true, + setNextSigningKey: true, + cycle: []int{1, 2, 3, 4}, }, { // don't set SigningKey to ensure its non-existence can be handled - &namedKey{ + namedKey: &namedKey{ name: "test-key-nil-signing-key", Algorithm: "RS256", VerificationTTL: 1 * cyclePeriod, RotationPeriod: 1 * cyclePeriod, - KeyRing: append([]*expireableKey{}, &expireableKey{KeyID: id}), + KeyRing: nil, SigningKey: nil, - NextSigningKey: jwk, + NextSigningKey: nil, NextRotation: time.Now(), }, - []struct { - cycle int - numKeys int - numPublicKeys int - }{ - {1, 2, 2}, + expectedKeyCount: 2, + setSigningKey: false, + setNextSigningKey: true, + cycle: []int{1, 2}, + }, + { + // don't set NextSigningKey to ensure its non-existence can be handled + namedKey: &namedKey{ + name: "test-key-nil-next-signing-key", + Algorithm: "RS256", + VerificationTTL: 1 * cyclePeriod, + RotationPeriod: 1 * cyclePeriod, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), }, + expectedKeyCount: 2, + setSigningKey: true, + setNextSigningKey: false, + cycle: []int{1, 2}, + }, + { + // don't set keys to ensure non-existence can be handled + namedKey: &namedKey{ + name: "test-key-nil-signing-and-next-signing-key", + Algorithm: "RS256", + VerificationTTL: 1 * cyclePeriod, + RotationPeriod: 1 * cyclePeriod, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), + }, + expectedKeyCount: 2, + setSigningKey: false, + setNextSigningKey: false, + cycle: []int{1, 2}, }, } for _, testSet := range testSets { - // Store namedKey storage := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc") + if testSet.setSigningKey { + if err := testSet.namedKey.generateAndSetKey(ctx, hclog.NewNullLogger(), storage); err != nil { + t.Fatalf("failed to set signing key") + } + } + if testSet.setNextSigningKey { + if err := testSet.namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), storage); err != nil { + t.Fatalf("failed to set next signing key") + } + } + // Store namedKey entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+testSet.namedKey.name, testSet.namedKey) if err := storage.Put(ctx, entry); err != nil { t.Fatalf("writing to in mem storage failed") } currentCycle := 1 - numCases := len(testSet.testCases) - lastCycle := testSet.testCases[numCases-1].cycle + numCases := len(testSet.cycle) + lastCycle := testSet.cycle[numCases-1] namedKeySamples := make([]*logical.StorageEntry, numCases) publicKeysSamples := make([][]string, numCases) @@ -985,7 +1077,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) { // var start time.Time for currentCycle <= lastCycle { c.identityStore.oidcPeriodicFunc(ctx) - if currentCycle == testSet.testCases[i].cycle { + if currentCycle == testSet.cycle[i] { namedKeyEntry, _ := storage.Get(ctx, namedKeyConfigPath+testSet.namedKey.name) publicKeysEntry, _ := storage.List(ctx, publicKeysConfigPath) namedKeySamples[i] = namedKeyEntry @@ -1005,25 +1097,25 @@ func TestOIDC_PeriodicFunc(t *testing.T) { } // measure collected samples - for i, tc := range testSet.testCases { + for i, cycle := range testSet.cycle { namedKeySamples[i].DecodeJSON(&testSet.namedKey) actualKeyRingLen := len(testSet.namedKey.KeyRing) - if actualKeyRingLen < tc.numKeys { - t.Fatalf( + if actualKeyRingLen < testSet.expectedKeyCount { + t.Errorf( "For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d", testSet.namedKey.name, - tc.cycle, - tc.numKeys, + cycle, + testSet.expectedKeyCount, actualKeyRingLen, ) } actualPubKeysLen := len(publicKeysSamples[i]) - if actualPubKeysLen < tc.numPublicKeys { - t.Fatalf( + if actualPubKeysLen < testSet.expectedKeyCount { + t.Errorf( "For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d", testSet.namedKey.name, - tc.cycle, - tc.numPublicKeys, + cycle, + testSet.expectedKeyCount, actualPubKeysLen, ) }