diff --git a/constraint/pkg/client/client.go b/constraint/pkg/client/client.go index a7e7f494e..7da1604aa 100644 --- a/constraint/pkg/client/client.go +++ b/constraint/pkg/client/client.go @@ -58,6 +58,8 @@ func createDataPath(target, subpath string) string { // On error, the responses return value will still be populated so that // partial results can be analyzed. func (c *Client) AddData(ctx context.Context, data interface{}) (*types.Responses, error) { + // TODO(#189): Make AddData atomic across all Drivers/Targets. + resp := types.NewResponses() errMap := make(clienterrors.ErrorMap) for target, h := range c.targets { @@ -69,12 +71,40 @@ func (c *Client) AddData(ctx context.Context, data interface{}) (*types.Response if !handled { continue } - if err := c.driver.PutData(ctx, createDataPath(target, relPath), processedData); err != nil { + + var cache handler.Cache + if cacher, ok := h.(handler.Cacher); ok { + cache = cacher.GetCache() + } + + // Add to the target cache first because cache.Remove cannot fail. Thus, we + // can prevent the system from getting into an inconsistent state. + if cache != nil { + err = cache.Add(relPath, processedData) + if err != nil { + // Use a different key than the driver to avoid clobbering errors. + errMap[target] = err + + continue + } + } + + // paths passed to driver must be specific to the target to prevent key + // collisions. + driverPath := createDataPath(target, relPath) + err = c.driver.PutData(ctx, driverPath, processedData) + if err != nil { errMap[target] = err + + if cache != nil { + cache.Remove(relPath) + } continue } + resp.Handled[target] = true } + if len(errMap) == 0 { return resp, nil } @@ -96,15 +126,24 @@ func (c *Client) RemoveData(ctx context.Context, data interface{}) (*types.Respo if !handled { continue } + if _, err := c.driver.DeleteData(ctx, createDataPath(target, relPath)); err != nil { errMap[target] = err continue } resp.Handled[target] = true + + if cacher, ok := h.(handler.Cacher); ok { + cache := cacher.GetCache() + + cache.Remove(relPath) + } } + if len(errMap) == 0 { return resp, nil } + return resp, &errMap } diff --git a/constraint/pkg/client/client_test.go b/constraint/pkg/client/client_test.go index 4693750a3..abe83659e 100644 --- a/constraint/pkg/client/client_test.go +++ b/constraint/pkg/client/client_test.go @@ -1548,3 +1548,178 @@ func TestClient_AddTemplate_Duplicate(t *testing.T) { t.Fatal(diff) } } + +func TestClient_AddData_Cache(t *testing.T) { + tests := []struct { + name string + before map[string]*handlertest.Object + add interface{} + want map[interface{}]interface{} + wantErr error + }{ + { + name: "add invalid type", + before: nil, + add: "foo", + want: nil, + wantErr: &clienterrors.ErrorMap{ + handlertest.HandlerName: handlertest.ErrInvalidType, + }, + }, + { + name: "add invalid Object", + before: nil, + add: &handlertest.Object{ + Namespace: "", + Name: "", + }, + want: nil, + wantErr: &clienterrors.ErrorMap{ + handlertest.HandlerName: handlertest.ErrInvalidObject, + }, + }, + { + name: "add Object", + before: nil, + add: &handlertest.Object{ + Namespace: "foo", + Name: "bar", + }, + want: nil, + wantErr: nil, + }, + { + name: "add Namespace", + before: nil, + add: &handlertest.Object{ + Namespace: "foo", + }, + want: map[interface{}]interface{}{ + "namespace/foo/": &handlertest.Object{ + Namespace: "foo", + }, + }, + wantErr: nil, + }, + { + name: "replace Namespace", + before: map[string]*handlertest.Object{ + "namespace/foo/": { + Namespace: "foo", + Data: "qux", + }, + }, + add: &handlertest.Object{ + Namespace: "foo", + Data: "bar", + }, + want: map[interface{}]interface{}{ + "namespace/foo/": &handlertest.Object{ + Namespace: "foo", + Data: "bar", + }, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := &handlertest.Cache{} + h := &handlertest.Handler{Cache: cache} + + c := clienttest.New(t, client.Targets(h)) + + ctx := context.Background() + for _, v := range tt.before { + _, err := c.AddData(ctx, v) + if err != nil { + t.Fatal(err) + } + } + + _, err := c.AddData(ctx, tt.add) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("got error: %#v,\nwant %#v", err, tt.wantErr) + } + + got := make(map[interface{}]interface{}) + cache.Namespaces.Range(func(key, value interface{}) bool { + got[key] = value + return true + }) + + if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) + } + }) + } +} + +func TestClient_RemoveData_Cache(t *testing.T) { + tests := []struct { + name string + before map[string]*handlertest.Object + remove interface{} + want map[interface{}]interface{} + wantErr error + }{ + { + name: "remove invalid", + before: nil, + remove: "foo", + want: nil, + wantErr: &clienterrors.ErrorMap{ + handlertest.HandlerName: handlertest.ErrInvalidType, + }, + }, + { + name: "remove nonexistent", + before: nil, + remove: &handlertest.Object{Namespace: "foo"}, + want: nil, + wantErr: nil, + }, + { + name: "remove Namespace", + before: map[string]*handlertest.Object{ + "/namespace/foo": {Namespace: "foo"}, + }, + remove: &handlertest.Object{Namespace: "foo"}, + want: nil, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := &handlertest.Cache{} + h := &handlertest.Handler{Cache: cache} + + c := clienttest.New(t, client.Targets(h)) + + ctx := context.Background() + for _, v := range tt.before { + _, err := c.AddData(ctx, v) + if err != nil { + t.Fatal(err) + } + } + + _, err := c.RemoveData(ctx, tt.remove) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("got error: %v,\nwant %v", err, tt.wantErr) + } + + got := make(map[interface{}]interface{}) + cache.Namespaces.Range(func(key, value interface{}) bool { + got[key] = value + return true + }) + + if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) + } + }) + } +} diff --git a/constraint/pkg/handler/cache.go b/constraint/pkg/handler/cache.go new file mode 100644 index 000000000..e41c25e68 --- /dev/null +++ b/constraint/pkg/handler/cache.go @@ -0,0 +1,44 @@ +package handler + +// Cacher is a type - usually a Handler - which needs to cache state. +// Handlers only need implement this interface if they have need of a cache. +// Handlers which do not implement Cacher are assumed to be stateless from +// Client's perspective. +type Cacher interface { + // GetCache returns the Cache. If nil, the Cacher is treated as having no + // cache. + GetCache() Cache +} + +// Cache is an interface for Handlers to define which allows them to track +// objects not currently under review. For example, this is required to make +// referential constraints work, or to have Constraint match criteria which +// relies on more than just the object under review. +// +// Implementations must satisfy the per-method requirements for Client to handle +// the Cache properly. +type Cache interface { + // Add inserts a new object into Cache with identifier key. If an object + // already exists, replaces the object at key. + Add(key string, object interface{}) error + + // Remove deletes the object at key from Cache. Deletion succeeds if key + // does not exist. + // Remove always succeeds; if for some reason key cannot be deleted the application + // should panic. + Remove(key string) +} + +type NoCache struct{} + +func (n NoCache) Add(key string, object interface{}) error { + return nil +} + +func (n NoCache) Get(key string) (interface{}, error) { + return nil, nil +} + +func (n NoCache) Remove(key string) {} + +var _ Cache = NoCache{} diff --git a/constraint/pkg/handler/handlertest/cache.go b/constraint/pkg/handler/handlertest/cache.go new file mode 100644 index 000000000..e864cd6a8 --- /dev/null +++ b/constraint/pkg/handler/handlertest/cache.go @@ -0,0 +1,43 @@ +package handlertest + +import ( + "errors" + "fmt" + "sync" + + "github.com/open-policy-agent/frameworks/constraint/pkg/handler" +) + +var ErrInvalidObject = errors.New("invalid object") + +// Cache is a threadsafe Cache for the test Handler which keeps track of +// Namespaces. +type Cache struct { + Namespaces sync.Map +} + +var _ handler.Cache = &Cache{} + +// Add inserts object into Cache if object is a Namespace. +func (c *Cache) Add(key string, object interface{}) error { + obj, ok := object.(*Object) + if !ok { + return fmt.Errorf("%w: got object type %T, want %T", ErrInvalidType, object, &Object{}) + } + + if obj.Name != "" { + return nil + } + + if obj.Namespace == "" { + return fmt.Errorf("%w: must specify one of Name or Namespace", ErrInvalidObject) + } + + c.Namespaces.Store(key, object) + + return nil +} + +func (c *Cache) Remove(key string) { + c.Namespaces.Delete(key) +} diff --git a/constraint/pkg/handler/handlertest/handler.go b/constraint/pkg/handler/handlertest/handler.go index 1b5c15a5a..2e45951db 100644 --- a/constraint/pkg/handler/handlertest/handler.go +++ b/constraint/pkg/handler/handlertest/handler.go @@ -14,6 +14,8 @@ import ( var _ handler.TargetHandler = &Handler{} +var _ handler.Cacher = &Handler{} + // HandlerName is the default handler name. const HandlerName = "test.target" @@ -28,6 +30,8 @@ type Handler struct { // ProcessDataError is the error to return when ProcessData is called. // If nil returns no error. ProcessDataError error + + Cache *Cache } func (h *Handler) GetName() string { @@ -112,13 +116,10 @@ func (h *Handler) ProcessData(obj interface{}) (bool, string, interface{}, error return false, "", nil, nil } - if o.Namespace == "" { - return true, fmt.Sprintf("cluster/%s", o.Name), obj, nil - } - return true, fmt.Sprintf("namespace/%s/%s", o.Namespace, o.Name), obj, nil + return true, o.Key(), obj, nil default: - return false, "", nil, fmt.Errorf("unrecognized type %T, want %T", - obj, &Object{}) + return false, "", nil, fmt.Errorf("%w: got object type %T, want %T", + ErrInvalidType, obj, &Object{}) } } @@ -167,5 +168,13 @@ func (h *Handler) ToMatcher(constraint *unstructured.Unstructured) (constraints. return nil, fmt.Errorf("unable to get spec.matchNamespace: %w", err) } - return Matcher{namespace: ns}, nil + return Matcher{namespace: ns, cache: h.Cache}, nil +} + +func (h *Handler) GetCache() handler.Cache { + if h.Cache == nil { + return handler.NoCache{} + } + + return h.Cache } diff --git a/constraint/pkg/handler/handlertest/matcher.go b/constraint/pkg/handler/handlertest/matcher.go index 7ebca777a..af876c4b6 100644 --- a/constraint/pkg/handler/handlertest/matcher.go +++ b/constraint/pkg/handler/handlertest/matcher.go @@ -1,24 +1,47 @@ package handlertest import ( + "errors" "fmt" "github.com/open-policy-agent/frameworks/constraint/pkg/core/constraints" ) +var ( + ErrNotFound = errors.New("not found") + ErrInvalidType = errors.New("invalid type") +) + +// Matcher is a test matcher which matches Objects with a matching namespace. +// Checks that Namespace exists in cache before proceeding. type Matcher struct { namespace string + cache *Cache } +// Match returns true if the object under review's Namespace matches the Namespace +// the Matcher filters for. If the object's Namespace is not cached in cache, +// returns an error. +// +// Matches all objects if the Matcher has no namespace specified. func (m Matcher) Match(review interface{}) (bool, error) { if m.namespace == "" { return true, nil } + wantNamespace := Object{Namespace: m.namespace} + + key := wantNamespace.Key() + _, exists := m.cache.Namespaces.Load(key) + if !exists { + return false, fmt.Errorf("%w: namespace %q not in cache", + ErrNotFound, m.namespace) + } + reviewObj, ok := review.(*Review) if !ok { - return false, fmt.Errorf("unrecognized type %T, want %T", - review, &Review{}) + return false, fmt.Errorf("%w: unrecognized type %T, want %T", + ErrInvalidType, review, &Review{}) } return m.namespace == reviewObj.Object.Namespace, nil diff --git a/constraint/pkg/handler/handlertest/object.go b/constraint/pkg/handler/handlertest/object.go index 812a5a4c3..4006a002d 100644 --- a/constraint/pkg/handler/handlertest/object.go +++ b/constraint/pkg/handler/handlertest/object.go @@ -1,11 +1,26 @@ package handlertest -// Object is a test object under review. +import "fmt" + +// Object is a test object under review. The idea is to represent objects just +// complex enough to showcase (and test) the features of frameworks's Client, +// Drivers, and Handlers. type Object struct { - Name string `json:"name"` + // Name is the identifier of an Object within the scope of its Namespace + // (if present). If unset, the Object is a special "Namespace" object. + Name string `json:"name"` + + // Namespace is used for Constraints which apply to a subset of Objects. + // If unset, the Object is not scoped to a Namespace. Namespace string `json:"namespace"` // Data is checked by "CheckData" templates. - Data string `json:"data"` - Root interface{} `json:"root"` + Data string `json:"data"` +} + +func (o *Object) Key() string { + if o.Namespace == "" { + return fmt.Sprintf("cluster/%s/%s", o.Namespace, o.Name) + } + return fmt.Sprintf("namespace/%s/%s", o.Namespace, o.Name) }