Skip to content
Permalink
598176de32
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
4732 lines (4131 sloc) 128 KB
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package ast
import (
"fmt"
"io"
"sort"
"strconv"
"strings"
"github.com/open-policy-agent/opa/internal/debug"
"github.com/open-policy-agent/opa/internal/gojsonschema"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/types"
"github.com/open-policy-agent/opa/util"
)
// CompileErrorLimitDefault is the default number errors a compiler will allow before
// exiting.
const CompileErrorLimitDefault = 10
var errLimitReached = NewError(CompileErr, nil, "error limit reached")
// Compiler contains the state of a compilation process.
type Compiler struct {
// Errors contains errors that occurred during the compilation process.
// If there are one or more errors, the compilation process is considered
// "failed".
Errors Errors
// Modules contains the compiled modules. The compiled modules are the
// output of the compilation process. If the compilation process failed,
// there is no guarantee about the state of the modules.
Modules map[string]*Module
// ModuleTree organizes the modules into a tree where each node is keyed by
// an element in the module's package path. E.g., given modules containing
// the following package directives: "a", "a.b", "a.c", and "a.b", the
// resulting module tree would be:
//
// root
// |
// +--- data (no modules)
// |
// +--- a (1 module)
// |
// +--- b (2 modules)
// |
// +--- c (1 module)
//
ModuleTree *ModuleTreeNode
// RuleTree organizes rules into a tree where each node is keyed by an
// element in the rule's path. The rule path is the concatenation of the
// containing package and the stringified rule name. E.g., given the
// following module:
//
// package ex
// p[1] { true }
// p[2] { true }
// q = true
//
// root
// |
// +--- data (no rules)
// |
// +--- ex (no rules)
// |
// +--- p (2 rules)
// |
// +--- q (1 rule)
RuleTree *TreeNode
// Graph contains dependencies between rules. An edge (u,v) is added to the
// graph if rule 'u' refers to the virtual document defined by 'v'.
Graph *Graph
// TypeEnv holds type information for values inferred by the compiler.
TypeEnv *TypeEnv
// RewrittenVars is a mapping of variables that have been rewritten
// with the key being the generated name and value being the original.
RewrittenVars map[Var]Var
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []struct {
name string
metricName string
f func()
}
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
}
// CompilerStage defines the interface for stages in the compiler.
type CompilerStage func(*Compiler) *Error
// CompilerStageDefinition defines a compiler stage
type CompilerStageDefinition struct {
Name string
MetricName string
Stage CompilerStage
}
// RulesOptions defines the options for retrieving rules by Ref from the
// compiler.
type RulesOptions struct {
// IncludeHiddenModules determines if the result contains hidden modules,
// currently only the "system" namespace, i.e. "data.system.*".
IncludeHiddenModules bool
}
// QueryContext contains contextual information for running an ad-hoc query.
//
// Ad-hoc queries can be run in the context of a package and imports may be
// included to provide concise access to data.
type QueryContext struct {
Package *Package
Imports []*Import
}
// NewQueryContext returns a new QueryContext object.
func NewQueryContext() *QueryContext {
return &QueryContext{}
}
// WithPackage sets the pkg on qc.
func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext {
if qc == nil {
qc = NewQueryContext()
}
qc.Package = pkg
return qc
}
// WithImports sets the imports on qc.
func (qc *QueryContext) WithImports(imports []*Import) *QueryContext {
if qc == nil {
qc = NewQueryContext()
}
qc.Imports = imports
return qc
}
// Copy returns a deep copy of qc.
func (qc *QueryContext) Copy() *QueryContext {
if qc == nil {
return nil
}
cpy := *qc
if cpy.Package != nil {
cpy.Package = qc.Package.Copy()
}
cpy.Imports = make([]*Import, len(qc.Imports))
for i := range qc.Imports {
cpy.Imports[i] = qc.Imports[i].Copy()
}
return &cpy
}
// QueryCompiler defines the interface for compiling ad-hoc queries.
type QueryCompiler interface {
// Compile should be called to compile ad-hoc queries. The return value is
// the compiled version of the query.
Compile(q Body) (Body, error)
// TypeEnv returns the type environment built after running type checking
// on the query.
TypeEnv() *TypeEnv
// WithContext sets the QueryContext on the QueryCompiler. Subsequent calls
// to Compile will take the QueryContext into account.
WithContext(qctx *QueryContext) QueryCompiler
// WithEnablePrintStatements enables print statements in queries compiled
// with the QueryCompiler.
WithEnablePrintStatements(yes bool) QueryCompiler
// WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not
// allow inside of queries. By default the query compiler inherits the
// compiler's unsafe built-in functions. This function allows callers to
// override that set. If an empty (non-nil) map is provided, all built-ins
// are allowed.
WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler
// WithStageAfter registers a stage to run during query compilation after
// the named stage.
WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler
// RewrittenVars maps generated vars in the compiled query to vars from the
// parsed query. For example, given the query "input := 1" the rewritten
// query would be "__local0__ = 1". The mapping would then be {__local0__: input}.
RewrittenVars() map[Var]Var
// ComprehensionIndex returns an index data structure for the given comprehension
// term. If no index is found, returns nil.
ComprehensionIndex(term *Term) *ComprehensionIndex
}
// QueryCompilerStage defines the interface for stages in the query compiler.
type QueryCompilerStage func(QueryCompiler, Body) (Body, error)
// QueryCompilerStageDefinition defines a QueryCompiler stage
type QueryCompilerStageDefinition struct {
Name string
MetricName string
Stage QueryCompilerStage
}
// NewCompiler returns a new empty compiler.
func NewCompiler() *Compiler {
c := &Compiler{
Modules: map[string]*Module{},
RewrittenVars: map[Var]Var{},
ruleIndices: util.NewHashMap(func(a, b util.T) bool {
r1, r2 := a.(Ref), b.(Ref)
return r1.Equal(r2)
}, func(x util.T) int {
return x.(Ref).Hash()
}),
maxErrs: CompileErrorLimitDefault,
after: map[string][]CompilerStageDefinition{},
unsafeBuiltinsMap: map[string]struct{}{},
deprecatedBuiltinsMap: map[string]struct{}{},
comprehensionIndices: map[*Term]*ComprehensionIndex{},
debug: debug.Discard(),
}
c.ModuleTree = NewModuleTree(nil)
c.RuleTree = NewRuleTree(c.ModuleTree)
c.stages = []struct {
name string
metricName string
f func()
}{
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
// Reference resolution should run first as it may be used to lazily
// load additional modules. If any stages run before resolution, they
// need to be re-run after resolution.
{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree},
// The local variable generator must be initialized after references are
// resolved and the dynamic module loader has run but before subsequent
// stages that need to generate variables.
{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
{"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls},
{"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls},
{"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms},
{"SetGraph", "compile_stage_set_graph", c.setGraph},
{"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms},
{"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead},
{"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers},
{"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts},
{"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs},
{"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads},
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet},
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
{"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins},
{"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices},
{"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices},
}
return c
}
// SetErrorLimit sets the number of errors the compiler can encounter before it
// quits. Zero or a negative number indicates no limit.
func (c *Compiler) SetErrorLimit(limit int) *Compiler {
c.maxErrs = limit
return c
}
// WithEnablePrintStatements enables print statements inside of modules compiled
// by the compiler. If print statements are not enabled, calls to print() are
// erased at compile-time.
func (c *Compiler) WithEnablePrintStatements(yes bool) *Compiler {
c.enablePrintStatements = yes
return c
}
// WithPathConflictsCheck enables base-virtual document conflict
// detection. The compiler will check that rules don't overlap with
// paths that exist as determined by the provided callable.
func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler {
c.pathExists = fn
return c
}
// WithStageAfter registers a stage to run during compilation after
// the named stage.
func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler {
c.after[after] = append(c.after[after], stage)
return c
}
// WithMetrics will set a metrics.Metrics and be used for profiling
// the Compiler instance.
func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
c.metrics = metrics
return c
}
// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller
// to specify the set of built-in functions available to the policy. In the future, capabilities
// may be able to restrict access to other language features. Capabilities allow callers to check
// if policies are compatible with a particular version of OPA. If policies are a compiled for a
// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them
// successfully.
func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
c.capabilities = capabilities
return c
}
// Capabilities returns the capabilities enabled during compilation.
func (c *Compiler) Capabilities() *Capabilities {
return c.capabilities
}
// WithDebug sets where debug messages are written to. Passing `nil` has no
// effect.
func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
if sink != nil {
c.debug = debug.New(sink)
}
return c
}
// WithBuiltins is deprecated. Use WithCapabilities instead.
func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
c.customBuiltins = make(map[string]*Builtin)
for k, v := range builtins {
c.customBuiltins[k] = v
}
return c
}
// WithUnsafeBuiltins is deprecated. Use WithCapabilities instead.
func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler {
for name := range unsafeBuiltins {
c.unsafeBuiltinsMap[name] = struct{}{}
}
return c
}
// WithStrict enables strict mode in the compiler.
func (c *Compiler) WithStrict(strict bool) *Compiler {
c.strict = strict
return c
}
// QueryCompiler returns a new QueryCompiler object.
func (c *Compiler) QueryCompiler() QueryCompiler {
c.init()
return newQueryCompiler(c)
}
// Compile runs the compilation process on the input modules. The compiled
// version of the modules and associated data structures are stored on the
// compiler. If the compilation process fails for any reason, the compiler will
// contain a slice of errors.
func (c *Compiler) Compile(modules map[string]*Module) {
c.init()
c.Modules = make(map[string]*Module, len(modules))
for k, v := range modules {
c.Modules[k] = v.Copy()
c.sorted = append(c.sorted, k)
}
sort.Strings(c.sorted)
c.compile()
}
// WithSchemas sets a schemaSet to the compiler
func (c *Compiler) WithSchemas(schemas *SchemaSet) *Compiler {
c.schemaSet = schemas
return c
}
// Failed returns true if a compilation error has been encountered.
func (c *Compiler) Failed() bool {
return len(c.Errors) > 0
}
// ComprehensionIndex returns a data structure specifying how to index comprehension
// results so that callers do not have to recompute the comprehension more than once.
// If no index is found, returns nil.
func (c *Compiler) ComprehensionIndex(term *Term) *ComprehensionIndex {
return c.comprehensionIndices[term]
}
// GetArity returns the number of args a function referred to by ref takes. If
// ref refers to built-in function, the built-in declaration is consulted,
// otherwise, the ref is used to perform a ruleset lookup.
func (c *Compiler) GetArity(ref Ref) int {
if bi := c.builtins[ref.String()]; bi != nil {
return len(bi.Decl.Args())
}
rules := c.GetRulesExact(ref)
if len(rules) == 0 {
return -1
}
return len(rules[0].Head.Args)
}
// GetRulesExact returns a slice of rules referred to by the reference.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesExact("data.a.b.c.p") => [rule1, rule2]
// GetRulesExact("data.a.b.c.p.x") => nil
// GetRulesExact("data.a.b.c") => nil
func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
node := c.RuleTree
for _, x := range ref {
if node = node.Child(x.Value); node == nil {
return nil
}
}
return extractRules(node.Values)
}
// GetRulesForVirtualDocument returns a slice of rules that produce the virtual
// document referred to by the reference.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c") => nil
func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {
node := c.RuleTree
for _, x := range ref {
if node = node.Child(x.Value); node == nil {
return nil
}
if len(node.Values) > 0 {
return extractRules(node.Values)
}
}
return extractRules(node.Values)
}
// GetRulesWithPrefix returns a slice of rules that share the prefix ref.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[x] = y { ... } # rule1
// p[k] = v { ... } # rule2
// q { ... } # rule3
//
// The following calls yield the rules on the right.
//
// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2]
// GetRulesWithPrefix("data.a.b.c.p.a") => nil
// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3]
func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
node := c.RuleTree
for _, x := range ref {
if node = node.Child(x.Value); node == nil {
return nil
}
}
var acc func(node *TreeNode)
acc = func(node *TreeNode) {
rules = append(rules, extractRules(node.Values)...)
for _, child := range node.Children {
if child.Hide {
continue
}
acc(child)
}
}
acc(node)
return rules
}
func extractRules(s []util.T) (rules []*Rule) {
for _, r := range s {
rules = append(rules, r.(*Rule))
}
return rules
}
// GetRules returns a slice of rules that are referred to by ref.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[x] = y { q[x] = y; ... } # rule1
// q[x] = y { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRules("data.a.b.c.p") => [rule1]
// GetRules("data.a.b.c.p.x") => [rule1]
// GetRules("data.a.b.c.q") => [rule2]
// GetRules("data.a.b.c") => [rule1, rule2]
// GetRules("data.a.b.d") => nil
func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {
set := map[*Rule]struct{}{}
for _, rule := range c.GetRulesForVirtualDocument(ref) {
set[rule] = struct{}{}
}
for _, rule := range c.GetRulesWithPrefix(ref) {
set[rule] = struct{}{}
}
for rule := range set {
rules = append(rules, rule)
}
return rules
}
// GetRulesDynamic returns a slice of rules that could be referred to by a ref.
//
// Deprecated: use GetRulesDynamicWithOpts
func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule {
return c.GetRulesDynamicWithOpts(ref, RulesOptions{})
}
// GetRulesDynamicWithOpts returns a slice of rules that could be referred to by
// a ref.
// When parts of the ref are statically known, we use that information to narrow
// down which rules the ref could refer to, but in the most general case this
// will be an over-approximation.
//
// E.g., given the following modules:
//
// package a.b.c
//
// r1 = 1 # rule1
//
// and:
//
// package a.d.c
//
// r2 = 2 # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2]
// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2]
// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1]
//
// Using the RulesOptions parameter, the inclusion of hidden modules can be
// controlled:
//
// With
//
// package system.main
//
// r3 = 3 # rule3
//
// We'd get this result:
//
// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3]
//
// Without the options, it would be excluded.
func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule {
node := c.RuleTree
set := map[*Rule]struct{}{}
var walk func(node *TreeNode, i int)
walk = func(node *TreeNode, i int) {
if i >= len(ref) {
// We've reached the end of the reference and want to collect everything
// under this "prefix".
node.DepthFirst(func(descendant *TreeNode) bool {
insertRules(set, descendant.Values)
if opts.IncludeHiddenModules {
return false
}
return descendant.Hide
})
} else if i == 0 || IsConstant(ref[i].Value) {
// The head of the ref is always grounded. In case another part of the
// ref is also grounded, we can lookup the exact child. If it's not found
// we can immediately return...
if child := node.Child(ref[i].Value); child == nil {
return
} else if len(child.Values) > 0 {
// If there are any rules at this position, it's what the ref would
// refer to. We can just append those and stop here.
insertRules(set, child.Values)
} else {
// Otherwise, we continue using the child node.
walk(child, i+1)
}
} else {
// This part of the ref is a dynamic term. We can't know what it refers
// to and will just need to try all of the children.
for _, child := range node.Children {
if child.Hide && !opts.IncludeHiddenModules {
continue
}
insertRules(set, child.Values)
walk(child, i+1)
}
}
}
walk(node, 0)
rules := make([]*Rule, 0, len(set))
for rule := range set {
rules = append(rules, rule)
}
return rules
}
// Utility: add all rule values to the set.
func insertRules(set map[*Rule]struct{}, rules []util.T) {
for _, rule := range rules {
set[rule.(*Rule)] = struct{}{}
}
}
// RuleIndex returns a RuleIndex built for the rule set referred to by path.
// The path must refer to the rule set exactly, i.e., given a rule set at path
// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a
// RuleIndex built for the rule.
func (c *Compiler) RuleIndex(path Ref) RuleIndex {
r, ok := c.ruleIndices.Get(path)
if !ok {
return nil
}
return r.(RuleIndex)
}
// PassesTypeCheck determines whether the given body passes type checking
func (c *Compiler) PassesTypeCheck(body Body) bool {
checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType)
env := c.TypeEnv
_, errs := checker.CheckBody(env, body)
return len(errs) == 0
}
// ModuleLoader defines the interface that callers can implement to enable lazy
// loading of modules during compilation.
type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error)
// WithModuleLoader sets f as the ModuleLoader on the compiler.
//
// The compiler will invoke the ModuleLoader after resolving all references in
// the current set of input modules. The ModuleLoader can return a new
// collection of parsed modules that are to be included in the compilation
// process. This process will repeat until the ModuleLoader returns an empty
// collection or an error. If an error is returned, compilation will stop
// immediately.
func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler {
c.moduleLoader = f
return c
}
func (c *Compiler) counterAdd(name string, n uint64) {
if c.metrics == nil {
return
}
c.metrics.Counter(name).Add(n)
}
func (c *Compiler) buildRuleIndices() {
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
if len(node.Values) == 0 {
return false
}
index := newBaseDocEqIndex(func(ref Ref) bool {
return isVirtual(c.RuleTree, ref.GroundPrefix())
})
if rules := extractRules(node.Values); index.Build(rules) {
c.ruleIndices.Put(rules[0].Path(), index)
}
return false
})
}
func (c *Compiler) buildComprehensionIndices() {
for _, name := range c.sorted {
WalkRules(c.Modules[name], func(r *Rule) bool {
candidates := r.Head.Args.Vars()
candidates.Update(ReservedVars)
n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices)
c.counterAdd(compileStageComprehensionIndexBuild, n)
return false
})
}
}
// checkRecursion ensures that there are no recursive definitions, i.e., there are
// no cycles in the Graph.
func (c *Compiler) checkRecursion() {
eq := func(a, b util.T) bool {
return a.(*Rule) == b.(*Rule)
}
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
for _, rule := range node.Values {
for node := rule.(*Rule); node != nil; node = node.Else {
c.checkSelfPath(node.Loc(), eq, node, node)
}
}
return false
})
}
func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
tr := NewGraphTraversal(c.Graph)
if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
n := []string{}
for _, x := range p {
n = append(n, astNodeToString(x))
}
c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> ")))
}
}
func astNodeToString(x interface{}) string {
switch x := x.(type) {
case *Rule:
return string(x.Head.Name)
default:
panic("not reached")
}
}
// checkRuleConflicts ensures that rules definitions are not in conflict.
func (c *Compiler) checkRuleConflicts() {
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
if len(node.Values) == 0 {
return false
}
kinds := map[DocKind]struct{}{}
defaultRules := 0
arities := map[int]struct{}{}
declared := false
for _, rule := range node.Values {
r := rule.(*Rule)
kinds[r.Head.DocKind()] = struct{}{}
arities[len(r.Head.Args)] = struct{}{}
if r.Head.Assign {
declared = true
}
if r.Default {
defaultRules++
}
}
name := Var(node.Key.(String))
if declared && len(node.Values) > 1 {
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule named %v redeclared at %v", name, node.Values[1].(*Rule).Loc()))
} else if len(kinds) > 1 || len(arities) > 1 {
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name))
} else if defaultRules > 1 {
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name))
}
return false
})
if c.pathExists != nil {
for _, err := range CheckPathConflicts(c, c.pathExists) {
c.err(err)
}
}
c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
for _, mod := range node.Modules {
for _, rule := range mod.Rules {
if childNode, ok := node.Children[String(rule.Head.Name)]; ok {
for _, childMod := range childNode.Modules {
msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc())
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
}
}
}
}
return false
})
}
func (c *Compiler) checkUndefinedFuncs() {
for _, name := range c.sorted {
m := c.Modules[name]
for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) {
c.err(err)
}
}
}
func checkUndefinedFuncs(env *TypeEnv, x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {
var errs Errors
WalkExprs(x, func(expr *Expr) bool {
if !expr.IsCall() {
return false
}
ref := expr.Operator()
if arity := arity(ref); arity >= 0 {
operands := len(expr.Operands())
if expr.Generated { // an output var was added
if !expr.IsEquality() && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands-1))
return true
}
} else { // either output var or not
if operands != arity && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands))
return true
}
}
return false
}
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
return true
})
return errs
}
func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error {
if want, ok := env.Get(f).(*types.Function); ok { // generate richer error for built-in functions
have := make([]types.Type, len(expr.Operands()))
for i, op := range expr.Operands() {
have[i] = env.Get(op)
}
return newArgError(expr.Loc(), f, "arity mismatch", have, want.FuncArgs())
}
if act != 1 {
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act)
}
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d argument", f, exp, act)
}
// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
// positions of built-in expressions will be bound when evaluating the rule from left
// to right, re-ordering as necessary.
func (c *Compiler) checkSafetyRuleBodies() {
for _, name := range c.sorted {
m := c.Modules[name]
WalkRules(m, func(r *Rule) bool {
safe := ReservedVars.Copy()
safe.Update(r.Head.Args.Vars())
r.Body = c.checkBodySafety(safe, r.Body)
return false
})
}
}
func (c *Compiler) checkBodySafety(safe VarSet, b Body) Body {
reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b)
if errs := safetyErrorSlice(unsafe, c.RewrittenVars); len(errs) > 0 {
for _, err := range errs {
c.err(err)
}
return b
}
return reordered
}
// SafetyCheckVisitorParams defines the AST visitor parameters to use for collecting
// variables during the safety check. This has to be exported because it's relied on
// by the copy propagation implementation in topdown.
var SafetyCheckVisitorParams = VarVisitorParams{
SkipRefCallHead: true,
SkipClosures: true,
}
// checkSafetyRuleHeads ensures that variables appearing in the head of a
// rule also appear in the body.
func (c *Compiler) checkSafetyRuleHeads() {
for _, name := range c.sorted {
m := c.Modules[name]
WalkRules(m, func(r *Rule) bool {
safe := r.Body.Vars(SafetyCheckVisitorParams)
safe.Update(r.Head.Args.Vars())
unsafe := r.Head.Vars().Diff(safe)
for v := range unsafe {
if w, ok := c.RewrittenVars[v]; ok {
v = w
}
if !v.IsGenerated() {
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
}
}
return false
})
}
}
func compileSchema(goSchema interface{}, allowNet []string) (*gojsonschema.Schema, error) {
gojsonschema.SetAllowNet(allowNet)
var refLoader gojsonschema.JSONLoader
sl := gojsonschema.NewSchemaLoader()
if goSchema != nil {
refLoader = gojsonschema.NewGoLoader(goSchema)
} else {
return nil, fmt.Errorf("no schema as input to compile")
}
schemasCompiled, err := sl.Compile(refLoader)
if err != nil {
return nil, fmt.Errorf("unable to compile the schema: %w", err)
}
return schemasCompiled, nil
}
func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, error) {
if len(schemas) == 0 {
return nil, nil
}
var result = schemas[0]
for i := range schemas {
if len(schemas[i].PropertiesChildren) > 0 {
if !schemas[i].Types.Contains("object") {
if err := schemas[i].Types.Add("object"); err != nil {
return nil, fmt.Errorf("unable to set the type in schemas")
}
}
} else if len(schemas[i].ItemsChildren) > 0 {
if !schemas[i].Types.Contains("array") {
if err := schemas[i].Types.Add("array"); err != nil {
return nil, fmt.Errorf("unable to set the type in schemas")
}
}
}
}
for i := 1; i < len(schemas); i++ {
if result.Types.String() != schemas[i].Types.String() {
return nil, fmt.Errorf("unable to merge these schemas: type mismatch: %v and %v", result.Types.String(), schemas[i].Types.String())
} else if result.Types.Contains("object") && len(result.PropertiesChildren) > 0 && schemas[i].Types.Contains("object") && len(schemas[i].PropertiesChildren) > 0 {
result.PropertiesChildren = append(result.PropertiesChildren, schemas[i].PropertiesChildren...)
} else if result.Types.Contains("array") && len(result.ItemsChildren) > 0 && schemas[i].Types.Contains("array") && len(schemas[i].ItemsChildren) > 0 {
for j := 0; j < len(schemas[i].ItemsChildren); j++ {
if len(result.ItemsChildren)-1 < j && !(len(schemas[i].ItemsChildren)-1 < j) {
result.ItemsChildren = append(result.ItemsChildren, schemas[i].ItemsChildren[j])
}
if result.ItemsChildren[j].Types.String() != schemas[i].ItemsChildren[j].Types.String() {
return nil, fmt.Errorf("unable to merge these schemas")
}
}
}
}
return result, nil
}
func parseSchema(schema interface{}) (types.Type, error) {
subSchema, ok := schema.(*gojsonschema.SubSchema)
if !ok {
return nil, fmt.Errorf("unexpected schema type %v", subSchema)
}
// Handle referenced schemas, returns directly when a $ref is found
if subSchema.RefSchema != nil {
return parseSchema(subSchema.RefSchema)
}
// Handle anyOf
if subSchema.AnyOf != nil {
var orType types.Type
// If there is a core schema, find its type first
if subSchema.Types.IsTyped() {
copySchema := *subSchema
copySchemaRef := &copySchema
copySchemaRef.AnyOf = nil
coreType, err := parseSchema(copySchemaRef)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err)
}
// Only add Object type with static props to orType
if objType, ok := coreType.(*types.Object); ok {
if objType.StaticProperties() != nil && objType.DynamicProperties() == nil {
orType = types.Or(orType, coreType)
}
}
}
// Iterate through every property of AnyOf and add it to orType
for _, pSchema := range subSchema.AnyOf {
newtype, err := parseSchema(pSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
}
orType = types.Or(newtype, orType)
}
return orType, nil
}
if subSchema.AllOf != nil {
subSchemaArray := subSchema.AllOf
allOfResult, err := mergeSchemas(subSchemaArray...)
if err != nil {
return nil, err
}
if subSchema.Types.IsTyped() {
if (subSchema.Types.Contains("object") && allOfResult.Types.Contains("object")) || (subSchema.Types.Contains("array") && allOfResult.Types.Contains("array")) {
objectOrArrayResult, err := mergeSchemas(allOfResult, subSchema)
if err != nil {
return nil, err
}
return parseSchema(objectOrArrayResult)
} else if subSchema.Types.String() != allOfResult.Types.String() {
return nil, fmt.Errorf("unable to merge these schemas")
}
}
return parseSchema(allOfResult)
}
if subSchema.Types.IsTyped() {
if subSchema.Types.Contains("boolean") {
return types.B, nil
} else if subSchema.Types.Contains("string") {
return types.S, nil
} else if subSchema.Types.Contains("integer") || subSchema.Types.Contains("number") {
return types.N, nil
} else if subSchema.Types.Contains("object") {
if len(subSchema.PropertiesChildren) > 0 {
staticProps := make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren))
for _, pSchema := range subSchema.PropertiesChildren {
newtype, err := parseSchema(pSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
}
staticProps = append(staticProps, types.NewStaticProperty(pSchema.Property, newtype))
}
return types.NewObject(staticProps, nil), nil
}
return types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), nil
} else if subSchema.Types.Contains("array") {
if len(subSchema.ItemsChildren) > 0 {
if subSchema.ItemsChildrenIsSingleSchema {
iSchema := subSchema.ItemsChildren[0]
newtype, err := parseSchema(iSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
}
return types.NewArray(nil, newtype), nil
}
newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren))
for i := 0; i != len(subSchema.ItemsChildren); i++ {
iSchema := subSchema.ItemsChildren[i]
newtype, err := parseSchema(iSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
}
newTypes = append(newTypes, newtype)
}
return types.NewArray(newTypes, nil), nil
}
return types.NewArray(nil, types.A), nil
}
}
// Assume types if not specified in schema
if len(subSchema.PropertiesChildren) > 0 {
if err := subSchema.Types.Add("object"); err == nil {
return parseSchema(subSchema)
}
} else if len(subSchema.ItemsChildren) > 0 {
if err := subSchema.Types.Add("array"); err == nil {
return parseSchema(subSchema)
}
}
return types.A, nil
}
func (c *Compiler) setAnnotationSet() {
// Sorting modules by name for stable error reporting
sorted := make([]*Module, 0, len(c.Modules))
for _, mName := range c.sorted {
sorted = append(sorted, c.Modules[mName])
}
as, errs := BuildAnnotationSet(sorted)
for _, err := range errs {
c.err(err)
}
c.annotationSet = as
}
// checkTypes runs the type checker on all rules. The type checker builds a
// TypeEnv that is stored on the compiler.
func (c *Compiler) checkTypes() {
// Recursion is caught in earlier step, so this cannot fail.
sorted, _ := c.Graph.Sort()
checker := newTypeChecker().
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
env, errs := checker.CheckTypes(c.TypeEnv, sorted, c.annotationSet)
for _, err := range errs {
c.err(err)
}
c.TypeEnv = env
}
func (c *Compiler) checkUnsafeBuiltins() {
for _, name := range c.sorted {
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
for _, err := range errs {
c.err(err)
}
}
}
func (c *Compiler) checkDeprecatedBuiltins() {
for _, name := range c.sorted {
errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, c.Modules[name], c.strict)
for _, err := range errs {
c.err(err)
}
}
}
func (c *Compiler) runStage(metricName string, f func()) {
if c.metrics != nil {
c.metrics.Timer(metricName).Start()
defer c.metrics.Timer(metricName).Stop()
}
f()
}
func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error {
if c.metrics != nil {
c.metrics.Timer(metricName).Start()
defer c.metrics.Timer(metricName).Stop()
}
return s(c)
}
func (c *Compiler) compile() {
defer func() {
if r := recover(); r != nil && r != errLimitReached {
panic(r)
}
}()
for _, s := range c.stages {
c.runStage(s.metricName, s.f)
if c.Failed() {
return
}
for _, s := range c.after[s.name] {
err := c.runStageAfter(s.MetricName, s.Stage)
if err != nil {
c.err(err)
}
}
}
}
func (c *Compiler) init() {
if c.initialized {
return
}
if c.capabilities == nil {
c.capabilities = CapabilitiesForThisVersion()
}
c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))
for _, bi := range c.capabilities.Builtins {
c.builtins[bi.Name] = bi
if c.strict && bi.IsDeprecated() {
c.deprecatedBuiltinsMap[bi.Name] = struct{}{}
}
}
for name, bi := range c.customBuiltins {
c.builtins[name] = bi
}
// Load the global input schema if one was provided.
if c.schemaSet != nil {
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error()))
} else {
c.inputType = tpe
}
}
}
c.TypeEnv = newTypeChecker().
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
Env(c.builtins)
c.initialized = true
}
func (c *Compiler) err(err *Error) {
if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
c.Errors = append(c.Errors, errLimitReached)
panic(errLimitReached)
}
c.Errors = append(c.Errors, err)
}
func (c *Compiler) getExports() *util.HashMap {
rules := util.NewHashMap(func(a, b util.T) bool {
r1 := a.(Ref)
r2 := a.(Ref)
return r1.Equal(r2)
}, func(v util.T) int {
return v.(Ref).Hash()
})
for _, name := range c.sorted {
mod := c.Modules[name]
rv, ok := rules.Get(mod.Package.Path)
if !ok {
rv = []Var{}
}
rvs := rv.([]Var)
for _, rule := range mod.Rules {
rvs = append(rvs, rule.Head.Name)
}
rules.Put(mod.Package.Path, rvs)
}
return rules
}
func (c *Compiler) GetAnnotationSet() *AnnotationSet {
return c.annotationSet
}
func (c *Compiler) checkDuplicateImports() {
if !c.strict {
return
}
for _, name := range c.sorted {
mod := c.Modules[name]
processedImports := map[Var]*Import{}
for _, imp := range mod.Imports {
name := imp.Name()
if processed, conflict := processedImports[name]; conflict {
c.err(NewError(CompileErr, imp.Location, "import must not shadow %v", processed))
} else {
processedImports[name] = imp
}
}
}
}
func (c *Compiler) checkKeywordOverrides() {
for _, name := range c.sorted {
mod := c.Modules[name]
errs := checkKeywordOverrides(mod, c.strict)
for _, err := range errs {
c.err(err)
}
}
}
func checkKeywordOverrides(node interface{}, strict bool) Errors {
if !strict {
return nil
}
errors := Errors{}
WalkRules(node, func(rule *Rule) bool {
name := rule.Head.Name.String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errors = append(errors, NewError(CompileErr, rule.Location, "rules must not shadow %v (use a different rule name)", name))
}
return true
})
WalkExprs(node, func(expr *Expr) bool {
if expr.IsAssignment() {
name := expr.Operand(0).String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errors = append(errors, NewError(CompileErr, expr.Location, "variables must not shadow %v (use a different variable name)", name))
}
}
return false
})
return errors
}
// resolveAllRefs resolves references in expressions to their fully qualified values.
//
// For instance, given the following module:
//
// package a.b
// import data.foo.bar
// p[x] { bar[_] = x }
//
// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
func (c *Compiler) resolveAllRefs() {
rules := c.getExports()
for _, name := range c.sorted {
mod := c.Modules[name]
var ruleExports []Var
if x, ok := rules.Get(mod.Package.Path); ok {
ruleExports = x.([]Var)
}
globals := getGlobals(mod.Package, ruleExports, mod.Imports)
WalkRules(mod, func(rule *Rule) bool {
err := resolveRefsInRule(globals, rule)
if err != nil {
c.err(NewError(CompileErr, rule.Location, err.Error()))
}
return false
})
if c.strict { // check for unused imports
for _, imp := range mod.Imports {
path := imp.Path.Value.(Ref)
if FutureRootDocument.Equal(path[0]) {
continue // ignore future imports
}
for v, u := range globals {
if v.Equal(imp.Name()) && !u.used {
c.err(NewError(CompileErr, imp.Location, "%s unused", imp.String()))
}
}
}
}
// Once imports have been resolved, they are no longer needed.
mod.Imports = nil
}
if c.moduleLoader != nil {
parsed, err := c.moduleLoader(c.Modules)
if err != nil {
c.err(NewError(CompileErr, nil, err.Error()))
return
}
if len(parsed) == 0 {
return
}
for id, module := range parsed {
c.Modules[id] = module.Copy()
c.sorted = append(c.sorted, id)
}
sort.Strings(c.sorted)
c.resolveAllRefs()
}
}
func (c *Compiler) initLocalVarGen() {
c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules)
}
func (c *Compiler) rewriteComprehensionTerms() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
_, _ = rewriteComprehensionTerms(f, mod) // ignore error
}
}
func (c *Compiler) rewriteExprTerms() {
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
rewriteExprTermsInHead(c.localvargen, rule)
rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
return false
})
}
}
func (c *Compiler) checkVoidCalls() {
for _, name := range c.sorted {
mod := c.Modules[name]
for _, err := range checkVoidCalls(c.TypeEnv, mod) {
c.err(err)
}
}
}
func (c *Compiler) rewritePrintCalls() {
if !c.enablePrintStatements {
for _, name := range c.sorted {
erasePrintCalls(c.Modules[name])
}
return
}
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(r *Rule) bool {
safe := r.Head.Args.Vars()
safe.Update(ReservedVars)
vis := func(b Body) bool {
for _, err := range rewritePrintCalls(c.localvargen, c.GetArity, safe, b) {
c.err(err)
}
return false
}
WalkBodies(r.Head, vis)
WalkBodies(r.Body, vis)
return false
})
}
}
// checkVoidCalls returns errors for any expressions that treat void function
// calls as values. The only void functions in Rego are specific built-ins like
// print().
func checkVoidCalls(env *TypeEnv, x interface{}) Errors {
var errs Errors
WalkTerms(x, func(x *Term) bool {
if call, ok := x.Value.(Call); ok {
if tpe, ok := env.Get(call[0]).(*types.Function); ok && tpe.Result() == nil {
errs = append(errs, NewError(TypeErr, x.Loc(), "%v used as value", call))
}
}
return false
})
return errs
}
// rewritePrintCalls will rewrite the body so that print operands are captured
// in local variables and their evaluation occurs within a comprehension.
// Wrapping the terms inside of a comprehension ensures that undefined values do
// not short-circuit evaluation.
//
// For example, given the following print statement:
//
// print("the value of x is:", input.x)
//
// The expression would be rewritten to:
//
// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x})
func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) Errors {
var errs Errors
// Visit comprehension bodies recursively to ensure print statements inside
// those bodies only close over variables that are safe.
for i := range body {
if ContainsClosures(body[i]) {
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
WalkClosures(body[i], func(x interface{}) bool {
switch x := x.(type) {
case *SetComprehension:
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ArrayComprehension:
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ObjectComprehension:
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *Every:
safe.Update(x.KeyValueVars())
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
}
return true
})
if len(errs) > 0 {
return errs
}
}
}
for i := range body {
if !isPrintCall(body[i]) {
continue
}
var errs Errors
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
args := body[i].Operands()
for j := range args {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis.Walk(args[j])
unsafe := vis.Vars().Diff(safe)
for _, v := range unsafe.Sorted() {
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
}
}
if len(errs) > 0 {
return errs
}
arr := NewArray()
for j := range args {
x := NewTerm(gen.Generate()).SetLocation(args[j].Loc())
capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc())
arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
}
body.Set(NewExpr([]*Term{
NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()),
NewTerm(arr).SetLocation(body[i].Loc()),
}).SetLocation(body[i].Loc()), i)
}
return nil
}
func erasePrintCalls(node interface{}) {
NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case *Rule:
x.Body = erasePrintCallsInBody(x.Body)
case *ArrayComprehension:
x.Body = erasePrintCallsInBody(x.Body)
case *SetComprehension:
x.Body = erasePrintCallsInBody(x.Body)
case *ObjectComprehension:
x.Body = erasePrintCallsInBody(x.Body)
case *Every:
x.Body = erasePrintCallsInBody(x.Body)
}
return false
}).Walk(node)
}
func erasePrintCallsInBody(x Body) Body {
if !containsPrintCall(x) {
return x
}
var cpy Body
for i := range x {
// Recursively visit any comprehensions contained in this expression.
erasePrintCalls(x[i])
if !isPrintCall(x[i]) {
cpy.Append(x[i])
}
}
if len(cpy) == 0 {
term := BooleanTerm(true).SetLocation(x.Loc())
expr := NewExpr(term).SetLocation(x.Loc())
cpy.Append(expr)
}
return cpy
}
func containsPrintCall(x Body) bool {
var found bool
WalkExprs(x, func(expr *Expr) bool {
if !found {
if isPrintCall(expr) {
found = true
}
}
return found
})
return found
}
func isPrintCall(x *Expr) bool {
return x.IsCall() && x.Operator().Equal(Print.Ref())
}
// rewriteTermsInHead will rewrite rules so that the head does not contain any
// terms that require evaluation (e.g., refs or comprehensions). If the key or
// value contains one or more of these terms, the key or value will be moved
// into the body and assigned to a new variable. The new variable will replace
// the key or value in the head.
//
// For instance, given the following rule:
//
// p[{"foo": data.foo[i]}] { i < 100 }
//
// The rule would be re-written as:
//
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
func (c *Compiler) rewriteRefsInHead() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
if requiresEval(rule.Head.Key) {
expr := f.Generate(rule.Head.Key)
rule.Head.Key = expr.Operand(0)
rule.Body.Append(expr)
}
if requiresEval(rule.Head.Value) {
expr := f.Generate(rule.Head.Value)
rule.Head.Value = expr.Operand(0)
rule.Body.Append(expr)
}
for i := 0; i < len(rule.Head.Args); i++ {
if requiresEval(rule.Head.Args[i]) {
expr := f.Generate(rule.Head.Args[i])
rule.Head.Args[i] = expr.Operand(0)
rule.Body.Append(expr)
}
}
return false
})
}
}
func (c *Compiler) rewriteEquals() {
for _, name := range c.sorted {
mod := c.Modules[name]
rewriteEquals(mod)
}
}
func (c *Compiler) rewriteDynamicTerms() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
rule.Body = rewriteDynamics(f, rule.Body)
return false
})
}
}
func (c *Compiler) rewriteLocalVars() {
for _, name := range c.sorted {
mod := c.Modules[name]
gen := c.localvargen
WalkRules(mod, func(rule *Rule) bool {
// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
// first to preserve scoping rules. For example:
//
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
//
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
strict: c.strict,
}
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
for _, err := range nestedXform.errs {
c.err(err)
}
// Rewrite assignments in body.
used := NewVarSet()
if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
}
if rule.Head.Value != nil {
used.Update(rule.Head.Value.Vars())
}
stack := newLocalDeclaredVars()
c.rewriteLocalArgVars(gen, stack, rule)
body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict)
for _, err := range errs {
c.err(err)
}
// For rewritten vars use the collection of all variables that
// were in the stack at some point in time.
for k, v := range stack.rewritten {
c.RewrittenVars[k] = v
}
rule.Body = body
// Rewrite vars in head that refer to locally declared vars in the body.
localXform := rewriteHeadVarLocalTransform{declared: declared}
for i := range rule.Head.Args {
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
}
if rule.Head.Key != nil {
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
}
if rule.Head.Value != nil {
rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value)
}
return false
})
}
}
type rewriteNestedHeadVarLocalTransform struct {
gen *localVarGenerator
errs Errors
RewrittenVars map[Var]Var
strict bool
}
func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool {
if term, ok := x.(*Term); ok {
stop := false
stack := newLocalDeclaredVars()
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
NewGenericVisitor(xform.Visit).Walk(kcpy)
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return kcpy, vcpy, nil
})
term.Value = cpy
stop = true
case *set:
cpy, _ := x.Map(func(v *Term) (*Term, error) {
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return vcpy, nil
})
term.Value = cpy
stop = true
case *ArrayComprehension:
xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs, xform.strict)
stop = true
case *SetComprehension:
xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs, xform.strict)
stop = true
case *ObjectComprehension:
xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs, xform.strict)
stop = true
}
for k, v := range stack.rewritten {
xform.RewrittenVars[k] = v
}
return stop
}
return false
}
type rewriteHeadVarLocalTransform struct {
declared map[Var]Var
}
func (xform rewriteHeadVarLocalTransform) Transform(x interface{}) (interface{}, error) {
if v, ok := x.(Var); ok {
if gv, ok := xform.declared[v]; ok {
return gv, nil
}
}
return x, nil
}
func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) {
vis := &ruleArgLocalRewriter{
stack: stack,
gen: gen,
}
for i := range rule.Head.Args {
Walk(vis, rule.Head.Args[i])
}
for i := range vis.errs {
c.err(vis.errs[i])
}
}
type ruleArgLocalRewriter struct {
stack *localDeclaredVars
gen *localVarGenerator
errs []*Error
}
func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {
t, ok := x.(*Term)
if !ok {
return vis
}
switch v := t.Value.(type) {
case Var:
gv, ok := vis.stack.Declared(v)
if ok {
vis.stack.Seen(v)
} else {
gv = vis.gen.Generate()
vis.stack.Insert(v, gv, argVar)
}
t.Value = gv
return nil
case *object:
if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) {
vcpy := v.Copy()
Walk(vis, vcpy)
return k, vcpy, nil
}); err != nil {
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = cpy
}
return nil
case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set:
// Scalars are no-ops. Comprehensions are handled above. Sets must not
// contain variables.
return nil
case Call:
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls"))
return nil
default:
// Recurse on refs and arrays. Any embedded
// variables can be rewritten.
return vis
}
}
func (c *Compiler) rewriteWithModifiers() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
body, ok := x.(Body)
if !ok {
return x, nil
}
body, err := rewriteWithModifiersInBody(c, f, body)
if err != nil {
c.err(err)
}
return body, nil
})
_, _ = Transform(t, mod) // ignore error
}
}
func (c *Compiler) setModuleTree() {
c.ModuleTree = NewModuleTree(c.Modules)
}
func (c *Compiler) setRuleTree() {
c.RuleTree = NewRuleTree(c.ModuleTree)
}
func (c *Compiler) setGraph() {
list := func(r Ref) []*Rule {
return c.GetRulesDynamicWithOpts(r, RulesOptions{IncludeHiddenModules: true})
}
c.Graph = NewGraph(c.Modules, list)
}
type queryCompiler struct {
compiler *Compiler
qctx *QueryContext
typeEnv *TypeEnv
rewritten map[Var]Var
after map[string][]QueryCompilerStageDefinition
unsafeBuiltins map[string]struct{}
comprehensionIndices map[*Term]*ComprehensionIndex
enablePrintStatements bool
}
func newQueryCompiler(compiler *Compiler) QueryCompiler {
qc := &queryCompiler{
compiler: compiler,
qctx: nil,
after: map[string][]QueryCompilerStageDefinition{},
comprehensionIndices: map[*Term]*ComprehensionIndex{},
}
return qc
}
func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler {
qc.enablePrintStatements = yes
return qc
}
func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler {
qc.qctx = qctx
return qc
}
func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler {
qc.after[after] = append(qc.after[after], stage)
return qc
}
func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler {
qc.unsafeBuiltins = unsafe
return qc
}
func (qc *queryCompiler) RewrittenVars() map[Var]Var {
return qc.rewritten
}
func (qc *queryCompiler) ComprehensionIndex(term *Term) *ComprehensionIndex {
if result, ok := qc.comprehensionIndices[term]; ok {
return result
} else if result, ok := qc.compiler.comprehensionIndices[term]; ok {
return result
}
return nil
}
func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) {
if qc.compiler.metrics != nil {
qc.compiler.metrics.Timer(metricName).Start()
defer qc.compiler.metrics.Timer(metricName).Stop()
}
return s(qctx, query)
}
func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) {
if qc.compiler.metrics != nil {
qc.compiler.metrics.Timer(metricName).Start()
defer qc.compiler.metrics.Timer(metricName).Stop()
}
return s(qc, query)
}
func (qc *queryCompiler) Compile(query Body) (Body, error) {
if len(query) == 0 {
return nil, Errors{NewError(CompileErr, nil, "empty query cannot be compiled")}
}
query = query.Copy()
stages := []struct {
name string
metricName string
f func(*QueryContext, Body) (Body, error)
}{
{"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides},
{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
{"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls},
{"RewritePrintCalls", "query_compile_stage_rewrite_print_calls", qc.rewritePrintCalls},
{"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms},
{"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms},
{"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers},
{"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs},
{"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety},
{"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms},
{"CheckTypes", "query_compile_stage_check_types", qc.checkTypes},
{"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins},
{"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins},
{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices},
}
qctx := qc.qctx.Copy()
for _, s := range stages {
var err error
query, err = qc.runStage(s.metricName, qctx, query, s.f)
if err != nil {
return nil, qc.applyErrorLimit(err)
}
for _, s := range qc.after[s.name] {
query, err = qc.runStageAfter(s.MetricName, query, s.Stage)
if err != nil {
return nil, qc.applyErrorLimit(err)
}
}
}
return query, nil
}
func (qc *queryCompiler) TypeEnv() *TypeEnv {
return qc.typeEnv
}
func (qc *queryCompiler) applyErrorLimit(err error) error {
if errs, ok := err.(Errors); ok {
if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs {
err = append(errs[:qc.compiler.maxErrs], errLimitReached)
}
}
return err
}
func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) {
if errs := checkKeywordOverrides(body, qc.compiler.strict); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) {
var globals map[Var]*usedRef
if qctx != nil {
pkg := qctx.Package
// Query compiler ought to generate a package if one was not provided and one or more imports were provided.
// The generated package name could even be an empty string to avoid conflicts (it doesn't have to be valid syntactically)
if pkg == nil && len(qctx.Imports) > 0 {
pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)}
}
if pkg != nil {
var ruleExports []Var
rules := qc.compiler.getExports()
if exist, ok := rules.Get(pkg.Path); ok {
ruleExports = exist.([]Var)
}
globals = getGlobals(qctx.Package, ruleExports, qctx.Imports)
qctx.Imports = nil
}
}
ignore := &declaredVarStack{declaredVars(body)}
return resolveRefsInBody(globals, ignore, body), nil
}
func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
f := newEqualityFactory(gen)
node, err := rewriteComprehensionTerms(f, body)
if err != nil {
return nil, err
}
return node.(Body), nil
}
func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
f := newEqualityFactory(gen)
return rewriteDynamics(f, body), nil
}
func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
return rewriteExprTermsInBody(gen, body), nil
}
func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
stack := newLocalDeclaredVars()
body, _, err := rewriteLocalVars(gen, stack, nil, body, qc.compiler.strict)
if len(err) != 0 {
return nil, err
}
qc.rewritten = make(map[Var]Var, len(stack.rewritten))
for k, v := range stack.rewritten {
// The vars returned during the rewrite will include all seen vars,
// even if they're not declared with an assignment operation. We don't
// want to include these inside the rewritten set though.
qc.rewritten[k] = v
}
return body, nil
}
func (qc *queryCompiler) rewritePrintCalls(_ *QueryContext, body Body) (Body, error) {
if !qc.enablePrintStatements {
return erasePrintCallsInBody(body), nil
}
gen := newLocalVarGenerator("q", body)
if errs := rewritePrintCalls(gen, qc.compiler.GetArity, ReservedVars, body); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error) {
if errs := checkVoidCalls(qc.compiler.TypeEnv, body); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
if errs := checkUndefinedFuncs(qc.compiler.TypeEnv, body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
safe := ReservedVars.Copy()
reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body)
if errs := safetyErrorSlice(unsafe, qc.RewrittenVars()); len(errs) > 0 {
return nil, errs
}
return reordered, nil
}
func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) {
var errs Errors
checker := newTypeChecker().
WithSchemaSet(qc.compiler.schemaSet).
WithInputType(qc.compiler.inputType).
WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body)
if len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) {
var unsafe map[string]struct{}
if qc.unsafeBuiltins != nil {
unsafe = qc.unsafeBuiltins
} else {
unsafe = qc.compiler.unsafeBuiltinsMap
}
errs := checkUnsafeBuiltins(unsafe, body)
if len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) {
errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body, qc.compiler.strict)
if len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) {
f := newEqualityFactory(newLocalVarGenerator("q", body))
body, err := rewriteWithModifiersInBody(qc.compiler, f, body)
if err != nil {
return nil, Errors{err}
}
return body, nil
}
func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) {
// NOTE(tsandall): The query compiler does not have a metrics object so we
// cannot record index metrics currently.
_ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices)
return body, nil
}
// ComprehensionIndex specifies how the comprehension term can be indexed. The keys
// tell the evaluator what variables to use for indexing. In the future, the index
// could be expanded with more information that would allow the evaluator to index
// a larger fragment of comprehensions (e.g., by closing over variables in the outer
// query.)
type ComprehensionIndex struct {
Term *Term
Keys []*Term
}
func (ci *ComprehensionIndex) String() string {
if ci == nil {
return ""
}
return fmt.Sprintf("<keys: %v>", NewArray(ci.Keys...))
}
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node interface{}, result map[*Term]*ComprehensionIndex) uint64 {
var n uint64
cpy := candidates.Copy()
WalkBodies(node, func(b Body) bool {
for _, expr := range b {
index := getComprehensionIndex(dbg, arity, cpy, rwVars, expr)
if index != nil {
result[index.Term] = index
n++
}
// Any variables appearing in the expressions leading up to the comprehension
// are fair-game to be used as index keys.
cpy.Update(expr.Vars(VarVisitorParams{SkipClosures: true, SkipRefCallHead: true}))
}
return false
})
return n
}
func getComprehensionIndex(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, expr *Expr) *ComprehensionIndex {
// Ignore everything except <var> = <comprehension> expressions. Extract
// the comprehension term from the expression.
if !expr.IsEquality() || expr.Negated || len(expr.With) > 0 {
// No debug message, these are assumed to be known hinderances
// to comprehension indexing.
return nil
}
var term *Term
lhs, rhs := expr.Operand(0), expr.Operand(1)
if _, ok := lhs.Value.(Var); ok && IsComprehension(rhs.Value) {
term = rhs
} else if _, ok := rhs.Value.(Var); ok && IsComprehension(lhs.Value) {
term = lhs
}
if term == nil {
// no debug for this, it's the ordinary "nothing to do here" case
return nil
}
// Ignore comprehensions that contain expressions that close over variables
// in the outer body if those variables are not also output variables in the
// comprehension body. In other words, ignore comprehensions that we cannot
// safely evaluate without bindings from the outer body. For example:
//
// x = [1]
// [true | data.y[z] = x] # safe to evaluate w/o outer body
// [true | data.y[z] = x[0]] # NOT safe to evaluate because 'x' would be unsafe.
//
// By identifying output variables in the body we also know what to index on by
// intersecting with candidate variables from the outer query.
//
// For example:
//
// x = data.foo[_]
// _ = [y | data.bar[y] = x] # index on 'x'
//
// This query goes from O(data.foo*data.bar) to O(data.foo+data.bar).
var body Body
switch x := term.Value.(type) {
case *ArrayComprehension:
body = x.Body
case *SetComprehension:
body = x.Body
case *ObjectComprehension:
body = x.Body
}
outputs := outputVarsForBody(body, arity, ReservedVars)
unsafe := body.Vars(SafetyCheckVisitorParams).Diff(outputs).Diff(ReservedVars)
if len(unsafe) > 0 {
dbg.Printf("%s: comprehension index: unsafe vars: %v", expr.Location, unsafe)
return nil
}
// Similarly, ignore comprehensions that contain references with output variables
// that intersect with the candidates. Indexing these comprehensions could worsen
// performance.
regressionVis := newComprehensionIndexRegressionCheckVisitor(candidates)
regressionVis.Walk(body)
if regressionVis.worse {
dbg.Printf("%s: comprehension index: output vars intersect candidates", expr.Location)
return nil
}
// Check if any nested comprehensions close over candidates. If any intersection is found
// the comprehension cannot be cached because it would require closing over the candidates
// which the evaluator does not support today.
nestedVis := newComprehensionIndexNestedCandidateVisitor(candidates)
nestedVis.Walk(body)
if nestedVis.found {
dbg.Printf("%s: comprehension index: nested comprehensions close over candidates", expr.Location)
return nil
}
// Make a sorted set of variable names that will serve as the index key set.
// Sort to ensure deterministic indexing. In future this could be relaxed
// if we can decide that one ordering is better than another. If the set is
// empty, there is no indexing to do.
indexVars := candidates.Intersect(outputs)
if len(indexVars) == 0 {
dbg.Printf("%s: comprehension index: no index vars", expr.Location)
return nil
}
result := make([]*Term, 0, len(indexVars))
for v := range indexVars {
result = append(result, NewTerm(v))
}
sort.Slice(result, func(i, j int) bool {
return result[i].Value.Compare(result[j].Value) < 0
})
debugRes := make([]*Term, len(result))
for i, r := range result {
if o, ok := rwVars[r.Value.(Var)]; ok {
debugRes[i] = NewTerm(o)
} else {
debugRes[i] = r
}
}
dbg.Printf("%s: comprehension index: built with keys: %v", expr.Location, debugRes)
return &ComprehensionIndex{Term: term, Keys: result}
}
type comprehensionIndexRegressionCheckVisitor struct {
candidates VarSet
seen VarSet
worse bool
}
// TODO(tsandall): Improve this so that users can either supply this list explicitly
// or the information is maintained on the built-in function declaration. What we really
// need to know is whether the built-in function allows callers to push down output
// values or not. It's unlikely that anything outside of OPA does this today so this
// solution is fine for now.
var comprehensionIndexBlacklist = map[string]int{
WalkBuiltin.Name: len(WalkBuiltin.Decl.Args()),
}
func newComprehensionIndexRegressionCheckVisitor(candidates VarSet) *comprehensionIndexRegressionCheckVisitor {
return &comprehensionIndexRegressionCheckVisitor{
candidates: candidates,
seen: NewVarSet(),
}
}
func (vis *comprehensionIndexRegressionCheckVisitor) Walk(x interface{}) {
NewGenericVisitor(vis.visit).Walk(x)
}
func (vis *comprehensionIndexRegressionCheckVisitor) visit(x interface{}) bool {
if !vis.worse {
switch x := x.(type) {
case *Expr:
operands := x.Operands()
if pos := comprehensionIndexBlacklist[x.Operator().String()]; pos > 0 && pos < len(operands) {
vis.assertEmptyIntersection(operands[pos].Vars())
}
case Ref:
vis.assertEmptyIntersection(x.OutputVars())
case Var:
vis.seen.Add(x)
// Always skip comprehensions. We do not have to visit their bodies here.
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
return true
}
}
return vis.worse
}
func (vis *comprehensionIndexRegressionCheckVisitor) assertEmptyIntersection(vs VarSet) {
for v := range vs {
if vis.candidates.Contains(v) && !vis.seen.Contains(v) {
vis.worse = true
return
}
}
}
type comprehensionIndexNestedCandidateVisitor struct {
candidates VarSet
found bool
}
func newComprehensionIndexNestedCandidateVisitor(candidates VarSet) *comprehensionIndexNestedCandidateVisitor {
return &comprehensionIndexNestedCandidateVisitor{
candidates: candidates,
}
}
func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x interface{}) {
NewGenericVisitor(vis.visit).Walk(x)
}
func (vis *comprehensionIndexNestedCandidateVisitor) visit(x interface{}) bool {
if vis.found {
return true
}
if v, ok := x.(Value); ok && IsComprehension(v) {
varVis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true})
varVis.Walk(v)
vis.found = len(varVis.Vars().Intersect(vis.candidates)) > 0
return true
}
return false
}
// ModuleTreeNode represents a node in the module tree. The module
// tree is keyed by the package path.
type ModuleTreeNode struct {
Key Value
Modules []*Module
Children map[Value]*ModuleTreeNode
Hide bool
}
// NewModuleTree returns a new ModuleTreeNode that represents the root
// of the module tree populated with the given modules.
func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
root := &ModuleTreeNode{
Children: map[Value]*ModuleTreeNode{},
}
for _, m := range mods {
node := root
for i, x := range m.Package.Path {
c, ok := node.Children[x.Value]
if !ok {
var hide bool
if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 {
hide = true
}
c = &ModuleTreeNode{
Key: x.Value,
Children: map[Value]*ModuleTreeNode{},
Hide: hide,
}
node.Children[x.Value] = c
}
node = c
}
node.Modules = append(node.Modules, m)
}
return root
}
// Size returns the number of modules in the tree.
func (n *ModuleTreeNode) Size() int {
s := len(n.Modules)
for _, c := range n.Children {
s += c.Size()
}
return s
}
// DepthFirst performs a depth-first traversal of the module tree rooted at n.
// If f returns true, traversal will not continue to the children of n.
func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) {
if !f(n) {
for _, node := range n.Children {
node.DepthFirst(f)
}
}
}
// TreeNode represents a node in the rule tree. The rule tree is keyed by
// rule path.
type TreeNode struct {
Key Value
Values []util.T
Children map[Value]*TreeNode
Sorted []Value
Hide bool
}
// NewRuleTree returns a new TreeNode that represents the root
// of the rule tree populated with the given rules.
func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
ruleSets := map[String][]util.T{}
// Build rule sets for this package.
for _, mod := range mtree.Modules {
for _, rule := range mod.Rules {
key := String(rule.Head.Name)
ruleSets[key] = append(ruleSets[key], rule)
}
}
// Each rule set becomes a leaf node.
children := map[Value]*TreeNode{}
sorted := make([]Value, 0, len(ruleSets))
for key, rules := range ruleSets {
sorted = append(sorted, key)
children[key] = &TreeNode{
Key: key,
Children: nil,
Values: rules,
}
}
// Each module in subpackage becomes child node.
for key, child := range mtree.Children {
sorted = append(sorted, key)
children[child.Key] = NewRuleTree(child)
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Compare(sorted[j]) < 0
})
return &TreeNode{
Key: mtree.Key,
Values: nil,
Children: children,
Sorted: sorted,
Hide: mtree.Hide,
}
}
// Size returns the number of rules in the tree.
func (n *TreeNode) Size() int {
s := len(n.Values)
for _, c := range n.Children {
s += c.Size()
}
return s
}
// Child returns n's child with key k.
func (n *TreeNode) Child(k Value) *TreeNode {
switch k.(type) {
case String, Var:
return n.Children[k]
}
return nil
}
// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
// f returns true, traversal will not continue to the children of n.
func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) {
if !f(n) {
for _, node := range n.Children {
node.DepthFirst(f)
}
}
}
// Graph represents the graph of dependencies between rules.
type Graph struct {
adj map[util.T]map[util.T]struct{}
radj map[util.T]map[util.T]struct{}
nodes map[util.T]struct{}
sorted []util.T
}
// NewGraph returns a new Graph based on modules. The list function must return
// the rules referred to directly by the ref.
func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph {
graph := &Graph{
adj: map[util.T]map[util.T]struct{}{},
radj: map[util.T]map[util.T]struct{}{},
nodes: map[util.T]struct{}{},
sorted: nil,
}
// Create visitor to walk a rule AST and add edges to the rule graph for
// each dependency.
vis := func(a *Rule) *GenericVisitor {
stop := false
return NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case Ref:
for _, b := range list(x) {
for node := b; node != nil; node = node.Else {
graph.addDependency(a, node)
}
}
case *Rule:
if stop {
// Do not recurse into else clauses (which will be handled
// by the outer visitor.)
return true
}
stop = true
}
return false
})
}
// Walk over all rules, add them to graph, and build adjencency lists.
for _, module := range modules {
WalkRules(module, func(a *Rule) bool {
graph.addNode(a)
vis(a).Walk(a)
return false
})
}
return graph
}
// Dependencies returns the set of rules that x depends on.
func (g *Graph) Dependencies(x util.T) map[util.T]struct{} {
return g.adj[x]
}
// Dependents returns the set of rules that depend on x.
func (g *Graph) Dependents(x util.T) map[util.T]struct{} {
return g.radj[x]
}
// Sort returns a slice of rules sorted by dependencies. If a cycle is found,
// ok is set to false.
func (g *Graph) Sort() (sorted []util.T, ok bool) {
if g.sorted != nil {
return g.sorted, true
}
sorter := &graphSort{
sorted: make([]util.T, 0, len(g.nodes)),
deps: g.Dependencies,
marked: map[util.T]struct{}{},
temp: map[util.T]struct{}{},
}
for node := range g.nodes {
if !sorter.Visit(node) {
return nil, false
}
}
g.sorted = sorter.sorted
return g.sorted, true
}
func (g *Graph) addDependency(u util.T, v util.T) {
if _, ok := g.nodes[u]; !ok {
g.addNode(u)
}
if _, ok := g.nodes[v]; !ok {
g.addNode(v)
}
edges, ok := g.adj[u]
if !ok {
edges = map[util.T]struct{}{}
g.adj[u] = edges
}
edges[v] = struct{}{}
edges, ok = g.radj[v]
if !ok {
edges = map[util.T]struct{}{}
g.radj[v] = edges
}
edges[u] = struct{}{}
}
func (g *Graph) addNode(n util.T) {
g.nodes[n] = struct{}{}
}
type graphSort struct {
sorted []util.T
deps func(util.T) map[util.T]struct{}
marked map[util.T]struct{}
temp map[util.T]struct{}
}
func (sort *graphSort) Marked(node util.T) bool {
_, marked := sort.marked[node]
return marked
}
func (sort *graphSort) Visit(node util.T) (ok bool) {
if _, ok := sort.temp[node]; ok {
return false
}
if sort.Marked(node) {
return true
}
sort.temp[node] = struct{}{}
for other := range sort.deps(node) {
if !sort.Visit(other) {
return false
}
}
sort.marked[node] = struct{}{}
delete(sort.temp, node)
sort.sorted = append(sort.sorted, node)
return true
}
// GraphTraversal is a Traversal that understands the dependency graph
type GraphTraversal struct {
graph *Graph
visited map[util.T]struct{}
}
// NewGraphTraversal returns a Traversal for the dependency graph
func NewGraphTraversal(graph *Graph) *GraphTraversal {
return &GraphTraversal{
graph: graph,
visited: map[util.T]struct{}{},
}
}
// Edges lists all dependency connections for a given node
func (g *GraphTraversal) Edges(x util.T) []util.T {
r := []util.T{}
for v := range g.graph.Dependencies(x) {
r = append(r, v)
}
return r
}
// Visited returns whether a node has been visited, setting a node to visited if not
func (g *GraphTraversal) Visited(u util.T) bool {
_, ok := g.visited[u]
g.visited[u] = struct{}{}
return ok
}
type unsafePair struct {
Expr *Expr
Vars VarSet
}
type unsafeVarLoc struct {
Var Var
Loc *Location
}
type unsafeVars map[*Expr]VarSet
func (vs unsafeVars) Add(e *Expr, v Var) {
if u, ok := vs[e]; ok {
u[v] = struct{}{}
} else {
vs[e] = VarSet{v: struct{}{}}
}
}
func (vs unsafeVars) Set(e *Expr, s VarSet) {
vs[e] = s
}
func (vs unsafeVars) Update(o unsafeVars) {
for k, v := range o {
if _, ok := vs[k]; !ok {
vs[k] = VarSet{}
}
vs[k].Update(v)
}
}
func (vs unsafeVars) Vars() (result []unsafeVarLoc) {
locs := map[Var]*Location{}
// If var appears in multiple sets then pick first by location.
for expr, vars := range vs {
for v := range vars {
if locs[v].Compare(expr.Location) > 0 {
locs[v] = expr.Location
}
}
}
for v, loc := range locs {
result = append(result, unsafeVarLoc{
Var: v,
Loc: loc,
})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Loc.Compare(result[j].Loc) < 0
})
return result
}
func (vs unsafeVars) Slice() (result []unsafePair) {
for expr, vs := range vs {
result = append(result, unsafePair{
Expr: expr,
Vars: vs,
})
}
return
}
// reorderBodyForSafety returns a copy of the body ordered such that
// left to right evaluation of the body will not encounter unbound variables
// in input positions or negated expressions.
//
// Expressions are added to the re-ordered body as soon as they are considered
// safe. If multiple expressions become safe in the same pass, they are added
// in their original order. This results in minimal re-ordering of the body.
//
// If the body cannot be reordered to ensure safety, the second return value
// contains a mapping of expressions to unsafe variables in those expressions.
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
body, unsafe := reorderBodyForClosures(arity, globals, body)
if len(unsafe) != 0 {
return nil, unsafe
}
reordered := Body{}
safe := VarSet{}
for _, e := range body {
for v := range e.Vars(SafetyCheckVisitorParams) {
if globals.Contains(v) {
safe.Add(v)
} else {
unsafe.Add(e, v)
}
}
}
for {
n := len(reordered)
for _, e := range body {
if reordered.Contains(e) {
continue
}
safe.Update(outputVarsForExpr(e, arity, safe))
for v := range unsafe[e] {
if safe.Contains(v) {
delete(unsafe[e], v)
}
}
if len(unsafe[e]) == 0 {
delete(unsafe, e)
reordered.Append(e)
}
}
if len(reordered) == n {
break
}
}
// Recursively visit closures and perform the safety checks on them.
// Update the globals at each expression to include the variables that could
// be closed over.
g := globals.Copy()
for i, e := range reordered {
if i > 0 {
g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams))
}
xform := &bodySafetyTransformer{
builtins: builtins,
arity: arity,
current: e,
globals: g,
unsafe: unsafe,
}
NewGenericVisitor(xform.Visit).Walk(e)
}
return reordered, unsafe
}
type bodySafetyTransformer struct {
builtins map[string]*Builtin
arity func(Ref) int
current *Expr
globals VarSet
unsafe unsafeVars
}
func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
switch term := x.(type) {
case *Term:
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
NewGenericVisitor(xform.Visit).Walk(kcpy)
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return kcpy, vcpy, nil
})
term.Value = cpy
return true
case *set:
cpy, _ := x.Map(func(v *Term) (*Term, error) {
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return vcpy, nil
})
term.Value = cpy
return true
case *ArrayComprehension:
xform.reorderArrayComprehensionSafety(x)
return true
case *ObjectComprehension:
xform.reorderObjectComprehensionSafety(x)
return true
case *SetComprehension:
xform.reorderSetComprehensionSafety(x)
return true
}
case *Expr:
if ev, ok := term.Terms.(*Every); ok {
xform.globals.Update(ev.KeyValueVars())
ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body)
return true
}
}
return false
}
func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body {
bv := body.Vars(SafetyCheckVisitorParams)
bv.Update(xform.globals)
uv := tv.Diff(bv)
for v := range uv {
xform.unsafe.Add(xform.current, v)
}
r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body)
if len(u) == 0 {
return r
}
xform.unsafe.Update(u)
return body
}
func (xform *bodySafetyTransformer) reorderArrayComprehensionSafety(ac *ArrayComprehension) {
ac.Body = xform.reorderComprehensionSafety(ac.Term.Vars(), ac.Body)
}
func (xform *bodySafetyTransformer) reorderObjectComprehensionSafety(oc *ObjectComprehension) {
tv := oc.Key.Vars()
tv.Update(oc.Value.Vars())
oc.Body = xform.reorderComprehensionSafety(tv, oc.Body)
}
func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetComprehension) {
sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body)
}
// reorderBodyForClosures returns a copy of the body ordered such that
// expressions (such as array comprehensions) that close over variables are ordered
// after other expressions that contain the same variable in an output position.
func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
reordered := Body{}
unsafe := unsafeVars{}
for {
n := len(reordered)
for _, e := range body {
if reordered.Contains(e) {
continue
}
// Collect vars that are contained in closures within this
// expression.
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
vis.Walk(x)
return true
})
// Compute vars that are closed over from the body but not yet
// contained in the output position of an expression in the reordered
// body. These vars are considered unsafe.
cv := vs.Intersect(body.Vars(SafetyCheckVisitorParams)).Diff(globals)
uv := cv.Diff(outputVarsForBody(reordered, arity, globals))
if len(uv) == 0 {
reordered = append(reordered, e)
delete(unsafe, e)
} else {
unsafe.Set(e, uv)
}
}
if len(reordered) == n {
break
}
}
return reordered, unsafe
}
// OutputVarsFromBody returns all variables which are the "output" for
// the given body. For safety checks this means that they would be
// made safe by the body.
func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
return outputVarsForBody(body, c.GetArity, safe)
}
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
o := safe.Copy()
for _, e := range body {
o.Update(outputVarsForExpr(e, arity, o))
}
return o.Diff(safe)
}
// OutputVarsFromExpr returns all variables which are the "output" for
// the given expression. For safety checks this means that they would be
// made safe by the expr.
func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
return outputVarsForExpr(expr, c.GetArity, safe)
}
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
// Negated expressions must be safe.
if expr.Negated {
return VarSet{}
}
// With modifier inputs must be safe.
for _, with := range expr.With {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis.Walk(with)
vars := vis.Vars()
unsafe := vars.Diff(safe)
if len(unsafe) > 0 {
return VarSet{}
}
}
switch terms := expr.Terms.(type) {
case *Term:
return outputVarsForTerms(expr, safe)
case []*Term:
if expr.IsEquality() {
return outputVarsForExprEq(expr, safe)
}
operator, ok := terms[0].Value.(Ref)
if !ok {
return VarSet{}
}
ar := arity(operator)
if ar < 0 {
return VarSet{}
}
return outputVarsForExprCall(expr, ar, safe, terms)
case *Every:
return outputVarsForTerms(terms.Domain, safe)
default:
panic("illegal expression")
}
}
func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
if !validEqAssignArgCount(expr) {
return safe
}
output := outputVarsForTerms(expr, safe)
output.Update(safe)
output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
return output.Diff(safe)
}
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet {
output := outputVarsForTerms(expr, safe)
numInputTerms := arity + 1
if numInputTerms >= len(terms) {
return output
}
params := VarVisitorParams{
SkipClosures: true,
SkipSets: true,
SkipObjectKeys: true,
SkipRefHead: true,
}
vis := NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[:numInputTerms]))
unsafe := vis.Vars().Diff(output).Diff(safe)
if len(unsafe) > 0 {
return VarSet{}
}
vis = NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[numInputTerms:]))
output.Update(vis.vars)
return output
}
func outputVarsForTerms(expr interface{}, safe VarSet) VarSet {
output := VarSet{}
WalkTerms(expr, func(x *Term) bool {
switch r := x.Value.(type) {
case *SetComprehension, *ArrayComprehension, *ObjectComprehension:
return true
case Ref:
if !isRefSafe(r, safe) {
return true
}
output.Update(r.OutputVars())
return false
}
return false
})
return output
}
type equalityFactory struct {
gen *localVarGenerator
}
func newEqualityFactory(gen *localVarGenerator) *equalityFactory {
return &equalityFactory{gen}
}
func (f *equalityFactory) Generate(other *Term) *Expr {
term := NewTerm(f.gen.Generate()).SetLocation(other.Location)
expr := Equality.Expr(term, other)
expr.Generated = true
expr.Location = other.Location
return expr
}
type localVarGenerator struct {
exclude VarSet
suffix string
next int
}
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{vars: exclude}
for _, key := range sorted {
vis.Walk(modules[key])
}
return &localVarGenerator{exclude: exclude, next: 0}
}
func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{vars: exclude}
vis.Walk(node)
return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
}
func (l *localVarGenerator) Generate() Var {
for {
result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__")
l.next++
if !l.exclude.Contains(result) {
return result
}
}
}
func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]*usedRef {
globals := map[Var]*usedRef{}
// Populate globals with exports within the package.
for _, v := range rules {
global := append(Ref{}, pkg.Path...)
global = append(global, &Term{Value: String(v)})
globals[v] = &usedRef{ref: global}
}
// Populate globals with imports.
for _, i := range imports {
globals[i.Name()] = &usedRef{ref: i.Path.Value.(Ref)}
}
return globals
}
func requiresEval(x *Term) bool {
if x == nil {
return false
}
return ContainsRefs(x) || ContainsComprehensions(x)
}
func resolveRef(globals map[Var]*usedRef, ignore *declaredVarStack, ref Ref) Ref {
r := Ref{}
for i, x := range ref {
switch v := x.Value.(type) {
case Var:
if g, ok := globals[v]; ok && !ignore.Contains(v) {
cpy := g.ref.Copy()
for i := range cpy {
cpy[i].SetLocation(x.Location)
}
if i == 0 {
r = cpy
} else {
r = append(r, NewTerm(cpy).SetLocation(x.Location))
}
g.used = true
} else {
r = append(r, x)
}
case Ref, *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
r = append(r, resolveRefsInTerm(globals, ignore, x))
default:
r = append(r, x)
}
}
return r
}
type usedRef struct {
ref Ref
used bool
}
func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
ignore := &declaredVarStack{}
vars := NewVarSet()
var vis *GenericVisitor
var err error
// Walk args to collect vars and transform body so that callers can shadow
// root documents.
vis = NewGenericVisitor(func(x interface{}) bool {
if err != nil {
return true
}
switch x := x.(type) {
case Var:
vars.Add(x)
// Object keys cannot be pattern matched so only walk values.
case *object:
x.Foreach(func(k, v *Term) {
vis.Walk(v)
})
// Skip terms that could contain vars that cannot be pattern matched.
case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
return true
case *Term:
if _, ok := x.Value.(Ref); ok {
if RootDocumentRefs.Contains(x) {
// We could support args named input, data, etc. however
// this would require rewriting terms in the head and body.
// Preventing root document shadowing is simpler, and
// arguably, will prevent confusing names from being used.
err = fmt.Errorf("args must not shadow %v (use a different variable name)", x)
return true
}
}
}
return false
})
vis.Walk(rule.Head.Args)
if err != nil {
return err
}
ignore.Push(vars)
ignore.Push(declaredVars(rule.Body))
if rule.Head.Key != nil {
rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
}
if rule.Head.Value != nil {
rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value)
}
rule.Body = resolveRefsInBody(globals, ignore, rule.Body)
return nil
}
func resolveRefsInBody(globals map[Var]*usedRef, ignore *declaredVarStack, body Body) Body {
r := make([]*Expr, 0, len(body))
for _, expr := range body {
r = append(r, resolveRefsInExpr(globals, ignore, expr))
}
return r
}
func resolveRefsInExpr(globals map[Var]*usedRef, ignore *declaredVarStack, expr *Expr) *Expr {
cpy := *expr
switch ts := expr.Terms.(type) {
case *Term:
cpy.Terms = resolveRefsInTerm(globals, ignore, ts)
case []*Term:
buf := make([]*Term, len(ts))
for i := 0; i < len(ts); i++ {
buf[i] = resolveRefsInTerm(globals, ignore, ts[i])
}
cpy.Terms = buf
case *SomeDecl:
if val, ok := ts.Symbols[0].Value.(Call); ok {
cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}}
}
case *Every:
locals := NewVarSet()
if ts.Key != nil {
locals.Update(ts.Key.Vars())
}
locals.Update(ts.Value.Vars())
ignore.Push(locals)
cpy.Terms = &Every{
Key: ts.Key.Copy(), // TODO(sr): do more?
Value: ts.Value.Copy(), // TODO(sr): do more?
Domain: resolveRefsInTerm(globals, ignore, ts.Domain),
Body: resolveRefsInBody(globals, ignore, ts.Body),
}
ignore.Pop()
}
for _, w := range cpy.With {
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
w.Value = resolveRefsInTerm(globals, ignore, w.Value)
}
return &cpy
}
func resolveRefsInTerm(globals map[Var]*usedRef, ignore *declaredVarStack, term *Term) *Term {
switch v := term.Value.(type) {
case Var:
if g, ok := globals[v]; ok && !ignore.Contains(v) {
cpy := g.ref.Copy()
for i := range cpy {
cpy[i].SetLocation(term.Location)
}
g.used = true
return NewTerm(cpy).SetLocation(term.Location)
}
return term
case Ref:
fqn := resolveRef(globals, ignore, v)
cpy := *term
cpy.Value = fqn
return &cpy
case *object:
cpy := *term
cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) {
k = resolveRefsInTerm(globals, ignore, k)
v = resolveRefsInTerm(globals, ignore, v)
return k, v, nil
})
return &cpy
case *Array:
cpy := *term
cpy.Value = NewArray(resolveRefsInTermArray(globals, ignore, v)...)
return &cpy
case Call:
cpy := *term
cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v))
return &cpy
case Set:
s, _ := v.Map(func(e *Term) (*Term, error) {
return resolveRefsInTerm(globals, ignore, e), nil
})
cpy := *term
cpy.Value = s
return &cpy
case *ArrayComprehension:
ac := &ArrayComprehension{}
ignore.Push(declaredVars(v.Body))
ac.Term = resolveRefsInTerm(globals, ignore, v.Term)
ac.Body = resolveRefsInBody(globals, ignore, v.Body)
cpy := *term
cpy.Value = ac
ignore.Pop()
return &cpy
case *ObjectComprehension:
oc := &ObjectComprehension{}
ignore.Push(declaredVars(v.Body))
oc.Key = resolveRefsInTerm(globals, ignore, v.Key)
oc.Value = resolveRefsInTerm(globals, ignore, v.Value)
oc.Body = resolveRefsInBody(globals, ignore, v.Body)
cpy := *term
cpy.Value = oc
ignore.Pop()
return &cpy
case *SetComprehension:
sc := &SetComprehension{}
ignore.Push(declaredVars(v.Body))
sc.Term = resolveRefsInTerm(globals, ignore, v.Term)
sc.Body = resolveRefsInBody(globals, ignore, v.Body)
cpy := *term
cpy.Value = sc
ignore.Pop()
return &cpy
default:
return term
}
}
func resolveRefsInTermArray(globals map[Var]*usedRef, ignore *declaredVarStack, terms *Array) []*Term {
cpy := make([]*Term, terms.Len())
for i := 0; i < terms.Len(); i++ {
cpy[i] = resolveRefsInTerm(globals, ignore, terms.Elem(i))
}
return cpy
}
func resolveRefsInTermSlice(globals map[Var]*usedRef, ignore *declaredVarStack, terms []*Term) []*Term {
cpy := make([]*Term, len(terms))
for i := 0; i < len(terms); i++ {
cpy[i] = resolveRefsInTerm(globals, ignore, terms[i])
}
return cpy
}
type declaredVarStack []VarSet
func (s declaredVarStack) Contains(v Var) bool {
for i := len(s) - 1; i >= 0; i-- {
if _, ok := s[i][v]; ok {
return ok
}
}
return false
}
func (s declaredVarStack) Add(v Var) {
s[len(s)-1].Add(v)
}
func (s *declaredVarStack) Push(vs VarSet) {
*s = append(*s, vs)
}
func (s *declaredVarStack) Pop() {
curr := *s
*s = curr[:len(curr)-1]
}
func declaredVars(x interface{}) VarSet {
vars := NewVarSet()
vis := NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case *Expr:
if x.IsAssignment() && validEqAssignArgCount(x) {
WalkVars(x.Operand(0), func(v Var) bool {
vars.Add(v)
return false
})
} else if decl, ok := x.Terms.(*SomeDecl); ok {
for i := range decl.Symbols {
switch val := decl.Symbols[i].Value.(type) {
case Var:
vars.Add(val)
case Call:
args := val[1:]
if len(args) == 3 { // some x, y in xs
WalkVars(args[1], func(v Var) bool {
vars.Add(v)
return false
})
}
// some x in xs
WalkVars(args[0], func(v Var) bool {
vars.Add(v)
return false
})
}
}
}
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
return true
}
return false
})
vis.Walk(x)
return vars
}
// rewriteComprehensionTerms will rewrite comprehensions so that the term part
// is bound to a variable in the body. This allows any type of term to be used
// in the term part (even if the term requires evaluation.)
//
// For instance, given the following comprehension:
//
// [x[0] | x = y[_]; y = [1,2,3]]
//
// The comprehension would be rewritten as:
//
// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]]
func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) {
return TransformComprehensions(node, func(x interface{}) (Value, error) {
switch x := x.(type) {
case *ArrayComprehension:
if requiresEval(x.Term) {
expr := f.Generate(x.Term)
x.Term = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
case *SetComprehension:
if requiresEval(x.Term) {
expr := f.Generate(x.Term)
x.Term = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
case *ObjectComprehension:
if requiresEval(x.Key) {
expr := f.Generate(x.Key)
x.Key = expr.Operand(0)
x.Body.Append(expr)
}
if requiresEval(x.Value) {
expr := f.Generate(x.Value)
x.Value = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
}
panic("illegal type")
})
}
// rewriteEquals will rewrite exprs under x as unification calls instead of ==
// calls. For example:
//
// data.foo == data.bar is rewritten as data.foo = data.bar
//
// This stage should only run the safety check (since == is a built-in with no
// outputs, so the inputs must not be marked as safe.)
//
// This stage is not executed by the query compiler by default because when
// callers specify == instead of = they expect to receive a true/false/undefined
// result back whereas with = the result is only ever true/undefined. For
// partial evaluation cases we do want to rewrite == to = to simplify the
// result.
func rewriteEquals(x interface{}) {
doubleEq := Equal.Ref()
unifyOp := Equality.Ref()
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
if x, ok := x.(*Expr); ok && x.IsCall() {
operator := x.Operator()
if operator.Equal(doubleEq) && len(x.Operands()) == 2 {
x.SetOperator(NewTerm(unifyOp))
}
}
return x, nil
})
_, _ = Transform(t, x) // ignore error
}
// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
// comprehensions) are bound to vars earlier in the query. This translation
// results in eager evaluation.
//
// For instance, given the following query:
//
// foo(data.bar) = 1
//
// The rewritten version will be:
//
// __local0__ = data.bar; foo(__local0__) = 1
func rewriteDynamics(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
switch {
case expr.IsEquality():
result = rewriteDynamicsEqExpr(f, expr, result)
case expr.IsCall():
result = rewriteDynamicsCallExpr(f, expr, result)
case expr.IsEvery():
result = rewriteDynamicsEveryExpr(f, expr, result)
default:
result = rewriteDynamicsTermExpr(f, expr, result)
}
}
return result
}
func appendExpr(body Body, expr *Expr) Body {
body.Append(expr)
return body
}
func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body {
if !validEqAssignArgCount(expr) {
return appendExpr(result, expr)
}
terms := expr.Terms.([]*Term)
result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result)
result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result)
return appendExpr(result, expr)
}
func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
terms := expr.Terms.([]*Term)
for i := 1; i < len(terms); i++ {
result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result)
}
return appendExpr(result, expr)
}
func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body {
ev := expr.Terms.(*Every)
result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result)
ev.Body = rewriteDynamics(f, ev.Body)
return appendExpr(result, expr)
}
func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
term := expr.Terms.(*Term)
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
return appendExpr(result, expr)
}
func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch v := term.Value.(type) {
case Ref:
for i := 1; i < len(v); i++ {
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
}
case *ArrayComprehension:
v.Body = rewriteDynamics(f, v.Body)
case *SetComprehension:
v.Body = rewriteDynamics(f, v.Body)
case *ObjectComprehension:
v.Body = rewriteDynamics(f, v.Body)
default:
result, term = rewriteDynamicsOne(original, f, term, result)
}
return result, term
}
func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch v := term.Value.(type) {
case Ref:
for i := 1; i < len(v); i++ {
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
}
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
return result, result[len(result)-1].Operand(0)
case *Array:
for i := 0; i < v.Len(); i++ {
var t *Term
result, t = rewriteDynamicsOne(original, f, v.Elem(i), result)
v.set(i, t)
}
return result, term
case *object:
cpy := NewObject()
v.Foreach(func(key, value *Term) {
result, key = rewriteDynamicsOne(original, f, key, result)
result, value = rewriteDynamicsOne(original, f, value, result)
cpy.Insert(key, value)
})
return result, NewTerm(cpy).SetLocation(term.Location)
case Set:
cpy := NewSet()
for _, term := range v.Slice() {
var rw *Term
result, rw = rewriteDynamicsOne(original, f, term, result)
cpy.Add(rw)
}
return result, NewTerm(cpy).SetLocation(term.Location)
case *ArrayComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
return result, result[len(result)-1].Operand(0)
case *SetComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
return result, result[len(result)-1].Operand(0)
case *ObjectComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
return result, result[len(result)-1].Operand(0)
}
return result, term
}
func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) {
body = rewriteDynamics(f, body)
generated := f.Generate(term)
generated.With = original.With
return body, generated
}
func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) {
for i := range rule.Head.Args {
support, output := expandExprTerm(gen, rule.Head.Args[i])
for j := range support {
rule.Body.Append(support[j])
}
rule.Head.Args[i] = output
}
if rule.Head.Key != nil {
support, output := expandExprTerm(gen, rule.Head.Key)
for i := range support {
rule.Body.Append(support[i])
}
rule.Head.Key = output
}
if rule.Head.Value != nil {
support, output := expandExprTerm(gen, rule.Head.Value)
for i := range support {
rule.Body.Append(support[i])
}
rule.Head.Value = output
}
}
func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body {
cpy := make(Body, 0, len(body))
for i := 0; i < len(body); i++ {
for _, expr := range expandExpr(gen, body[i]) {
cpy.Append(expr)
}
}
return cpy
}
func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
for i := range expr.With {
extras, value := expandExprTerm(gen, expr.With[i].Value)
expr.With[i].Value = value
result = append(result, extras...)
}
switch terms := expr.Terms.(type) {
case *Term:
extras, term := expandExprTerm(gen, terms)
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
}
}
result = append(result, extras...)
expr.Terms = term
result = append(result, expr)
case []*Term:
for i := 1; i < len(terms); i++ {
var extras []*Expr
extras, terms[i] = expandExprTerm(gen, terms[i])
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
}
}
result = append(result, extras...)
}
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
eq.Generated = true
eq.With = expr.With
extras = append(extras, eq)
terms.Domain = term
}
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
}
return
}
func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
output = term
switch v := term.Value.(type) {
case Call:
for i := 1; i < len(v); i++ {
var extras []*Expr
extras, v[i] = expandExprTerm(gen, v[i])
support = append(support, extras...)
}
output = NewTerm(gen.Generate()).SetLocation(term.Location)
expr := v.MakeExpr(output).SetLocation(term.Location)
expr.Generated = true
support = append(support, expr)
case Ref:
support = expandExprRef(gen, v)
case *Array:
support = expandExprTermArray(gen, v)
case *object:
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
extras1, expandedKey := expandExprTerm(gen, k)
extras2, expandedValue := expandExprTerm(gen, v)
support = append(support, extras1...)
support = append(support, extras2...)
return expandedKey, expandedValue, nil
})
output = NewTerm(cpy).SetLocation(term.Location)
case Set:
cpy, _ := v.Map(func(x *Term) (*Term, error) {
extras, expanded := expandExprTerm(gen, x)
support = append(support, extras...)
return expanded, nil
})
output = NewTerm(cpy).SetLocation(term.Location)
case *ArrayComprehension:
support, term := expandExprTerm(gen, v.Term)
for i := range support {
v.Body.Append(support[i])
}
v.Term = term
v.Body = rewriteExprTermsInBody(gen, v.Body)
case *SetComprehension:
support, term := expandExprTerm(gen, v.Term)
for i := range support {
v.Body.Append(support[i])
}
v.Term = term
v.Body = rewriteExprTermsInBody(gen, v.Body)
case *ObjectComprehension:
support, key := expandExprTerm(gen, v.Key)
for i := range support {
v.Body.Append(support[i])
}
v.Key = key
support, value := expandExprTerm(gen, v.Value)
for i := range support {
v.Body.Append(support[i])
}
v.Value = value
v.Body = rewriteExprTermsInBody(gen, v.Body)
}
return
}
func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) {
// Start by calling a normal expandExprTerm on all terms.
support = expandExprTermSlice(gen, v)
// Rewrite references in order to support indirect references. We rewrite
// e.g.
//
// [1, 2, 3][i]
//
// to
//
// __local_var = [1, 2, 3]
// __local_var[i]
//
// to support these. This only impacts the reference subject, i.e. the
// first item in the slice.
var subject = v[0]
switch subject.Value.(type) {
case *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
f := newEqualityFactory(gen)
assignToLocal := f.Generate(subject)
support = append(support, assignToLocal)
v[0] = assignToLocal.Operand(0)
}
return
}
func expandExprTermArray(gen *localVarGenerator, arr *Array) (support []*Expr) {
for i := 0; i < arr.Len(); i++ {
extras, v := expandExprTerm(gen, arr.Elem(i))
arr.set(i, v)
support = append(support, extras...)
}
return
}
func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) {
for i := 0; i < len(v); i++ {
var extras []*Expr
extras, v[i] = expandExprTerm(gen, v[i])
support = append(support, extras...)
}
return
}
type localDeclaredVars struct {
vars []*declaredVarSet
// rewritten contains a mapping of *all* user-defined variables
// that have been rewritten whereas vars contains the state
// from the current query (not any nested queries, and all vars
// seen).
rewritten map[Var]Var
}
type varOccurrence int
const (
newVar varOccurrence = iota
argVar
seenVar
assignedVar
declaredVar
)
type declaredVarSet struct {
vs map[Var]Var
reverse map[Var]Var
occurrence map[Var]varOccurrence
count map[Var]int
}
func newDeclaredVarSet() *declaredVarSet {
return &declaredVarSet{
vs: map[Var]Var{},
reverse: map[Var]Var{},
occurrence: map[Var]varOccurrence{},
count: map[Var]int{},
}
}
func newLocalDeclaredVars() *localDeclaredVars {
return &localDeclaredVars{
vars: []*declaredVarSet{newDeclaredVarSet()},
rewritten: map[Var]Var{},
}
}
func (s *localDeclaredVars) Push() {
s.vars = append(s.vars, newDeclaredVarSet())
}
func (s *localDeclaredVars) Pop() *declaredVarSet {
sl := s.vars
curr := sl[len(sl)-1]
s.vars = sl[:len(sl)-1]
return curr
}
func (s localDeclaredVars) Peek() *declaredVarSet {
return s.vars[len(s.vars)-1]
}
func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
elem := s.vars[len(s.vars)-1]
elem.vs[x] = y
elem.reverse[y] = x
elem.occurrence[x] = occurrence
elem.count[x] = 1
// If the variable has been rewritten (where x != y, with y being
// the generated value), store it in the map of rewritten vars.
// Assume that the generated values are unique for the compilation.
if !x.Equal(y) {
s.rewritten[y] = x
}
}
func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) {
for i := len(s.vars) - 1; i >= 0; i-- {
if y, ok = s.vars[i].vs[x]; ok {
return
}
}
return
}
// Occurrence returns a flag that indicates whether x has occurred in the
// current scope.
func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
return s.vars[len(s.vars)-1].occurrence[x]
}
// GlobalOccurrence returns a flag that indicates whether x has occurred in the
// global scope.
func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) {
for i := len(s.vars) - 1; i >= 0; i-- {
if occ, ok := s.vars[i].occurrence[x]; ok {
return occ, true
}
}
return newVar, false
}
// Seen marks x as seen by incrementing its counter
func (s localDeclaredVars) Seen(x Var) {
for i := len(s.vars) - 1; i >= 0; i-- {
dvs := s.vars[i]
if c, ok := dvs.count[x]; ok {
dvs.count[x] = c + 1
return
}
}
s.vars[len(s.vars)-1].count[x] = 1
}
// Count returns how many times x has been seen
func (s localDeclaredVars) Count(x Var) int {
for i := len(s.vars) - 1; i >= 0; i-- {
if c, ok := s.vars[i].count[x]; ok {
return c
}
}
return 0
}
// rewriteLocalVars rewrites bodies to remove assignment/declaration
// expressions. For example:
//
// a := 1; p[a]
//
// Is rewritten to:
//
// __local0__ = 1; p[__local0__]
//
// During rewriting, assignees are validated to prevent use before declaration.
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) {
var errs Errors
body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict)
return body, stack.Pop().vs, errs
}
func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) {
var cpy Body
for i := range body {
var expr *Expr
switch {
case body[i].IsAssignment():
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict)
case body[i].IsSome():
expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict)
case body[i].IsEvery():
expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict)
default:
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict)
}
if expr != nil {
cpy.Append(expr)
}
}
// If the body only contained a var statement it will be empty at this
// point. Append true to the body to ensure that it's non-empty (zero length
// bodies are not supported.)
if len(cpy) == 0 {
cpy.Append(NewExpr(BooleanTerm(true)))
}
errs = checkUnusedAssignedVars(body[0].Loc(), stack, used, errs, strict)
return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs)
}
func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
if !strict || len(errs) > 0 {
return errs
}
dvs := stack.Peek()
unused := NewVarSet()
for v, occ := range dvs.occurrence {
// A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in
// the same, or nested, scope to be counted as used.
if !v.IsWildcard() && occ == assignedVar && stack.Count(v) <= 1 {
unused.Add(dvs.vs[v])
}
}
rewrittenUsed := NewVarSet()
for v := range used {
if gv, ok := stack.Declared(v); ok {
rewrittenUsed.Add(gv)
} else {
rewrittenUsed.Add(v)
}
}
unused = unused.Diff(rewrittenUsed)
for _, gv := range unused.Sorted() {
errs = append(errs, NewError(CompileErr, loc, "assigned var %v unused", dvs.reverse[gv]))
}
return errs
}
func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
// NOTE(tsandall): Do not generate more errors if there are existing
// declaration errors.
if len(errs) > 0 {
return errs
}
dvs := stack.Peek()
declared := NewVarSet()
for v, occ := range dvs.occurrence {
if occ == declaredVar {
declared.Add(dvs.vs[v])
}
}
bodyvars := cpy.Vars(VarVisitorParams{})
for v := range used {
if gv, ok := stack.Declared(v); ok {
bodyvars.Add(gv)
} else {
bodyvars.Add(v)
}
}
unused := declared.Diff(bodyvars).Diff(used)
for _, gv := range unused.Sorted() {
rv := dvs.reverse[gv]
if !rv.IsGenerated() {
errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", rv))
}
}
return errs
}
func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
every := e.Terms.(*Every)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict)
stack.Push()
defer stack.Pop()
// if the key exists, rewrite
if every.Key != nil {
if v := every.Key.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Key.Value = gv
}
} else { // if the key doesn't exist, add dummy local
every.Key = NewTerm(g.Generate())
}
// value is always present
if v := every.Value.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Value.Value = gv
}
used := NewVarSet()
every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict)
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
}
func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
decl := e.Terms.(*SomeDecl)
for i := range decl.Symbols {
switch v := decl.Symbols[i].Value.(type) {
case Var:
if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
}
case Call:
var key, val, container *Term
switch len(v) {
case 4: // member3
key = v[1]
val = v[2]
container = v[3]
case 3: // member
key = NewTerm(g.Generate())
val = v[1]
container = v[2]
}
var rhs *Term
switch c := container.Value.(type) {
case Ref:
rhs = RefTerm(append(c, key)...)
default:
rhs = RefTerm(container, key)
}
e.Terms = []*Term{
RefTerm(VarTerm(Equality.Name)), val, rhs,
}
for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() {
if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil {
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
}
}
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
}
}
return nil, errs
}
func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
vis := NewGenericVisitor(func(x interface{}) bool {
var stop bool
switch x := x.(type) {
case *Term:
stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict)
case *With:
_, errs = rewriteDeclaredVarsInTerm(g, stack, x.Value, errs, strict)
stop = true
}
return stop
})
vis.Walk(expr)
return expr, errs
}
func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
if expr.Negated {
errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression"))
return expr, errs
}
numErrsBefore := len(errs)
if !validEqAssignArgCount(expr) {
return expr, errs
}
// Rewrite terms on right hand side capture seen vars and recursively
// process comprehensions before left hand side is processed. Also
// rewrite with modifier.
errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs, strict)
for _, w := range expr.With {
errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict)
}
// Rewrite vars on left hand side with unique names. Catch redeclaration
// and invalid term types here.
var vis func(t *Term) bool
vis = func(t *Term) bool {
switch v := t.Value.(type) {
case Var:
if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = gv
}
return true
case *Array:
return false
case *object:
v.Foreach(func(_, v *Term) {
WalkTerms(v, vis)
})
return true
case Ref:
if RootDocumentRefs.Contains(t) {
if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = gv
}
return true
}
}
errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value)))
return true
}
WalkTerms(expr.Operand(0), vis)
if len(errs) == numErrsBefore {
loc := expr.Operator()[0].Location
expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc))
}
return expr, errs
}
func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) (bool, Errors) {
switch v := term.Value.(type) {
case Var:
if gv, ok := stack.Declared(v); ok {
term.Value = gv
stack.Seen(v)
} else if stack.Occurrence(v) == newVar {
stack.Insert(v, v, seenVar)
}
case Ref:
if RootDocumentRefs.Contains(term) {
x := v[0].Value.(Var)
if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar {
gv, _ := stack.Declared(x)
term.Value = gv
}
return true, errs
}
return false, errs
case *object:
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs, strict)
return kcpy, v, nil
})
term.Value = cpy
case Set:
cpy, _ := v.Map(func(elem *Term) (*Term, error) {
elemcpy := elem.Copy()
errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs, strict)
return elemcpy, nil
})
term.Value = cpy
case *ArrayComprehension:
errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs, strict)
case *SetComprehension:
errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs, strict)
case *ObjectComprehension:
errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs, strict)
default:
return false, errs
}
return true, errs
}
func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors {
WalkNodes(term, func(n Node) bool {
var stop bool
switch n := n.(type) {
case *With:
_, errs = rewriteDeclaredVarsInTerm(g, stack, n.Value, errs, strict)
stop = true
case *Term:
stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs, strict)
}
return stop
})
return errs
}
func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Term.Vars())
stack.Push()
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict)
stack.Pop()
return errs
}
func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Term.Vars())
stack.Push()
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict)
stack.Pop()
return errs
}
func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Key.Vars())
used.Update(v.Value.Vars())
stack.Push()
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs, strict)
stack.Pop()
return errs
}
func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) {
switch stack.Occurrence(v) {
case seenVar:
return gv, fmt.Errorf("var %v referenced above", v)
case assignedVar:
return gv, fmt.Errorf("var %v assigned above", v)
case declaredVar:
return gv, fmt.Errorf("var %v declared above", v)
case argVar:
return gv, fmt.Errorf("arg %v redeclared", v)
}
gv = g.Generate()
stack.Insert(v, gv, occ)
return
}
// rewriteWithModifiersInBody will rewrite the body so that with modifiers do
// not contain terms that require evaluation as values. If this function
// encounters an invalid with modifier target then it will raise an error.
func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) {
var result Body
for i := range body {
exprs, err := rewriteWithModifier(c, f, body[i])
if err != nil {
return nil, err
}
if len(exprs) > 0 {
for _, expr := range exprs {
result.Append(expr)
}
} else {
result.Append(body[i])
}
}
return result, nil
}
func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {
var result []*Expr
for i := range expr.With {
err := validateTarget(c, expr.With[i].Target)
if err != nil {
return nil, err
}
if requiresEval(expr.With[i].Value) {
eq := f.Generate(expr.With[i].Value)
result = append(result, eq)
expr.With[i].Value = eq.Operand(0)
}
}
// If any of the with modifiers in this expression were rewritten then result
// will be non-empty. In this case, the expression will have been modified and
// it should also be added to the result.
if len(result) > 0 {
result = append(result, expr)
}
return result, nil
}
func validateTarget(c *Compiler, term *Term) *Error {
if !isInputRef(term) && !isDataRef(term) {
return NewError(TypeErr, term.Location, "with keyword target must reference existing %v or %v", InputRootDocument, DefaultRootDocument)
}
if isDataRef(term) {
ref := term.Value.(Ref)
node := c.RuleTree
for i := 0; i < len(ref)-1; i++ {
child := node.Child(ref[i].Value)
if child == nil {
break
} else if len(child.Values) > 0 {
return NewError(CompileErr, term.Loc(), "with keyword cannot partially replace virtual document(s)")
}
node = child
}
if node != nil {
if child := node.Child(ref[len(ref)-1].Value); child != nil {
for _, value := range child.Values {
if len(value.(*Rule).Head.Args) > 0 {
return NewError(CompileErr, term.Loc(), "with keyword cannot replace functions")
}
}
}
}
}
return nil
}
func isInputRef(term *Term) bool {
if ref, ok := term.Value.(Ref); ok {
if ref.HasPrefix(InputRootRef) {
return true
}
}
return false
}
func isDataRef(term *Term) bool {
if ref, ok := term.Value.(Ref); ok {
if ref.HasPrefix(DefaultRootRef) {
return true
}
}
return false
}
func isVirtual(node *TreeNode, ref Ref) bool {
for i := 0; i < len(ref); i++ {
child := node.Child(ref[i].Value)
if child == nil {
return false
} else if len(child.Values) > 0 {
return true
}
node = child
}
return true
}
func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) {
if len(unsafe) == 0 {
return
}
for _, pair := range unsafe.Vars() {
v := pair.Var
if w, ok := rewritten[v]; ok {
v = w
}
if !v.IsGenerated() {
if _, ok := futureKeywords[string(v)]; ok {
result = append(result, NewError(UnsafeVarErr, pair.Loc,
"var %[1]v is unsafe (hint: `import future.keywords.%[1]v` to import a future keyword)", v))
continue
}
result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", v))
}
}
if len(result) > 0 {
return
}
// If the expression contains unsafe generated variables, report which
// expressions are unsafe instead of the variables that are unsafe (since
// the latter are not meaningful to the user.)
pairs := unsafe.Slice()
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0
})
// Report at most one error per generated variable.
seen := NewVarSet()
for _, expr := range pairs {
before := len(seen)
for v := range expr.Vars {
if v.IsGenerated() {
seen.Add(v)
}
}
if len(seen) > before {
result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe"))
}
}
return
}
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors {
errs := make(Errors, 0)
WalkExprs(node, func(x *Expr) bool {
if x.IsCall() {
operator := x.Operator().String()
if _, ok := unsafeBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator))
}
}
return false
})
return errs
}
func checkDeprecatedBuiltins(deprecatedBuiltinsMap map[string]struct{}, node interface{}, strict bool) Errors {
// Early out; deprecatedBuiltinsMap is only populated in strict-mode.
if !strict {
return nil
}
errs := make(Errors, 0)
WalkExprs(node, func(x *Expr) bool {
if x.IsCall() {
operator := x.Operator().String()
if _, ok := deprecatedBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, x.Loc(), "deprecated built-in function calls in expression: %v", operator))
}
}
return false
})
return errs
}
func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
return func(node Ref) Ref {
i, _ := TransformVars(node, func(v Var) (Value, error) {
for _, m := range vars {
if u, ok := m[v]; ok {
return u, nil
}
}
return v, nil
})
return i.(Ref)
}
}