From 7584f4a55f9b0ae6f6ff45d7933ba407f43a77f8 Mon Sep 17 00:00:00 2001 From: Prachi Pendse Date: Tue, 19 Mar 2024 20:16:00 -0700 Subject: [PATCH] fix: store constraint status audit results in sorted order (#3293) Signed-off-by: Prachi Pendse --- pkg/audit/manager.go | 216 +++++++++++++++++++++++--------------- pkg/audit/manager_test.go | 65 ++++++++++++ 2 files changed, 198 insertions(+), 83 deletions(-) diff --git a/pkg/audit/manager.go b/pkg/audit/manager.go index 16bb6a3f487..8859c429149 100644 --- a/pkg/audit/manager.go +++ b/pkg/audit/manager.go @@ -1,6 +1,7 @@ package audit import ( + "container/heap" "context" "encoding/json" "errors" @@ -9,7 +10,6 @@ import ( "io" "os" "path" - "sort" "strconv" "strings" "time" @@ -69,7 +69,7 @@ var ( apiCacheDir = flag.String("api-cache-dir", defaultAPICacheDir, "The directory where audit from api server cache are stored, defaults to /tmp/audit") auditConnection = flag.String("audit-connection", defaultConnection, "Connection name for publishing audit violation messages. Defaults to audit-connection") auditChannel = flag.String("audit-channel", defaultChannel, "Channel name for publishing audit violation messages. Defaults to audit-channel") - emptyAuditResults []updateListEntry + emptyAuditResults = newLimitQueue(0) logStatsAudit = flag.Bool("log-stats-audit", false, "(alpha) log stats metrics for the audit run") ) @@ -126,47 +126,98 @@ type PubsubMsg struct { ResourceLabels map[string]string `json:"resourceLabels,omitempty"` } -// updateListEntry holds the information necessary to update the -// audit results in the `status` field of the constraint template. -// Adding data to this struct has a large impact on memory usage. -type updateListEntry struct { - group string - version string - kind string - namespace string - name string - msg string - enforcementAction util.EnforcementAction -} +// A max PriorityQueue implements heap.Interface and holds StatusViolation. +type SVQueue []*StatusViolation + +func (svq SVQueue) Len() int { return len(svq) } -// ByGVKNNMsg implements sort.Interface based on the group, version, kind, name, namespace, and msg fields. -type byGVKNNMsg []updateListEntry +// Implements sort.Interface based on the group, version, kind, namespace, name, message and enforcement action fields. +// For Pop to give us the highest priority, use greater than here. +func (svq SVQueue) Less(i, j int) bool { + if svq[i].Group != svq[j].Group { + return svq[i].Group > svq[j].Group + } + if svq[i].Version != svq[j].Version { + return svq[i].Version > svq[j].Version + } + if svq[i].Kind != svq[j].Kind { + return svq[i].Kind > svq[j].Kind + } + if svq[i].Namespace != svq[j].Namespace { + return svq[i].Namespace > svq[j].Namespace + } + if svq[i].Name != svq[j].Name { + return svq[i].Name > svq[j].Name + } + if svq[i].Message != svq[j].Message { + return svq[i].Message > svq[j].Message + } + return svq[i].EnforcementAction > svq[j].EnforcementAction +} -func (a byGVKNNMsg) Len() int { - return len(a) +func (svq SVQueue) Swap(i, j int) { + svq[i], svq[j] = svq[j], svq[i] } -func (a byGVKNNMsg) Less(i, j int) bool { - if a[i].group != a[j].group { - return a[i].group < a[j].group +func (svq *SVQueue) Push(x any) { + sv, ok := x.(*StatusViolation) + if !ok { + return } - if a[i].version != a[j].version { - return a[i].version < a[j].version + *svq = append(*svq, sv) +} + +func (svq *SVQueue) Pop() any { + old := *svq + n := len(old) + sv := old[n-1] + old[n-1] = nil + *svq = old[:n-1] + return sv +} + +// LimitQueue implements logic to ensure priority queue len <= limit in order to provide performance guarantees on heap methods. +type LimitQueue struct { + limit int + svq SVQueue +} + +func newLimitQueue(l int) *LimitQueue { + lq := LimitQueue{ + limit: l, + svq: make(SVQueue, 0, l), } - if a[i].kind != a[j].kind { - return a[i].kind < a[j].kind + heap.Init(&lq.svq) + return &lq +} + +func (lq *LimitQueue) Len() int { return lq.svq.Len() } + +func (lq *LimitQueue) Push(x *StatusViolation) { + heap.Push(&lq.svq, x) + for lq.svq.Len() > lq.limit { + heap.Pop(&lq.svq) } - if a[i].namespace != a[j].namespace { - return a[i].namespace < a[j].namespace +} + +func (lq *LimitQueue) Pop() *StatusViolation { + if lq.Len() == 0 { + return &StatusViolation{} } - if a[i].name != a[j].name { - return a[i].name < a[j].name + sv, ok := heap.Pop(&lq.svq).(*StatusViolation) + if !ok { + return &StatusViolation{} } - return a[i].msg < a[j].msg + return sv } -func (a byGVKNNMsg) Swap(i, j int) { - a[i], a[j] = a[j], a[i] +func (lq *LimitQueue) Peek() *StatusViolation { + if lq.Len() == 0 { + return nil + } + sv := lq.Pop() + lq.Push(sv) + return sv } // nsCache is used for caching namespaces and their labels. @@ -264,7 +315,7 @@ func (am *Manager) audit(ctx context.Context) error { return nil } - updateLists := make(map[util.KindVersionName][]updateListEntry) + updateLists := make(map[util.KindVersionName]*LimitQueue) totalViolationsPerConstraint := make(map[util.KindVersionName]int64) totalViolationsPerEnforcementAction := make(map[util.EnforcementAction]int64) // resetting total violations per enforcement action @@ -292,8 +343,10 @@ func (am *Manager) audit(ctx context.Context) error { // log constraints with violations for gvknn := range updateLists { - ar := updateLists[gvknn][0] - logConstraint(am.log, &gvknn, ar.enforcementAction, totalViolationsPerConstraint[gvknn]) + ar := updateLists[gvknn].Peek() + if ar != nil { + logConstraint(am.log, &gvknn, ar.EnforcementAction, totalViolationsPerConstraint[gvknn]) + } } for k, v := range totalViolationsPerEnforcementAction { @@ -312,7 +365,7 @@ func (am *Manager) audit(ctx context.Context) error { func (am *Manager) auditResources( ctx context.Context, constraintsGVK []schema.GroupVersionKind, - updateLists map[util.KindVersionName][]updateListEntry, + updateLists map[util.KindVersionName]*LimitQueue, totalViolationsPerConstraint map[util.KindVersionName]int64, totalViolationsPerEnforcementAction map[util.EnforcementAction]int64, timestamp string, @@ -601,7 +654,7 @@ func nsMapFromObjs(objs []unstructured.Unstructured) (map[string]*corev1.Namespa } func (am *Manager) reviewObjects(ctx context.Context, kind string, folderCount int, nsCache *nsCache, - updateLists map[util.KindVersionName][]updateListEntry, + updateLists map[util.KindVersionName]*LimitQueue, totalViolationsPerConstraint map[util.KindVersionName]int64, totalViolationsPerEnforcementAction map[util.EnforcementAction]int64, timestamp string, @@ -804,57 +857,64 @@ func (am *Manager) getAllConstraintKinds() ([]schema.GroupVersionKind, error) { } func (am *Manager) addAuditResponsesToUpdateLists( - updateLists map[util.KindVersionName][]updateListEntry, + updateLists map[util.KindVersionName]*LimitQueue, res []Result, totalViolationsPerConstraint map[util.KindVersionName]int64, totalViolationsPerEnforcementAction map[util.EnforcementAction]int64, timestamp string, ) { for _, r := range res { - key := util.GetUniqueKey(*r.Constraint) + constraint := r.Constraint + key := util.GetUniqueKey(*constraint) + keyQueue, ok := updateLists[key] + if !ok { + keyQueue = newLimitQueue(int(*constraintViolationsLimit)) + updateLists[key] = keyQueue + } + totalViolationsPerConstraint[key]++ - details := r.Metadata["details"] + ea := util.EnforcementAction(r.EnforcementAction) + totalViolationsPerEnforcementAction[ea]++ gvk := r.obj.GroupVersionKind() namespace := r.obj.GetNamespace() name := r.obj.GetName() - uid := r.obj.GetUID() - rv := r.obj.GetResourceVersion() - ea := util.EnforcementAction(r.EnforcementAction) - - // append audit results only if it is below violations limit - if uint(len(updateLists[key])) < *constraintViolationsLimit { - msg := r.Msg - if len(msg) > msgSize { - msg = truncateString(msg, msgSize) - } - entry := updateListEntry{ - group: gvk.Group, - version: gvk.Version, - kind: gvk.Kind, - namespace: namespace, - name: name, - msg: msg, - enforcementAction: ea, - } - updateLists[key] = append(updateLists[key], entry) + msg := r.Msg + if len(msg) > msgSize { + msg = truncateString(msg, msgSize) + } + action := string(ea) + violation := &StatusViolation{ + Group: gvk.Group, + Version: gvk.Version, + Kind: gvk.Kind, + Namespace: namespace, + Name: name, + Message: msg, + EnforcementAction: action, } + // since keyQueue is a LimitQueue, it guarantees len <= limit after a push. + // the limit on size ensures Push() has O(1) time complexity. + keyQueue.Push(violation) - totalViolationsPerEnforcementAction[ea]++ - logViolation(am.log, r.Constraint, ea, gvk, namespace, name, r.Msg, details, r.obj.GetLabels()) + details := r.Metadata["details"] + labels := r.obj.GetLabels() + logViolation(am.log, constraint, ea, gvk, namespace, name, msg, details, labels) if *pubsubController.PubsubEnabled { - err := am.pubsubSystem.Publish(context.Background(), *auditConnection, *auditChannel, violationMsg(r.Constraint, ea, gvk, namespace, name, r.Msg, details, r.obj.GetLabels(), timestamp)) + err := am.pubsubSystem.Publish(context.Background(), *auditConnection, *auditChannel, violationMsg(constraint, ea, gvk, namespace, name, msg, details, labels, timestamp)) if err != nil { am.log.Error(err, "pubsub audit Publishing") } } if *emitAuditEvents { - emitEvent(r.Constraint, timestamp, ea, gvk, namespace, name, rv, r.Msg, am.gkNamespace, uid, am.eventRecorder) + uid := r.obj.GetUID() + rv := r.obj.GetResourceVersion() + emitEvent(constraint, timestamp, ea, gvk, namespace, name, rv, msg, am.gkNamespace, uid, am.eventRecorder) } } } -func (am *Manager) writeAuditResults(ctx context.Context, constraintsGVKs []schema.GroupVersionKind, updateLists map[util.KindVersionName][]updateListEntry, timestamp string, totalViolations map[util.KindVersionName]int64) { +func (am *Manager) writeAuditResults(ctx context.Context, constraintsGVKs []schema.GroupVersionKind, updateLists map[util.KindVersionName]*LimitQueue, timestamp string, totalViolations map[util.KindVersionName]int64) { // if there is a previous reporting thread, close it before starting a new one if am.ucloop != nil { // this is closing the previous audit reporting thread @@ -891,27 +951,17 @@ func (am *Manager) skipExcludedNamespace(obj *unstructured.Unstructured) (bool, return isNamespaceExcluded, err } -func (ucloop *updateConstraintLoop) updateConstraintStatus(ctx context.Context, instance *unstructured.Unstructured, auditResults []updateListEntry, timestamp string, totalViolations int64) error { +func (ucloop *updateConstraintLoop) updateConstraintStatus(ctx context.Context, instance *unstructured.Unstructured, auditResults *LimitQueue, timestamp string, totalViolations int64) error { constraintName := instance.GetName() ucloop.log.Info("updating constraint status", "constraintName", constraintName) - // sort audit results - sort.Sort(byGVKNNMsg(auditResults)) - // create constraint status violations var statusViolations []interface{} - for i := range auditResults { - ar := auditResults[i] // avoid large shallow copy in range loop - // append statusViolations for this constraint until constraintViolationsLimit has reached + for auditResults.Len() > 0 { if uint(len(statusViolations)) < *constraintViolationsLimit { - statusViolations = append(statusViolations, StatusViolation{ - Group: ar.group, - Version: ar.version, - Kind: ar.kind, - Name: ar.name, - Namespace: ar.namespace, - Message: ar.msg, - EnforcementAction: string(ar.enforcementAction), - }) + // Append the maximum statusViolation for this constraint in sort order until constraintViolationsLimit is reached. + statusViolations = append(statusViolations, auditResults.Pop()) + } else { + break // end early if statusViolations is full. } } raw, err := json.Marshal(statusViolations) @@ -976,7 +1026,7 @@ type updateConstraintLoop struct { client client.Client stop chan struct{} stopped chan struct{} - ul map[util.KindVersionName][]updateListEntry + ul map[util.KindVersionName]*LimitQueue ts string tv map[util.KindVersionName]int64 log logr.Logger @@ -1090,7 +1140,7 @@ func logFinish(l logr.Logger) { ) } -func logConstraint(l logr.Logger, gvknn *util.KindVersionName, enforcementAction util.EnforcementAction, totalViolations int64) { +func logConstraint(l logr.Logger, gvknn *util.KindVersionName, enforcementAction string, totalViolations int64) { l.Info( "audit results for constraint", logging.EventType, "constraint_audited", diff --git a/pkg/audit/manager_test.go b/pkg/audit/manager_test.go index 6ab50afc5b3..b74d534d5fd 100644 --- a/pkg/audit/manager_test.go +++ b/pkg/audit/manager_test.go @@ -1,6 +1,7 @@ package audit import ( + "container/heap" "context" "os" "reflect" @@ -24,6 +25,70 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" ) +func Test_SVQueue(t *testing.T) { + sv1 := &StatusViolation{ + Group: "rbac.authorization.k8s.io", + Version: "v1", + Kind: "ClusterRoleBinding", + } + sv2 := &StatusViolation{ + Group: "authorization.k8s.io", + Version: "v1", + Kind: "SubjectAccessReview", + } + sv3 := &StatusViolation{ + Group: "rbac.authorization.k8s.io", + Version: "v1", + Kind: "RoleBinding", + } + + svq := make(SVQueue, 0, 3) + heap.Init(&svq) + // Push into queue in unordered fashion, expect length to be correct, and pop in sort order. + heap.Push(&svq, sv1) + heap.Push(&svq, sv2) + heap.Push(&svq, sv3) + require.EqualValues(t, svq.Len(), 3) + require.EqualValues(t, heap.Pop(&svq), sv3) + require.EqualValues(t, heap.Pop(&svq), sv1) + require.EqualValues(t, heap.Pop(&svq), sv2) + require.EqualValues(t, svq.Len(), 0) +} + +func Test_LimitQueue(t *testing.T) { + sv1 := &StatusViolation{ + Group: "rbac.authorization.k8s.io", + Version: "v1", + Kind: "ClusterRoleBinding", + } + sv2 := &StatusViolation{ + Group: "authorization.k8s.io", + Version: "v1", + Kind: "SubjectAccessReview", + } + sv3 := &StatusViolation{ + Group: "rbac.authorization.k8s.io", + Version: "v1", + Kind: "RoleBinding", + } + + lq := newLimitQueue(2) + // Push into queue in unordered fashion, expect length to stay <= 2, peek the max object, and pop in sort order. + lq.Push(sv1) + lq.Push(sv2) + lq.Push(sv3) + require.EqualValues(t, lq.Len(), 2) + require.EqualValues(t, lq.Peek(), sv1) + require.EqualValues(t, lq.Pop(), sv1) + require.EqualValues(t, lq.Pop(), sv2) + require.EqualValues(t, lq.Len(), 0) + // Ensure that Peek does not add a nil element if the queue is empty. + lq.Peek() + require.EqualValues(t, lq.Len(), 0) + // Ensure that Pop is nil if the queue is empty. + require.EqualValues(t, lq.Pop(), &StatusViolation{}) +} + func Test_auditFromCache(t *testing.T) { podToReview := fakes.Pod(fakes.WithNamespace("test-namespace-1")) podGVK := podToReview.GroupVersionKind()