Skip to content

Commit

Permalink
Adding strict flag to opa eval (#5228)
Browse files Browse the repository at this point in the history
You can now enable strict mode with `opa eval` by passing `--strict` (`-S`).

Fixes #5182.

Signed-off-by: Peter Macdonald <macdonald.peter90@gmail.com>
  • Loading branch information
Parsifal-M authored Nov 15, 2022
1 parent 73bbda3 commit 77b6b3f
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 4 deletions.
12 changes: 10 additions & 2 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ type QueryCompiler interface {
// ComprehensionIndex returns an index data structure for the given comprehension
// term. If no index is found, returns nil.
ComprehensionIndex(term *Term) *ComprehensionIndex

// WithStrict enables strict mode for the query compiler.
WithStrict(strict bool) QueryCompiler
}

// QueryCompilerStage defines the interface for stages in the query compiler.
Expand Down Expand Up @@ -411,10 +414,10 @@ func (c *Compiler) ParsedModules() map[string]*Module {
return c.parsedModules
}

// QueryCompiler returns a new QueryCompiler object.
func (c *Compiler) QueryCompiler() QueryCompiler {
c.init()
return newQueryCompiler(c)
c0 := *c
return newQueryCompiler(&c0)
}

// Compile runs the compilation process on the input modules. The compiled
Expand Down Expand Up @@ -2422,6 +2425,11 @@ func newQueryCompiler(compiler *Compiler) QueryCompiler {
return qc
}

func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler {
qc.compiler.WithStrict(strict)
return qc
}

func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler {
qc.enablePrintStatements = yes
return qc
Expand Down
6 changes: 6 additions & 0 deletions cmd/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type evalCommandParams struct {
timeout time.Duration
optimizationLevel int
entrypoints repeatedStringFlag
strict bool
}

func newEvalCommandParams() evalCommandParams {
Expand Down Expand Up @@ -319,6 +320,7 @@ access.
addSchemaFlags(evalCommand.Flags(), params.schema)
addTargetFlag(evalCommand.Flags(), params.target)
addCountFlag(evalCommand.Flags(), &params.count, "benchmark")
addStrictFlag(evalCommand.Flags(), &params.strict, false)

RootCommand.AddCommand(evalCommand)
}
Expand Down Expand Up @@ -626,6 +628,10 @@ func setupEval(args []string, params evalCommandParams) (*evalContext, error) {
regoArgs = append(regoArgs, rego.Capabilities(params.capabilities.C))
}

if params.strict {
regoArgs = append(regoArgs, rego.Strict(params.strict))
}

evalCtx := &evalContext{
params: params,
metrics: m,
Expand Down
232 changes: 232 additions & 0 deletions cmd/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,3 +1110,235 @@ time.clock(input.y, time.clock(input.x))
})
}
}

func TestPolicyWithStrictFlag(t *testing.T) {
testsShouldError := []struct {
note string
policy string
query string
expectedCode string
expectedMessage string
}{
{
note: "strict mode should error on duplicate imports",
policy: `package x
import future.keywords.if
import future.keywords.if
foo = 2`,
query: "data.foo",
expectedCode: "rego_compile_error",
expectedMessage: "import must not shadow import future.keywords.if",
},
{
note: "strict mode should error on unused imports",
policy: `package x
import future.keywords.if
import data.foo
foo = 2`,
query: "data.foo",
expectedCode: "rego_compile_error",
expectedMessage: "import data.foo unused",
},
{
note: "strict mode should error when reserved vars data or input is used",
policy: `package x
import future.keywords.if
data if { x = 1}`,
query: "data.foo",
expectedCode: "rego_compile_error",
expectedMessage: "rules must not shadow data (use a different rule name)",
},
}

for _, tc := range testsShouldError {
t.Run(tc.note, func(t *testing.T) {

files := map[string]string{
"test.rego": tc.policy,
}

test.WithTempFS(files, func(path string) {
params := newEvalCommandParams()
params.strict = true

_ = params.dataPaths.Set(filepath.Join(path, "test.rego"))

var buf bytes.Buffer
_, err := eval([]string{tc.query}, params, &buf)
if err == nil {
t.Fatal("expected error, got nil")
}
var output presentation.Output
if err := util.NewJSONDecoder(&buf).Decode(&output); err != nil {
t.Fatal(err)
}

if code := output.Errors[0].Code; code != tc.expectedCode {
t.Errorf("expected code '%v', got '%v'", tc.expectedCode, code)
}
if msg := output.Errors[0].Message; msg != tc.expectedMessage {
t.Errorf("expected message '%v', got '%v'", tc.expectedMessage, msg)
}
})
})
}

testsShouldPass := []struct {
note string
policy string
query string
}{
{
note: "This should not error as it is valid",
policy: `package x
import future.keywords.if
foo = 2`,
query: "data.foo",
},
{
note: "Strict mode should not validate the query, only the policy, this should not error",
policy: `package x
import future.keywords.if
foo = 2`,
query: "x := data.x.foo",
},
}
for _, tc := range testsShouldPass {
t.Run(tc.note, func(t *testing.T) {

files := map[string]string{
"test.rego": tc.policy,
}

test.WithTempFS(files, func(path string) {
params := newEvalCommandParams()
params.strict = true

var buf bytes.Buffer
_, err := eval([]string{tc.query}, params, &buf)
if err != nil {
t.Errorf("Should not error, got error: '%v'", err)
}
})
})
}

}

func TestBundleWithStrictFlag(t *testing.T) {
testsShouldError := []struct {
note string
policy string
query string
expectedCode string
expectedMessage string
}{
{
note: "strict mode should error on duplicate imports in this bundle",
policy: `package x
import future.keywords.if
import future.keywords.if
foo = 2`,
query: "data.foo",
expectedCode: "rego_compile_error",
expectedMessage: "import must not shadow import future.keywords.if",
},
{
note: "strict mode should error on unused imports in this bundle",
policy: `package x
import future.keywords.if
import data.foo
foo = 2`,
query: "data.foo",
expectedCode: "rego_compile_error",
expectedMessage: "import data.foo unused",
},
{
note: "strict mode should error when reserved vars data or input is used in this bundle",
policy: `package x
import future.keywords.if
data if { x = 1}`,
query: "data.foo",
expectedCode: "rego_compile_error",
expectedMessage: "rules must not shadow data (use a different rule name)",
},
}

for _, tc := range testsShouldError {
t.Run(tc.note, func(t *testing.T) {

files := map[string]string{
"test.rego": tc.policy,
}

test.WithTempFS(files, func(path string) {
params := newEvalCommandParams()
if err := params.bundlePaths.Set(path); err != nil {
t.Fatal(err)
}
params.strict = true

var buf bytes.Buffer
_, err := eval([]string{tc.query}, params, &buf)
if err == nil {
t.Fatal("expected error, got nil")
}
var output presentation.Output
if err := util.NewJSONDecoder(&buf).Decode(&output); err != nil {
t.Fatal(err)
}

if code := output.Errors[0].Code; code != tc.expectedCode {
t.Errorf("expected code '%v', got '%v'", tc.expectedCode, code)
}
if msg := output.Errors[0].Message; msg != tc.expectedMessage {
t.Errorf("expected message '%v', got '%v'", tc.expectedMessage, msg)
}
})
})
}

testsShouldPass := []struct {
note string
policy string
query string
}{
{
note: "This bundle should not error as it is valid",
policy: `package x
import future.keywords.if
foo = 2`,
query: "data.foo",
},
{
note: "Strict mode should not validate the query, only the policy, this bundle should not error",
policy: `package x
import future.keywords.if
foo = 2`,
query: "x := data.x.foo",
},
}
for _, tc := range testsShouldPass {
t.Run(tc.note, func(t *testing.T) {

files := map[string]string{
"test.rego": tc.policy,
}

test.WithTempFS(files, func(path string) {
params := newEvalCommandParams()
if err := params.bundlePaths.Set(path); err != nil {
t.Fatal(err)
}
params.strict = true

var buf bytes.Buffer
_, err := eval([]string{tc.query}, params, &buf)
if err != nil {
t.Errorf("Should not error, got error: '%v'", err)
}
})
})
}

}
1 change: 1 addition & 0 deletions docs/content/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ opa eval <query> [flags]
-s, --schema string set schema file path or directory path
--shallow-inlining disable inlining of rules that depend on unknowns
--stdin read query from stdin
-S, --strict enable compiler strict mode
-I, --stdin-input read input document from stdin
--strict-builtin-errors treat built-in function errors as fatal
-t, --target {rego,wasm} set the runtime to exercise (default rego)
Expand Down
14 changes: 12 additions & 2 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ type Rego struct {
printHook print.Hook
enablePrintStatements bool
distributedTacingOpts tracing.Options
strict bool
}

// Function represents a built-in function that is callable in Rego.
Expand Down Expand Up @@ -1107,6 +1108,13 @@ func EnablePrintStatements(yes bool) func(r *Rego) {
}
}

// Strict enables or disables strict-mode in the compiler
func Strict(yes bool) func(r *Rego) {
return func(r *Rego) {
r.strict = yes
}
}

// New returns a new Rego object.
func New(options ...func(r *Rego)) *Rego {

Expand All @@ -1130,7 +1138,8 @@ func New(options ...func(r *Rego)) *Rego {
WithDebug(r.dump).
WithSchemas(r.schemaSet).
WithCapabilities(r.capabilities).
WithEnablePrintStatements(r.enablePrintStatements)
WithEnablePrintStatements(r.enablePrintStatements).
WithStrict(r.strict)
}

if r.store == nil {
Expand Down Expand Up @@ -1910,7 +1919,8 @@ func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, m metrics.Met
qc := r.compiler.QueryCompiler().
WithContext(qctx).
WithUnsafeBuiltins(r.unsafeBuiltins).
WithEnablePrintStatements(r.enablePrintStatements)
WithEnablePrintStatements(r.enablePrintStatements).
WithStrict(false)

for _, extra := range extras {
qc = qc.WithStageAfter(extra.after, extra.stage)
Expand Down

0 comments on commit 77b6b3f

Please sign in to comment.