Skip to content

Commit

Permalink
agent: handle dependencies between cached leases during persistent st…
Browse files Browse the repository at this point in the history
…orage restore
  • Loading branch information
tomhjp committed Oct 7, 2021
1 parent 7136280 commit d53e739
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 56 deletions.
9 changes: 0 additions & 9 deletions command/agent/cache/cacheboltdb/bolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,6 @@ func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, inde
})
}

func getBucketIDs(b *bolt.Bucket) ([][]byte, error) {
ids := [][]byte{}
err := b.ForEach(func(k, v []byte) error {
ids = append(ids, k)
return nil
})
return ids, err
}

// Delete an index (token or lease) by id from bolt storage
func (b *BoltStorage) Delete(id string) error {
return b.db.Update(func(tx *bolt.Tx) error {
Expand Down
3 changes: 3 additions & 0 deletions command/agent/cache/cachememdb/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type Index struct {
// RequestToken is the token used in the request
RequestToken string

// RequestTokenIndexID is the ID of the RequestToken's entry in the cache
RequestTokenIndexID string

// RequestHeader is the header used in the request
RequestHeader http.Header

Expand Down
142 changes: 95 additions & 47 deletions command/agent/cache/lease_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,

index.Lease = secret.LeaseID
index.LeaseToken = req.Token
index.RequestTokenIndexID = entry.ID

index.Type = cacheboltdb.SecretLeaseType

Expand All @@ -381,6 +382,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
parentCtx = entry.RenewCtxInfo.Ctx

index.TokenParent = req.Token
index.RequestTokenIndexID = entry.ID
}

renewCtxInfo = c.createCtxInfo(parentCtx)
Expand Down Expand Up @@ -966,6 +968,11 @@ func (c *LeaseCache) Flush() error {
return nil
}

type indexAndChannel struct {
index *cachememdb.Index
ch chan struct{}
}

// Restore loads the cachememdb from the persistent storage passed in. Loads
// tokens first, since restoring a lease's renewal context and watcher requires
// looking up the token in the cachememdb.
Expand All @@ -982,26 +989,85 @@ func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStora
}
}

// Then process auth leases
authLeases, err := storage.GetByType(ctx, cacheboltdb.AuthLeaseType)
if err != nil {
errors = multierror.Append(errors, err)
} else {
if err := c.restoreLeases(authLeases); err != nil {
var wg sync.WaitGroup
leasesMap := make(map[string]indexAndChannel)

// Fetch all the auth and secret leases we'll process upfront.
// Doing so allows us to process dependencies between the leases if there are any.
// Any lease that depends on another can wait for its parent to be restored before
// being processed itself.
// This algorithm requires that each index ID is globally unique across buckets,
// and that the graph of dependencies is acyclic.
for _, leaseType := range []string{cacheboltdb.AuthLeaseType, cacheboltdb.SecretLeaseType} {
leases, err := storage.GetByType(ctx, leaseType)
if err != nil {
errors = multierror.Append(errors, err)
} else {
for _, lease := range leases {
newIndex, err := cachememdb.Deserialize(lease)
if err != nil {
errors = multierror.Append(errors, err)
continue
}
if _, exists := leasesMap[newIndex.ID]; exists {
// This will only happen if we get a SHA256 hash collision,
// but handle it gracefully by reporting the failure.
errors = multierror.Append(errors, fmt.Errorf("failed to restore lease with id=%s, path=%s due to hash collision", newIndex.ID, newIndex.RequestPath))
continue
}
wg.Add(1)
leasesMap[newIndex.ID] = indexAndChannel{
index: newIndex,
ch: make(chan struct{}),
}
}
}
}

// Then process secret leases
secretLeases, err := storage.GetByType(ctx, cacheboltdb.SecretLeaseType)
if err != nil {
errors = multierror.Append(errors, err)
} else {
if err := c.restoreLeases(secretLeases); err != nil {
errors = multierror.Append(errors, err)
errorsCh := make(chan error)
go func() {
for {
select {
case err, ok := <-errorsCh:
if !ok {
return
}
errors = multierror.Append(errors, err)
}
}
}()

// Now restore the auth and secret leases.
for id, lease := range leasesMap {
go func(id string, lease indexAndChannel) {
defer wg.Done()
defer close(lease.ch)

c.logger.Trace("processing lease", "id", id)
// Check if this lease has already expired
expired, err := c.hasExpired(time.Now().UTC(), lease.index)
if err != nil {
c.logger.Warn("failed to check if lease is expired", "id", id, "error", err)
}
if expired {
return
}

if err := c.restoreLeaseRenewCtx(lease.index, leasesMap); err != nil {
errorsCh <- err
return
}
if err := c.db.Set(lease.index); err != nil {
errorsCh <- err
return
}
c.logger.Trace("restored lease", "id", id, "path", lease.index.RequestPath)
}(id, lease)
}

wg.Wait()
close(errorsCh)

return errors.ErrorOrNil()
}

Expand All @@ -1014,6 +1080,7 @@ func (c *LeaseCache) restoreTokens(tokens [][]byte) error {
errors = multierror.Append(errors, err)
continue
}

newIndex.RenewCtxInfo = c.createCtxInfo(nil)
if err := c.db.Set(newIndex); err != nil {
errors = multierror.Append(errors, err)
Expand All @@ -1025,42 +1092,9 @@ func (c *LeaseCache) restoreTokens(tokens [][]byte) error {
return errors.ErrorOrNil()
}

func (c *LeaseCache) restoreLeases(leases [][]byte) error {
var errors *multierror.Error

for _, lease := range leases {
newIndex, err := cachememdb.Deserialize(lease)
if err != nil {
errors = multierror.Append(errors, err)
continue
}

// Check if this lease has already expired
expired, err := c.hasExpired(time.Now().UTC(), newIndex)
if err != nil {
c.logger.Warn("failed to check if lease is expired", "id", newIndex.ID, "error", err)
}
if expired {
continue
}

if err := c.restoreLeaseRenewCtx(newIndex); err != nil {
errors = multierror.Append(errors, err)
continue
}
if err := c.db.Set(newIndex); err != nil {
errors = multierror.Append(errors, err)
continue
}
c.logger.Trace("restored lease", "id", newIndex.ID, "path", newIndex.RequestPath)
}

return errors.ErrorOrNil()
}

// restoreLeaseRenewCtx re-creates a RenewCtx for an index object and starts
// the watcher go routine
func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error {
func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index, channels map[string]indexAndChannel) error {
if index.Response == nil {
return fmt.Errorf("cached response was nil for %s", index.ID)
}
Expand All @@ -1081,6 +1115,13 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error {
var renewCtxInfo *cachememdb.ContextInfo
switch {
case secret.LeaseID != "":
if parent, ok := channels[index.RequestTokenIndexID]; ok {
c.logger.Trace("waiting for parent token to restore", "id", index.RequestTokenIndexID)
select {
case <-parent.ch:
}
c.logger.Trace("parent token restored", "id", index.RequestTokenIndexID)
}
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
if err != nil {
return err
Expand All @@ -1096,6 +1137,13 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error {
case secret.Auth != nil:
var parentCtx context.Context
if !secret.Auth.Orphan {
if parent, ok := channels[index.RequestTokenIndexID]; ok {
c.logger.Trace("waiting for parent token to restore", "id", index.RequestTokenIndexID)
select {
case <-parent.ch:
}
c.logger.Trace("parent token restored", "id", index.RequestTokenIndexID)
}
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
if err != nil {
return err
Expand Down

0 comments on commit d53e739

Please sign in to comment.