Skip to content

Commit

Permalink
Update nil signing key handling
Browse files Browse the repository at this point in the history
- bypass setting ExpireAt if signing key is nil in rotate
- return err if singing key is nil in signPayload
  • Loading branch information
fairclothjm committed Jan 24, 2022
1 parent 9a9eba6 commit 7fc38ac
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 84 deletions.
53 changes: 26 additions & 27 deletions vault/identity_store_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,9 @@ func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logge
}

func (k *namedKey) signPayload(payload []byte) (string, error) {
if k.SigningKey == nil {
return "", fmt.Errorf("signing key is nil")
}
signingKey := jose.SigningKey{Key: k.SigningKey, Algorithm: jose.SignatureAlgorithm(k.Algorithm)}
signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{})
if err != nil {
Expand Down Expand Up @@ -1491,12 +1494,22 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
if k.SigningKey == nil {
logger.Debug("nil signing key detected on rotation")
err := k.generateAndSetKey(ctx, logger, s)
if err != nil {
return err
verificationTTL := k.VerificationTTL
if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}

now := time.Now()
if k.SigningKey != nil {
// set the previous public key's expiry time
for _, key := range k.KeyRing {
if key.KeyID == k.SigningKey.KeyID {
key.ExpireAt = now.Add(verificationTTL)
break
}
}
} else {
logger.Debug("nil signing key detected on rotation")
}

if k.NextSigningKey == nil {
Expand All @@ -1509,20 +1522,6 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St
}
}

verificationTTL := k.VerificationTTL
if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}

now := time.Now()
// set the previous public key's expiry time
for _, key := range k.KeyRing {
if key.KeyID == k.SigningKey.KeyID {
key.ExpireAt = now.Add(verificationTTL)
break
}
}

// do the rotation
k.SigningKey = k.NextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)
Expand Down Expand Up @@ -1714,21 +1713,21 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
return now, err
}

namedKeys, err := s.List(ctx, namedKeyConfigPath)
keyNames, err := s.List(ctx, namedKeyConfigPath)
if err != nil {
return now, err
}

usedKeys := make([]string, 0)

for _, k := range namedKeys {
entry, err := s.Get(ctx, namedKeyConfigPath+k)
for _, keyName := range keyNames {
entry, err := s.Get(ctx, namedKeyConfigPath+keyName)
if err != nil {
return now, err
}

if entry == nil {
i.Logger().Warn("could not find key to update", "key", k)
i.Logger().Warn("could not find key to update", "key", keyName)
continue
}

Expand All @@ -1741,14 +1740,14 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
keyRing := key.KeyRing
var keyringUpdated bool

for i := 0; i < len(keyRing); i++ {
k := keyRing[i]
for j := 0; j < len(keyRing); j++ {
k := keyRing[j]
if !k.ExpireAt.IsZero() && k.ExpireAt.Before(now) {
keyRing[i] = keyRing[len(keyRing)-1]
keyRing[j] = keyRing[len(keyRing)-1]
keyRing = keyRing[:len(keyRing)-1]

keyringUpdated = true
i--
j--
continue
}

Expand Down
206 changes: 149 additions & 57 deletions vault/identity_store_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ package vault

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"strconv"
"strings"
"testing"
"time"

"github.com/go-test/deep"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework"
Expand Down Expand Up @@ -893,99 +891,193 @@ func TestOIDC_SignIDToken(t *testing.T) {
}
}

// TestOIDC_SignIDToken_NilSigningKey
func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)
// storage := &logical.InmemStorage{}

// Create and load an entity, an entity is required to generate an ID token
testEntity := &identity.Entity{
Name: "test-entity-name",
ID: "test-entity-id",
BucketKey: "test-entity-bucket-key",
}

txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
if err != nil {
t.Fatal(err)
}
txn.Commit()

// Create a test key "test-key" with a nil SigningKey
namedKey := &namedKey{
name: "test-key",
AllowedClientIDs: []string{"*"},
Algorithm: "RS256",
VerificationTTL: 60 * time.Second,
RotationPeriod: 60 * time.Second,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
}
s := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc")
if err := namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), s); err != nil {
t.Fatalf("failed to set next signing key")
}
// Store namedKey
entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+namedKey.name, namedKey)
if err := s.Put(ctx, entry); err != nil {
t.Fatalf("writing to in mem storage failed")
}

// Create a test role "test-role" -- expect no warning
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/role/test-role",
Operation: logical.CreateOperation,
Data: map[string]interface{}{
"key": "test-key",
"ttl": "1m",
},
Storage: s,
})
expectSuccess(t, resp, err)
if resp != nil {
t.Fatalf("was expecting a nil response but instead got: %#v", resp)
}

// Generate a token against the role "test-role" -- should fail
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/token/test-role",
Operation: logical.ReadOperation,
Storage: s,
EntityID: "test-entity-id",
})
expectError(t, resp, err)
// validate error message
expectedStrings := map[string]interface{}{
"error signing OIDC token: signing key is nil": true,
}
expectStrings(t, []string{err.Error()}, expectedStrings)
}

// TestOIDC_PeriodicFunc tests timing logic for running key
// rotations and expiration actions.
func TestOIDC_PeriodicFunc(t *testing.T) {
// Prepare a storage to run through periodicFunc
c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)

// Prepare a dummy signing key
key, _ := rsa.GenerateKey(rand.Reader, 2048)
id, _ := uuid.GenerateUUID()
jwk := &jose.JSONWebKey{
Key: key,
KeyID: id,
Algorithm: "RS256",
Use: "sig",
}

cyclePeriod := 2 * time.Second

testSets := []struct {
namedKey *namedKey
testCases []struct {
cycle int
numKeys int
numPublicKeys int
}
namedKey *namedKey
expectedKeyCount int
setSigningKey bool
setNextSigningKey bool
cycle []int
}{
{
// don't set NextSigningKey to ensure its non-existence can be handled
&namedKey{
namedKey: &namedKey{
name: "test-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: jwk,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
},
[]struct {
cycle int
numKeys int
numPublicKeys int
}{
{1, 2, 2},
{2, 2, 2},
{3, 2, 2},
{4, 2, 2},
{5, 2, 2},
{6, 2, 2},
{7, 2, 2},
},
expectedKeyCount: 3,
setSigningKey: true,
setNextSigningKey: true,
cycle: []int{1, 2, 3, 4},
},
{
// don't set SigningKey to ensure its non-existence can be handled
&namedKey{
namedKey: &namedKey{
name: "test-key-nil-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: append([]*expireableKey{}, &expireableKey{KeyID: id}),
KeyRing: nil,
SigningKey: nil,
NextSigningKey: jwk,
NextSigningKey: nil,
NextRotation: time.Now(),
},
[]struct {
cycle int
numKeys int
numPublicKeys int
}{
{1, 2, 2},
expectedKeyCount: 2,
setSigningKey: false,
setNextSigningKey: true,
cycle: []int{1, 2},
},
{
// don't set NextSigningKey to ensure its non-existence can be handled
namedKey: &namedKey{
name: "test-key-nil-next-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
},
expectedKeyCount: 2,
setSigningKey: true,
setNextSigningKey: false,
cycle: []int{1, 2},
},
{
// don't set keys to ensure non-existence can be handled
namedKey: &namedKey{
name: "test-key-nil-signing-and-next-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
},
expectedKeyCount: 2,
setSigningKey: false,
setNextSigningKey: false,
cycle: []int{1, 2},
},
}

for _, testSet := range testSets {
// Store namedKey
storage := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc")
if testSet.setSigningKey {
if err := testSet.namedKey.generateAndSetKey(ctx, hclog.NewNullLogger(), storage); err != nil {
t.Fatalf("failed to set signing key")
}
}
if testSet.setNextSigningKey {
if err := testSet.namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), storage); err != nil {
t.Fatalf("failed to set next signing key")
}
}
// Store namedKey
entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+testSet.namedKey.name, testSet.namedKey)
if err := storage.Put(ctx, entry); err != nil {
t.Fatalf("writing to in mem storage failed")
}

currentCycle := 1
numCases := len(testSet.testCases)
lastCycle := testSet.testCases[numCases-1].cycle
numCases := len(testSet.cycle)
lastCycle := testSet.cycle[numCases-1]
namedKeySamples := make([]*logical.StorageEntry, numCases)
publicKeysSamples := make([][]string, numCases)

i := 0
// var start time.Time
for currentCycle <= lastCycle {
c.identityStore.oidcPeriodicFunc(ctx)
if currentCycle == testSet.testCases[i].cycle {
if currentCycle == testSet.cycle[i] {
namedKeyEntry, _ := storage.Get(ctx, namedKeyConfigPath+testSet.namedKey.name)
publicKeysEntry, _ := storage.List(ctx, publicKeysConfigPath)
namedKeySamples[i] = namedKeyEntry
Expand All @@ -1005,25 +1097,25 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
}

// measure collected samples
for i, tc := range testSet.testCases {
for i, cycle := range testSet.cycle {
namedKeySamples[i].DecodeJSON(&testSet.namedKey)
actualKeyRingLen := len(testSet.namedKey.KeyRing)
if actualKeyRingLen < tc.numKeys {
t.Fatalf(
if actualKeyRingLen < testSet.expectedKeyCount {
t.Errorf(
"For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d",
testSet.namedKey.name,
tc.cycle,
tc.numKeys,
cycle,
testSet.expectedKeyCount,
actualKeyRingLen,
)
}
actualPubKeysLen := len(publicKeysSamples[i])
if actualPubKeysLen < tc.numPublicKeys {
t.Fatalf(
if actualPubKeysLen < testSet.expectedKeyCount {
t.Errorf(
"For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d",
testSet.namedKey.name,
tc.cycle,
tc.numPublicKeys,
cycle,
testSet.expectedKeyCount,
actualPubKeysLen,
)
}
Expand Down

0 comments on commit 7fc38ac

Please sign in to comment.