Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force token cache to support audit annotations #90140

Merged
merged 1 commit into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ go_test(
deps = [
"//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/uuid:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/apis/audit:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/audit:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/authentication/user:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
"//vendor/github.com/google/go-cmp/cmp:go_default_library",
"//vendor/github.com/google/uuid:go_default_library",
],
Expand All @@ -37,7 +40,10 @@ go_library(
"//staging/src/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/cache:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/apis/audit:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/audit:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/endpoints/request:go_default_library",
"//staging/src/k8s.io/component-base/metrics:go_default_library",
"//staging/src/k8s.io/component-base/metrics/legacyregistry:go_default_library",
"//vendor/golang.org/x/sync/singleflight:go_default_library",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ import (

apierrors "k8s.io/apimachinery/pkg/api/errors"
utilclock "k8s.io/apimachinery/pkg/util/clock"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2"
)

Expand All @@ -47,6 +50,16 @@ type cacheRecord struct {
resp *authenticator.Response
ok bool
err error

// this cache assumes token authn has no side-effects or temporal dependence.
// neither of these are true for audit annotations set via AddAuditAnnotation.
//
// for audit annotations, the assumption is that for some period of time (cache TTL),
// all requests with the same API audiences and the same bearer token result in the
// same annotations. This may not be true if the authenticator sets an annotation
// based on the current time, but that may be okay since cache TTLs are generally
// small (seconds).
annotations map[string]string
}

type cachedTokenAuthenticator struct {
Expand Down Expand Up @@ -109,6 +122,17 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL,

// AuthenticateToken implements authenticator.Token
func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
record := a.doAuthenticateToken(ctx, token)
if !record.ok || record.err != nil {
return nil, false, record.err
}
for key, value := range record.annotations {
audit.AddAuditAnnotation(ctx, key, value)
}
return record.resp, true, nil
}

func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, token string) *cacheRecord {
doneAuthenticating := stats.authenticating()

auds, audsOk := authenticator.AudiencesFrom(ctx)
Expand All @@ -117,39 +141,40 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
if record, ok := a.cache.get(key); ok {
// Record cache hit
doneAuthenticating(true)
return record.resp, record.ok, record.err
return record
}

// Record cache miss
doneBlocking := stats.blocking()
defer doneBlocking()
defer doneAuthenticating(false)

type lookup struct {
resp *authenticator.Response
ok bool
}
c := a.group.DoChan(key, func() (val interface{}, _ error) {
// always use one place to read and write the output of AuthenticateToken
record := &cacheRecord{}

c := a.group.DoChan(key, func() (val interface{}, err error) {
doneFetching := stats.fetching()
// We're leaving the request handling stack so we need to handle crashes
// ourselves. Log a stack trace and return a 500 if something panics.
defer func() {
if r := recover(); r != nil {
err = errAuthnCrash
// make sure to always return a record
record.err = errAuthnCrash
val = record

// Same as stdlib http server code. Manually allocate stack
// trace buffer size to prevent excessively large logs
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
klog.Errorf("%v\n%s", r, buf)
}
doneFetching(err == nil)
doneFetching(record.err == nil)
}()

// Check again for a cached record. We may have raced with a fetch.
if record, ok := a.cache.get(key); ok {
return lookup{record.resp, record.ok}, record.err
return record, nil
}

// Detach the context because the lookup may be shared by multiple callers,
Expand All @@ -161,29 +186,35 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
ctx = authenticator.WithAudiences(ctx, auds)
}

resp, ok, err := a.authenticator.AuthenticateToken(ctx, token)
if !a.cacheErrs && err != nil {
return nil, err
// since this is shared work between multiple requests, we have no way of knowing if any
// particular request supports audit annotations. thus we always attempt to record them.
ev := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx = request.WithAuditEvent(ctx, ev)

record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token)
record.annotations = ev.Annotations

if !a.cacheErrs && record.err != nil {
return record, nil
}

switch {
case ok && a.successTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
case !ok && a.failureTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
case record.ok && a.successTTL > 0:
a.cache.set(key, record, a.successTTL)
case !record.ok && a.failureTTL > 0:
a.cache.set(key, record, a.failureTTL)
}
return lookup{resp, ok}, err

return record, nil
})

select {
case result := <-c:
if result.Err != nil {
return nil, false, result.Err
}
lookup := result.Val.(lookup)
return lookup.resp, lookup.ok, nil
// we always set Val and never set Err
return result.Val.(*cacheRecord)
case <-ctx.Done():
return nil, false, ctx.Err()
// fake a record on context cancel
return &cacheRecord{err: ctx.Err()}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ import (
"github.com/google/go-cmp/cmp"
utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/uuid"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/endpoints/request"
)

func TestCachedTokenAuthenticator(t *testing.T) {
Expand Down Expand Up @@ -274,6 +277,144 @@ func TestSharedLookup(t *testing.T) {
})
}

func TestCachedAuditAnnotations(t *testing.T) {
snorlax := &authenticator.Response{User: &user.DefaultInfo{Name: "snorlax"}}

t.Run("annotations from cache", func(t *testing.T) {
var lookups uint32
c := make(chan struct{})
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
<-c
atomic.AddUint32(&lookups, 1)
audit.AddAuditAnnotation(ctx, "snorlax", "rocks")
audit.AddAuditAnnotation(ctx, "pandas", "are amazing")
return snorlax, true, nil
}), false, time.Minute, 0)

allAnnotations := make(chan map[string]string, 10)
defer close(allAnnotations)

var wg sync.WaitGroup
for i := 0; i < cap(allAnnotations); i++ {
wg.Add(1)
go func() {
defer wg.Done()

// exercise both ways of tracking audit annotations
r := mathrand.New(mathrand.NewSource(mathrand.Int63()))
randomChoice := r.Int()%2 == 0
ctx := context.Background()

if randomChoice {
ctx = audit.WithAuditAnnotations(ctx)
} else {
ctx = request.WithAuditEvent(ctx, &auditinternal.Event{Level: auditinternal.LevelMetadata})
}

_, _, _ = a.AuthenticateToken(ctx, "token")

if randomChoice {
allAnnotations <- extractAnnotations(ctx)
} else {
allAnnotations <- request.AuditEventFrom(ctx).Annotations
}
}()
}

// no good way to make sure that all the callers are queued so we sleep.
time.Sleep(1 * time.Second)
close(c)
wg.Wait()

want := map[string]string{"snorlax": "rocks", "pandas": "are amazing"}
for i := 0; i < cap(allAnnotations); i++ {
annotations := <-allAnnotations
if diff := cmp.Diff(want, annotations); diff != "" {
t.Errorf("%d: unexpected annotations (-want +got): %s", i, diff)
}
}

if queued := len(allAnnotations); queued != 0 {
t.Errorf("expected all annoations to be processed: %d", queued)
}

if lookups > 3 {
t.Errorf("unexpected number of lookups: got=%d, wanted less than 3", lookups)
}
})

t.Run("annotations do not change during cache TTL", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
audit.AddAuditAnnotation(ctx, "timestamp", time.Now().String())
return snorlax, true, nil
}), false, time.Minute, 0)

allAnnotations := make([]map[string]string, 0, 10)

for i := 0; i < cap(allAnnotations); i++ {
ctx := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx, "token")
allAnnotations = append(allAnnotations, extractAnnotations(ctx))
}

if len(allAnnotations) != cap(allAnnotations) {
t.Errorf("failed to process all annotations")
}

want := allAnnotations[0]
if ok := len(want) == 1 && len(want["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations: %v", want)
}

for i, annotations := range allAnnotations[1:] {
if diff := cmp.Diff(want, annotations); diff != "" {
t.Errorf("%d: unexpected annotations (-want +got): %s", i, diff)
}
}
})

t.Run("different tokens can have different annotations", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
audit.AddAuditAnnotation(ctx, "timestamp", time.Now().String())
return snorlax, true, nil
}), false, time.Minute, 0)

ctx1 := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx1, "token1")
annotations1 := extractAnnotations(ctx1)

// guarantee different now times
time.Sleep(time.Second)

ctx2 := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx2, "token2")
annotations2 := extractAnnotations(ctx2)

if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 1: %v", annotations1)
}
if ok := len(annotations2) == 1 && len(annotations2["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 2: %v", annotations2)
}

if annotations1["timestamp"] == annotations2["timestamp"] {
t.Errorf("annotations should have different timestamp value: %v", annotations1)
}
})
}

func extractAnnotations(ctx context.Context) map[string]string {
annotationsSlice := reflect.ValueOf(ctx).Elem().FieldByName("val").Elem().Elem()
annotations := map[string]string{}
for i := 0; i < annotationsSlice.Len(); i++ {
annotation := annotationsSlice.Index(i)
key := annotation.FieldByName("key").String()
val := annotation.FieldByName("value").String()
annotations[key] = val
}
return annotations
}

func BenchmarkCachedTokenAuthenticator(b *testing.B) {
tokenCount := []int{100, 500, 2500, 12500, 62500}
threadCount := []int{1, 16, 256}
Expand Down Expand Up @@ -318,6 +459,8 @@ func (s *singleBenchmark) makeTokens() {
s.tokenToAuds = map[string]authenticator.Audiences{}
s.tokens = []string{}

rr := mathrand.New(mathrand.NewSource(mathrand.Int63()))

for i := 0; i < s.tokenCount; i++ {
tok := fmt.Sprintf("%v-%v", jwtToken, i)
r := cacheRecord{
Expand All @@ -327,14 +470,23 @@ func (s *singleBenchmark) makeTokens() {
}
// make different combinations of audience, failures, denies for the tokens.
auds := []string{}
for i := 0; i < mathrand.Intn(4); i++ {
for i := 0; i < rr.Intn(4); i++ {
auds = append(auds, string(uuid.NewUUID()))
}
choice := mathrand.Float64()
choice := rr.Float64()
switch {
case choice < 0.9:
r.ok = true
r.err = nil

// add some realistic annotations on ~20% of successful authentications
if f := rr.Float64(); f < 0.2 {
r.annotations = map[string]string{
"audience.authentication.kubernetes.io": "e8357258-88b1-11ea-bc55-0242ac130003",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embed the float64 in an annotation with Sprint.

"namespace.authentication.kubernetes.io": "kube-system",
"float.authentication.kubernetes.io": fmt.Sprint(f),
}
}
case choice < 0.99:
r.ok = false
r.err = nil
Expand All @@ -355,6 +507,9 @@ func (s *singleBenchmark) lookup(ctx context.Context, token string) (*authentica
if !ok {
panic("test setup problem")
}
for key, val := range r.annotations {
audit.AddAuditAnnotation(ctx, key, val)
}
return r.resp, r.ok, r.err
}

Expand Down