Skip to content

Commit

Permalink
Add interface that makes it possible to add a custom stage to the com…
Browse files Browse the repository at this point in the history
…piler.

Signed-off-by: Ashutosh Narkar <anarkar4387@gmail.com>
  • Loading branch information
ashutosh-narkar authored and tsandall committed Mar 8, 2019
1 parent 6b20b49 commit 9444351
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
18 changes: 17 additions & 1 deletion ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ type Compiler struct {
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStage
}

// CompilerStage defines the interface for stages in the compiler.
type CompilerStage func(*Compiler) *Error

// QueryContext contains contextual information for running an ad-hoc query.
//
// Ad-hoc queries can be run in the context of a package and imports may be
Expand Down Expand Up @@ -191,6 +195,7 @@ func NewCompiler() *Compiler {
return x.(Ref).Hash()
}),
maxErrs: CompileErrorLimitDefault,
after: map[string][]CompilerStage{},
}

c.ModuleTree = NewModuleTree(nil)
Expand Down Expand Up @@ -243,6 +248,13 @@ func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Comp
return c
}

// WithStageAfter registers a stage to run during compilation after
// the named stage.
func (c *Compiler) WithStageAfter(after string, stage CompilerStage) *Compiler {
c.after[after] = append(c.after[after], stage)
return c
}

// QueryCompiler returns a new QueryCompiler object.
func (c *Compiler) QueryCompiler() QueryCompiler {
return newQueryCompiler(c)
Expand Down Expand Up @@ -635,8 +647,12 @@ func (c *Compiler) compile() {
if s.f(); c.Failed() {
return
}
for _, s := range c.after[s.name] {
if err := s(c); err != nil {
c.err(err)
}
}
}

}

func (c *Compiler) err(err *Error) {
Expand Down
54 changes: 54 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,36 @@ func TestCompilerExample(t *testing.T) {
assertNotFailed(t, c)
}

func TestCompilerWithStageAfterBad(t *testing.T) {
c := NewCompiler().WithStageAfter("CheckRecursion", checkUnsafeFunctionCall)
modules := map[string]*Module{
"test": MustParseModule(`package test
r = opa.runtime()
`),
}

c.Compile(modules)

if !c.Failed() {
t.Errorf("Expected compilation error")
}
}

func TestCompilerWithStageAfterGood(t *testing.T) {
c := NewCompiler().WithStageAfter("CheckRecursion", checkUnsafeFunctionCall)
modules := map[string]*Module{
"test": MustParseModule(`package test
r = 1
`),
}

c.Compile(modules)

if c.Failed() {
t.Errorf("Unexpected compilation error %v", c.Errors)
}
}

func TestCompilerFunctions(t *testing.T) {
tests := []struct {
note string
Expand Down Expand Up @@ -2351,6 +2381,30 @@ func assertNotFailed(t *testing.T, c *Compiler) {
}
}

func checkUnsafeFunctionCall(c *Compiler) *Error {
var unsafeBuiltinsMap = map[string]bool{OPARuntime.Name: true}

x := c.Modules["test"]
unsafeOperators := []string{}
WalkNodes(x, func(x Node) bool {
if expr, ok := x.(*Expr); ok {
if expr.IsCall() {
operator := expr.Operator().String()
if _, ok := unsafeBuiltinsMap[operator]; ok {
unsafeOperators = append(unsafeOperators, operator)
}
}
return false
}
return false
})

if len(unsafeOperators) > 0 {
return NewError(CompileErr, &Location{}, "unsafe built-in function call in module: %v", strings.Join(unsafeOperators, ","))
}
return nil
}

func getCompilerWithParsedModules(mods map[string]string) *Compiler {

parsed := map[string]*Module{}
Expand Down

0 comments on commit 9444351

Please sign in to comment.