Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 1.3: Fix identity token panic during invalidation #8043

Merged
merged 1 commit into from Dec 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion vault/identity_store.go
Expand Up @@ -315,7 +315,9 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
return

case strings.HasPrefix(key, oidcTokensPrefix):
i.oidcCache.Flush(nil)
if err := i.oidcCache.Flush(noNamespace); err != nil {
i.logger.Error("error flushing oidc cache", "error", err)
}
}
}

Expand Down
135 changes: 108 additions & 27 deletions vault/identity_store_oidc.go
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
Expand Down Expand Up @@ -90,6 +91,8 @@ type oidcCache struct {
c *cache.Cache
}

var errNilNamespace = errors.New("nil namespace in oidc cache request")

const (
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
Expand All @@ -111,7 +114,7 @@ var supportedAlgs = []string{
}

// pseudo-namespace for cache items that don't belong to any real namespace.
var nilNamespace = &namespace.Namespace{ID: "__NIL_NAMESPACE"}
var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"}

func oidcPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
Expand Down Expand Up @@ -370,7 +373,9 @@ func (i *IdentityStore) pathOIDCUpdateConfig(ctx context.Context, req *logical.R
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return resp, nil
}
Expand All @@ -381,7 +386,12 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
return nil, err
}

if v, ok := i.oidcCache.Get(ns, "config"); ok {
v, ok, err := i.oidcCache.Get(ns, "config")
if err != nil {
return nil, err
}

if ok {
return v.(*oidcConfig), nil
}

Expand All @@ -404,7 +414,9 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*

c.effectiveIssuer += "/v1/" + ns.Path + issuerPath

i.oidcCache.SetDefault(ns, "config", &c)
if err := i.oidcCache.SetDefault(ns, "config", &c); err != nil {
return nil, err
}

return &c, nil
}
Expand All @@ -416,8 +428,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return nil, err
}

defer i.oidcCache.Flush(ns)

name := d.Get("name").(string)

i.oidcLock.Lock()
Expand Down Expand Up @@ -494,6 +504,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
}
}

if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

// store named key
entry, err := logical.StorageEntryJSON(namedKeyConfigPath+name, key)
if err != nil {
Expand Down Expand Up @@ -590,7 +604,9 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return nil, nil
}
Expand Down Expand Up @@ -645,7 +661,9 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return nil, nil
}
Expand Down Expand Up @@ -683,7 +701,12 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.

var key *namedKey

if keyRaw, found := i.oidcCache.Get(ns, "namedKeys/"+role.Key); found {
keyRaw, found, err := i.oidcCache.Get(ns, "namedKeys/"+role.Key)
if err != nil {
return nil, err
}

if found {
key = keyRaw.(*namedKey)
} else {
entry, _ := req.Storage.Get(ctx, namedKeyConfigPath+role.Key)
Expand All @@ -695,7 +718,9 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
return nil, err
}

i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key)
if err := i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key); err != nil {
return nil, err
}
}
// Validate that the role is allowed to sign with its key (the key could have been updated)
if !strutil.StrListContains(key.AllowedClientIDs, "*") && !strutil.StrListContains(key.AllowedClientIDs, role.ClientID) {
Expand Down Expand Up @@ -923,7 +948,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateRole(ctx context.Context, req *logic
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return nil, nil
}

Expand Down Expand Up @@ -994,7 +1022,12 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return nil, err
}

if v, ok := i.oidcCache.Get(ns, "discoveryResponse"); ok {
v, ok, err := i.oidcCache.Get(ns, "discoveryResponse")
if err != nil {
return nil, err
}

if ok {
data = v.([]byte)
} else {
c, err := i.getOIDCConfig(ctx, req.Storage)
Expand All @@ -1015,7 +1048,9 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return nil, err
}

i.oidcCache.SetDefault(ns, "discoveryResponse", data)
if err := i.oidcCache.SetDefault(ns, "discoveryResponse", data); err != nil {
return nil, err
}
}

resp := &logical.Response{
Expand All @@ -1040,7 +1075,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}

if v, ok := i.oidcCache.Get(ns, "jwksResponse"); ok {
v, ok, err := i.oidcCache.Get(ns, "jwksResponse")
if err != nil {
return nil, err
}

if ok {
data = v.([]byte)
} else {
jwks, err := i.generatePublicJWKS(ctx, req.Storage)
Expand All @@ -1053,7 +1093,9 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}

i.oidcCache.SetDefault(ns, "jwksResponse", data)
if err := i.oidcCache.SetDefault(ns, "jwksResponse", data); err != nil {
return nil, err
}
}

resp := &logical.Response{
Expand All @@ -1072,7 +1114,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}
if len(keys) > 0 {
if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return nil, err
}

if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
Expand Down Expand Up @@ -1311,7 +1358,12 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
return nil, err
}

if jwksRaw, ok := i.oidcCache.Get(ns, "jwks"); ok {
jwksRaw, ok, err := i.oidcCache.Get(ns, "jwks")
if err != nil {
return nil, err
}

if ok {
return jwksRaw.(*jose.JSONWebKeySet), nil
}

Expand All @@ -1336,7 +1388,9 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
jwks.Keys = append(jwks.Keys, *key)
}

i.oidcCache.SetDefault(ns, "jwks", jwks)
if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil {
return nil, err
}

return jwks, nil
}
Expand Down Expand Up @@ -1435,7 +1489,9 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
}

if didUpdate {
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
i.Logger().Error("error flushing oidc cache", "error", err)
}
}

return nextExpiration, nil
Expand Down Expand Up @@ -1501,7 +1557,13 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {

nsPaths := i.listNamespacePaths()

if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
i.Logger().Error("error reading oidc cache", "err", err)
return
}

if ok {
nextRun = v.(time.Time)
}

Expand Down Expand Up @@ -1531,7 +1593,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Warn("error expiring OIDC public keys", "err", err)
}

i.oidcCache.Flush(nilNamespace)
if err := i.oidcCache.Flush(noNamespace); err != nil {
i.Logger().Error("error flushing oidc cache", "err", err)
}

// re-run at the soonest expiration or rotation time
if nextRotation.Before(nextRun) {
Expand All @@ -1542,7 +1606,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
nextRun = nextExpiration
}
}
i.oidcCache.SetDefault(nilNamespace, "nextRun", nextRun)
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}
}
}

Expand All @@ -1556,20 +1622,35 @@ func (c *oidcCache) nskey(ns *namespace.Namespace, key string) string {
return fmt.Sprintf("v0:%s:%s", ns.ID, key)
}

func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool) {
return c.c.Get(c.nskey(ns, key))
func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool, error) {
if ns == nil {
return nil, false, errNilNamespace
}
v, found := c.c.Get(c.nskey(ns, key))
return v, found, nil
}

func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) {
func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) error {
if ns == nil {
return errNilNamespace
}
c.c.SetDefault(c.nskey(ns, key), obj)

return nil
}

func (c *oidcCache) Flush(ns *namespace.Namespace) {
func (c *oidcCache) Flush(ns *namespace.Namespace) error {
if ns == nil {
return errNilNamespace
}

for itemKey := range c.c.Items() {
if isTargetNamespacedKey(itemKey, []string{nilNamespace.ID, ns.ID}) {
if isTargetNamespacedKey(itemKey, []string{noNamespace.ID, ns.ID}) {
c.c.Delete(itemKey)
}
}

return nil
}

// isTargetNamespacedKey returns true for a properly constructed namespaced key (<version>:<nsID>:<key>)
Expand Down
32 changes: 27 additions & 5 deletions vault/identity_store_oidc_test.go
Expand Up @@ -619,7 +619,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
currentCycle = currentCycle + 1

// sleep until we are in the next cycle - where a next run will happen
v, _ := c.identityStore.oidcCache.Get(nilNamespace, "nextRun")
v, _, _ := c.identityStore.oidcCache.Get(noNamespace, "nextRun")
nextRun := v.(time.Time)
now := time.Now()
diff := nextRun.Sub(now)
Expand Down Expand Up @@ -1012,7 +1012,7 @@ func TestOIDC_isTargetNamespacedKey(t *testing.T) {
func TestOIDC_Flush(t *testing.T) {
c := newOIDCCache()
ns := []*namespace.Namespace{
nilNamespace, //ns[0] is nilNamespace
noNamespace, //ns[0] is nilNamespace
&namespace.Namespace{ID: "ns1"},
&namespace.Namespace{ID: "ns2"},
}
Expand All @@ -1021,7 +1021,9 @@ func TestOIDC_Flush(t *testing.T) {
populateNs := func() {
for i := range ns {
for _, val := range []string{"keyA", "keyB", "keyC"} {
c.SetDefault(ns[i], val, struct{}{})
if err := c.SetDefault(ns[i], val, struct{}{}); err != nil {
t.Fatal(err)
}
}
}
}
Expand Down Expand Up @@ -1052,17 +1054,37 @@ func TestOIDC_Flush(t *testing.T) {

// flushing ns1 should flush ns1 and nilNamespace but not ns2
populateNs()
c.Flush(ns[1])
if err := c.Flush(ns[1]); err != nil {
t.Fatal(err)
}
items := c.c.Items()
verify(items, []*namespace.Namespace{ns[2]}, []*namespace.Namespace{ns[0], ns[1]})

// flushing nilNamespace should flush nilNamespace but not ns1 or ns2
populateNs()
c.Flush(ns[0])
if err := c.Flush(ns[0]); err != nil {
t.Fatal(err)
}
items = c.c.Items()
verify(items, []*namespace.Namespace{ns[1], ns[2]}, []*namespace.Namespace{ns[0]})
}

func TestOIDC_CacheNamespaceNilCheck(t *testing.T) {
cache := newOIDCCache()

if _, _, err := cache.Get(nil, "foo"); err == nil {
t.Fatal("expected error, got nil")
}

if err := cache.SetDefault(nil, "foo", 42); err == nil {
t.Fatal("expected error, got nil")
}

if err := cache.Flush(nil); err == nil {
t.Fatal("expected error, got nil")
}
}

// some helpers
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
t.Helper()
Expand Down