Skip to content

Commit

Permalink
check for nil signing key on rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
fairclothjm committed Jan 19, 2022
1 parent b14f1ed commit f0fc03e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 29 deletions.
55 changes: 37 additions & 18 deletions vault/identity_store_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,19 +548,11 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica

// generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm)
err = key.generateAndSetKey(ctx, i.Logger(), req.Storage)
if err != nil {
return nil, err
}

key.SigningKey = signingKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})

if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)

err = key.generateAndSetNextKey(ctx, i.Logger(), req.Storage)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1013,6 +1005,24 @@ func mergeJSONTemplates(logger hclog.Logger, output map[string]interface{}, temp
return nil
}

// generateAndSetKey will generate new signing and public key pairs and set
// them as the SigningKey.
func (k *namedKey) generateAndSetKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error {
signingKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}

k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})

if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
return nil
}

// generateAndSetNextKey will generate new signing and public key pairs and set
// them as the NextSigningKey.
func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error {
Expand Down Expand Up @@ -1481,8 +1491,25 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL
if k.SigningKey == nil {
logger.Debug("nil signing key detected on rotation")
err := k.generateAndSetKey(ctx, logger, s)
if err != nil {
return err
}
}

if k.NextSigningKey == nil {
logger.Debug("nil next signing key detected on rotation")
// keys will not have a NextSigningKey if they were generated before
// vault 1.9
err := k.generateAndSetNextKey(ctx, logger, s)
if err != nil {
return err
}
}

verificationTTL := k.VerificationTTL
if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}
Expand All @@ -1496,14 +1523,6 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St
}
}

if k.NextSigningKey == nil {
// keys will not have a NextSigningKey if they were generated before
// vault 1.9
err := k.generateAndSetNextKey(ctx, logger, s)
if err != nil {
return err
}
}
// do the rotation
k.SigningKey = k.NextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)
Expand Down
60 changes: 49 additions & 11 deletions vault/identity_store_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,12 +937,32 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
numPublicKeys int
}{
{1, 2, 2},
{2, 3, 3},
{3, 3, 3},
{4, 3, 3},
{5, 3, 3},
{6, 3, 3},
{7, 3, 3},
{2, 2, 2},
{3, 2, 2},
{4, 2, 2},
{5, 2, 2},
{6, 2, 2},
{7, 2, 2},
},
},
{
// don't set SigningKey to ensure its non-existence can be handled
&namedKey{
name: "test-key-nil-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: append([]*expireableKey{}, &expireableKey{KeyID: id}),
SigningKey: nil,
NextSigningKey: jwk,
NextRotation: time.Now(),
},
[]struct {
cycle int
numKeys int
numPublicKeys int
}{
{1, 2, 2},
},
},
}
Expand Down Expand Up @@ -985,15 +1005,33 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
}

// measure collected samples
for i := range testSet.testCases {
for i, tc := range testSet.testCases {
namedKeySamples[i].DecodeJSON(&testSet.namedKey)
if len(testSet.namedKey.KeyRing) != testSet.testCases[i].numKeys {
t.Fatalf("At cycle: %d expected namedKey's KeyRing to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numKeys, len(testSet.namedKey.KeyRing))
actualKeyRingLen := len(testSet.namedKey.KeyRing)
if actualKeyRingLen < tc.numKeys {
t.Fatalf(
"For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d",
testSet.namedKey.name,
tc.cycle,
tc.numKeys,
actualKeyRingLen,
)
}
if len(publicKeysSamples[i]) != testSet.testCases[i].numPublicKeys {
t.Fatalf("At cycle: %d expected public keys to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numPublicKeys, len(publicKeysSamples[i]))
actualPubKeysLen := len(publicKeysSamples[i])
if actualPubKeysLen < tc.numPublicKeys {
t.Fatalf(
"For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d",
testSet.namedKey.name,
tc.cycle,
tc.numPublicKeys,
actualPubKeysLen,
)
}
}

if err := storage.Delete(ctx, namedKeyConfigPath+testSet.namedKey.name); err != nil {
t.Fatalf("deleting from in mem storage failed")
}
}
}

Expand Down

0 comments on commit f0fc03e

Please sign in to comment.