diff --git a/constraint/pkg/client/drivers/interface.go b/constraint/pkg/client/drivers/interface.go index 5663ac20b..55f1185cc 100644 --- a/constraint/pkg/client/drivers/interface.go +++ b/constraint/pkg/client/drivers/interface.go @@ -4,7 +4,6 @@ import ( "context" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" - "github.com/open-policy-agent/frameworks/constraint/pkg/types" ) diff --git a/constraint/pkg/client/drivers/local/local.go b/constraint/pkg/client/drivers/local/local.go index ebff7e275..76e5e6b97 100644 --- a/constraint/pkg/client/drivers/local/local.go +++ b/constraint/pkg/client/drivers/local/local.go @@ -195,7 +195,7 @@ func (d *Driver) PutModule(name string, src string) error { return err } -// PutModules upserts a number of modules under a given prefix. +// putModules upserts a number of modules under a given prefix. func (d *Driver) putModules(namePrefix string, srcs []string) error { if err := d.checkModuleSetName(namePrefix); err != nil { return err diff --git a/constraint/pkg/client/drivers/local/local_unit_test.go b/constraint/pkg/client/drivers/local/local_unit_test.go index 0847f4fad..45d048a57 100644 --- a/constraint/pkg/client/drivers/local/local_unit_test.go +++ b/constraint/pkg/client/drivers/local/local_unit_test.go @@ -6,6 +6,9 @@ import ( "sort" "testing" + "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/open-policy-agent/opa/ast" @@ -39,6 +42,8 @@ fooisbar[msg] { msg := "input.foo is bar" } ` + MockTemplate string = "MockConstraintTemplate" + MockTargetHandler string = "foo" ) func TestDriver_PutModule(t *testing.T) { @@ -434,6 +439,144 @@ func TestDriver_DeleteModules(t *testing.T) { } } +func TestDriver_AddTemplates(t *testing.T) { + testCases := []struct { + name string + rego string + targetHandler string + externs []string + + wantErr error + wantModules []string + }{ + { + name: "no target", + wantErr: ErrInvalidConstraintTemplate, + wantModules: nil, + }, + { + name: "rego missing violation", + targetHandler: MockTargetHandler, + rego: Module, + wantErr: ErrInvalidConstraintTemplate, + wantModules: nil, + }, + { + name: "valid template", + targetHandler: MockTargetHandler, + rego: ` +package something + +violation[msg] {msg := "always"}`, + wantModules: []string{toModuleSetName(createTemplatePath(MockTargetHandler, MockTemplate), 0)}, + }, + { + name: "inventory disallowed template", + targetHandler: MockTargetHandler, + rego: `package something + +violation[{"msg": "msg"}] { + data.inventory = "something_else" +}`, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "inventory allowed template", + targetHandler: MockTargetHandler, + rego: `package something + +violation[{"msg": "msg"}] { + data.inventory = "something_else" +}`, + externs: []string{"data.inventory"}, + wantErr: nil, + wantModules: []string{toModuleSetName(createTemplatePath(MockTargetHandler, MockTemplate), 0)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + d := New() + dr, ok := d.(*Driver) + if !ok { + t.Fatalf("got New() type = %T, want %T", dr, &Driver{}) + } + dr.SetExterns(tc.externs) + tmpl := createTemplate(tc.targetHandler, tc.rego) + gotErr := dr.AddTemplate(tmpl) + if !errors.Is(gotErr, tc.wantErr) { + t.Fatalf("got AddTemplate() error = %v, want %v", gotErr, tc.wantErr) + } + + gotModules := make([]string, 0, len(dr.modules)) + for gotModule := range dr.modules { + gotModules = append(gotModules, gotModule) + } + sort.Strings(gotModules) + + if diff := cmp.Diff(tc.wantModules, gotModules, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) + } + }) + } +} + +func TestDriver_RemoveTemplates(t *testing.T) { + testCases := []struct { + name string + rego string + targetHandler string + externs []string + wantErr error + }{ + { + name: "valid template", + targetHandler: MockTargetHandler, + rego: ` +package something + +violation[msg] {msg := "always"}`, + }, + { + name: "inventory allowed template", + targetHandler: MockTargetHandler, + rego: `package something + +violation[{"msg": "msg"}] { + data.inventory = "something_else" +}`, + externs: []string{"data.inventory"}, + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + d := New() + dr, ok := d.(*Driver) + if !ok { + t.Fatalf("got New() type = %T, want %T", dr, &Driver{}) + } + dr.SetExterns(tc.externs) + tmpl := createTemplate(tc.targetHandler, tc.rego) + gotErr := dr.AddTemplate(tmpl) + if !errors.Is(gotErr, tc.wantErr) { + t.Fatalf("got AddTemplate() error = %v, want %v", gotErr, tc.wantErr) + } + if len(dr.modules) == 0 { + t.Errorf("driver failed to add module") + } + gotErr = dr.RemoveTemplate(context.Background(), tmpl) + if gotErr != nil { + t.Errorf("err = %v; want nil", gotErr) + } + if len(dr.modules) != 0 { + t.Errorf("driver has module = %v; want nil", len(dr.modules)) + } + }) + } +} + func TestDriver_PutData(t *testing.T) { testCases := []struct { name string @@ -835,3 +978,32 @@ type readErrorStorage struct { func (s *readErrorStorage) Read(_ context.Context, _ storage.Transaction, _ storage.Path) (interface{}, error) { return nil, errors.New("error writing data") } + +func createTemplate(targetHandler, rego string) *templates.ConstraintTemplate { + tmpl := &templates.ConstraintTemplate{ + TypeMeta: v1.TypeMeta{}, + ObjectMeta: v1.ObjectMeta{ + Name: "mockconstrainttemplate", + }, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: MockTemplate, + ShortNames: nil, + }, + }, + }, + }, + } + if targetHandler == "" && rego == "" { + return tmpl + } + tmpl.Spec.Targets = []templates.Target{ + { + Target: targetHandler, + Rego: rego, + }, + } + return tmpl +}