diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 6ff810c750786..c2f93f2ccc406 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -548,19 +548,11 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica // generate current and next keys if creating a new key or changing algorithms if key.Algorithm != prevAlgorithm { - signingKey, err := generateKeys(key.Algorithm) + err = key.generateAndSetKey(ctx, i.Logger(), req.Storage) if err != nil { return nil, err } - key.SigningKey = signingKey - key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID}) - - if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { - return nil, err - } - i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) - err = key.generateAndSetNextKey(ctx, i.Logger(), req.Storage) if err != nil { return nil, err @@ -1013,6 +1005,24 @@ func mergeJSONTemplates(logger hclog.Logger, output map[string]interface{}, temp return nil } +// generateAndSetKey will generate new signing and public key pairs and set +// them as the SigningKey. +func (k *namedKey) generateAndSetKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error { + signingKey, err := generateKeys(k.Algorithm) + if err != nil { + return err + } + + k.SigningKey = signingKey + k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID}) + + if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil { + return err + } + logger.Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) + return nil +} + // generateAndSetNextKey will generate new signing and public key pairs and set // them as the NextSigningKey. func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error { @@ -1481,8 +1491,25 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req // namedKey.rotate(overrides) performs a key rotation on a namedKey. // verification_ttl can be overridden with an overrideVerificationTTL value >= 0 func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error { - verificationTTL := k.VerificationTTL + if k.SigningKey == nil { + logger.Debug("nil signing key detected on rotation") + err := k.generateAndSetKey(ctx, logger, s) + if err != nil { + return err + } + } + + if k.NextSigningKey == nil { + logger.Debug("nil next signing key detected on rotation") + // keys will not have a NextSigningKey if they were generated before + // vault 1.9 + err := k.generateAndSetNextKey(ctx, logger, s) + if err != nil { + return err + } + } + verificationTTL := k.VerificationTTL if overrideVerificationTTL >= 0 { verificationTTL = overrideVerificationTTL } @@ -1496,14 +1523,6 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St } } - if k.NextSigningKey == nil { - // keys will not have a NextSigningKey if they were generated before - // vault 1.9 - err := k.generateAndSetNextKey(ctx, logger, s) - if err != nil { - return err - } - } // do the rotation k.SigningKey = k.NextSigningKey k.NextRotation = now.Add(k.RotationPeriod) diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index d52ae7a14c760..35f9c389d9ee1 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -937,12 +937,32 @@ func TestOIDC_PeriodicFunc(t *testing.T) { numPublicKeys int }{ {1, 2, 2}, - {2, 3, 3}, - {3, 3, 3}, - {4, 3, 3}, - {5, 3, 3}, - {6, 3, 3}, - {7, 3, 3}, + {2, 2, 2}, + {3, 2, 2}, + {4, 2, 2}, + {5, 2, 2}, + {6, 2, 2}, + {7, 2, 2}, + }, + }, + { + // don't set SigningKey to ensure its non-existence can be handled + &namedKey{ + name: "test-key-nil-signing-key", + Algorithm: "RS256", + VerificationTTL: 1 * cyclePeriod, + RotationPeriod: 1 * cyclePeriod, + KeyRing: append([]*expireableKey{}, &expireableKey{KeyID: id}), + SigningKey: nil, + NextSigningKey: jwk, + NextRotation: time.Now(), + }, + []struct { + cycle int + numKeys int + numPublicKeys int + }{ + {1, 2, 2}, }, }, } @@ -985,15 +1005,33 @@ func TestOIDC_PeriodicFunc(t *testing.T) { } // measure collected samples - for i := range testSet.testCases { + for i, tc := range testSet.testCases { namedKeySamples[i].DecodeJSON(&testSet.namedKey) - if len(testSet.namedKey.KeyRing) != testSet.testCases[i].numKeys { - t.Fatalf("At cycle: %d expected namedKey's KeyRing to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numKeys, len(testSet.namedKey.KeyRing)) + actualKeyRingLen := len(testSet.namedKey.KeyRing) + if actualKeyRingLen < tc.numKeys { + t.Fatalf( + "For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d", + testSet.namedKey.name, + tc.cycle, + tc.numKeys, + actualKeyRingLen, + ) } - if len(publicKeysSamples[i]) != testSet.testCases[i].numPublicKeys { - t.Fatalf("At cycle: %d expected public keys to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numPublicKeys, len(publicKeysSamples[i])) + actualPubKeysLen := len(publicKeysSamples[i]) + if actualPubKeysLen < tc.numPublicKeys { + t.Fatalf( + "For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d", + testSet.namedKey.name, + tc.cycle, + tc.numPublicKeys, + actualPubKeysLen, + ) } } + + if err := storage.Delete(ctx, namedKeyConfigPath+testSet.namedKey.name); err != nil { + t.Fatalf("deleting from in mem storage failed") + } } }