diff --git a/ast/check.go b/ast/check.go index b4075c8b78..73eacfacd2 100644 --- a/ast/check.go +++ b/ast/check.go @@ -231,37 +231,30 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { f := types.NewFunction(args, cpy.Get(rule.Head.Value)) - // Union with existing. - exist := env.tree.Get(path) - tpe = types.Or(exist, f) - + tpe = f } else { switch rule.Head.RuleKind() { case SingleValue: typeV := cpy.Get(rule.Head.Value) - if last := path[len(path)-1]; !last.IsGround() { - - // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] + if !path.IsGround() { + // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] or data.p.q.r[x].y[z] + objPath := path.DynamicSuffix() path = path.GroundPrefix() - typeK := cpy.Get(last) - if typeK != nil && typeV != nil { - exist := env.tree.Get(path) - typeV = types.Or(types.Values(exist), typeV) - typeK = types.Or(types.Keys(exist), typeK) - tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + var err error + tpe, err = nestedObject(cpy, objPath, typeV) + if err != nil { + tc.err([]*Error{NewError(TypeErr, rule.Head.Location, err.Error())}) + tpe = nil } } else { if typeV != nil { - exist := env.tree.Get(path) - tpe = types.Or(typeV, exist) + tpe = typeV } } case MultiValue: typeK := cpy.Get(rule.Head.Key) if typeK != nil { - exist := env.tree.Get(path) - typeK = types.Or(types.Keys(exist), typeK) tpe = types.NewSet(typeK) } } @@ -272,6 +265,32 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { } } +// nestedObject creates a nested structure of object types, where each term on path corresponds to a level in the +// nesting. Each term in the path only contributes to the dynamic portion of its corresponding object. +func nestedObject(env *TypeEnv, path Ref, tpe types.Type) (types.Type, error) { + if len(path) == 0 { + return tpe, nil + } + + k := path[0] + typeV, err := nestedObject(env, path[1:], tpe) + if err != nil { + return nil, err + } + if typeV == nil { + return nil, nil + } + + var dynamicProperty *types.DynamicProperty + typeK := env.Get(k) + if typeK == nil { + return nil, nil + } + dynamicProperty = types.NewDynamicProperty(typeK, typeV) + + return types.NewObject(nil, dynamicProperty), nil +} + func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error { if err := tc.checkExprWith(env, expr, 0); err != nil { return err diff --git a/ast/check_test.go b/ast/check_test.go index ae8e27fcbd..8127dd1680 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -358,6 +358,16 @@ func TestCheckInferenceRules(t *testing.T) { {`overlap`, `p.q2.a = input.a { true }`}, {`overlap`, `p.q2[56] = input.a { true }`}, } + ruleset3 := [][2]string{ + {`simple`, `p.q[r][s] = 42 { x = ["a", "b"]; r = x[s] }`}, + {`mixed`, `p.q[r].s[t] = 42 { x = ["a", "b"]; r = x[t] }`}, + {`overrides`, `p.q[r] = "foo" { x = ["a", "b"]; r = x[_] }`}, + {`overrides`, `p.q.r[s] = 42 { x = ["a", "b"]; x[s] }`}, + {`overrides`, `p.q[r].s = true { x = [true, false]; r = x[_] }`}, + {`overrides_static`, `p.q[r].a = "foo" { r = "bar"; s = "baz" }`}, + {`overrides_static`, `p.q[r].b = 42 { r = "bar" }`}, + {`overrides_static`, `p.q[r].c = true { r = "bar" }`}, + } tests := []struct { note string @@ -549,6 +559,104 @@ func TestCheckInferenceRules(t *testing.T) { types.NewDynamicProperty(types.Any{types.N, types.S}, types.Any{types.B, types.N, types.S}), ), }, + { + note: "general ref-rules, only vars in obj-path, complete obj access", + rules: ruleset3, + ref: "data.simple.p.q", + expected: types.NewObject( + []*types.StaticProperty{}, + types.NewDynamicProperty(types.S, + types.NewObject( + []*types.StaticProperty{}, + types.NewDynamicProperty(types.N, types.N), + ), + ), + ), + }, + { + note: "general ref-rules, only vars in obj-path, intermediate obj access", + rules: ruleset3, + ref: "data.simple.p.q.b", + expected: types.NewObject( + []*types.StaticProperty{}, + types.NewDynamicProperty(types.N, types.N), + ), + }, + { + note: "general ref-rules, only vars in obj-path, leaf access", + rules: ruleset3, + ref: "data.simple.p.q.b[1]", + expected: types.N, + }, + { + note: "general ref-rules, vars and constants in obj-path, complete obj access", + rules: ruleset3, + ref: "data.mixed.p.q", + expected: types.NewObject( + []*types.StaticProperty{}, + types.NewDynamicProperty(types.S, + types.NewObject(nil, + types.NewDynamicProperty(types.S, types.NewObject(nil, + types.NewDynamicProperty(types.N, types.N))), + ), + ), + ), + }, + { + note: "general ref-rules, key overrides, complete obj access", + rules: ruleset3, + ref: "data.overrides.p.q", + expected: types.NewObject(nil, types.NewDynamicProperty( + types.Or(types.B, types.S), + types.Any{ + types.S, + types.NewObject(nil, types.NewDynamicProperty( + types.Any{types.N, types.S}, + types.Any{types.B, types.N})), + }, + ), + ), + }, + { + note: "general ref-rules, multiple static key overrides, complete obj access", + rules: ruleset3, + ref: "data.overrides_static.p.q", + expected: types.NewObject( + []*types.StaticProperty{}, + types.NewDynamicProperty(types.S, + types.NewObject( + nil, + types.NewDynamicProperty(types.S, types.Any{types.B, types.N, types.S}), + ), + ), + ), + }, + { + note: "general ref-rules, multiple static key overrides, intermediate obj access", + rules: ruleset3, + ref: "data.overrides_static.p.q.foo", + expected: types.NewObject(nil, + types.NewDynamicProperty(types.S, types.Any{types.B, types.N, types.S}), + ), + }, + { + note: "general ref-rules, multiple static key overrides, leaf access (a)", + rules: ruleset3, + ref: "data.overrides_static.p.q.foo.a", + expected: types.Any{types.B, types.N, types.S}, // Dynamically build object types don't have static properties, so even though we "know" the 'a' key has a string value, we've lost this information. + }, + { + note: "general ref-rules, multiple static key overrides, leaf access (b)", + rules: ruleset3, + ref: "data.overrides_static.p.q.bar.b", + expected: types.Any{types.B, types.N, types.S}, + }, + { + note: "general ref-rules, multiple static key overrides, leaf access (c)", + rules: ruleset3, + ref: "data.overrides_static.p.q.baz.c", + expected: types.Any{types.B, types.N, types.S}, + }, } for _, tc := range tests { diff --git a/ast/compile.go b/ast/compile.go index 967f4ea8c2..112571e7c4 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -7,6 +7,7 @@ package ast import ( "fmt" "io" + "os" "sort" "strconv" "strings" @@ -81,6 +82,28 @@ type Compiler struct { // +--- b // | // +--- c (1 rule) + // + // Another example with general refs containing vars at arbitrary locations: + // + // package ex + // a.b[x].d { x := "c" } # R1 + // a.b.c[x] { x := "d" } # R2 + // a.b[x][y] { x := "c"; y := "d" } # R3 + // p := true # R4 + // + // root + // | + // +--- data (no rules) + // | + // +--- ex (no rules) + // | + // +--- a + // | | + // | +--- b (R1, R3) + // | | + // | +--- c (R2) + // | + // +--- p (R4) RuleTree *TreeNode // Graph contains dependencies between rules. An edge (u,v) is added to the @@ -123,6 +146,7 @@ type Compiler struct { keepModules bool // whether to keep the unprocessed, parse modules (below) parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker + generalRuleRefsEnabled bool } // CompilerStage defines the interface for stages in the compiler. @@ -311,6 +335,8 @@ func NewCompiler() *Compiler { {"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices}, } + _, c.generalRuleRefsEnabled = os.LookupEnv("EXPERIMENTAL_GENERAL_RULE_REFS") + return c } @@ -788,19 +814,23 @@ func (c *Compiler) buildRuleIndices() { return false } rules := extractRules(node.Values) - hasNonGroundKey := false + hasNonGroundRef := false for _, r := range rules { - if ref := r.Head.Ref(); len(ref) > 1 { - if !ref[len(ref)-1].IsGround() { - hasNonGroundKey = true - } - } + hasNonGroundRef = !r.Head.Ref().IsGround() } - if hasNonGroundKey { - // collect children: as of now, this cannot go deeper than one level, - // so we grab those, and abort the DepthFirst processing for this branch - for _, n := range node.Children { - rules = append(rules, extractRules(n.Values)...) + if hasNonGroundRef { + // Collect children to ensure that all rules within the extent of a rule with a general ref + // are found on the same index. E.g. the following rules should be indexed under data.a.b.c: + // + // package a + // b.c[x].e := 1 { x := input.x } + // b.c.d := 2 + // b.c.d2.e[x] := 3 { x := input.x } + for _, child := range node.Children { + child.DepthFirst(func(c *TreeNode) bool { + rules = append(rules, extractRules(c.Values)...) + return false + }) } } @@ -810,7 +840,7 @@ func (c *Compiler) buildRuleIndices() { if index.Build(rules) { c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index) } - return hasNonGroundKey // currently, we don't allow those branches to go deeper + return hasNonGroundRef // currently, we don't allow those branches to go deeper }) } @@ -870,10 +900,11 @@ func (c *Compiler) checkRuleConflicts() { kinds := make(map[RuleKind]struct{}, len(node.Values)) defaultRules := 0 + completeRules := 0 + partialRules := 0 arities := make(map[int]struct{}, len(node.Values)) name := "" - var singleValueConflicts []Ref - var multiValueConflicts []Ref + var conflicts []Ref for _, rule := range node.Values { r := rule.(*Rule) @@ -885,70 +916,99 @@ func (c *Compiler) checkRuleConflicts() { defaultRules++ } - // Single-value rules may not have any other rules in their extent: these pairs are invalid: + // Single-value rules may not have any other rules in their extent. + // Rules with vars in their ref are allowed to have rules inside their extent. + // Only the ground portion (terms before the first var term) of a rule's ref is considered when determining + // whether it's inside the extent of another (c.RuleTree is organized this way already). + // These pairs are invalid: // // data.p.q.r { true } # data.p.q is { "r": true } // data.p.q.r.s { true } // - // data.p.q[r] { r := input.r } # data.p.q could be { "r": true } - // data.p.q.r.s { true } + // data.p.q.r { true } + // data.p.q.r[s].t { s = input.key } // - // data.p[r] := x { r = input.key; x = input.bar } - // data.p.q[r] := x { r = input.key; x = input.bar } - // But this is allowed: + // + // data.p.q.r { true } + // data.p.q[r].s.t { r = input.key } + // + // data.p[r] := x { r = input.key; x = input.bar } + // data.p.q[r] := x { r = input.key; x = input.bar } + // + // data.p.q[r] { r := input.r } + // data.p.q.r.s { true } + // // data.p.q[r] = 1 { r := "r" } // data.p.q.s = 2 + // + // data.p[q][r] { q := input.q; r := input.r } + // data.p.q.r { true } + // + // data.p.q[r] { r := input.r } + // data.p[q].r { q := input.q } + // + // data.p.q[r][s] { r := input.r; s := input.s } + // data.p[q].r.s { q := input.q } - if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 { - if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren - for _, c := range node.Children { - grandchildrenFound := false - - if len(c.Values) > 0 { - childRules := extractRules(c.Values) - for _, childRule := range childRules { - childRef := childRule.Ref() - if childRule.Head.RuleKind() == SingleValue && !childRef[len(childRef)-1].IsGround() { - // The child is a partial object rule, so it's effectively "generating" grandchildren. - grandchildrenFound = true - break + if c.generalRuleRefsEnabled { + if r.Ref().IsGround() && len(node.Children) > 0 { + conflicts = node.flattenChildren() + } + } else { // TODO: Remove when general rule refs are enabled by default. + if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 { + if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren + for _, c := range node.Children { + grandchildrenFound := false + + if len(c.Values) > 0 { + childRules := extractRules(c.Values) + for _, childRule := range childRules { + childRef := childRule.Ref() + if childRule.Head.RuleKind() == SingleValue && !childRef[len(childRef)-1].IsGround() { + // The child is a partial object rule, so it's effectively "generating" grandchildren. + grandchildrenFound = true + break + } } } - } - if len(c.Children) > 0 { - grandchildrenFound = true - } + if len(c.Children) > 0 { + grandchildrenFound = true + } - if grandchildrenFound { - singleValueConflicts = node.flattenChildren() - break + if grandchildrenFound { + conflicts = node.flattenChildren() + break + } } + } else { // p.q.s and p.q.s.t => any children are in conflict + conflicts = node.flattenChildren() } - } else { // p.q.s and p.q.s.t => any children are in conflict - singleValueConflicts = node.flattenChildren() } - } - // Multi-value rules may not have any other rules in their extent; e.g.: - // - // data.p[v] { v := ... } - // data.p.q := 42 # In direct conflict with data.p[v], which is constructing a set and cannot have values assigned to a sub-path. + // Multi-value rules may not have any other rules in their extent; e.g.: + // + // data.p[v] { v := ... } + // data.p.q := 42 # In direct conflict with data.p[v], which is constructing a set and cannot have values assigned to a sub-path. - if r.Head.RuleKind() == MultiValue && len(node.Children) > 0 { - multiValueConflicts = node.flattenChildren() + if r.Head.RuleKind() == MultiValue && len(node.Children) > 0 { + conflicts = node.flattenChildren() + } + } + + if r.Head.RuleKind() == SingleValue && r.Head.Ref().IsGround() { + completeRules++ + } else { + partialRules++ } } switch { - case singleValueConflicts != nil: - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts)) - - case multiValueConflicts != nil: - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multi-value rule %v conflicts with %v", name, multiValueConflicts)) + case conflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule %v conflicts with %v", name, conflicts)) - case len(kinds) > 1 || len(arities) > 1: + case len(kinds) > 1 || len(arities) > 1 || (completeRules >= 1 && partialRules >= 1): c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) case defaultRules > 1: @@ -1697,13 +1757,12 @@ func (c *Compiler) rewriteRuleHeadRefs() { } for i := 1; i < len(ref); i++ { - // NOTE(sr): In the first iteration, non-string values in the refs are forbidden + // NOTE: Unless enabled via the EXPERIMENTAL_GENERAL_RULE_REFS env var, non-string values in the refs are forbidden // except for the last position, e.g. // OK: p.q.r[s] // NOT OK: p[q].r.s - // TODO(sr): This is stricter than necessary. We could allow any non-var values there, - // but we'll also have to adjust the type tree, for example. - if i != len(ref)-1 { // last + // TODO: Remove when general rule refs are enabled by default. + if !c.generalRuleRefsEnabled && i != len(ref)-1 { // last if _, ok := ref[i].Value.(String); !ok { c.err(NewError(TypeErr, rule.Loc(), "rule head must only contain string terms (except for last): %v", ref[i])) continue diff --git a/ast/compile_test.go b/ast/compile_test.go index ca7752e510..bb8ca866a4 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -467,6 +467,7 @@ func toRef(s string) Ref { } func TestCompilerCheckRuleHeadRefs(t *testing.T) { + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") tests := []struct { note string @@ -480,7 +481,6 @@ func TestCompilerCheckRuleHeadRefs(t *testing.T) { `package x p.q[i].r = 1 { i := 10 }`, ), - err: "rego_type_error: rule head must only contain string terms (except for last): i", }, { note: "valid: ref is single-value rule with var key", @@ -559,7 +559,6 @@ func TestCompilerCheckRuleHeadRefs(t *testing.T) { `package x p.q[arr[0]].r { i := 10 }`, ), - err: "rego_type_error: rule head must only contain string terms (except for last): arr[0]", }, { note: "invalid: non-string in ref (not last position)", @@ -567,7 +566,6 @@ func TestCompilerCheckRuleHeadRefs(t *testing.T) { `package x p.q[10].r { true }`, ), - err: "rego_type_error: rule head must only contain string terms (except for last): 10", }, { note: "valid: multi-value with var key", @@ -609,6 +607,64 @@ func TestCompilerCheckRuleHeadRefs(t *testing.T) { } } +// TODO: Remove when general rule refs are enabled by default. +func TestCompilerCheckRuleHeadRefsWithGeneralRuleRefsDisabled(t *testing.T) { + + tests := []struct { + note string + modules []*Module + expected *Rule + err string + }{ + { + note: "ref contains var", + modules: modules( + `package x + p.q[i].r = 1 { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): i", + }, + { + note: "invalid: ref in ref", + modules: modules( + `package x + p.q[arr[0]].r { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): arr[0]", + }, + { + note: "invalid: non-string in ref (not last position)", + modules: modules( + `package x + p.q[10].r { true }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): 10", + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.rewriteRuleHeadRefs) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + if len(c.Errors) > 0 { + t.Fatalf("expected no errors, got %v", c.Errors) + } + if tc.expected != nil { + assertRulesEqual(t, tc.expected, mods["0"].Rules[0]) + } + } + }) + } +} + func TestRuleTreeWithDotsInHeads(t *testing.T) { // TODO(sr): multi-val with var key in ref @@ -1813,6 +1869,9 @@ p { true }`, import future.keywords bar.baz contains "quz" if true`, + "mod8.rego": `package badrules.complete_partial +p := 1 +p[r] := 2 { r := "foo" }`, }) c.WithPathConflictsCheck(func(path []string) (bool, error) { @@ -1833,6 +1892,7 @@ bar.baz contains "quz" if true`, "rego_type_error: conflicting rules data.badrules.arity.g found", "rego_type_error: conflicting rules data.badrules.arity.p.q.h found", "rego_type_error: conflicting rules data.badrules.arity.p.q.i found", + "rego_type_error: conflicting rules data.badrules.complete_partial.p[r] found", "rego_type_error: conflicting rules data.badrules.p[x] found", "rego_type_error: conflicting rules data.badrules.q found", "rego_type_error: multiple default rules data.badrules.defkw.foo found", @@ -1880,6 +1940,7 @@ func TestCompilerCheckRuleConflictsDefaultFunction(t *testing.T) { } func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") tests := []struct { note string @@ -1966,7 +2027,7 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p.q.r { true }`, `package pkg p.q.r.s { true }`), - err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s]", + err: "rego_type_error: rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s]", }, { note: "single-value with other rule overlap", @@ -1975,7 +2036,15 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p.q.r { true } p.q.r.s { true } p.q.r.t { true }`), - err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s data.pkg.p.q.r.t]", + err: "rego_type_error: rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s data.pkg.p.q.r.t]", + }, + { + note: "single-value with other partial object (same ref) overlap", + modules: modules( + `package pkg + p.q := 1 + p.q[r] := 2 { r := "foo" }`), + err: "rego_type_error: conflicting rules data.pkg.p.q[r] foun", }, { note: "single-value with other rule overlap, unknown key", @@ -1984,16 +2053,22 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p.q[r] = x { r = input.key; x = input.foo } p.q.r.s = x { true } `), - err: "rego_type_error: single-value rule data.pkg.p.q[r] conflicts with [data.pkg.p.q.r.s]", }, { - note: "single-value partial object with other partial object rule overlap, unknown keys (regression test for #5855)", + note: "single-value with other rule overlap, unknown ref var and key", + modules: modules( + `package pkg + p.q[r][s] = x { r = input.key1; s = input.key2; x = input.foo } + p.q.r.s.t = x { true } + `), + }, + { + note: "single-value partial object with other partial object rule overlap, unknown keys (regression test for #5855; invalidated by multi-var refs)", modules: modules( `package pkg p[r] := x { r = input.key; x = input.bar } p.q[r] := x { r = input.key; x = input.bar } `), - err: "rego_type_error: single-value rule data.pkg.p[r] conflicts with [data.pkg.p.q[r]]", }, { note: "single-value partial object with other partial object (implicit 'true' value) rule overlap, unknown keys", @@ -2002,7 +2077,6 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p[r] := x { r = input.key; x = input.bar } p.q[r] { r = input.key } `), - err: "rego_type_error: single-value rule data.pkg.p[r] conflicts with [data.pkg.p.q[r]]", }, { note: "single-value partial object with multi-value rule (ref head) overlap, unknown key", @@ -2037,7 +2111,7 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p[v] { v := ["a", "b"][_] } p.q := 42 `), - err: "rego_type_error: multi-value rule data.pkg.p conflicts with [data.pkg.p.q]", + err: "rego_type_error: rule data.pkg.p conflicts with [data.pkg.p.q]", }, { note: "multi-value rule with other rule (ref) overlap", @@ -2046,7 +2120,117 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p[v] { v := ["a", "b"][_] } p.q.r { true } `), - err: "rego_type_error: multi-value rule data.pkg.p conflicts with [data.pkg.p.q.r]", + err: "rego_type_error: rule data.pkg.p conflicts with [data.pkg.p.q.r]", + }, + { + note: "multi-value rule (dots in head) with other rule (ref) overlap", + modules: modules( + `package pkg + import future.keywords + p.q contains v { v := ["a", "b"][_] } + p.q.r { true } + `), + err: "rule data.pkg.p.q conflicts with [data.pkg.p.q.r]", + }, + { + note: "multi-value rule (dots and var in head) with other rule (ref) overlap", + modules: modules( + `package pkg + import future.keywords + p[q] contains v { v := ["a", "b"][_] } + p.q.r { true } + `), + }, + { + note: "function with other rule (ref) overlap", + modules: modules( + `package pkg + p(x) := x + p.q.r { true } + `), + err: "rego_type_error: rule data.pkg.p conflicts with [data.pkg.p.q.r]", + }, + { + note: "function with other rule (ref) overlap", + modules: modules( + `package pkg + p(x) := x + p.q.r { true } + `), + err: "rego_type_error: rule data.pkg.p conflicts with [data.pkg.p.q.r]", + }, + { + note: "function (ref) with other rule (ref) overlap", + modules: modules( + `package pkg + p.q(x) := x + p.q.r { true } + `), + err: "rego_type_error: rule data.pkg.p.q conflicts with [data.pkg.p.q.r]", + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.checkRuleConflicts) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + assertCompilerErrorStrings(t, c, []string{}) + } + }) + } +} + +// TODO: Remove when general rule refs are enabled by default. +func TestGeneralRuleRefsDisabled(t *testing.T) { + // EXPERIMENTAL_GENERAL_RULE_REFS env var not set + + tests := []struct { + note string + modules []*Module + err string + }{ + { + note: "single-value with other rule overlap, unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.r.s = x { true } + `), + err: "rego_type_error: rule data.pkg.p.q[r] conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value with other rule overlap, unknown ref var and key", + modules: modules( + `package pkg + p.q[r][s] = x { r = input.key1; s = input.key2; x = input.foo } + p.q.r.s.t = x { true } + `), + err: "rego_type_error: rule head must only contain string terms (except for last): r", + }, + { + note: "single-value partial object with other partial object rule overlap, unknown keys (regression test for #5855; invalidated by multi-var refs)", + modules: modules( + `package pkg + p[r] := x { r = input.key; x = input.bar } + p.q[r] := x { r = input.key; x = input.bar } + `), + err: "rego_type_error: rule data.pkg.p[r] conflicts with [data.pkg.p.q[r]]", + }, + { + note: "single-value partial object with other partial object (implicit 'true' value) rule overlap, unknown keys", + modules: modules( + `package pkg + p[r] := x { r = input.key; x = input.bar } + p.q[r] { r = input.key } + `), + err: "rego_type_error: rule data.pkg.p[r] conflicts with [data.pkg.p.q[r]]", }, } for _, tc := range tests { diff --git a/ast/env.go b/ast/env.go index 21c56392b1..784a34c047 100644 --- a/ast/env.go +++ b/ast/env.go @@ -357,7 +357,8 @@ func (n *typeTreeNode) Insert(path Ref, tpe types.Type, env *TypeEnv) { } // mergeTypes merges the types of 'a' and 'b'. If both are sets, their 'of' types are joined with an types.Or. -// If both are objects, the key and value types of their dynamic properties are joined with types.Or:s. +// If both are objects, the key types of their dynamic properties are joined with types.Or:s, and their value types +// are recursively merged (using mergeTypes). // If 'a' and 'b' are both objects, and at least one of them have static properties, they are joined // with an types.Or, instead of being merged. // If 'a' is an Any containing an Object, and 'b' is an Object (or vice versa); AND both objects have no @@ -381,9 +382,10 @@ func mergeTypes(a, b types.Type) types.Type { aDynProps := a.DynamicProperties() bDynProps := bObj.DynamicProperties() - return types.NewObject(nil, types.NewDynamicProperty( + dynProps := types.NewDynamicProperty( types.Or(aDynProps.Key, bDynProps.Key), - types.Or(aDynProps.Value, bDynProps.Value))) + mergeTypes(aDynProps.Value, bDynProps.Value)) + return types.NewObject(nil, dynProps) } else if bAny, ok := b.(types.Any); ok && len(a.StaticProperties()) == 0 { // If a is an object type with no static components ... for _, t := range bAny { diff --git a/ast/parser.go b/ast/parser.go index 3337a964e4..540d7b80c1 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -11,6 +11,7 @@ import ( "io" "math/big" "net/url" + "os" "regexp" "sort" "strconv" @@ -96,13 +97,14 @@ func (e *parsedTermCacheItem) String() string { // ParserOptions defines the options for parsing Rego statements. type ParserOptions struct { - Capabilities *Capabilities - ProcessAnnotation bool - AllFutureKeywords bool - FutureKeywords []string - SkipRules bool - JSONOptions *JSONOptions - unreleasedKeywords bool // TODO(sr): cleanup + Capabilities *Capabilities + ProcessAnnotation bool + AllFutureKeywords bool + FutureKeywords []string + SkipRules bool + JSONOptions *JSONOptions + unreleasedKeywords bool // TODO(sr): cleanup + generalRuleRefsEnabled bool } // JSONOptions defines the options for JSON operations, @@ -140,6 +142,7 @@ func NewParser() *Parser { s: &state{}, po: ParserOptions{}, } + _, p.po.generalRuleRefsEnabled = os.LookupEnv("EXPERIMENTAL_GENERAL_RULE_REFS") return p } @@ -624,7 +627,7 @@ func (p *Parser) parseRules() []*Rule { return []*Rule{&rule} } - if usesContains && !rule.Head.Reference.IsGround() { + if !p.po.generalRuleRefsEnabled && usesContains && !rule.Head.Reference.IsGround() { p.error(p.s.Loc(), "multi-value rules need ground refs") return nil } @@ -701,7 +704,7 @@ func (p *Parser) parseRules() []*Rule { } if p.s.tok == tokens.Else { - if r := rule.Head.Ref(); len(r) > 1 && !r[len(r)-1].Value.IsGround() { + if r := rule.Head.Ref(); len(r) > 1 && !r.IsGround() { p.error(p.s.Loc(), "else keyword cannot be used on rules with variables in head") return nil } diff --git a/ast/parser_test.go b/ast/parser_test.go index 11d522fd9c..b4cb9951a9 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -2536,6 +2536,14 @@ else := 2 rule: ` a.b[x] := 1 if false else := 2 +`, + err: "else keyword cannot be used on rules with variables in head", + }, + { + note: "single-value general ref head with var", + rule: ` +a.b[x].c := 1 if false +else := 2 `, err: "else keyword cannot be used on rules with variables in head", }, diff --git a/ast/policy.go b/ast/policy.go index caf756f6aa..2822b82d87 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -975,6 +975,12 @@ func (head *Head) SetLoc(loc *Location) { head.Location = loc } +func (head *Head) HasDynamicRef() bool { + pos := head.Reference.Dynamic() + // Ref is dynamic if it has one non-constant term that isn't the first or last term or if it's a partial set rule. + return pos > 0 && (pos < len(head.Reference)-1 || head.RuleKind() == MultiValue) +} + // Copy returns a deep copy of a. func (a Args) Copy() Args { cpy := Args{} diff --git a/ast/term.go b/ast/term.go index c4f560f6e2..e89366aa74 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1043,6 +1043,14 @@ func (ref Ref) GroundPrefix() Ref { return prefix } +func (ref Ref) DynamicSuffix() Ref { + i := ref.Dynamic() + if i < 0 { + return nil + } + return ref[i:] +} + // IsGround returns true if all of the parts of the Ref are ground. func (ref Ref) IsGround() bool { if len(ref) == 0 { diff --git a/docs/content/policy-language.md b/docs/content/policy-language.md index 09f0a608f3..043115b110 100644 --- a/docs/content/policy-language.md +++ b/docs/content/policy-language.md @@ -992,6 +992,120 @@ data.example ```live:eg/ref_heads:output ``` +#### General References + +Any term, except the very first, in a rule head's reference can be a variable. These variables can be assigned within the rule, just as for any other partial rule, to dynamically construct a nested collection of objects. + +{{< danger >}} +General refs in rule heads is an experimental feature, and can be enabled by setting the `EXPERIMENTAL_GENERAL_RULE_REFS` environment variable. + +This feature is currently not supported for Wasm and IR. +{{< /danger >}} + +Data: + +```json +{ + "users": [ + { + "id": "alice", + "role": "employee", + "country": "USA" + }, + { + "id": "bob", + "role": "customer", + "country": "USA" + }, + { + "id": "dora", + "role": "admin", + "country": "Sweden" + } + ], + "admins": [ + { + "id": "charlie" + } + ] +} +``` + +Module: + +```rego +package example + +import future.keywords + +# A partial object rule that converts a list of users to a mapping by "role" and then "id". +users_by_role[role][id] := user if { + some user in data.users + id := user.id + role := user.role +} + +# Partial rule with an explicit "admin" key override +users_by_role.admin[id] := user if { + some user in data.admins + id := user.id +} + +# Leaf entries can be partial sets +users_by_country[country] contains user.id if { + some user in data.users + country := user.country +} +``` + +Query: + +``` +data.example +``` + +Output: + +```json +{ + "users_by_country": { + "Sweden": [ + "dora" + ], + "USA": [ + "alice", + "bob" + ] + }, + "users_by_role": { + "admin": { + "charlie": { + "id": "charlie" + }, + "dora": { + "country": "Sweden", + "id": "dora", + "role": "admin" + } + }, + "customer": { + "bob": { + "country": "USA", + "id": "bob", + "role": "customer" + } + }, + "employee": { + "alice": { + "country": "USA", + "id": "alice", + "role": "employee" + } + } + } +} +``` + ### Functions Rego supports user-defined functions that can be called with the same semantics as [Built-in Functions](#built-in-functions). They have access to both the [the data Document](../philosophy/#the-opa-document-model) and [the input Document](../philosophy/#the-opa-document-model). diff --git a/format/format.go b/format/format.go index 582ce27344..edd287ca83 100644 --- a/format/format.go +++ b/format/format.go @@ -461,7 +461,7 @@ func (w *writer) writeElse(rule *ast.Rule, o fmtOpts, comments []*ast.Comment) [ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fmtOpts, comments []*ast.Comment) []*ast.Comment { ref := head.Ref() - if head.Key != nil && head.Value == nil { + if head.Key != nil && head.Value == nil && !head.HasDynamicRef() { ref = ref.GroundPrefix() } if o.refHeads || len(ref) == 1 { diff --git a/format/format_test.go b/format/format_test.go index 8d363210ff..0c545fb557 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -78,6 +78,8 @@ func TestFormatSourceError(t *testing.T) { } func TestFormatSource(t *testing.T) { + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") + regoFiles, err := filepath.Glob("testfiles/*.rego") if err != nil { panic(err) diff --git a/format/testfiles/test_ref_heads.rego b/format/testfiles/test_ref_heads.rego index a55386ab59..f67663a4d1 100644 --- a/format/testfiles/test_ref_heads.rego +++ b/format/testfiles/test_ref_heads.rego @@ -11,3 +11,10 @@ q[1] = y if true r[x] if x := 10 p.q.r[x] if x := 10 p.q.r[2] if true + +g[h].i[j].k { true } +g[h].i[j].k { h := 1; j = 2 } +g[3].i[j].k = x { j := 3; x = 4 } +g[h].i[j].k[l] if { true } +g[h].i[j].k[l] contains x { x = "foo" } +g[h].i[j].k[l] contains x { h := 5; j := 6; l = 7; x = "foo" } diff --git a/format/testfiles/test_ref_heads.rego.formatted b/format/testfiles/test_ref_heads.rego.formatted index 5f23e3681b..c5fad5f043 100644 --- a/format/testfiles/test_ref_heads.rego.formatted +++ b/format/testfiles/test_ref_heads.rego.formatted @@ -17,3 +17,26 @@ r[x] if x := 10 p.q.r[x] if x := 10 p.q.r[2] = true + +g[h].i[j].k = true + +g[h].i[j].k if { + h := 1 + j = 2 +} + +g[3].i[j].k = x if { + j := 3 + x = 4 +} + +g[h].i[j].k[l] = true + +g[h].i[j].k[l] contains x if x = "foo" + +g[h].i[j].k[l] contains x if { + h := 5 + j := 6 + l = 7 + x = "foo" +} diff --git a/internal/wasm/sdk/test/e2e/exceptions.yaml b/internal/wasm/sdk/test/e2e/exceptions.yaml index 664e29b20f..951a438bd4 100644 --- a/internal/wasm/sdk/test/e2e/exceptions.yaml +++ b/internal/wasm/sdk/test/e2e/exceptions.yaml @@ -2,4 +2,14 @@ "data/toplevel integer": "https://github.com/open-policy-agent/opa/issues/3711" "data/nested integer": "https://github.com/open-policy-agent/opa/issues/3711" "withkeyword/function: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311" -"withkeyword/builtin: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311" \ No newline at end of file +"withkeyword/builtin: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311" +"refheads/general, single var": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, multiple vars": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, deep query": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, overlapping rule, no conflict": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, overlapping rule, conflict": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, set leaf": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, set leaf, deep query": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, input var": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, external non-ground var": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" +"refheads/general, multiple result-set entries": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet" \ No newline at end of file diff --git a/test/cases/cases.go b/test/cases/cases.go index a614d75e8e..1b4a0f5a0b 100644 --- a/test/cases/cases.go +++ b/test/cases/cases.go @@ -44,6 +44,7 @@ type TestCase struct { WantError *string `json:"want_error,omitempty"` // expect query error message (overrides error code) SortBindings bool `json:"sort_bindings,omitempty"` // indicates that binding values should be treated as sets StrictError bool `json:"strict_error,omitempty"` // indicates that the error depends on strict builtin error mode + Env map[string]string `json:"env,omitempty"` // environment variables to be set during the test } // Load returns a set of built-in test cases. diff --git a/test/cases/testdata/refheads/test-generic-refs.yaml b/test/cases/testdata/refheads/test-generic-refs.yaml new file mode 100644 index 0000000000..74fb221f6e --- /dev/null +++ b/test/cases/testdata/refheads/test-generic-refs.yaml @@ -0,0 +1,174 @@ +cases: + - note: 'refheads/general, single var' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p[q].r := i { q := ["a", "b", "c"][i] } + query: data.test.p = x + want_result: + - x: + a: + r: 0 + b: + r: 1 + c: + r: 2 + - note: 'refheads/general, multiple vars' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p[q][r] { q := ["a", "b", "c"][r] } + query: data.test.p = x + want_result: + - x: + a: + 0: true + b: + 1: true + c: + 2: true + - note: 'refheads/general, deep query' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p[q][r] { q := ["a", "b", "c"][r] } + query: data.test.p.b = x + want_result: + - x: + 1: true + - note: 'refheads/general, overlapping rule, no conflict' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p[q].r := i { q := ["a", "b", "c"][i] } + p.a.r := 0 + query: data.test.p = x + want_result: + - x: + a: + r: 0 + b: + r: 1 + c: + r: 2 + - note: 'refheads/general, overlapping rule, conflict' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p[q].r := i { q := ["a", "b", "c"][i] } + p.a.r := 42 + query: data.test.p = x + want_error: eval_conflict_error + - note: 'refheads/general, set leaf' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + import future.keywords + + p[q].r contains s { + x := ["a", "b", "c"] + q := x[_] + s := x[_] + q != s + } + + p.b.r contains "foo" + query: data.test.p = x + want_result: + - x: + a: + r: [ "b", "c" ] + b: + r: [ "a", "c", "foo" ] + c: + r: [ "a", "b" ] + - note: 'refheads/general, set leaf, deep query' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + import future.keywords + + p[q].r contains s { + x := ["a", "b", "c"] + q := x[_] + s := x[_] + q != s + } + + p.b.r contains "foo" + query: data.test.p.b.r.c = x + want_result: + - x: "c" + - note: 'refheads/general, input var' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p[input.x].r := "foo" + query: data.test.p = x + input: + x: "bar" + want_result: + - x: + bar: + r: "foo" + - note: 'refheads/general, external non-ground var' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + a := [x | x := input.x[_]] + b := input.y + + p[a[b]].r[s] := i { + s := a[i] + } + query: data.test.p = x + input: + x: [ "foo", "bar", "baz" ] + "y": 1 + want_result: + - x: + bar: + r: + foo: 0 + bar: 1 + baz: 2 + - note: 'refheads/general, multiple result-set entries' + env: + EXPERIMENTAL_GENERAL_RULE_REFS: "true" + modules: + - | + package test + + p.q[r].s := 1 { r := "foo" } + p.q[r].s := 2 { r := "bar" } + query: data.test.p.q[i].s = x + want_result: + - i: foo + x: 1 + - i: bar + x: 2 \ No newline at end of file diff --git a/topdown/eval.go b/topdown/eval.go index c973ab8fcf..292c8fd8ca 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -23,6 +23,8 @@ type evalIterator func(*eval) error type unifyIterator func() error +type unifyRefIterator func(pos int) error + type queryIDFactory struct { curr uint64 } @@ -2278,6 +2280,14 @@ func (e evalVirtual) eval(iter unifyIterator) error { switch ir.Kind { case ast.MultiValue: + var empty *ast.Term + if ir.OnlyGroundRefs { + // rule ref contains no vars, so we're building a set + empty = ast.SetTerm() + } else { + // rule ref contains vars, so we're building an object containing a set leaf + empty = ast.ObjectTerm() + } eval := evalVirtualPartial{ e: e.e, ref: e.ref, @@ -2287,12 +2297,10 @@ func (e evalVirtual) eval(iter unifyIterator) error { bindings: e.bindings, rterm: e.rterm, rbindings: e.rbindings, - empty: ast.SetTerm(), + empty: empty, } return eval.eval(iter) case ast.SingleValue: - // NOTE(sr): If we allow vars in others than the last position of a ref, we need - // to start reworking things here if ir.OnlyGroundRefs { eval := evalVirtualComplete{ e: e.e, @@ -2359,9 +2367,31 @@ func (e evalVirtualPartial) eval(iter unifyIterator) error { return e.evalEachRule(iter, unknown) } +// returns the maximum length a ref can be without being longer than the longest rule ref in rules. +func maxRefLength(rules []*ast.Rule, ceil int) int { + var l int + for _, r := range rules { + rl := len(r.Ref()) + if r.Head.RuleKind() == ast.MultiValue { + rl = rl + 1 + } + if rl >= ceil { + return ceil + } else if rl > l { + l = rl + } + } + return l +} + func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error { - if e.e.unknown(e.ref[e.pos+1], e.bindings) { + if e.ir.Empty() { + return nil + } + + m := maxRefLength(e.ir.Rules, len(e.ref)) + if e.e.unknown(e.ref[e.pos+1:m], e.bindings) { for _, rule := range e.ir.Rules { if err := e.evalOneRulePostUnify(iter, rule); err != nil { return err @@ -2387,12 +2417,24 @@ func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error } result := e.empty + for _, rule := range e.ir.Rules { - if err := e.evalOneRulePreUnify(iter, rule, hint, result, unknown); err != nil { + result, err = e.evalOneRulePreUnify(iter, rule, hint, result, unknown) + if err != nil { return err } } + if hint.key != nil { + if v, err := result.Value.Find(hint.key[e.pos+1:]); err == nil && v != nil { + e.e.virtualCache.Put(hint.key, ast.NewTerm(v)) + } + } + + if !unknown { + return e.evalTerm(iter, e.pos+1, result, e.bindings) + } + return nil } @@ -2428,7 +2470,7 @@ func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, e err := child.eval(func(*eval) error { child.traceExit(rule) var err error - result, _, err = e.reduce(rule.Head, child.bindings, result) + result, _, err = e.reduce(rule, child.bindings, result) if err != nil { return err } @@ -2445,9 +2487,18 @@ func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, e return result, nil } -func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, hint evalVirtualPartialCacheHint, result *ast.Term, unknown bool) error { +func wrapInObjects(leaf *ast.Term, ref ast.Ref) *ast.Term { + // We build the nested objects leaf-to-root to preserve ground:ness + if len(ref) == 0 { + return leaf + } + key := ref[0] + val := wrapInObjects(leaf, ref[1:]) + return ast.ObjectTerm(ast.Item(key, val)) +} + +func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, hint evalVirtualPartialCacheHint, result *ast.Term, unknown bool) (*ast.Term, error) { - key := e.ref[e.pos+1] child := e.e.child(rule.Body) child.traceEnter(rule) @@ -2457,67 +2508,89 @@ func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Ru if headKey == nil { headKey = rule.Head.Reference[len(rule.Head.Reference)-1] } - err := child.biunify(headKey, key, child.bindings, e.bindings, func() error { + + // Walk the dynamic portion of rule ref and key to unify vars + err := child.biunifyRuleHead(e.pos+1, e.ref, rule, e.bindings, child.bindings, func(pos int) error { defined = true return child.eval(func(child *eval) error { + child.traceExit(rule) + term := rule.Head.Value if term == nil { term = headKey } - if hint.key != nil { - result := child.bindings.Plug(term) - e.e.virtualCache.Put(hint.key, result) - } + if unknown { + term, termbindings := child.bindings.apply(term) + + if rule.Head.RuleKind() == ast.MultiValue { + term = ast.SetTerm(term) + } + + objRef := rule.Ref()[e.pos+1:] + term = wrapInObjects(term, objRef) - // NOTE(tsandall): if the rule set depends on any unknowns then do - // not perform the duplicate check because evaluation of the ruleset - // may not produce a definitive result. This is a bit strict--we - // could improve by skipping only when saves occur. - if !unknown { + err := e.evalTerm(iter, e.pos+1, term, termbindings) + if err != nil { + return err + } + } else { var dup bool var err error - result, dup, err = e.reduce(rule.Head, child.bindings, result) + result, dup, err = e.reduce(rule, child.bindings, result) if err != nil { return err - } else if dup { + } else if !unknown && dup { child.traceDuplicate(rule) return nil } } - child.traceExit(rule) - term, termbindings := child.bindings.apply(term) - err := e.evalTerm(iter, e.pos+2, term, termbindings) - if err != nil { - return err - } - child.traceRedo(rule) + return nil }) }) if err != nil { - return err + return nil, err } - // TODO(tsandall): why are we tracing here? this looks wrong. if !defined { child.traceFail(rule) } - return nil + return result, nil } -func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.Rule) error { - headKey := rule.Head.Key - if headKey == nil { - headKey = rule.Head.Reference[len(rule.Head.Reference)-1] +func (e *eval) biunifyRuleHead(pos int, ref ast.Ref, rule *ast.Rule, refBindings, ruleBindings *bindings, iter unifyRefIterator) error { + return e.biunifyDynamicRef(pos, ref, rule.Ref(), refBindings, ruleBindings, func(pos int) error { + // FIXME: Is there a simpler, more robust way of figuring out that we should biunify the rule key? + if rule.Head.RuleKind() == ast.MultiValue && pos < len(ref) && len(rule.Ref()) <= len(ref) { + headKey := rule.Head.Key + if headKey == nil { + headKey = rule.Head.Reference[len(rule.Head.Reference)-1] + } + return e.biunify(ref[pos], headKey, refBindings, ruleBindings, func() error { + return iter(pos + 1) + }) + } + return iter(pos) + }) +} + +func (e *eval) biunifyDynamicRef(pos int, a, b ast.Ref, b1, b2 *bindings, iter unifyRefIterator) error { + if pos >= len(a) || pos >= len(b) { + return iter(pos) } - key := e.ref[e.pos+1] + return e.biunify(a[pos], b[pos], b1, b2, func() error { + return e.biunifyDynamicRef(pos+1, a, b, b1, b2, iter) + }) +} + +func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.Rule) error { child := e.e.child(rule.Body) child.traceEnter(rule) @@ -2525,7 +2598,7 @@ func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.R err := child.eval(func(child *eval) error { defined = true - return e.e.biunify(headKey, key, child.bindings, e.bindings, func() error { + return e.e.biunifyRuleHead(e.pos+1, e.ref, rule, e.bindings, child.bindings, func(pos int) error { return e.evalOneRuleContinue(iter, rule, child) }) }) @@ -2551,7 +2624,15 @@ func (e evalVirtualPartial) evalOneRuleContinue(iter unifyIterator, rule *ast.Ru } term, termbindings := child.bindings.apply(term) - err := e.evalTerm(iter, e.pos+2, term, termbindings) + + if rule.Head.RuleKind() == ast.MultiValue { + term = ast.SetTerm(term) + } + + objRef := rule.Ref()[e.pos+1:] + term = wrapInObjects(term, objRef) + + err := e.evalTerm(iter, e.pos+1, term, termbindings) if err != nil { return err } @@ -2633,10 +2714,10 @@ func (e evalVirtualPartial) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) ruleRef[i] = child.bindings.plugNamespaced(ruleRef[i], e.e.caller.bindings) } head.Reference = ruleRef - if head.Name.Equal(ast.Var("")) { + if head.Name.Equal(ast.Var("")) && (len(ruleRef) == 1 || (len(ruleRef) == 2 && head.Key == nil)) { head.Name = ruleRef[0].Value.(ast.Var) } - if len(ruleRef) > 1 && head.Key == nil { + if len(ruleRef) == 2 && head.Key == nil { head.Key = ruleRef[len(ruleRef)-1] } } @@ -2681,6 +2762,7 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac var hint evalVirtualPartialCacheHint if e.e.unknown(e.ref[:e.pos+1], e.bindings) { + // FIXME: Return empty hint if unknowns in any e.ref elem overlapping with applicable rule refs? return hint, nil } @@ -2692,17 +2774,29 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac plugged := e.bindings.Plug(e.ref[e.pos+1]) - if plugged.IsGround() { - hint.key = append(e.plugged[:e.pos+1], plugged) + if _, ok := plugged.Value.(ast.Var); ok { + hint.full = true + hint.key = e.plugged[:e.pos+1] + e.e.instr.counterIncr(evalOpVirtualCacheMiss) + return hint, nil + } + + m := maxRefLength(e.ir.Rules, len(e.ref)) + + for i := e.pos + 1; i < m; i++ { + plugged = e.bindings.Plug(e.ref[i]) + + if !plugged.IsGround() { + break + } + + hint.key = append(e.plugged[:i], plugged) if cached, _ := e.e.virtualCache.Get(hint.key); cached != nil { e.e.instr.counterIncr(evalOpVirtualCacheHit) hint.hit = true - return hint, e.evalTerm(iter, e.pos+2, cached, e.bindings) + return hint, e.evalTerm(iter, i+1, cached, e.bindings) } - } else if _, ok := plugged.Value.(ast.Var); ok { - hint.full = true - hint.key = e.plugged[:e.pos+1] } e.e.instr.counterIncr(evalOpVirtualCacheMiss) @@ -2710,26 +2804,81 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac return hint, nil } -func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, bool, error) { +func getNestedObject(ref ast.Ref, rootObj *ast.Object, b *bindings, l *ast.Location) (*ast.Object, error) { + current := rootObj + for _, term := range ref { + key := b.Plug(term) + if child := (*current).Get(key); child != nil { + if val, ok := child.Value.(ast.Object); ok { + current = &val + } else { + return nil, objectDocKeyConflictErr(l) + } + } else { + child := ast.NewObject() + (*current).Insert(key, ast.NewTerm(child)) + current = &child + } + } + + return current, nil +} + +func (e evalVirtualPartial) reduce(rule *ast.Rule, b *bindings, result *ast.Term) (*ast.Term, bool, error) { var exists bool + head := rule.Head switch v := result.Value.(type) { - case ast.Set: // MultiValue + case ast.Set: key := b.Plug(head.Key) exists = v.Contains(key) v.Add(key) - case ast.Object: // SingleValue - key := head.Reference[len(head.Reference)-1] // NOTE(sr): multiple vars in ref heads need to deal with this better - key = b.Plug(key) - value := b.Plug(head.Value) - if curr := v.Get(key); curr != nil { - if !curr.Equal(value) { - return nil, false, objectDocKeyConflictErr(head.Location) - } - exists = true + case ast.Object: + // data.p.q[r].s.t := 42 {...} + // |----|-| + // ^ ^ + // | leafKey + // objPath + fullPath := rule.Ref() + objPath := fullPath[e.pos+1 : len(fullPath)-1] // the portion of the ref that generates nested objects + leafKey := b.Plug(fullPath[len(fullPath)-1]) // the portion of the ref that is the deepest nested key for the value + + leafObj, err := getNestedObject(objPath, &v, b, head.Location) + if err != nil { + return nil, false, err + } + + if kind := head.RuleKind(); kind == ast.SingleValue { + // We're inserting into an object + val := b.Plug(head.Value) + + if curr := (*leafObj).Get(leafKey); curr != nil { + if !curr.Equal(val) { + return nil, false, objectDocKeyConflictErr(head.Location) + } + exists = true + } else { + (*leafObj).Insert(leafKey, val) + } } else { - v.Insert(key, value) + // We're inserting into a set + var set *ast.Set + if leaf := (*leafObj).Get(leafKey); leaf != nil { + if s, ok := leaf.Value.(ast.Set); ok { + set = &s + } else { + return nil, false, objectDocKeyConflictErr(head.Location) + } + } else { + s := ast.NewSet() + (*leafObj).Insert(leafKey, ast.NewTerm(s)) + set = &s + } + + key := b.Plug(head.Key) + exists = (*set).Contains(key) + (*set).Add(key) } } @@ -2956,6 +3105,7 @@ func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref s := ref[len(ref)-1].Value.(ast.String) name = ast.Var(s) } + // TODO: Do we need to deal with general refs here? head := ast.NewHead(name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)) if !e.e.inliningControl.shallow { diff --git a/topdown/eval_test.go b/topdown/eval_test.go index ace27b65c2..383f3a5ce5 100644 --- a/topdown/eval_test.go +++ b/topdown/eval_test.go @@ -6,6 +6,8 @@ package topdown import ( "context" + "encoding/json" + "strings" "testing" "github.com/open-policy-agent/opa/ast" @@ -245,6 +247,9 @@ func TestContainsNestedRefOrCall(t *testing.T) { } func TestTopdownVirtualCache(t *testing.T) { + // TODO: break out into separate tests + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") + ctx := context.Background() store := inmem.New() @@ -325,6 +330,160 @@ func TestTopdownVirtualCache(t *testing.T) { hit: 1, miss: 1, }, + { + note: "partial object: query into object value", + module: `package p + s["foo"] = { "x": 42, "y": 43 } { true } + s["bar"] = { "x": 42, "y": 43 } { true }`, + query: `data.p.s["foo"].x = x; data.p.s["foo"].y`, + hit: 1, + miss: 1, + exp: 42, + }, + { + note: "partial object: simple, general ref", + module: `package p + s.t[u].v = true { x = ["foo", "bar"]; u = x[_] }`, + query: `data.p.s.t["foo"].v = x; data.p.s.t["foo"].v`, + hit: 1, + miss: 1, + exp: true, + }, + { + note: "partial object: simple, general ref, multiple vars", + module: `package p + s.t[u].v[w] = true { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[_] }`, + query: `data.p.s.t = x; data.p.s.t`, + hit: 1, + miss: 1, + exp: map[string]interface{}{ + "foo": map[string]interface{}{ + "v": map[string]interface{}{ + "do": true, + "re": true, + }, + }, + "bar": map[string]interface{}{ + "v": map[string]interface{}{ + "do": true, + "re": true, + }, + }, + }, + }, + { + note: "partial object: simple, general ref, multiple vars (2)", + module: `package p + s.t[u].v[w] = true { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[_] }`, + query: `data.p.s.t.foo = x; data.p.s.t["foo"]`, + hit: 1, + miss: 1, + exp: map[string]interface{}{ + "v": map[string]interface{}{ + "do": true, + "re": true, + }, + }, + }, + { + note: "partial object: simple, general ref, multiple vars (3)", + module: `package p + s.t[u].v[w] = true { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[_] }`, + query: `data.p.s.t.foo.v = x; data.p.s.t["foo"].v`, + hit: 1, + miss: 1, + exp: map[string]interface{}{ + "do": true, + "re": true, + }, + }, + { + note: "partial object: simple, general ref, multiple vars (4)", + module: `package p + s.t[u].v[w] = true { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[_] }`, + query: `data.p.s.t.foo.v.re = x; data.p.s.t["foo"].v["re"]`, + hit: 1, + miss: 1, + exp: true, + }, + { + note: "partial object: simple, general ref, miss", + module: `package p + s.t[u].v[w] = true { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[_] }`, + query: `data.p.s.t.foo.v.re = x; data.p.s.t.foo.v.do`, + hit: 0, + miss: 2, + exp: true, + }, + { + note: "partial object: simple, general ref, miss (2)", + module: `package p + s.t[u].v[w] = i { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[i] }`, + query: `data.p.s.t.foo.v.re = x; data.p.s.t.foo.v.do; data.p.s.t.foo.v.re`, + hit: 1, + miss: 2, + exp: 1, + }, + { + note: "partial object: simple, general ref, miss (3)", + module: `package p + s.t[u].v[w] = i { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[i] }`, + query: `data.p.s.t.foo.v.re = x; data.p.s.t.foo.v.do; data.p.s.t.bar.v.re`, + hit: 0, + miss: 3, + exp: 1, + }, + { + note: "partial object: simple, general ref, miss (3)", + module: `package p + s.t[u].v[w] = i { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[i] }`, + query: `data.p.s.t.foo.v.re = x; data.p.s.t.foo.v.do; data.p.s.t.bar.v.re; data.p.s.t.foo.v.do`, + hit: 1, + miss: 3, + exp: 1, + }, + { + note: "partial object: simple, general ref, miss (4)", + module: `package p + s.t[u].v[w] = i { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[i] }`, + query: `data.p.s.t.foo = x; data.p.s.t.foo.v.do`, + hit: 1, + miss: 1, + exp: map[string]interface{}{ + "v": map[string]interface{}{ + "do": 0, + "re": 1, + }, + }, + }, + { + note: "partial object: simple, general ref, miss (5)", + module: `package p + s.t[u].v[w] = i { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[i] }`, + query: `data.p.s.t.foo; data.p.s.t.foo.v.do = x`, + hit: 1, + miss: 1, + exp: 0, + }, + { + note: "partial object: simple, general ref, miss (6)", + module: `package p + s.t[u].v[w] = i { x = ["foo", "bar"]; u = x[_]; y = ["do", "re"]; w = y[i] }`, + query: `data.p.s.t.foo.v.do = x; data.p.s.t.foo`, + hit: 0, // Note: Could we be smart in query term eval order to gain an extra hit here? + miss: 2, + exp: 0, + }, + { + note: "partial object: simple, query into value", + module: `package p + s["foo"].t = { "x": 42, "y": 43 } { true } + s["bar"].t = { "x": 42, "y": 43 } { true }`, + query: `data.p.s["foo"].t.x = x; data.p.s["foo"].t.x`, + hit: 1, + miss: 1, + exp: 42, + }, { note: "partial set: simple", module: `package p @@ -443,3 +602,803 @@ func TestTopdownVirtualCache(t *testing.T) { }) } } + +func TestPartialRule(t *testing.T) { + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") + + ctx := context.Background() + store := inmem.New() + + tests := []struct { + note string + module string + query string + exp string + expErr string + }{ + { + note: "partial set", + module: `package test + p[v] { + v := [1, 2, 3][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": [1, 2, 3]}}}]`, + }, + { + note: "partial object", + module: `package test + p[i] := v { + v := [1, 2, 3][i] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"0": 1, "1": 2, "2": 3}}}}]`, + }, + { + note: "partial object (const key)", + module: `package test + p["foo"] := v { + v := 42 + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": 42}}}}]`, + }, + { + note: "ref head", + module: `package test + p.foo := v { + v := 42 + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": 42}}}}]`, + }, + { + note: "partial object (ref head)", + module: `package test + p.q.r[i] := v { + v := ["a", "b", "c"][i] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"q": {"r": {"0": "a", "1": "b", "2": "c"}}}}}}]`, + }, + { + note: "partial object (ref head), query to obj root", + module: `package test + p.q.r[i] := v { + v := ["a", "b", "c"][i] + } + `, + query: `data.test.p.q.r = x`, + exp: `[{"x": {"0": "a", "1": "b", "2": "c"}}]`, + }, + { + note: "partial object (ref head), query to obj root, enumerating keys", + module: `package test + p.q.r[i] := v { + v := ["a", "b", "c"][i] + } + `, + query: `data.test.p.q.r[x]`, + // NOTE: $_term_0_0 wildcard var is filtered from eval result output + exp: `[{"x": 0, "$_term_0_0": "a"}, {"x": 1, "$_term_0_0": "b"}, {"x": 2, "$_term_0_0": "c"}]`, + }, + { + note: "partial object (ref head), implicit 'true' value", + module: `package test + p.q.r[v] { + v := [1, 2, 3][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"q": {"r": {"1": true, "2": true, "3": true}}}}}}]`, + }, + { + note: "partial set (ref head)", + module: `package test + import future.keywords + p.q contains v if { + v := [1, 2, 3][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"q": [1, 2, 3]}}}}]`, + }, + { + note: "partial set (general ref head)", + module: `package test + import future.keywords + p[j] contains v if { + v := [1, 2, 3][_] + j := ["a", "b", "c"][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}}}}]`, + }, + { + note: "partial set (general ref head, static suffix)", + module: `package test + import future.keywords + p[q].r contains v if { + q := "foo" + v := [1, 2, 3][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": {"r": [1, 2, 3]}}}}}]`, + }, + { + note: "partial object (general ref head, multiple vars)", + module: `package test + p.q[x].r[i] := v { + some i + v := [1, 2, 3][i] + x := ["a", "b", "c"][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"q": {"a": {"r": {"0": 1, "1": 2, "2": 3}}, "b": {"r": {"0": 1, "1": 2, "2": 3}}, "c": {"r": {"0": 1, "1": 2, "2": 3}}}}}}}]`, + }, + { + note: "partial object (general ref head, multiple vars) #2", + module: `package test + p[j].foo[i] := v { + v := [1, 2, 3][i] + j := ["a", "b", "c"][_] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"a": {"foo": {"0": 1, "1": 2, "2": 3}}, "b": {"foo": {"0": 1, "1": 2, "2": 3}}, "c": {"foo": {"0": 1, "1": 2, "2": 3}}}}}}]`, + }, + { + note: "partial set (multiple vars in general ref head)", + module: `package test + import future.keywords + p[j][i] contains v if { + v := [1, 2, 3][_] + j := ["a", "b", "c"][_] + i := "foo" + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"a": {"foo": [1, 2, 3]}, "b": {"foo": [1, 2, 3]}, "c": {"foo": [1, 2, 3]}}}}}]`, + }, + // Overlapping rules + { + note: "partial object with overlapping rule (defining key/value in object)", + module: `package test + foo.bar[i] := v { + v := ["a", "b", "c"][i] + } + foo.bar.baz := 42 + `, + query: `data = x`, + exp: `[{"x": {"test": {"foo": {"bar": {"0": "a", "1": "b", "2": "c", "baz": 42}}}}}]`, + }, + { + note: "partial object with overlapping rule (dee ref on overlap)", + module: `package test + p[k] := 1 { + k := "foo" + } + p.q.r.s.t := 42 + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": 1, "q": {"r": {"s": {"t": 42}}}}}}}]`, + }, + { + note: "partial object with overlapping rule (dee ref on overlap; conflict)", + module: `package test + p[k] := 1 { + k := "q" + } + p.q.r.s.t := 42 + `, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object with overlapping rule (key conflict)", + module: `package test + foo.bar[k] := v { + k := "a" + v := 43 + } + foo.bar["a"] := 42 + `, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object generating conflicting nested keys (different nested object depth)", + module: `package test + p.q.r { + true + } + p.q[r].s.t { + r := "foo" + }`, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"q": {"foo": {"s": {"t": true}}, "r": true}}}}}]`, + }, + { + note: "partial object generating conflicting nested keys (different nested object depth; key conflict)", + module: `package test + p.q[k].s := 1 { + k := "r" + } + p.q[k].s.t := 1 { + k := "r" + }`, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object (overlapping rules producing same values)", + module: `package test + p.foo.bar[i] := v { + v := ["a", "b", "c"][i] + } + p.foo[i][j] := v { + i := "bar" + v := ["a", "b", "c"][j] + } + p[q][i][j] := v { + q := "foo" + i := "bar" + v := ["a", "b", "c"][j] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": {"bar": {"0": "a", "1": "b", "2": "c"}}}}}}]`, + }, + { + note: "partial object (overlapping rules, same depth, producing non-conflicting keys)", + module: `package test + p.foo[i].bar := v { + v := ["a", "b", "c"][i] + } + p.foo.bar[i] := v { + v := ["a", "b", "c"][i] + } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": { + "0": {"bar": "a"}, + "1": {"bar": "b"}, + "2": {"bar": "c"}, + "bar": {"0": "a", "1": "b", "2": "c"}}}}}}]`, + }, + // Intersections with object values + { + note: "partial object NOT intersecting with object value of other rule", + module: `package test + p.foo := {"bar": {"baz": 1}} + p[k] := 2 {k := "other"} + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": {"bar": {"baz": 1}}, "other": 2}}}}]`, + }, + { + note: "partial object NOT intersecting with object value of other rule (nested object merge along rule refs)", + module: `package test + p.foo.bar := {"baz": 1} # p.foo.bar == {"baz": 1} + p[k].bar2 := v {k := "foo"; v := {"other": 2}} # p.foo.bar2 == {"other": 2} + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"foo": {"bar": {"baz": 1}, "bar2": {"other": 2}}}}}}]`, + }, + { + note: "partial object intersecting with object value of other rule (not merging otherwise conflict-free obj values)", + module: `package test + p.foo := {"bar": {"baz": 1}} # p == {"foo": {"bar": {"baz": 1}}} + p[k] := v {k := "foo"; v := {"bar": {"other": 2}}} # p == {"foo": {"bar": {"other": 2}}} + `, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", // conflict on key "bar" which is inside rule values, which may not be modified by other rule + }, + { + note: "partial object rules with overlapping known ref vars (no eval-time conflict)", + module: `package test + p[k].r1 := 1 { k := "q" } + p[k].r2 := 2 { k := "q" } + `, + query: `data = x`, + exp: `[{"x": {"test": {"p": {"q": {"r1": 1, "r2": 2}}}}}]`, + }, + { + note: "partial object rules with overlapping known ref vars (eval-time conflict)", + module: `package test + p[k].r := 1 { k := "q" } + p[k].r := 2 { k := "q" } + `, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object rules with overlapping known ref vars, non-overlapping object type values (eval-time conflict)", + module: `package test + p[k].r := {"s1": 1} { k := "q" } + p[k].r := {"s2": 2} { k := "q" } + `, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + // Deep queries + { + note: "deep query into partial object (ref head)", + module: `package test + p.q[r] := 1 { r := "foo" } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": 1}]`, + }, + { + note: "deep query into partial object (ref head) and object value", + module: `package test + p.q[r] := x { + r := "foo" + x := {"bar": {"baz": 1}} + } + `, + query: `data.test.p.q.foo.bar = x`, + exp: `[{"x": {"baz": 1}}]`, + }, + { + note: "deep query into partial object starting-point (general ref head) up to array value", + module: `package test + p.q[r].s[t].u := x { + obj := { + "foo": { + "do": ["a", "b", "c"], + "re": ["d", "e", "f"], + }, + "bar": { + "mi": ["g", "h", "i"], + "fa": ["j", "k", "l"], + } + } + x := obj[r][t] + } + `, + query: `data.test.p.q = x`, + exp: `[{"x": {"bar": {"s": {"fa": {"u": ["j", "k", "l"]}, "mi": {"u": ["g", "h", "i"]}}}, "foo": {"s": {"do": {"u": ["a", "b", "c"]}, "re": {"u": ["d", "e", "f"]}}}}}]`, + }, + { + note: "deep query into partial object mid-point (general ref head) up to array value", + module: `package test + p.q[r].s[t].u := x { + obj := { + "foo": { + "do": ["a", "b", "c"], + "re": ["d", "e", "f"], + }, + "bar": { + "mi": ["g", "h", "i"], + "fa": ["j", "k", "l"], + } + } + x := obj[r][t] + } + `, + query: `data.test.p.q.bar.s = x`, + exp: `[{"x": {"fa": {"u": ["j", "k", "l"]}, "mi": {"u": ["g", "h", "i"]}}}]`, + }, + { + note: "deep query into partial object (general ref head) up to array value", + module: `package test + p.q[r].s[t].u := x { + obj := { + "foo": { + "do": ["a", "b", "c"], + "re": ["d", "e", "f"], + }, + "bar": { + "mi": ["g", "h", "i"], + "fa": ["j", "k", "l"], + } + } + x := obj[r][t] + } + `, + query: `data.test.p.q.bar.s.mi.u = x`, + exp: `[{"x": ["g", "h", "i"]}]`, + }, + { + note: "deep query into partial object (general ref head) and array value", + module: `package test + p.q[r].s[t].u := x { + obj := { + "foo": { + "do": ["a", "b", "c"], + "re": ["d", "e", "f"], + }, + "bar": { + "mi": ["g", "h", "i"], + "fa": ["j", "k", "l"], + } + } + x := obj[r][t] + } + `, + query: `data.test.p.q.foo.s.re.u[1] = x`, + exp: `[{"x": "e"}]`, + }, + { + note: "query up to (ref head), but not into partial set", + module: `package test + import future.keywords + p.q.r contains s { {"foo", "bar", "bax"}[s] } + `, + query: `data.test.p = x`, + exp: `[{"x": {"q": {"r": ["bar", "bax", "foo"]}}}]`, + }, + { + note: "deep query up to (ref mid-point), but not into partial set", + module: `package test + import future.keywords + p.q.r contains s { {"foo", "bar", "bax"}[s] } + `, + query: `data.test.p.q = x`, + exp: `[{"x": {"r": ["bar", "bax", "foo"]}}]`, + }, + { + note: "deep query up to (ref tail), but not into partial set", + module: `package test + import future.keywords + p.q.r contains s { {"foo", "bar", "bax"}[s] } + `, + query: `data.test.p.q.r = x`, + exp: `[{"x": ["bar", "bax", "foo"]}]`, + }, + { + note: "deep query into partial set", + module: `package test + import future.keywords + p.q contains r { {"foo", "bar", "bax"}[r] } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": "foo"}]`, + }, + { // enumeration + note: "deep query into partial object and object value, full depth, enumeration on object value", + module: `package test + p.q[r] := x { + r := ["foo", "bar"][_] + x := {"s": {"do": 0, "re": 1, "mi": 2}} + } + `, + query: `data.test.p.q.bar.s[y] = z`, + exp: `[{"y": "do", "z": 0}, {"y": "re", "z": 1}, {"y": "mi", "z": 2}]`, + }, + { // enumeration + note: "deep query into partial object and object value, full depth, enumeration on rule path and object value", + module: `package test + p.q[r] := x { + r := ["foo", "bar"][_] + x := {"s": {"do": 0, "re": 1, "mi": 2}} + } + `, + query: `data.test.p.q[x].s[y] = z`, + exp: `[{"x": "foo", "y": "do", "z": 0}, {"x": "foo", "y": "re", "z": 1}, {"x": "foo", "y": "mi", "z": 2}, {"x": "bar", "y": "do", "z": 0}, {"x": "bar", "y": "re", "z": 1}, {"x": "bar", "y": "mi", "z": 2}]`, + }, + { + note: "deep query into partial object (ref head) and set value", + module: `package test + import future.keywords + p.q contains t { + {"do", "re", "mi"}[t] + } + `, + query: `data.test.p.q.re = x`, + exp: `[{"x": "re"}]`, + }, + { + note: "deep query into partial object (general ref head) and set value", + module: `package test + import future.keywords + p.q[r] contains t { + r := ["foo", "bar"][_] + {"do", "re", "mi"}[t] + } + `, + query: `data.test.p.q.foo.re = x`, + exp: `[{"x": "re"}]`, + }, + { + note: "deep query into partial object (general ref head, static tail) and set value", + module: `package test + import future.keywords + p.q[r].s contains t { + r := ["foo", "bar"][_] + {"do", "re", "mi"}[t] + } + `, + query: `data.test.p.q.foo.s.re = x`, + exp: `[{"x": "re"}]`, + }, + { + note: "deep query into general ref to set value", + module: `package test + import future.keywords + p.q[r].s contains t { + r := ["foo", "bar"][_] + t := ["do", "re", "mi"][_] + } + `, + query: `data.test.p.q.foo.s = x`, + exp: `[{"x": ["do", "mi", "re"]}]`, // FIXME: set ordering makes this test brittle + }, + { + note: "deep query into general ref to object value", + module: `package test + p.q[r].s[t] := u { + r := ["foo", "bar"][_] + t := ["do", "re", "mi"][u] + } + `, + query: `data.test.p.q.foo.s = x`, + exp: `[{"x": {"do": 0, "re": 1, "mi": 2}}]`, + }, + { + note: "deep query into general ref enumerating set values", + module: `package test + import future.keywords + p.q[r].s contains t { + r := ["foo", "bar"][_] + {"do", "re", "mi"}[t] + } + `, + query: `data.test.p.q.foo.s[x]`, + // NOTE: $_term_0_0 wildcard var is filtered from eval result output + exp: `[{"$_term_0_0": "do", "x": "do"}, {"$_term_0_0": "re", "x": "re"}, {"$_term_0_0": "mi", "x": "mi"}]`, + }, + { + note: "deep query into partial object and object value, non-tail var", + module: `package test + p.q[r].s := x { + r := "foo" + x := {"bar": {"baz": 1}} + } + `, + query: `data.test.p.q.foo.s.bar = x`, + exp: `[{"x": {"baz": 1}}]`, + }, + { + note: "deep query into partial object, on first var in ref", + module: `package test + p.q[r].s := 1 { r := "foo" } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": {"s": 1}}]`, + }, + { + note: "deep query into partial object, beyond first var in ref", + module: `package test + p.q[r].s := 1 { r := "foo" } + `, + query: `data.test.p.q.foo.s = x`, + exp: `[{"x": 1}]`, + }, + { + note: "deep query into partial object, shallow rule ref", + module: `package test + p.q[r][s] := 1 { r := "foo"; s := "bar" } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": {"bar": 1}}]`, + }, + { + note: "deep query into partial object, shallow rule ref, multiple keys", + module: `package test + p.q[r][s] := t { l := ["do", "re", "mi"]; r := "foo"; s := l[t] } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": {"do": 0, "re": 1, "mi": 2}}]`, + }, + { + note: "deep query into partial object, beyond first var in ref, multiple vars", + module: `package test + p.q[r][s] := 1 { r := "foo"; s := "bar" } + `, + query: `data.test.p.q.foo.bar = x`, + exp: `[{"x": 1}]`, + }, + { + note: "deep query into partial object, beyond first var in ref, multiple vars", + module: `package test + p.q[r][s].t := 1 { r := "foo"; s := "bar" } + `, + query: `data.test.p.q.foo.bar = x`, + exp: `[{"x": {"t": 1}}]`, + }, + { + note: "deep query to partial object, overlapping rules (key override), no dynamic ref", + module: `package test + p.q[r] := 1 { r := "foo" } + p.q.r := 2 + `, + query: `data.test.p.q = x`, + exp: `[{"x": {"foo": 1, "r": 2}}]`, + }, + { + note: "deep query into partial object, overlapping rules (key override), no dynamic ref", + module: `package test + p.q[r] := 1 { r := "foo" } + p.q.r := 2 + `, + query: `data.test.p.q.r = x`, + exp: `[{"x": 2}]`, + }, + { + note: "deep query into partial object, overlapping rules, no dynamic ref", + module: `package test + p.q[r] := 1 { r := "foo" } + p.q[r] := 2 { r := "bar" } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": 1}]`, + }, + { + note: "deep query into partial object, overlapping rules with same key/value, no dynamic ref", + module: `package test + p.q[r] := 1 { r := "foo" } + p.q[r] := 1 { r := "foo" } + `, + query: `data.test.p.q.foo = x`, + exp: `[{"x": 1}]`, + }, + { + note: "deep query into partial object, overlapping rules, dynamic ref", + module: `package test + p.q[r].s := 1 { r := "r" } + p.q.r[s] := 2 { s := "foo" } + `, + query: `data.test.p.q.r = x`, + exp: `[{"x": {"s": 1, "foo": 2}}]`, + }, + { + note: "deep query into partial object, overlapping rules with same key/value, dynamic ref", + module: `package test + p.q[r].s := 1 { r := "r" } + p.q.r[s] := 1 { s := "s" } + `, + query: `data.test.p.q.r = x`, + exp: `[{"x": {"s": 1}}]`, + }, + // Multiple results (enumeration) + { + note: "shallow query into general ref, key enumeration", + module: `package test + p.q[r].s[t] := u { + r := ["a", "b", "c"][_] + t := ["d", "e", "f"][u] + }`, + query: `data.test.p.q[x] = y`, + exp: `[{"x": "a", "y": {"s": {"d": 0, "e": 1, "f": 2}}}, + {"x": "b", "y": {"s": {"d": 0, "e": 1, "f": 2}}}, + {"x": "c", "y": {"s": {"d": 0, "e": 1, "f": 2}}}]`, + }, + { + note: "query to partial object, overlapping rules, dynamic ref, key enumeration", + module: `package test + p.q[r].s := 1 { r := "foo" } + p.q[r].s := 2 { r := "bar" } + `, + query: `data.test.p.q[i] = x`, + exp: `[{"i": "bar", "x": {"s": 2}}, {"i": "foo", "x": {"s": 1}}]`, + }, + { + note: "deep query into partial object, overlapping rules, dynamic ref, key enumeration", + module: `package test + p.q[r].s := 1 { r := "foo" } + p.q[r].s := 2 { r := "bar" } + `, + query: `data.test.p.q[i].s = x`, + exp: `[{"i": "bar", "x": 2}, {"i": "foo", "x": 1}]`, + }, + // Errors + { + note: "partial object generating conflicting keys", + module: `package test + p[k] := x { + k := "foo" + x := [1, 2][_] + }`, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object (ref head) generating conflicting keys (dots in head)", + module: `package test + p.q[k] := x { + k := "foo" + x := [1, 2][_] + }`, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object (general ref head) generating conflicting nested keys", + module: `package test + p.q[k].s := x { + k := "foo" + x := [1, 2][_] + }`, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + { + note: "partial object (general ref head) generating conflicting ref vars", + module: `package test + p.q[k].s := x { + k := ["foo", "foo"][x] + }`, + query: `data = x`, + expErr: "eval_conflict_error: object keys must be unique", + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + compiler := compileModules([]string{tc.module}) + txn := storage.NewTransactionOrDie(ctx, store) + defer store.Abort(ctx, txn) + + query := NewQuery(ast.MustParseBody(tc.query)). + WithCompiler(compiler). + WithStore(store). + WithTransaction(txn) + + qrs, err := query.Run(ctx) + if tc.expErr != "" { + if err == nil { + t.Fatalf("Expected error %v but got result: %v", tc.expErr, qrs) + } + if exp, act := tc.expErr, err.Error(); !strings.Contains(act, exp) { + t.Fatalf("Expected error %v but got: %v", exp, act) + } + } else { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var exp []map[string]interface{} + _ = json.Unmarshal([]byte(tc.exp), &exp) + if expLen, act := len(exp), len(qrs); expLen != act { + t.Fatalf("expected %d query result:\n\n%+v,\n\ngot %d query results:\n\n%+v", expLen, exp, act, qrs) + } + testAssertResultSet(t, exp, qrs, false) + } + }) + } +} + +// TODO: Remove when general rule refs are enabled by default. +func TestGeneralRuleRefsFeatureFlag(t *testing.T) { + module := ast.MustParseModule(`package test + p[q].r { q := "q" }`) + mods := map[string]*ast.Module{ + "": module, + } + c := ast.NewCompiler() + c.Compile(mods) + + if !strings.Contains(c.Errors.Error(), "rego_type_error: rule head must only contain string terms (except for last)") { + t.Fatal("Expected error but got:", c.Errors) + } + + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") + + c = ast.NewCompiler() + c.Compile(mods) + + if c.Errors != nil { + t.Fatal("Unexpected error:", c.Errors) + } +} diff --git a/topdown/exported_test.go b/topdown/exported_test.go index 675132d423..82b6f52576 100644 --- a/topdown/exported_test.go +++ b/topdown/exported_test.go @@ -49,6 +49,10 @@ type opt func(*Query) *Query func testRun(t *testing.T, tc cases.TestCase, opts ...opt) { + for k, v := range tc.Env { + t.Setenv(k, v) + } + ctx := context.Background() modules := map[string]string{} diff --git a/topdown/topdown_partial_test.go b/topdown/topdown_partial_test.go index 689745f42d..126c7273a4 100644 --- a/topdown/topdown_partial_test.go +++ b/topdown/topdown_partial_test.go @@ -18,6 +18,8 @@ import ( ) func TestTopDownPartialEval(t *testing.T) { + // TODO: break out into separate tests + t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true") tests := []struct { note string @@ -137,7 +139,7 @@ func TestTopDownPartialEval(t *testing.T) { }, wantQueries: []string{}, }, - { + { // TODO: duplicate for general refs? note: "iterate rules: partial object", query: `data.test.p[x] = input.x`, modules: []string{ @@ -152,7 +154,7 @@ func TestTopDownPartialEval(t *testing.T) { `"d" = input.x; x = "c"`, }, }, - { + { // TODO: duplicate for general refs? note: "iterate rules: partial set", query: `input.x = x; data.test.p[x]`, modules: []string{ @@ -190,7 +192,7 @@ func TestTopDownPartialEval(t *testing.T) { `x = input; x[j] = z; y = [x]; i = 0`, }, }, - { + { // TODO: duplicate for general refs? note: "single term: save", query: `input.x = x; data.test.p[x]`, modules: []string{ @@ -222,6 +224,185 @@ func TestTopDownPartialEval(t *testing.T) { `x = input.x`, }, }, + { + note: "reference: partial object, general ref", + query: "data.test.p[x].q.foo = 1", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `x = input.a`, + }, + }, + { + note: "reference: partial object, general ref (2)", + query: "data.test.p[x].q.foo", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `x = input.a`, + }, + }, + { + note: "reference: partial object, general ref (3)", + query: "data.test.p[x].q", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `x = input.a`, + `x = input.b`, + }, + }, + { + note: "reference: partial object, general ref (4)", + query: "data.test.p[x].q[y]", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `x = input.a; y = "foo"`, + `x = input.b; y = "bar"`, + }, + }, + { + note: "reference: partial object, general ref (5)", + query: "data.test.p[x].q[y] = z", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `x = input.a; y = "foo"; z = 1`, + `x = input.b; y = "bar"; z = 2`, + }, + }, + { + note: "reference: partial object, general ref (6)", + query: "data.test.p = z", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `data.partial.test.p = z`, + }, + wantSupport: []string{ + `package partial.test + p[a2].q = {"foo": 1} { a2 = input.a } + p[a1].q = {"bar": 2} { a1 = input.b }`, + }, + }, + { + note: "reference: partial object, general ref (7)", + query: "data.test.p[x] = z", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 }`, + }, + wantQueries: []string{ + `x = input.a; z = {"q": {"foo": 1}}`, + `x = input.b; z = {"q": {"bar": 2}}`, + }, + }, + { + note: "reference: partial object, general ref (8)", + query: "data.test.p = z", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 } + p.foo.r = a { a = "baz" } + p.foo.s = a { a = input.c }`, + }, + wantQueries: []string{ + `data.partial.test.p = z`, + }, + wantSupport: []string{ + `package partial.test + p[a4].q = {"foo": 1} { a4 = input.a } + p[a3].q = {"bar": 2} { a3 = input.b } + p.foo.r = "baz" { true } + p.foo.s = a2 { a2 = input.c }`, + }, + }, + { + note: "reference: partial object, general ref (9)", + query: "data.test.p[x] = z", + modules: []string{ + `package test + p[a].q = {b: c} { a = input.a; b = "foo"; c = 1 } + p[a].q = {b: c} { a = input.b; b = "bar"; c = 2 } + p.foo.r = a { a = "baz" } + p.foo.s = a { a = input.c }`, + }, + wantQueries: []string{ + `x = input.a; z = {"q": {"foo": 1}}`, + `x = input.b; z = {"q": {"bar": 2}}`, + `x = "foo"; z = {"r": "baz"}`, + `z = {"s": input.c}; x = "foo"`, + }, + }, + { + note: "reference: partial object, general ref, multiple vars", + query: `data.test.p = x`, + modules: []string{ + `package test + p[q].r[s] := v { v := "foo"; q := 42; s := "bar" } + p[q].r[s].t := v { v := input.x; q := input.y; s := "baz" }`, + }, + wantQueries: []string{ + `data.partial.test.p = x`, + }, + wantSupport: []string{ + `package partial.test + p[42].r.bar = "foo" { true } + p[__local4__2].r.baz.t = __local3__2 { __local3__2 = input.x; __local4__2 = input.y }`, + }, + }, + { + note: "reference: partial object, general ref, multiple vars (2)", + query: `data.test.p[42] = x`, + modules: []string{ + `package test + p[q].r[s] := v { v := "foo"; q := 42; s := "bar" } + p[q].r[s].t := v { v := input.x; q := input.y; s := "baz" }`, + }, + wantQueries: []string{ + `x = {"r": {"bar": "foo"}}`, + `42 = input.y; x = {"r": {"baz": {"t": input.x}}}`, + }, + }, + { + note: "reference: partial object, general ref, multiple vars (2) (shallow)", + query: `data.test.p[42] = x`, + shallow: true, + modules: []string{ + `package test + #p[q].r[s] := v { v := "foo"; q := 42; s := "bar" } + #p[q].r[s].t := v { v := input.x; q := input.y; s := "baz" } + p[q][r][s].t := v { v := input.x; q := input.y; s := input.z; r := "known" }`, + }, + wantQueries: []string{ + `data.partial.test.p[42] = x`, + }, + wantSupport: []string{ + `package partial.test + p[__local1__1].known[__local2__1].t = __local0__1 { __local0__1 = input.x; __local1__1 = input.y; __local2__1 = input.z }`, + }, + }, { note: "reference: partial set", query: "data.test.p[x].foo = 1", @@ -234,6 +415,267 @@ func TestTopDownPartialEval(t *testing.T) { `1 = input.x; x = {"foo": 1}`, }, }, + { + note: "reference: partial set, general ref", + query: "data.test.p[x][y].foo = 1", + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `1 = input.x; y = {"foo": 1}; x = 42`, + }, + }, + { + note: "reference: partial set, general ref (2)", + query: "data.test.p[x][y].bar = 1", + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x = 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x = input.y }`, + }, + wantQueries: []string{ + `1 = input.x; x = input.y; y = {"bar": 1}`, + }, + }, + { + note: "reference: partial set, general ref (3)", + query: "data.test.p[42][y].foo = 1", + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `1 = input.x; y = {"foo": 1}`, + }, + }, + { + note: "reference: partial set, general ref (4)", + query: `data.test.p[x][y] = {"foo": 1}`, + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `1 = input.x; y = {"foo": 1}; x = 42`, + }, + }, + { + note: "reference: partial set, general ref (5)", + query: `data.test.p[x] = {{"foo": 1}}`, + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `{{"foo": input.x}} = {{"foo": 1}}; x = 42`, // `1 = input.x; x = 42` would be a more precise optimization (?) + }, + }, + { + note: "reference: partial set, general ref (6)", + query: `data.test.p`, + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `data.partial.test.p`, + }, + wantSupport: []string{ + `package partial.test + import future.keywords.contains + p[42] contains {"foo": b1} { b1 = input.x } + p[__local1__2] contains {"bar": b2} { b2 = input.x; __local1__2 = input.y }`, + }, + }, + { + note: "reference: partial set, general ref (7)", + query: `data.test.p = x`, + modules: []string{ + `package test + import future.keywords.contains + p[x] contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x] contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `data.partial.test.p = x`, + }, + wantSupport: []string{ + `package partial.test + import future.keywords.contains + p[42] contains {"foo": b1} { b1 = input.x } + p[__local1__2] contains {"bar": b2} { b2 = input.x; __local1__2 = input.y }`, + }, + }, + { + note: "reference: partial set, general ref (8)", + query: `data.test.p = x`, + modules: []string{ + `package test + import future.keywords.contains + p[x].r contains y { y = {a: b}; a = "foo"; b = input.x; x := 42 } + p[x].r contains y { y = {a: b}; a = "bar"; b = input.x; x := input.y }`, + }, + wantQueries: []string{ + `data.partial.test.p = x`, + }, + wantSupport: []string{ + `package partial.test + import future.keywords.contains + p[42].r contains {"foo": b1} { b1 = input.x } + p[__local1__2].r contains {"bar": b2} { b2 = input.x; __local1__2 = input.y }`, + }, + }, + { + note: "reference: partial set, general ref, multiple vars", + query: `data.test.p = x`, + modules: []string{ + `package test + import future.keywords.contains + p[q].r[s] contains x { x = "foo"; q := 42; s = "bar" } + p[q].r[s].t contains x { x = input.x; q := input.y; s = "baz" }`, + }, + wantQueries: []string{ + `data.partial.test.p = x`, + }, + wantSupport: []string{ + `package partial.test + import future.keywords.contains + p[42].r.bar contains "foo" { true } + p[__local1__2].r.baz.t contains x2 { x2 = input.x; __local1__2 = input.y }`, + }, + }, + { + note: "reference: partial set, general ref, multiple vars (2)", + query: `data.test.p[42] = x`, + modules: []string{ + `package test + import future.keywords.contains + p[q].r[s] contains v { v := "foo"; q := 42; s := "bar" } + p[q].r[s].t contains v { v := input.x; q := input.y; s := "baz" }`, + }, + wantQueries: []string{ + `x = {"r": {"bar": {"foo"}}}`, + `42 = input.y; x = {"r": {"baz": {"t": {input.x}}}}`, + }, + }, + { + note: "reference: partial set, general ref, multiple vars (3)", + query: `data.test.p.foo = x`, + modules: []string{ + `package test + import future.keywords.contains + p[q].r[s] contains x { x = "foo"; q := 42; s = "bar" } + p[q].r[s].t contains x { x = input.x; q := input.y; s = "baz" }`, + }, + wantQueries: []string{ + `"foo" = input.y; x = {"r": {"baz": {"t": {input.x}}}}`, + }, + }, + { + note: "reference: partial object, unknown in query ref", + query: "data.test.p[input.x]", + modules: []string{ + `package test + p[q].r[s] = v { q = {"foo", "bar"}[s]; v = "baz" } + p.q.r.s := 1`, + }, + wantQueries: []string{ + `"foo" = input.x`, + `"bar" = input.x`, + `"q" = input.x`, + }, + }, + { + note: "reference: partial object, unknown in query ref (2)", + query: "data.test.p.foo.r[input.x]", + modules: []string{ + `package test + p[q].r[s] = v { q = {"foo", "bar"}[s]; v = "baz" } + p.q.r.s := 1`, + }, + wantQueries: []string{ + `"foo" = input.x`, + }, + }, + { + note: "reference: partial object, unknown in query ref (3)", + query: "data.test.p[input.x].r[input.y]", + modules: []string{ + `package test + p[q].r[s] = v { q = {"foo", "bar"}[s]; v = "baz" } + p.q.r.s := 1`, + }, + wantQueries: []string{ + `"foo" = input.x; "foo" = input.y`, + `"bar" = input.x; "bar" = input.y`, + `"q" = input.x; "s" = input.y`, + }, + }, + { + note: "reference: partial object, unknown in query ref (4)", + query: "data.test.p[x].r[y][input.x]", + modules: []string{ + `package test + p[q].r[s] = {v: w} { q = {"foo", "bar"}[s]; v = "baz"; w = "bax" } + p.q.r.s := {1: 2}`, + }, + wantQueries: []string{ + `"baz" = input.x; x = "foo"; y = "foo"`, + `"baz" = input.x; x = "bar"; y = "bar"`, + `1 = input.x; x = "q"; y = "s"`, + }, + }, + { + note: "reference: partial object, unknown in query ref (5)", + query: "data.test.p[x].r[y][input.x] = input.y", + modules: []string{ + `package test + p[q].r[s] = {v: w} { q = {"foo", "bar"}[s]; v = "baz"; w = "bax" } + p.q.r.s := {1: 2}`, + }, + wantQueries: []string{ + `"baz" = input.x; "bax" = input.y; x = "foo"; y = "foo"`, + `"baz" = input.x; "bax" = input.y; x = "bar"; y = "bar"`, + `1 = input.x; 2 = input.y; x = "q"; y = "s"`, + }, + }, + { + note: "reference: partial object, unknown in query ref (6)", + query: `data.test.p[x].r[y][input.x] = "bax"`, + modules: []string{ + `package test + p[q].r[s] = {v: w} { q = {"foo", "bar"}[s]; v = "baz"; w = "bax" } + p.q.r.s := {1: 2}`, + }, + wantQueries: []string{ + `"baz" = input.x; x = "foo"; y = "foo"`, + `"baz" = input.x; x = "bar"; y = "bar"`, + }, + }, + { + note: "reference: partial object, unknown in query ref (7)", + query: `data.test.p[x].r[y][input.x] = 2`, + modules: []string{ + `package test + p[q].r[s] = {v: w} { q = {"foo", "bar"}[s]; v = "baz"; w = "bax" } + p.q.r.s := {1: 2}`, + }, + wantQueries: []string{ + `1 = input.x; x = "q"; y = "s"`, + }, + }, { note: "reference: complete", query: "data.test.p = 1", @@ -246,6 +688,18 @@ func TestTopDownPartialEval(t *testing.T) { `input.x = 1`, }, }, + { + note: "reference: complete, ref head", + query: "data.test.p.q = 1", + modules: []string{ + `package test + + p.q = x { input.x = x }`, + }, + wantQueries: []string{ + `input.x = 1`, + }, + }, { note: "reference: complete: suffix", query: "data.test.p = true", @@ -301,6 +755,38 @@ func TestTopDownPartialEval(t *testing.T) { `y.bar = 1; z1 = input.foo[y]`, }, }, + { + note: "reference: ref head: from query", + query: "data.test.p.q[y] = 1", + modules: []string{ + `package test + + p.q[x] = 1 { + input.foo[x] = z + x.bar = 1 + } + `, + }, + wantQueries: []string{ + `y.bar = 1; z1 = input.foo[y]`, + }, + }, + { + note: "reference: general ref head: from query", + query: "data.test.p.q[y].s = 1", + modules: []string{ + `package test + + p.q[x].s = 1 { + input.foo[x] = z + x.bar = 1 + } + `, + }, + wantQueries: []string{ + `y.bar = 1; z1 = input.foo[y]`, + }, + }, { note: "reference: head: applied", query: "data.test.p = true", @@ -317,11 +803,12 @@ func TestTopDownPartialEval(t *testing.T) { x.b = 2 }`, }, + // FIXME: is this a problem? wantQueries: []string{` - input[x_term_1_01] - x_term_1_01.b = 2 - x_term_1_01 - x_term_1_01.a = 1 + input[x_ref_01] + x_ref_01.b = 2 + x_ref_01 + x_ref_01.a = 1 `}, }, { @@ -381,6 +868,28 @@ func TestTopDownPartialEval(t *testing.T) { `input.x = "foo"; x = "foo"; y = 2`, }, }, + { + note: "namespace: partial object, ref head", + query: "input.x = x; data.test.p.q[x] = y; y = 2", + modules: []string{ + `package test + p.q[y] = x { y = "foo"; x = 2 }`, + }, + wantQueries: []string{ + `input.x = "foo"; x = "foo"; y = 2`, + }, + }, + { + note: "namespace: partial object, general ref head", + query: "input.x = x; input.y = y; data.test.p.q[x][y] = z; z = 2", + modules: []string{ + `package test + p.q[x][y] = z { x = "foo"; y = "bar"; z = 2 }`, + }, + wantQueries: []string{ + `input.x = "foo"; input.y = "bar"; x = "foo"; y = "bar"; z = 2`, + }, + }, { note: "namespace: embedding", query: "data.test.p = x", @@ -1144,6 +1653,23 @@ func TestTopDownPartialEval(t *testing.T) { p[x2] { input.x = x2 } `}, }, + { + note: "automatic shallow inlining: full extent: partial set, general ref head", + query: "data.test.p.q = x", + modules: []string{ + `package test + import future.keywords.contains + p.q contains x { input.x = x } + p.q[r].s contains t { input.r = r; input.t = t }`, + }, + wantQueries: []string{`data.partial.test.p.q = x`}, + wantSupport: []string{` + package partial.test.p + import future.keywords.contains + q[x2] { input.x = x2 } + q[r1].s contains t1 { input.r = r1; input.t = t1 } + `}, + }, { note: "automatic shallow inlining: full extent: partial object", query: "data.test.p = x", @@ -1159,6 +1685,21 @@ func TestTopDownPartialEval(t *testing.T) { p[x2] = y2 { x2 = input.x; y2 = input.y } `}, }, + { + note: "automatic shallow inlining: full extent: partial object, general ref head", + query: "data.test.p.q = x", + modules: []string{ + `package test + p.q[x] = y { x = input.x; y = input.y } + p.q[r].s[t] = y { r = input.r; t = input.t; y = input.y }`, + }, + wantQueries: []string{`data.partial.test.p.q = x`}, + wantSupport: []string{` + package partial.test.p + q[x2] = y2 { x2 = input.x; y2 = input.y } + q[r1].s[t1] = y1 { r1 = input.r; t1 = input.t; y1 = input.y } + `}, + }, { note: "automatic shallow inlining: full extent: no solutions", query: "data.test.p = x", @@ -1174,19 +1715,27 @@ func TestTopDownPartialEval(t *testing.T) { query: "data.test[x] = y", modules: []string{ `package test + import future.keywords.contains s[x] { x = input.x } + s2[x].u contains y { x = input.x; y = input.y } p[x] = y { x = input.x; y = input.y } + p2[x].r[y] = z { x = input.x; y = input.y; z = input.z } r = x { x = input.x }`, }, wantQueries: []string{ `data.partial.test.s = y; x = "s"`, + `data.partial.test.s2 = y; x = "s2"`, `data.partial.test.p = y; x = "p"`, + `data.partial.test.p2 = y; x = "p2"`, `y = input.x; x = "r"`, }, wantSupport: []string{` package partial.test + import future.keywords.contains p[x1] = y1 { x1 = input.x; y1 = input.y } - s[x3] { x3 = input.x } + p2[x2].r[y2] = z2 { x2 = input.x; y2 = input.y; z2 = input.z } + s[x4] { x4 = input.x } + s2[x5].u contains y5 { x5 = input.x; y5 = input.y } `}, }, { @@ -3329,6 +3878,16 @@ func TestTopDownPartialEval(t *testing.T) { }`}, wantQueries: []string{`"bar" = input.a; "baz" = input.b`}, }, + { + note: "general ref heads: \"triple\" unification, single-value rule", + query: "data.test.foo[input.a][input.b][input.c]", + modules: []string{`package test + foo.bar[baz][bax] { + baz := "baz" + bax := "bax" + }`}, + wantQueries: []string{`"bar" = input.a; "baz" = input.b; "bax" = input.c`}, + }, { // https://github.com/open-policy-agent/opa/issues/6027 note: "ref heads: \"double\" unification, multi-value rule", query: "data.test.foo[input.a][input.b]", @@ -3339,6 +3898,17 @@ func TestTopDownPartialEval(t *testing.T) { }`}, wantQueries: []string{`"bar" = input.a; "baz" = input.b`}, }, + { + note: "general ref heads: \"triple\" unification, multi-value rule", + query: "data.test.foo[input.a][input.b][input.c]", + modules: []string{`package test + import future.keywords.contains + foo.bar[baz] contains bax { + baz := "baz" + bax := "bax" + }`}, + wantQueries: []string{`"bar" = input.a; "baz" = input.b; "bax" = input.c`}, + }, { note: "ref heads: unknown rule value", query: "data.test.p.q[x]", diff --git a/types/types.go b/types/types.go index bc2a50b4c1..96d5140d7c 100644 --- a/types/types.go +++ b/types/types.go @@ -430,6 +430,77 @@ func (t *Object) Select(name interface{}) Type { return nil } +func (t *Object) Merge(other Type) *Object { + if otherObj, ok := other.(*Object); ok { + return mergeObjects(t, otherObj) + } + + var typeK Type + var typeV Type + dynProps := t.DynamicProperties() + if dynProps != nil { + typeK = Or(Keys(other), dynProps.Key) + typeV = Or(Values(other), dynProps.Value) + dynProps = NewDynamicProperty(typeK, typeV) + } else { + typeK = Keys(other) + typeV = Values(other) + if typeK != nil && typeV != nil { + dynProps = NewDynamicProperty(typeK, typeV) + } + } + + return NewObject(t.StaticProperties(), dynProps) +} + +func mergeObjects(a, b *Object) *Object { + var dynamicProps *DynamicProperty + if a.dynamic != nil && b.dynamic != nil { + typeK := Or(a.dynamic.Key, b.dynamic.Key) + var typeV Type + aObj, aIsObj := a.dynamic.Value.(*Object) + bObj, bIsObj := b.dynamic.Value.(*Object) + if aIsObj && bIsObj { + typeV = mergeObjects(aObj, bObj) + } else { + typeV = Or(a.dynamic.Value, b.dynamic.Value) + } + dynamicProps = NewDynamicProperty(typeK, typeV) + } else if a.dynamic != nil { + dynamicProps = a.dynamic + } else { + dynamicProps = b.dynamic + } + + staticPropsMap := make(map[interface{}]Type) + + for _, sp := range a.static { + staticPropsMap[sp.Key] = sp.Value + } + + for _, sp := range b.static { + currV := staticPropsMap[sp.Key] + if currV != nil { + currVObj, currVIsObj := currV.(*Object) + spVObj, spVIsObj := sp.Value.(*Object) + if currVIsObj && spVIsObj { + staticPropsMap[sp.Key] = mergeObjects(currVObj, spVObj) + } else { + staticPropsMap[sp.Key] = Or(currV, sp.Value) + } + } else { + staticPropsMap[sp.Key] = sp.Value + } + } + + staticProps := make([]*StaticProperty, 0, len(staticPropsMap)) + for k, v := range staticPropsMap { + staticProps = append(staticProps, NewStaticProperty(k, v)) + } + + return NewObject(staticProps, dynamicProps) +} + // Any represents a dynamic type. type Any []Type