diff --git a/constraint/cmd/rewrite-compatibility/main.go b/constraint/cmd/rewrite-compatibility/main.go index a31f5def4..25c089971 100644 --- a/constraint/cmd/rewrite-compatibility/main.go +++ b/constraint/cmd/rewrite-compatibility/main.go @@ -66,7 +66,8 @@ func compileSrcs( libs []string, newPkgPrefix string, oldRoot string, - newRoot string) error { + newRoot string, +) error { if len(cts) == 0 && len(libs) == 0 { return fmt.Errorf("must specify --ct or --lib or both") } diff --git a/constraint/pkg/client/client_addtemplate_bench_test.go b/constraint/pkg/client/client_addtemplate_bench_test.go index e318139c9..c56f42298 100644 --- a/constraint/pkg/client/client_addtemplate_bench_test.go +++ b/constraint/pkg/client/client_addtemplate_bench_test.go @@ -290,6 +290,7 @@ func BenchmarkClient_AddTemplate_Parallel(b *testing.B) { for i := range cts { cts[i] = makeConstraintTemplate(i, tc.module, tc.libs...) } + b.ResetTimer() for i := 0; i < b.N; i++ { b.StopTimer() diff --git a/constraint/pkg/client/drivers/local/compilers.go b/constraint/pkg/client/drivers/local/compilers.go index d2c2c37ff..908513a9e 100644 --- a/constraint/pkg/client/drivers/local/compilers.go +++ b/constraint/pkg/client/drivers/local/compilers.go @@ -116,37 +116,29 @@ func (d *Compilers) list() map[string]map[string]*ast.Compiler { return result } -type TargetModule struct { - Rego string - Libs []string -} - // parseConstraintTemplate validates the rego in template target by parsing // rego modules. -func parseConstraintTemplate(templ *templates.ConstraintTemplate, externs []string) (map[string]TargetModule, error) { +func parseConstraintTemplate(templ *templates.ConstraintTemplate, externs []string) (map[string][]*ast.Module, error) { rr, err := regorewriter.New(regorewriter.NewPackagePrefixer(templateLibPrefix), []string{libRoot}, externs) if err != nil { return nil, fmt.Errorf("creating rego rewriter: %w", err) } - mods := make(map[string]TargetModule) + mods := make(map[string][]*ast.Module) for _, target := range templ.Spec.Targets { targetMods, err := parseConstraintTemplateTarget(rr, target) if err != nil { return nil, err } - mods[target.Target] = TargetModule{ - Rego: target.Rego, - Libs: targetMods, - } + mods[target.Target] = targetMods } return mods, nil } -func parseConstraintTemplateTarget(rr *regorewriter.RegoRewriter, targetSpec templates.Target) ([]string, error) { - entryPoint, err := parseModule(targetSpec.Rego) +func parseConstraintTemplateTarget(rr *regorewriter.RegoRewriter, targetSpec templates.Target) ([]*ast.Module, error) { + entryPoint, err := parseModule(templatePath, targetSpec.Rego) if err != nil { return nil, fmt.Errorf("%w: %v", clienterrors.ErrInvalidConstraintTemplate, err) } @@ -170,7 +162,14 @@ func parseConstraintTemplateTarget(rr *regorewriter.RegoRewriter, targetSpec tem rr.AddEntryPointModule(templatePath, entryPoint) for idx, libSrc := range targetSpec.Libs { libPath := fmt.Sprintf(`%s["lib_%d"]`, templateLibPrefix, idx) - if err = rr.AddLib(libPath, libSrc); err != nil { + + m, err := parseModule(libPath, libSrc) + if err != nil { + return nil, fmt.Errorf("%w: %v", + clienterrors.ErrInvalidConstraintTemplate, err) + } + + if err = rr.AddLib(libPath, m); err != nil { return nil, fmt.Errorf("%w: %v", clienterrors.ErrInvalidConstraintTemplate, err) } @@ -182,44 +181,28 @@ func parseConstraintTemplateTarget(rr *regorewriter.RegoRewriter, targetSpec tem clienterrors.ErrInvalidConstraintTemplate, err) } - var mods []string - err = sources.ForEachModule(func(m *regorewriter.Module) error { - content, err2 := m.Content() - if err2 != nil { - return err2 - } - mods = append(mods, string(content)) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("%w: %v", - clienterrors.ErrInvalidConstraintTemplate, err) + var mods []*ast.Module + for _, m := range sources.EntryPoints { + mods = append(mods, m.Module) + } + for _, m := range sources.Libs { + mods = append(mods, m.Module) } return mods, nil } -func compileTemplateTarget(module TargetModule, capabilities *ast.Capabilities, printEnabled bool) (*ast.Compiler, error) { +func compileTemplateTarget(module []*ast.Module, capabilities *ast.Capabilities, printEnabled bool) (*ast.Compiler, error) { compiler := ast.NewCompiler(). WithCapabilities(capabilities). WithEnablePrintStatements(printEnabled) - modules := make(map[string]*ast.Module) - - builtinModule, err := ast.ParseModule(hookModulePath, hookModule) - if err != nil { - return nil, fmt.Errorf("%w: %v", clienterrors.ErrParse, err) - } - modules[hookModulePath] = builtinModule + modules := make(map[string]*ast.Module, len(module)+1) + modules[hookModulePath] = hookModule - for i, lib := range module.Libs { + for i, lib := range module { libPath := fmt.Sprintf("%s%d", templatePath, i) - libModule, err := ast.ParseModule(libPath, lib) - if err != nil { - return nil, fmt.Errorf("%w: %v", clienterrors.ErrParse, err) - } - modules[libPath] = libModule + modules[libPath] = lib } compiler.Compile(modules) @@ -229,3 +212,18 @@ func compileTemplateTarget(module TargetModule, capabilities *ast.Capabilities, return compiler, nil } + +// parseModule parses the module and also fails empty modules. +func parseModule(path, rego string) (*ast.Module, error) { + module, err := ast.ParseModule(path, rego) + if err != nil { + return nil, err + } + + if module == nil { + return nil, fmt.Errorf("%w: module %q is empty", + clienterrors.ErrInvalidModule, templatePath) + } + + return module, nil +} diff --git a/constraint/pkg/client/drivers/local/driver.go b/constraint/pkg/client/drivers/local/driver.go index 79e8c75f3..9b7d0368b 100644 --- a/constraint/pkg/client/drivers/local/driver.go +++ b/constraint/pkg/client/drivers/local/driver.go @@ -310,21 +310,6 @@ func (d *Driver) Dump(ctx context.Context) (string, error) { return string(b), nil } -// parseModule parses the module and also fails empty modules. -func parseModule(rego string) (*ast.Module, error) { - module, err := ast.ParseModule(templatePath, rego) - if err != nil { - return nil, err - } - - if module == nil { - return nil, fmt.Errorf("%w: module %q is empty", - clienterrors.ErrInvalidModule, templatePath) - } - - return module, nil -} - // rewriteModulePackage rewrites the module's package path to path. func rewriteModulePackage(module *ast.Module) error { pathParts := ast.Ref([]*ast.Term{ast.VarTerm(templatePath)}) diff --git a/constraint/pkg/client/drivers/local/rego.go b/constraint/pkg/client/drivers/local/rego.go index 6ea7a211c..2760bd129 100644 --- a/constraint/pkg/client/drivers/local/rego.go +++ b/constraint/pkg/client/drivers/local/rego.go @@ -1,5 +1,7 @@ package local +import "github.com/open-policy-agent/opa/ast" + const ( // templatePath is the path the Template's Rego code is stored. // Must match the "data.xxx.violation[r]" path in hookModule. @@ -16,7 +18,7 @@ const ( // This removes boilerplate that would otherwise need to be present in every // Template's Rego code. The violation's response is written to a standard // location we can read from to see if any violations occurred. - hookModule = ` + hookModuleRego = ` package hooks # Determine if the object under review violates any passed Constraints. @@ -45,3 +47,13 @@ violation[response] { ` ) + +var hookModule *ast.Module + +func init() { + var err error + hookModule, err = parseModule(hookModulePath, hookModuleRego) + if err != nil { + panic(err) + } +} diff --git a/constraint/pkg/client/drivers/to_result.go b/constraint/pkg/client/drivers/to_result.go index cdc80e2d5..7f26d1835 100644 --- a/constraint/pkg/client/drivers/to_result.go +++ b/constraint/pkg/client/drivers/to_result.go @@ -37,47 +37,42 @@ func ToResults(constraints map[ConstraintKey]*unstructured.Unstructured, resultS func ToResult(constraints map[ConstraintKey]*unstructured.Unstructured, r rego.Result) (*types.Result, error) { result := &types.Result{} - resultMapBinding, found := r.Bindings["result"] + resultMap, found, err := unstructured.NestedMap(r.Bindings, "result") + if err != nil { + return nil, fmt.Errorf("extracting result binding: %v", err) + } + if !found { return nil, errors.New("no binding for result") } - resultMap, ok := resultMapBinding.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("result binding was %T but want %T", - resultMapBinding, map[string]interface{}{}) + message, found, err := unstructured.NestedString(resultMap, "msg") + if err != nil { + return nil, fmt.Errorf("extracting message binding: %v", err) } - messageBinding, found := resultMap["msg"] if !found { return nil, errors.New("no binding for msg") } - message, ok := messageBinding.(string) - if !ok { - return nil, fmt.Errorf("message binding was %T but want %T", - messageBinding, "") - } result.Msg = message result.Metadata = map[string]interface{}{ "details": resultMap["details"], } - keyBinding, found := resultMap["key"] - if !found { - return nil, errors.New("no binding for Constraint key") + keyMap, found, err := unstructured.NestedStringMap(resultMap, "key") + if err != nil { + return nil, fmt.Errorf("extracting key binding: %v", err) } - keyMap, ok := keyBinding.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("key binding was %T but want %T", - keyBinding, map[string]interface{}{}) + if !found { + return nil, errors.New("no binding for Constraint key") } key := ConstraintKey{ - Kind: keyMap["kind"].(string), - Name: keyMap["name"].(string), + Kind: keyMap["kind"], + Name: keyMap["name"], } constraint := constraints[key] diff --git a/constraint/pkg/client/e2e_test.go b/constraint/pkg/client/e2e_test.go index 71fa6bec8..72a38bef4 100644 --- a/constraint/pkg/client/e2e_test.go +++ b/constraint/pkg/client/e2e_test.go @@ -207,7 +207,7 @@ func TestClient_Review(t *testing.T) { toReview: handlertest.NewReview("", "foo", "qux"), wantResults: []*types.Result{{ Target: handlertest.TargetName, - Msg: `template0:7: eval_conflict_error: functions must not produce multiple outputs for same inputs`, + Msg: `template:8: eval_conflict_error: functions must not produce multiple outputs for same inputs`, EnforcementAction: constraints.EnforcementActionDeny, Constraint: cts.MakeConstraint(t, clienttest.KindRuntimeError, "constraint"), }}, diff --git a/constraint/pkg/client/rego_helpers.go b/constraint/pkg/client/rego_helpers.go deleted file mode 100644 index 126d2d8d7..000000000 --- a/constraint/pkg/client/rego_helpers.go +++ /dev/null @@ -1,48 +0,0 @@ -package client - -import ( - "fmt" - "sort" - - "github.com/open-policy-agent/opa/ast" -) - -// ParseModule parses the module and also fails empty modules. -func ParseModule(path, rego string) (*ast.Module, error) { - module, err := ast.ParseModule(path, rego) - if err != nil { - return nil, err - } - - if module == nil { - return nil, fmt.Errorf("%w: module %q is empty", - ErrInvalidModule, path) - } - - return module, nil -} - -// RequireModuleRules makes sure the module contains all of the specified -// requiredRules. -func RequireModuleRules(module *ast.Module, requiredRules map[string]struct{}) error { - ruleSets := make(map[string]struct{}, len(module.Rules)) - for _, rule := range module.Rules { - ruleSets[string(rule.Head.Name)] = struct{}{} - } - - var missing []string - for name := range requiredRules { - _, ok := ruleSets[name] - if !ok { - missing = append(missing, name) - } - } - sort.Strings(missing) - - if len(missing) > 0 { - return fmt.Errorf("%w: missing required rules: %v", - ErrInvalidModule, missing) - } - - return nil -} diff --git a/constraint/pkg/client/rego_helpers_test.go b/constraint/pkg/client/rego_helpers_test.go deleted file mode 100644 index c0cbeca91..000000000 --- a/constraint/pkg/client/rego_helpers_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package client_test - -import ( - "testing" - - "github.com/open-policy-agent/frameworks/constraint/pkg/client" -) - -type regoTestCase struct { - Name string - Rego string - Path string - ErrorExpected bool - ExpectedRego string - ArityExpected int - RequiredRules map[string]struct{} -} - -func TestRequireRules(t *testing.T) { - tc := []regoTestCase{ - { - Name: "No Required Rules", - Rego: `package hello`, - ErrorExpected: false, - }, - { - Name: "Bad Rego", - Rego: `package hello {dangling bracket`, - ErrorExpected: true, - }, - { - Name: "Required Rule", - Rego: `package hello r{1 == 1}`, - RequiredRules: map[string]struct{}{"r": {}}, - ErrorExpected: false, - }, - { - Name: "Required Rule Extras", - Rego: `package hello r[v]{v == 1} q{3 == 3}`, - RequiredRules: map[string]struct{}{"r": {}}, - ErrorExpected: false, - }, - { - Name: "Required Rule Multiple", - Rego: `package hello r[v]{v == 1} q{3 == 3}`, - RequiredRules: map[string]struct{}{"r": {}, "q": {}}, - ErrorExpected: false, - }, - { - Name: "Required Rule Missing", - Rego: `package hello`, - RequiredRules: map[string]struct{}{"r": {}}, - ErrorExpected: true, - }, - } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - mod, err := client.ParseModule("foo", tt.Rego) - if err == nil { - err = client.RequireModuleRules(mod, tt.RequiredRules) - } - - if (err == nil) && tt.ErrorExpected { - t.Fatalf("err = nil; want non-nil") - } - if (err != nil) && !tt.ErrorExpected { - t.Fatalf("err = %q; want nil", err) - } - }) - } -} diff --git a/constraint/pkg/regorewriter/regorewriter.go b/constraint/pkg/regorewriter/regorewriter.go index dfaf5cdad..9e465fa10 100644 --- a/constraint/pkg/regorewriter/regorewriter.go +++ b/constraint/pkg/regorewriter/regorewriter.go @@ -61,12 +61,7 @@ func New(pt PackageTransformer, libs []string, externs []string) (*RegoRewriter, } // add is the internal method for parsing a module and entering it into the bookkeeping. -func (r *RegoRewriter) add(path, src string, slice *[]*Module) error { - m, err := ast.ParseModule(path, src) - if err != nil { - return fmt.Errorf("%w: %v", ErrInvalidModule, err) - } - +func (r *RegoRewriter) add(path string, m *ast.Module, slice *[]*Module) error { r.addModule(path, m, slice) return nil } @@ -81,13 +76,13 @@ func (r *RegoRewriter) AddEntryPointModule(path string, m *ast.Module) { // AddEntryPoint adds a base source which will not have it's package path rewritten. These correspond // to the rego that will be populated into a ConstraintTemplate with the 'violation' rule. -func (r *RegoRewriter) AddEntryPoint(path, src string) error { - return r.add(path, src, &r.entryPoints) +func (r *RegoRewriter) AddEntryPoint(path string, m *ast.Module) error { + return r.add(path, m, &r.entryPoints) } // AddLib adds a library source which will have the package path updated. -func (r *RegoRewriter) AddLib(path, src string) error { - return r.add(path, src, &r.libs) +func (r *RegoRewriter) AddLib(path string, m *ast.Module) error { + return r.add(path, m, &r.libs) } // addTestDir adds a test dir inside one of the provided paths. @@ -128,7 +123,13 @@ func (r *RegoRewriter) addFileFromFs(path string, slice *[]*Module) error { if err != nil { return fmt.Errorf("%w: %v", ErrReadingFile, err) } - return r.add(path, string(bytes), slice) + + m, err := ast.ParseModule(path, string(bytes)) + if err != nil { + return err + } + + return r.add(path, m, slice) } // addPathFromFs adds a module from the local filesystem. diff --git a/constraint/pkg/regorewriter/regorewriter_test.go b/constraint/pkg/regorewriter/regorewriter_test.go index 246c3711d..22fd8e244 100644 --- a/constraint/pkg/regorewriter/regorewriter_test.go +++ b/constraint/pkg/regorewriter/regorewriter_test.go @@ -112,12 +112,22 @@ func (tc *RegoRewriterTestcase) Run(t *testing.T) { t.Fatalf("Failed to create %s", err) } for path, content := range tc.baseSrcs { - if err := rr.AddEntryPoint(path, content); err != nil { + m, err := ast.ParseModule(path, content) + if err != nil { + t.Fatal(err) + } + + if err := rr.AddEntryPoint(path, m); err != nil { t.Fatalf("unexpected error during AddEntryPoint: %s", err) } } for path, content := range tc.libSrcs { - if err := rr.AddLib(path, content); err != nil { + m, err := ast.ParseModule(path, content) + if err != nil { + t.Fatal(err) + } + + if err := rr.AddLib(path, m); err != nil { t.Logf("unexpected error during AddLib %v", err) return } @@ -330,13 +340,6 @@ test_ok { } `, }, - { - name: "add invalid rego", - path: "invalidrego", - src: `package lib.rego -something invalid`, - wantError: ErrInvalidModule, - }, } for _, tc := range tcs { @@ -345,7 +348,13 @@ something invalid`, if err != nil { t.Fatalf("Failed to create RegoRewriter %q", err) } - if gotErr := rr.AddEntryPoint(tc.path, tc.src); !errors.Is(gotErr, tc.wantError) { + + m, err := ast.ParseModule(tc.path, tc.src) + if err != nil { + t.Fatal(err) + } + + if gotErr := rr.AddEntryPoint(tc.path, m); !errors.Is(gotErr, tc.wantError) { t.Errorf("got AddEntryPoint() error = %q, want %v", gotErr, tc.wantError) } }) @@ -474,7 +483,13 @@ violation[{"msg": msg}] { if err != nil { t.Fatalf("Failed to create RegoRewriter %s", err) } - if err := rr.AddEntryPoint("path", tc.content); err != nil { + + m, err := ast.ParseModule("path", tc.content) + if err != nil { + t.Fatal(err) + } + + if err := rr.AddEntryPoint("path", m); err != nil { t.Fatalf("failed to add base source %q", err) return } @@ -597,7 +612,13 @@ is_foo(name) { if err != nil { t.Fatalf("Failed to create RegoRewriter %s", err) } - if err := rr.AddLib("path", tc.content); err != nil { + + m, err := ast.ParseModule("path", tc.content) + if err != nil { + t.Fatal(err) + } + + if err := rr.AddLib("path", m); err != nil { t.Fatalf("failed to add lib source %q", err) } sources, err := rr.Rewrite()