From 3500990517a2d62e31d62914703c2f6a9e6d3f9f Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Tue, 17 Dec 2019 10:43:38 -0800 Subject: [PATCH] Fix identity token panic during invalidation (#8015) * Fix identity token crash during invalidation * Check for nil namespace * Fix test * Add nil check test * Check OIDC cache errors --- vault/identity_store.go | 4 +- vault/identity_store_oidc.go | 135 ++++++++++++++++++++++++------ vault/identity_store_oidc_test.go | 32 +++++-- 3 files changed, 138 insertions(+), 33 deletions(-) diff --git a/vault/identity_store.go b/vault/identity_store.go index 8c1bc31739948..9b6fb7cddf520 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -315,7 +315,9 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) { return case strings.HasPrefix(key, oidcTokensPrefix): - i.oidcCache.Flush(nil) + if err := i.oidcCache.Flush(noNamespace); err != nil { + i.logger.Error("error flushing oidc cache", "error", err) + } } } diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index b88d01ce583f6..c36c5c35c1810 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -8,6 +8,7 @@ import ( "crypto/rsa" "encoding/base64" "encoding/json" + "errors" "fmt" "net/url" "strings" @@ -90,6 +91,8 @@ type oidcCache struct { c *cache.Cache } +var errNilNamespace = errors.New("nil namespace in oidc cache request") + const ( issuerPath = "identity/oidc" oidcTokensPrefix = "oidc_tokens/" @@ -111,7 +114,7 @@ var supportedAlgs = []string{ } // pseudo-namespace for cache items that don't belong to any real namespace. -var nilNamespace = &namespace.Namespace{ID: "__NIL_NAMESPACE"} +var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"} func oidcPaths(i *IdentityStore) []*framework.Path { return []*framework.Path{ @@ -370,7 +373,9 @@ func (i *IdentityStore) pathOIDCUpdateConfig(ctx context.Context, req *logical.R return nil, err } - i.oidcCache.Flush(ns) + if err := i.oidcCache.Flush(ns); err != nil { + return nil, err + } return resp, nil } @@ -381,7 +386,12 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (* return nil, err } - if v, ok := i.oidcCache.Get(ns, "config"); ok { + v, ok, err := i.oidcCache.Get(ns, "config") + if err != nil { + return nil, err + } + + if ok { return v.(*oidcConfig), nil } @@ -404,7 +414,9 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (* c.effectiveIssuer += "/v1/" + ns.Path + issuerPath - i.oidcCache.SetDefault(ns, "config", &c) + if err := i.oidcCache.SetDefault(ns, "config", &c); err != nil { + return nil, err + } return &c, nil } @@ -416,8 +428,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica return nil, err } - defer i.oidcCache.Flush(ns) - name := d.Get("name").(string) i.oidcLock.Lock() @@ -494,6 +504,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica } } + if err := i.oidcCache.Flush(ns); err != nil { + return nil, err + } + // store named key entry, err := logical.StorageEntryJSON(namedKeyConfigPath+name, key) if err != nil { @@ -590,7 +604,9 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ return nil, err } - i.oidcCache.Flush(ns) + if err := i.oidcCache.Flush(ns); err != nil { + return nil, err + } return nil, nil } @@ -645,7 +661,9 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ return nil, err } - i.oidcCache.Flush(ns) + if err := i.oidcCache.Flush(ns); err != nil { + return nil, err + } return nil, nil } @@ -683,7 +701,12 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical. var key *namedKey - if keyRaw, found := i.oidcCache.Get(ns, "namedKeys/"+role.Key); found { + keyRaw, found, err := i.oidcCache.Get(ns, "namedKeys/"+role.Key) + if err != nil { + return nil, err + } + + if found { key = keyRaw.(*namedKey) } else { entry, _ := req.Storage.Get(ctx, namedKeyConfigPath+role.Key) @@ -695,7 +718,9 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical. return nil, err } - i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key) + if err := i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key); err != nil { + return nil, err + } } // Validate that the role is allowed to sign with its key (the key could have been updated) if !strutil.StrListContains(key.AllowedClientIDs, "*") && !strutil.StrListContains(key.AllowedClientIDs, role.ClientID) { @@ -923,7 +948,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateRole(ctx context.Context, req *logic return nil, err } - i.oidcCache.Flush(ns) + if err := i.oidcCache.Flush(ns); err != nil { + return nil, err + } + return nil, nil } @@ -994,7 +1022,12 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ return nil, err } - if v, ok := i.oidcCache.Get(ns, "discoveryResponse"); ok { + v, ok, err := i.oidcCache.Get(ns, "discoveryResponse") + if err != nil { + return nil, err + } + + if ok { data = v.([]byte) } else { c, err := i.getOIDCConfig(ctx, req.Storage) @@ -1015,7 +1048,9 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ return nil, err } - i.oidcCache.SetDefault(ns, "discoveryResponse", data) + if err := i.oidcCache.SetDefault(ns, "discoveryResponse", data); err != nil { + return nil, err + } } resp := &logical.Response{ @@ -1040,7 +1075,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical return nil, err } - if v, ok := i.oidcCache.Get(ns, "jwksResponse"); ok { + v, ok, err := i.oidcCache.Get(ns, "jwksResponse") + if err != nil { + return nil, err + } + + if ok { data = v.([]byte) } else { jwks, err := i.generatePublicJWKS(ctx, req.Storage) @@ -1053,7 +1093,9 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical return nil, err } - i.oidcCache.SetDefault(ns, "jwksResponse", data) + if err := i.oidcCache.SetDefault(ns, "jwksResponse", data); err != nil { + return nil, err + } } resp := &logical.Response{ @@ -1072,7 +1114,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical return nil, err } if len(keys) > 0 { - if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok { + v, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + if err != nil { + return nil, err + } + + if ok { now := time.Now() expireAt := v.(time.Time) if expireAt.After(now) { @@ -1311,7 +1358,12 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag return nil, err } - if jwksRaw, ok := i.oidcCache.Get(ns, "jwks"); ok { + jwksRaw, ok, err := i.oidcCache.Get(ns, "jwks") + if err != nil { + return nil, err + } + + if ok { return jwksRaw.(*jose.JSONWebKeySet), nil } @@ -1336,7 +1388,9 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag jwks.Keys = append(jwks.Keys, *key) } - i.oidcCache.SetDefault(ns, "jwks", jwks) + if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil { + return nil, err + } return jwks, nil } @@ -1435,7 +1489,9 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor } if didUpdate { - i.oidcCache.Flush(ns) + if err := i.oidcCache.Flush(ns); err != nil { + i.Logger().Error("error flushing oidc cache", "error", err) + } } return nextExpiration, nil @@ -1501,7 +1557,13 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { nsPaths := i.listNamespacePaths() - if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok { + v, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + if err != nil { + i.Logger().Error("error reading oidc cache", "err", err) + return + } + + if ok { nextRun = v.(time.Time) } @@ -1531,7 +1593,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { i.Logger().Warn("error expiring OIDC public keys", "err", err) } - i.oidcCache.Flush(nilNamespace) + if err := i.oidcCache.Flush(noNamespace); err != nil { + i.Logger().Error("error flushing oidc cache", "err", err) + } // re-run at the soonest expiration or rotation time if nextRotation.Before(nextRun) { @@ -1542,7 +1606,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { nextRun = nextExpiration } } - i.oidcCache.SetDefault(nilNamespace, "nextRun", nextRun) + if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil { + i.Logger().Error("error setting oidc cache", "err", err) + } } } @@ -1556,20 +1622,35 @@ func (c *oidcCache) nskey(ns *namespace.Namespace, key string) string { return fmt.Sprintf("v0:%s:%s", ns.ID, key) } -func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool) { - return c.c.Get(c.nskey(ns, key)) +func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool, error) { + if ns == nil { + return nil, false, errNilNamespace + } + v, found := c.c.Get(c.nskey(ns, key)) + return v, found, nil } -func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) { +func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) error { + if ns == nil { + return errNilNamespace + } c.c.SetDefault(c.nskey(ns, key), obj) + + return nil } -func (c *oidcCache) Flush(ns *namespace.Namespace) { +func (c *oidcCache) Flush(ns *namespace.Namespace) error { + if ns == nil { + return errNilNamespace + } + for itemKey := range c.c.Items() { - if isTargetNamespacedKey(itemKey, []string{nilNamespace.ID, ns.ID}) { + if isTargetNamespacedKey(itemKey, []string{noNamespace.ID, ns.ID}) { c.c.Delete(itemKey) } } + + return nil } // isTargetNamespacedKey returns true for a properly constructed namespaced key (::) diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 384e5695dee17..4d214c7269c88 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -619,7 +619,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) { currentCycle = currentCycle + 1 // sleep until we are in the next cycle - where a next run will happen - v, _ := c.identityStore.oidcCache.Get(nilNamespace, "nextRun") + v, _, _ := c.identityStore.oidcCache.Get(noNamespace, "nextRun") nextRun := v.(time.Time) now := time.Now() diff := nextRun.Sub(now) @@ -1012,7 +1012,7 @@ func TestOIDC_isTargetNamespacedKey(t *testing.T) { func TestOIDC_Flush(t *testing.T) { c := newOIDCCache() ns := []*namespace.Namespace{ - nilNamespace, //ns[0] is nilNamespace + noNamespace, //ns[0] is nilNamespace &namespace.Namespace{ID: "ns1"}, &namespace.Namespace{ID: "ns2"}, } @@ -1021,7 +1021,9 @@ func TestOIDC_Flush(t *testing.T) { populateNs := func() { for i := range ns { for _, val := range []string{"keyA", "keyB", "keyC"} { - c.SetDefault(ns[i], val, struct{}{}) + if err := c.SetDefault(ns[i], val, struct{}{}); err != nil { + t.Fatal(err) + } } } } @@ -1052,17 +1054,37 @@ func TestOIDC_Flush(t *testing.T) { // flushing ns1 should flush ns1 and nilNamespace but not ns2 populateNs() - c.Flush(ns[1]) + if err := c.Flush(ns[1]); err != nil { + t.Fatal(err) + } items := c.c.Items() verify(items, []*namespace.Namespace{ns[2]}, []*namespace.Namespace{ns[0], ns[1]}) // flushing nilNamespace should flush nilNamespace but not ns1 or ns2 populateNs() - c.Flush(ns[0]) + if err := c.Flush(ns[0]); err != nil { + t.Fatal(err) + } items = c.c.Items() verify(items, []*namespace.Namespace{ns[1], ns[2]}, []*namespace.Namespace{ns[0]}) } +func TestOIDC_CacheNamespaceNilCheck(t *testing.T) { + cache := newOIDCCache() + + if _, _, err := cache.Get(nil, "foo"); err == nil { + t.Fatal("expected error, got nil") + } + + if err := cache.SetDefault(nil, "foo", 42); err == nil { + t.Fatal("expected error, got nil") + } + + if err := cache.Flush(nil); err == nil { + t.Fatal("expected error, got nil") + } +} + // some helpers func expectSuccess(t *testing.T, resp *logical.Response, err error) { t.Helper()