Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-auth if signal got send again after added to cache #27241

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions pkg/auth/authmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ type authMap interface {
All() (map[authKey]authInfo, error)
}

type authMapCacher interface {
authMap
GetCacheInfo(key authKey) (authInfoCache, error)
}

type authKey struct {
localIdentity identity.NumericIdentity
remoteIdentity identity.NumericIdentity
Expand All @@ -35,6 +40,11 @@ type authInfo struct {
expiration time.Time
}

type authInfoCache struct {
authInfo
storedAt time.Time
}

func (r authInfo) String() string {
return fmt.Sprintf("expiration=%s", r.expiration)
}
26 changes: 19 additions & 7 deletions pkg/auth/authmap_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package auth
import (
"errors"
"fmt"
"time"

"github.com/cilium/ebpf"
"github.com/sirupsen/logrus"
Expand All @@ -17,15 +18,15 @@ import (
type authMapCache struct {
logger logrus.FieldLogger
authmap authMap
cacheEntries map[authKey]authInfo
cacheEntries map[authKey]authInfoCache
cacheEntriesMutex lock.RWMutex
}

func newAuthMapCache(logger logrus.FieldLogger, authmap authMap) *authMapCache {
return &authMapCache{
logger: logger,
authmap: authmap,
cacheEntries: map[authKey]authInfo{},
cacheEntries: map[authKey]authInfoCache{},
}
}

Expand All @@ -35,18 +36,23 @@ func (r *authMapCache) All() (map[authKey]authInfo, error) {

result := make(map[authKey]authInfo)
for k, v := range r.cacheEntries {
result[k] = v
result[k] = v.authInfo
}
return maps.Clone(result), nil
}

func (r *authMapCache) Get(key authKey) (authInfo, error) {
info, err := r.GetCacheInfo(key)
return info.authInfo, err
}

func (r *authMapCache) GetCacheInfo(key authKey) (authInfoCache, error) {
r.cacheEntriesMutex.RLock()
defer r.cacheEntriesMutex.RUnlock()

info, ok := r.cacheEntries[key]
if !ok {
return authInfo{}, fmt.Errorf("failed to get auth info for key: %s", key)
return authInfoCache{}, fmt.Errorf("failed to get auth info for key: %s", key)
}
return info, nil
}
Expand All @@ -59,7 +65,10 @@ func (r *authMapCache) Update(key authKey, info authInfo) error {
return err
}

r.cacheEntries[key] = info
r.cacheEntries[key] = authInfoCache{
authInfo: info,
storedAt: time.Now(),
}

return nil
}
Expand Down Expand Up @@ -88,7 +97,7 @@ func (r *authMapCache) DeleteIf(predicate func(key authKey, info authInfo) bool)
defer r.cacheEntriesMutex.Unlock()

for k, v := range r.cacheEntries {
if predicate(k, v) {
if predicate(k, v.authInfo) {
// delete every entry individually to keep the cache in sync in case of an error
if err := r.authmap.Delete(k); err != nil {
if !errors.Is(err, ebpf.ErrKeyNotExist) {
Expand All @@ -114,7 +123,10 @@ func (r *authMapCache) restoreCache() error {
return fmt.Errorf("failed to load all auth map entries: %w", err)
}
for k, v := range all {
r.cacheEntries[k] = v
r.cacheEntries[k] = authInfoCache{
authInfo: v,
storedAt: time.Now(),
}
}

r.logger.
Expand Down
29 changes: 18 additions & 11 deletions pkg/auth/authmap_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func Test_authMapCache_restoreCache(t *testing.T) {
},
},
},
cacheEntries: map[authKey]authInfo{},
cacheEntries: map[authKey]authInfoCache{},
}

err := am.restoreCache()
Expand All @@ -52,14 +52,15 @@ func Test_authMapCache_allReturnsCopy(t *testing.T) {
authmap: &fakeAuthMap{
entries: map[authKey]authInfo{},
},
cacheEntries: map[authKey]authInfo{
cacheEntries: map[authKey]authInfoCache{
{
localIdentity: 1,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
},
}
Expand Down Expand Up @@ -96,30 +97,33 @@ func Test_authMapCache_Delete(t *testing.T) {
am := authMapCache{
logger: logrus.New(),
authmap: fakeMap,
cacheEntries: map[authKey]authInfo{
cacheEntries: map[authKey]authInfoCache{
{
localIdentity: 1,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 3,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 4,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
},
}
Expand Down Expand Up @@ -171,30 +175,33 @@ func Test_authMapCache_DeleteIf(t *testing.T) {
am := authMapCache{
logger: logrus.New(),
authmap: fakeMap,
cacheEntries: map[authKey]authInfo{
cacheEntries: map[authKey]authInfoCache{
{
localIdentity: 1,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 3,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
{
localIdentity: 4,
remoteIdentity: 2,
remoteNodeID: 10,
authType: policy.AuthTypeDisabled,
}: {
expiration: time.Now().Add(10 * time.Minute),
authInfo: authInfo{time.Now().Add(10 * time.Minute)},
storedAt: time.Now().Add(-10 * time.Minute),
},
},
}
Expand Down
18 changes: 11 additions & 7 deletions pkg/auth/cell.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,27 @@ var Cell = cell.Module(
newAlwaysFailAuthHandler,
),
cell.Config(config{
MeshAuthEnabled: true,
MeshAuthQueueSize: 1024,
MeshAuthGCInterval: 5 * time.Minute,
MeshAuthEnabled: true,
MeshAuthQueueSize: 1024,
MeshAuthGCInterval: 5 * time.Minute,
MeshAuthSignalBackoffDuration: 1 * time.Second, // this default is based on the default TCP retransmission timeout
}),
cell.Config(MutualAuthConfig{}),
)

type config struct {
MeshAuthEnabled bool
MeshAuthQueueSize int
MeshAuthGCInterval time.Duration
MeshAuthEnabled bool
MeshAuthQueueSize int
MeshAuthGCInterval time.Duration
MeshAuthSignalBackoffDuration time.Duration
}

func (r config) Flags(flags *pflag.FlagSet) {
flags.Bool("mesh-auth-enabled", r.MeshAuthEnabled, "Enable authentication processing & garbage collection (beta)")
flags.Int("mesh-auth-queue-size", r.MeshAuthQueueSize, "Queue size for the auth manager")
flags.Duration("mesh-auth-gc-interval", r.MeshAuthGCInterval, "Interval in which auth entries are attempted to be garbage collected")
flags.Duration("mesh-auth-signal-backoff-duration", r.MeshAuthSignalBackoffDuration, "Time to wait betweeen two authentication required signals in case of a cache mismatch")
flags.MarkHidden("mesh-auth-signal-backoff-duration")
}

type authManagerParams struct {
Expand Down Expand Up @@ -93,7 +97,7 @@ func registerAuthManager(params authManagerParams) (*AuthManager, error) {
mapWriter := newAuthMapWriter(params.Logger, params.AuthMap)
mapCache := newAuthMapCache(params.Logger, mapWriter)

mgr, err := newAuthManager(params.Logger, params.AuthHandlers, mapCache, params.NodeIDHandler)
mgr, err := newAuthManager(params.Logger, params.AuthHandlers, mapCache, params.NodeIDHandler, params.Config.MeshAuthSignalBackoffDuration)
if err != nil {
return nil, fmt.Errorf("failed to create auth manager: %w", err)
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/auth/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ func (key signalAuthKey) String() string {
}

type AuthManager struct {
logger logrus.FieldLogger
nodeIDHandler types.NodeIDHandler
authHandlers map[policy.AuthType]authHandler
authmap authMap
logger logrus.FieldLogger
nodeIDHandler types.NodeIDHandler
authHandlers map[policy.AuthType]authHandler
authmap authMapCacher
authSignalBackoffTime time.Duration

mutex lock.Mutex
pending map[authKey]struct{}
Expand All @@ -56,7 +57,7 @@ type authResponse struct {
expirationTime time.Time
}

func newAuthManager(logger logrus.FieldLogger, authHandlers []authHandler, authmap authMap, nodeIDHandler types.NodeIDHandler) (*AuthManager, error) {
func newAuthManager(logger logrus.FieldLogger, authHandlers []authHandler, authmap authMapCacher, nodeIDHandler types.NodeIDHandler, authSignalBackoffTime time.Duration) (*AuthManager, error) {
ahs := map[policy.AuthType]authHandler{}
for _, ah := range authHandlers {
if ah == nil {
Expand All @@ -75,6 +76,7 @@ func newAuthManager(logger logrus.FieldLogger, authHandlers []authHandler, authm
nodeIDHandler: nodeIDHandler,
pending: make(map[authKey]struct{}),
handleAuthenticationFunc: handleAuthentication,
authSignalBackoffTime: authSignalBackoffTime,
}, nil
}

Expand Down Expand Up @@ -130,10 +132,15 @@ func handleAuthentication(a *AuthManager, k authKey, reAuth bool) {
// Check if the auth is actually required, as we might have
// updated the authmap since the datapath issued the auth
// required signal.
if i, err := a.authmap.Get(key); err == nil && i.expiration.After(time.Now()) {
// If the entry was cached more than authSignalBackoffTime
// it will authenticate again, this is to make sure that
// we re-authenticate if the authmap was updated by an
// external source.
meyskens marked this conversation as resolved.
Show resolved Hide resolved
if i, err := a.authmap.GetCacheInfo(key); err == nil && i.expiration.After(time.Now()) && time.Now().Before(i.storedAt.Add(a.authSignalBackoffTime)) {
a.logger.
WithField("key", key).
Debug("Already authenticated, skipping authentication")
WithField("storedAt", i.storedAt).
Debugf("Already authenticated in the past %s, skipping authentication", a.authSignalBackoffTime.String())
return
}
}
Expand Down
19 changes: 14 additions & 5 deletions pkg/auth/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func Test_newAuthManager_clashingAuthHandlers(t *testing.T) {
&alwaysFailAuthHandler{},
}

am, err := newAuthManager(logrus.New(), authHandlers, nil, nil)
am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
assert.ErrorContains(t, err, "multiple handlers for auth type: test-always-fail")
assert.Nil(t, am)
}
Expand All @@ -38,7 +38,7 @@ func Test_newAuthManager(t *testing.T) {
&fakeAuthHandler{},
}

am, err := newAuthManager(logrus.New(), authHandlers, nil, nil)
am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand Down Expand Up @@ -100,6 +100,7 @@ func Test_authManager_authenticate(t *testing.T) {
2: "172.18.0.2",
3: "172.18.0.3",
}),
time.Second,
)

assert.NoError(t, err)
Expand All @@ -115,7 +116,7 @@ func Test_authManager_authenticate(t *testing.T) {
func Test_authManager_handleAuthRequest(t *testing.T) {
authHandlers := []authHandler{newAlwaysPassAuthHandler(logrus.New())}

am, err := newAuthManager(logrus.New(), authHandlers, nil, nil)
am, err := newAuthManager(logrus.New(), authHandlers, nil, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand All @@ -137,7 +138,7 @@ func Test_authManager_handleCertificateRotationEvent_Error(t *testing.T) {
failGet: true,
}

am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil)
am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand All @@ -155,7 +156,7 @@ func Test_authManager_handleCertificateRotationEvent(t *testing.T) {
},
}

am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil)
am, err := newAuthManager(logrus.New(), authHandlers, aMap, nil, time.Second)
assert.NoError(t, err)
assert.NotNil(t, am)

Expand Down Expand Up @@ -262,6 +263,14 @@ func (r *fakeAuthMap) All() (map[authKey]authInfo, error) {
return r.entries, nil
}

func (r *fakeAuthMap) GetCacheInfo(key authKey) (authInfoCache, error) {
v, err := r.Get(key)

return authInfoCache{
authInfo: v,
}, err
}

func (r *fakeAuthMap) Get(key authKey) (authInfo, error) {
if r.failGet {
return authInfo{}, errors.New("failed to get entry")
Expand Down