Skip to content

Commit

Permalink
Re-auth if signal got send again after added to cache
Browse files Browse the repository at this point in the history
This change will re-trigger the auth mechanism if a signal got send but
the entry was already in cache.
It adds a 1 second backoff time to allow for the backend map to finish
updating, which is why it was added in the first place.

Signed-off-by: Maartje Eyskens <maartje@eyskens.me>
Signed-off-by: Maartje Eyskens <maartje.eyskens@isovalent.com>
  • Loading branch information
meyskens authored and borkmann committed Aug 25, 2023
1 parent 0eba8bc commit c45dbb9
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 37 deletions.
10 changes: 10 additions & 0 deletions pkg/auth/authmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ type authMap interface {
All() (map[authKey]authInfo, error)
}

type authMapCacher interface {
authMap
GetCacheInfo(key authKey) (authInfoCache, error)
}

type authKey struct {
localIdentity identity.NumericIdentity
remoteIdentity identity.NumericIdentity
Expand All @@ -35,6 +40,11 @@ type authInfo struct {
expiration time.Time
}

type authInfoCache struct {
authInfo
storedAt time.Time
}

func (r authInfo) String() string {
return fmt.Sprintf("expiration=%s", r.expiration)
}
26 changes: 19 additions & 7 deletions pkg/auth/authmap_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package auth
import (
"errors"
"fmt"
"time"

"github.com/cilium/ebpf"
"github.com/sirupsen/logrus"
Expand All @@ -17,15 +18,15 @@ import (
type authMapCache struct {
logger logrus.FieldLogger
authmap authMap
cacheEntries map[authKey]authInfo
cacheEntries map[authKey]authInfoCache
cacheEntriesMutex lock.RWMutex
}

func newAuthMapCache(logger logrus.FieldLogger, authmap authMap) *authMapCache {
return &authMapCache{
logger: logger,
authmap: authmap,
cacheEntries: map[authKey]authInfo{},
cacheEntries: map[authKey]authInfoCache{},
}
}

Expand All @@ -35,18 +36,23 @@ func (r *authMapCache) All() (map[authKey]authInfo, error) {

result := make(map[authKey]authInfo)
for k, v := range r.cacheEntries {
result[k] = v
result[k] = v.authInfo
}
return maps.Clone(result), nil
}

func (r *authMapCache) Get(key authKey) (authInfo, error) {
info, err := r.GetCacheInfo(key)
return info.authInfo, err
}

func (r *authMapCache) GetCacheInfo(key authKey) (authInfoCache, error) {
r.cacheEntriesMutex.RLock()
defer r.cacheEntriesMutex.RUnlock()

info, ok := r.cacheEntries[key]
if !ok {
return authInfo{}, fmt.Errorf("failed to get auth info for key: %s", key)
return authInfoCache{}, fmt.Errorf("failed to get auth info for key: %s", key)
}
return info, nil
}
Expand All @@ -59,7 +65,10 @@ func (r *authMapCache) Update(key authKey, info authInfo) error {
return err
}

r.cacheEntries[key] = info
r.cacheEntries[key] = authInfoCache{
authInfo: info,
storedAt: time.Now(),
}

return nil
}
Expand Down Expand Up @@ -88,7 +97,7 @@ func (r *authMapCache) DeleteIf(predicate func(key authKey, info authInfo) bool)
defer r.cacheEntriesMutex.Unlock()

for k, v := range r.cacheEntries {
if predicate(k, v) {
if predicate(k, v.authInfo) {
// delete every entry individually to keep the cache in sync in case of an error
if err := r.authmap.Delete(k); err != nil {
if !errors.Is(err, ebpf.ErrKeyNotExist) {
Expand All @@ -114,7 +123,10 @@ func (r *authMapCache) restoreCache() error {
return fmt.Errorf("failed to load all auth map entries: %w", err)
}
for k, v := range all {
r.cacheEntries[k] = v
r.cacheEntries[k] = authInfoCache{
authInfo: v,
storedAt: time.Now(),
}
}

r.logger.
Expand Down
29 changes: 18 additions & 11 deletions pkg/auth/authmap_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func Test_authMapCache_restoreCache(t *testing.T) {
},
},
},
cacheEntries: map[authKey]authInfo{},
cacheEntries: map[authKey]authInfoCache{},
}

err := am.restoreCache()
Expand All @@ -52,14 +52,15 @@ func Test_authMapCache_allReturnsCopy(t *testing.T) {
authmap: &fakeAuthMap{
entries: map[authKey]authInfo{},
},
cacheEntries: map[authKey]authInfo{
cacheEntries: map[authKey]authInfoCache{
{
localIdentity: 1,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
},
}
Expand Down Expand Up @@ -96,30 +97,33 @@ func Test_authMapCache_Delete(t *testing.T) {
am := authMapCache{
logger: logrus.New(),
authmap: fakeMap,
cacheEntries: map[authKey]authInfo{
cacheEntries: map[authKey]authInfoCache{
{
localIdentity: 1,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 3,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 4,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
},
}
Expand Down Expand Up @@ -171,30 +175,33 @@ func Test_authMapCache_DeleteIf(t *testing.T) {
am := authMapCache{
logger: logrus.New(),
authmap: fakeMap,
cacheEntries: map[authKey]authInfo{
cacheEntries: map[authKey]authInfoCache{
{
localIdentity: 1,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 3,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 4,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
},
}
Expand Down
18 changes: 11 additions & 7 deletions pkg/auth/cell.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,27 @@ var Cell = cell.Module(
newAlwaysFailAuthHandler,
),
cell.Config(config{
MeshAuthEnabled: true,
MeshAuthQueueSize: 1024,
MeshAuthGCInterval: 5 * time.Minute,
MeshAuthEnabled: true,
MeshAuthQueueSize: 1024,
MeshAuthGCInterval: 5 * time.Minute,
MeshAuthSignalBackoffDuration: 1 * time.Second, // this default is based on the default TCP retransmission timeout
}),
cell.Config(MutualAuthConfig{}),
)

type config struct {
MeshAuthEnabled bool
MeshAuthQueueSize int
MeshAuthGCInterval time.Duration
MeshAuthEnabled bool
MeshAuthQueueSize int
MeshAuthGCInterval time.Duration
MeshAuthSignalBackoffDuration time.Duration
}

func (r config) Flags(flags *pflag.FlagSet) {
flags.Bool("mesh-auth-enabled", r.MeshAuthEnabled, "Enable authentication processing & garbage collection (beta)")
flags.Int("mesh-auth-queue-size", r.MeshAuthQueueSize, "Queue size for the auth manager")
flags.Duration("mesh-auth-gc-interval", r.MeshAuthGCInterval, "Interval in which auth entries are attempted to be garbage collected")
flags.Duration("mesh-auth-signal-backoff-duration", r.MeshAuthSignalBackoffDuration, "Time to wait betweeen two authentication required signals in case of a cache mismatch")
flags.MarkHidden("mesh-auth-signal-backoff-duration")
}

type authManagerParams struct {
Expand Down Expand Up @@ -93,7 +97,7 @@ func registerAuthManager(params authManagerParams) (*AuthManager, error) {
mapWriter := newAuthMapWriter(params.Logger, params.AuthMap)
mapCache := newAuthMapCache(params.Logger, mapWriter)

mgr, err := newAuthManager(params.Logger, params.AuthHandlers, mapCache, params.NodeIDHandler)
mgr, err := newAuthManager(params.Logger, params.AuthHandlers, mapCache, params.NodeIDHandler, params.Config.MeshAuthSignalBackoffDuration)
if err != nil {
return nil, fmt.Errorf("failed to create auth manager: %w", err)
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/auth/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ func (key signalAuthKey) String() string {
}

type AuthManager struct {
logger logrus.FieldLogger
nodeIDHandler types.NodeIDHandler
authHandlers map[policy.AuthType]authHandler
authmap authMap
logger logrus.FieldLogger
nodeIDHandler types.NodeIDHandler
authHandlers map[policy.AuthType]authHandler
authmap authMapCacher
authSignalBackoffTime time.Duration

mutex lock.Mutex
pending map[authKey]struct{}
Expand All @@ -56,7 +57,7 @@ type authResponse struct {
expirationTime time.Time
}

func newAuthManager(logger logrus.FieldLogger, authHandlers []authHandler, authmap authMap, nodeIDHandler types.NodeIDHandler) (*AuthManager, error) {
func newAuthManager(logger logrus.FieldLogger, authHandlers []authHandler, authmap authMapCacher, nodeIDHandler types.NodeIDHandler, authSignalBackoffTime time.Duration) (*AuthManager, error) {
ahs := map[policy.AuthType]authHandler{}
for _, ah := range authHandlers {
if ah == nil {
Expand All @@ -75,6 +76,7 @@ func newAuthManager(logger logrus.FieldLogger, authHandlers []authHandler, authm
nodeIDHandler: nodeIDHandler,
pending: make(map[authKey]struct{}),
handleAuthenticationFunc: handleAuthentication,
authSignalBackoffTime: authSignalBackoffTime,
}, nil
}

Expand Down Expand Up @@ -130,10 +132,15 @@ func handleAuthentication(a *AuthManager, k authKey, reAuth bool) {
// Check if the auth is actually required, as we might have
// updated the authmap since the datapath issued the auth
// required signal.
if i, err := a.authmap.Get(key); err == nil && i.expiration.After(time.Now()) {
// If the entry was cached more than authSignalBackoffTime
// it will authenticate again, this is to make sure that
// we re-authenticate if the authmap was updated by an
// external source.
if i, err := a.authmap.GetCacheInfo(key); err == nil && i.expiration.After(time.Now()) && time.Now().Before(i.storedAt.Add(a.authSignalBackoffTime)) {
a.logger.
WithField("key", key).
Debug("Already authenticated, skipping authentication")
WithField("storedAt", i.storedAt).
Debugf("Already authenticated in the past %s, skipping authentication", a.authSignalBackoffTime.String())
return
}
}
Expand Down
19 changes: 14 additions & 5 deletions pkg/auth/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func Test_newAuthManager_clashingAuthHandlers(t *testing.T) {
&alwaysFailAuthHandler{},
}

am, err := newAuthManager(logrus.New(), authHandlers, nil, nil)
am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
assert.ErrorContains(t, err, "multiple handlers for auth type: test-always-fail")
assert.Nil(t, am)
}
Expand All @@ -38,7 +38,7 @@ func Test_newAuthManager(t *testing.T) {
&fakeAuthHandler{},
}

am, err := newAuthManager(logrus.New(), authHandlers, nil, nil)
am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand Down Expand Up @@ -100,6 +100,7 @@ func Test_authManager_authenticate(t *testing.T) {
2: "172.18.0.2",
3: "172.18.0.3",
}),
time.Second,
)

assert.NoError(t, err)
Expand All @@ -115,7 +116,7 @@ func Test_authManager_authenticate(t *testing.T) {
func Test_authManager_handleAuthRequest(t *testing.T) {
authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}

am, err := newAuthManager(logrus.New(), authHandlers, nil, nil)
am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand All @@ -137,7 +138,7 @@ func Test_authManager_handleCertificateRotationEvent_Error(t *testing.T) {
failGet: true,
}

am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil)
am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand All @@ -155,7 +156,7 @@ func Test_authManager_handleCertificateRotationEvent(t *testing.T) {
},
}

am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil)
am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand Down Expand Up @@ -262,6 +263,14 @@ func (r *fakeAuthMap) All() (map[authKey]authInfo, error) {
return r.entries, nil
}

func (r *fakeAuthMap) GetCacheInfo(key authKey) (authInfoCache, error) {
v, err := r.Get(key)

return authInfoCache{
authInfo: v,
}, err
}

func (r *fakeAuthMap) Get(key authKey) (authInfo, error) {
if r.failGet {
return authInfo{}, errors.New("failed to get entry")
Expand Down

0 comments on commit c45dbb9

Please sign in to comment.