Skip to content

Commit

Permalink
Move parsing module logic from local.AddTemplate to MapModules, shari…
Browse files Browse the repository at this point in the history
…ng with client.CreateCRD

Signed-off-by: Becky Huang <beckyhd@google.com>
  • Loading branch information
becky-hd committed Jan 7, 2022
1 parent aac50c0 commit 4a4a9a2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
5 changes: 5 additions & 0 deletions constraint/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"sync"

"github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local"

"github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers"
"github.com/open-policy-agent/frameworks/constraint/pkg/client/regolib"
constraintlib "github.com/open-policy-agent/frameworks/constraint/pkg/core/constraints"
Expand Down Expand Up @@ -249,6 +251,9 @@ func (c *Client) CreateCRD(templ *templates.ConstraintTemplate) (*apiextensions.
if err != nil {
return nil, err
}
if _, _, err = local.MapModules(templ, c.allowedDataFields); err != nil {
return nil, err
}
return artifacts.crd, nil
}

Expand Down
4 changes: 2 additions & 2 deletions constraint/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ violation[msg] {msg := "always"}`,
},
},
want: nil,
wantErr: ErrInvalidConstraintTemplate,
wantErr: local.ErrInvalidConstraintTemplate,
},
{
name: "empty rego package",
Expand All @@ -1265,7 +1265,7 @@ violation[msg] {msg := "always"}`,
},
},
want: nil,
wantErr: ErrInvalidConstraintTemplate,
wantErr: local.ErrInvalidConstraintTemplate,
},
{
name: "multiple targets",
Expand Down
42 changes: 27 additions & 15 deletions constraint/pkg/client/drivers/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,9 @@ func (d *driver) Dump(ctx context.Context) (string, error) {
return string(b), nil
}

// AddTemplate implements drivers.Driver.
func (d *driver) AddTemplate(templ *templates.ConstraintTemplate) error {
func MapModules(templ *templates.ConstraintTemplate, extern []string) (string, []string, error) {
if err := validateTargets(templ); err != nil {
return nil
return "", nil, err
}
targetSpec := templ.Spec.Targets[0]
targetHandler := targetSpec.Target
Expand All @@ -511,45 +510,45 @@ func (d *driver) AddTemplate(templ *templates.ConstraintTemplate) error {
rr, err := regorewriter.New(
regorewriter.NewPackagePrefixer(libPrefix),
[]string{"data.lib"},
d.externs)
allowedDataFields(extern))
if err != nil {
return fmt.Errorf("creating rego rewriter: %w", err)
return "", nil, fmt.Errorf("creating rego rewriter: %w", err)
}

namePrefix := createTemplatePath(targetHandler, kind)
entryPoint, err := parseModule(namePrefix, templ.Spec.Targets[0].Rego)
if err != nil {
return fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err)
return "", nil, fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err)
}

if entryPoint == nil {
return fmt.Errorf("%w: failed to parse module for unknown reason",
return "", nil, fmt.Errorf("%w: failed to parse module for unknown reason",
ErrInvalidConstraintTemplate)
}

if err = rewriteModulePackage(namePrefix, entryPoint); err != nil {
return err
return "", nil, err
}

req := map[string]struct{}{"violation": {}}

if err = requireModuleRules(entryPoint, req); err != nil {
return fmt.Errorf("%w: invalid rego: %v",
return "", nil, fmt.Errorf("%w: invalid rego: %v",
ErrInvalidConstraintTemplate, err)
}

rr.AddEntryPointModule(namePrefix, entryPoint)
for idx, libSrc := range targetSpec.Libs {
libPath := fmt.Sprintf(`%s["lib_%d"]`, libPrefix, idx)
if err = rr.AddLib(libPath, libSrc); err != nil {
return fmt.Errorf("%w: %v",
return "", nil, fmt.Errorf("%w: %v",
ErrInvalidConstraintTemplate, err)
}
}

sources, err := rr.Rewrite()
if err != nil {
return fmt.Errorf("%w: %v",
return "", nil, fmt.Errorf("%w: %v",
ErrInvalidConstraintTemplate, err)
}

Expand All @@ -563,9 +562,18 @@ func (d *driver) AddTemplate(templ *templates.ConstraintTemplate) error {
return nil
})
if err != nil {
return fmt.Errorf("%w: %v",
return "", nil, fmt.Errorf("%w: %v",
ErrInvalidConstraintTemplate, err)
}
return namePrefix, mods, nil
}

// AddTemplate implements drivers.Driver.
func (d *driver) AddTemplate(templ *templates.ConstraintTemplate) error {
namePrefix, mods, err := MapModules(templ, d.externs)
if err != nil {
return err
}
if err = d.PutModules(namePrefix, mods); err != nil {
return fmt.Errorf("%w: %v", ErrCompile, err)
}
Expand Down Expand Up @@ -667,10 +675,14 @@ func validateTargets(templ *templates.ConstraintTemplate) error {
}
}

func (d *driver) AddExterns(allowedDataFields []string) {
func (d *driver) AddExterns(fields []string) {
d.externs = fields
}

func allowedDataFields(fields []string) []string {
var externs []string
for _, field := range allowedDataFields {
for _, field := range fields {
externs = append(externs, fmt.Sprintf("data.%s", field))
}
d.externs = externs
return externs
}

0 comments on commit 4a4a9a2

Please sign in to comment.