From 62834a22a613554ea83215cc0485719207bcb164 Mon Sep 17 00:00:00 2001 From: Johan Fylling Date: Tue, 28 May 2024 10:16:58 +0200 Subject: [PATCH] Asserting `every` domain is an collection type before evaluation (#6763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixing an issue where a non-collection `every`-domain didn’t fail evaluation. Removing a possible attack surface, where an attacker with the ability to craft portions of the input document could replace a value with an expected collection type, that is known to be processed by an `every`-statement, with a non-collection value and thereby would cause the policy to accept a query that should otherwise be rejected. Fixes: #6762 Signed-off-by: Johan Fylling --- internal/compiler/wasm/wasm.go | 11 ++ internal/planner/planner.go | 34 ++++ ir/ir.go | 7 + test/cases/testdata/every/every.yaml | 93 ++++++++++- .../testdata/every/non_iterable_domain.yaml | 151 ++++++++++++++++++ topdown/eval.go | 54 ++++--- 6 files changed, 329 insertions(+), 21 deletions(-) create mode 100644 test/cases/testdata/every/non_iterable_domain.yaml diff --git a/internal/compiler/wasm/wasm.go b/internal/compiler/wasm/wasm.go index 3984d8662f..b827ebb91b 100644 --- a/internal/compiler/wasm/wasm.go +++ b/internal/compiler/wasm/wasm.go @@ -1139,6 +1139,17 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err instrs = append(instrs, instruction.Br{Index: 0}) break } + case *ir.IsSetStmt: + if loc, ok := stmt.Source.Value.(ir.Local); ok { + instrs = append(instrs, instruction.GetLocal{Index: c.local(loc)}) + instrs = append(instrs, instruction.Call{Index: c.function(opaValueType)}) + instrs = append(instrs, instruction.I32Const{Value: opaTypeSet}) + instrs = append(instrs, instruction.I32Ne{}) + instrs = append(instrs, instruction.BrIf{Index: 0}) + } else { + instrs = append(instrs, instruction.Br{Index: 0}) + break + } case *ir.IsUndefinedStmt: instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)}) instrs = append(instrs, instruction.I32Const{Value: 0}) diff --git a/internal/planner/planner.go b/internal/planner/planner.go index a6405291e7..df1d66b85e 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -893,6 +893,40 @@ func (p *Planner) planExprEvery(e *ast.Expr, iter planiter) error { }) err := p.planTerm(every.Domain, func() error { + // Assert that the domain is a collection type: + // block outer + // block a + // isArray + // br 1: break outer, and continue + // block b + // isObject + // br 1: break outer, and continue + // block c + // isSet + // br 1: break outer, and continue + // br 1: invalid domain, break every + + aBlock := &ir.Block{} + p.appendStmtToBlock(&ir.IsArrayStmt{Source: p.ltarget}, aBlock) + p.appendStmtToBlock(&ir.BreakStmt{Index: 1}, aBlock) + + bBlock := &ir.Block{} + p.appendStmtToBlock(&ir.IsObjectStmt{Source: p.ltarget}, bBlock) + p.appendStmtToBlock(&ir.BreakStmt{Index: 1}, bBlock) + + cBlock := &ir.Block{} + p.appendStmtToBlock(&ir.IsSetStmt{Source: p.ltarget}, cBlock) + p.appendStmtToBlock(&ir.BreakStmt{Index: 1}, cBlock) + + outerBlock := &ir.BlockStmt{Blocks: []*ir.Block{ + { + Stmts: []ir.Stmt{ + &ir.BlockStmt{Blocks: []*ir.Block{aBlock, bBlock, cBlock}}, + &ir.BreakStmt{Index: 1}}, + }, + }} + p.appendStmt(outerBlock) + return p.planScan(every.Key, func(ir.Local) error { p.appendStmt(&ir.ResetLocalStmt{ Target: cond1, diff --git a/ir/ir.go b/ir/ir.go index d98b96e2c8..c07670704e 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -364,6 +364,13 @@ type IsObjectStmt struct { Location } +// IsSetStmt represents a dynamic type check on a local variable. +type IsSetStmt struct { + Source Operand `json:"source"` + + Location +} + // IsDefinedStmt represents a check of whether a local variable is defined. type IsDefinedStmt struct { Source Local `json:"source"` diff --git a/test/cases/testdata/every/every.yaml b/test/cases/testdata/every/every.yaml index b2d1da96a6..d5419e8145 100644 --- a/test/cases/testdata/every/every.yaml +++ b/test/cases/testdata/every/every.yaml @@ -9,7 +9,50 @@ cases: p { every x in [] { x != x } } - note: every/empty domain + note: every/empty domain (array) + query: data.test.p = x + want_result: + - x: true + - data: + modules: + - | + package test + import future.keywords.every + + p { + every x in set() { x != x } + } + note: every/empty domain (set) + query: data.test.p = x + want_result: + - x: true + - data: + modules: + - | + package test + import future.keywords.every + + p { + every x in {} { x != x } + } + note: every/empty domain (object) + query: data.test.p = x + want_result: + - x: true + - data: + modules: + - | + package test + import future.keywords.every + + l[1] { + false + } + + p { + every x in l { x != x } + } + note: every/empty domain (partial rule ref) query: data.test.p = x want_result: - x: true @@ -22,7 +65,19 @@ cases: p { every _ in input { true } } - note: every/domain undefined + note: every/domain undefined (input) + query: data.test.p = x + want_result: [] + - data: + modules: + - | + package test + import future.keywords.every + + p { + every _ in data.foo { true } + } + note: every/domain undefined (data ref) query: data.test.p = x want_result: [] - data: @@ -57,6 +112,40 @@ cases: package test import future.keywords.every + p { + every k, v in {1, 2} { k == v } + } + note: every/simple key/val (set) + query: data.test.p = x + want_result: + - x: true + - data: + modules: + - | + package test + import future.keywords.every + + l[1] { + true + } + + l[2] { + true + } + + p { + every k, v in l { k == v } + } + note: every/simple key/val (partial rule ref) + query: data.test.p = x + want_result: + - x: true + - data: + modules: + - | + package test + import future.keywords.every + p { i := 10 every k, v in [1, 2] { k+v != i } diff --git a/test/cases/testdata/every/non_iterable_domain.yaml b/test/cases/testdata/every/non_iterable_domain.yaml new file mode 100644 index 0000000000..6b0d57b725 --- /dev/null +++ b/test/cases/testdata/every/non_iterable_domain.yaml @@ -0,0 +1,151 @@ +--- +cases: + - note: "every/non-iter domain: int" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in 42 { v > 1 } + } + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: string" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in "foobar" { v > 1 } + } + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: bool" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in true { v > 1 } + } + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: null" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in null { v > 1 } + } + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: built-in call" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in floor(13.37) { v > 1 } + } + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: function call" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in foo(1, 2) { v > 1 } + } + + foo(a, b) := a + b + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: rule ref" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in q { v > 1 } + } + + q := 1 + query: data.test.p = x + want_result: + - x: 1 + - note: "every/non-iter domain: data int" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in data.iterate_me { v > 1 } + } + query: data.test.p = x + data: + iterate_me: 1 + want_result: + - x: 1 + - note: "every/non-iter domain: input int" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in input.iterate_me { v > 1 } + } + query: data.test.p = x + input: + iterate_me: 1 + want_result: + - x: 1 + - note: "every/non-iter domain: input int (1st level)" + modules: + - | + package test + import future.keywords.every + + default p := 1 + + p := 2 { + every v in input { v > 1 } + } + query: data.test.p = x + input: 1 + want_result: + - x: 1 diff --git a/topdown/eval.go b/topdown/eval.go index 9e93bbc09e..f0ce3b2a60 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -407,15 +407,9 @@ func (e *eval) evalStep(iter evalIterator) error { }) case *ast.Every: eval := evalEvery{ - e: e, - expr: expr, - generator: ast.NewBody( - ast.Equality.Expr( - ast.RefTerm(terms.Domain, terms.Key).SetLocation(terms.Domain.Location), - terms.Value, - ).SetLocation(terms.Domain.Location), - ), - body: terms.Body, + Every: terms, + e: e, + expr: expr, } err = eval.eval(func() error { defined = true @@ -3390,19 +3384,32 @@ func (e evalTerm) save(iter unifyIterator) error { } type evalEvery struct { - e *eval - expr *ast.Expr - generator ast.Body - body ast.Body + *ast.Every + e *eval + expr *ast.Expr } func (e evalEvery) eval(iter unifyIterator) error { // unknowns in domain or body: save the expression, PE its body - if e.e.unknown(e.generator, e.e.bindings) || e.e.unknown(e.body, e.e.bindings) { + if e.e.unknown(e.Domain, e.e.bindings) || e.e.unknown(e.Body, e.e.bindings) { return e.save(iter) } - domain := e.e.closure(e.generator) + if pd := e.e.bindings.Plug(e.Domain); pd != nil { + if !isIterableValue(pd.Value) { + e.e.traceFail(e.expr) + return nil + } + } + + generator := ast.NewBody( + ast.Equality.Expr( + ast.RefTerm(e.Domain, e.Key).SetLocation(e.Domain.Location), + e.Value, + ).SetLocation(e.Domain.Location), + ) + + domain := e.e.closure(generator) all := true // all generator evaluations yield one successful body evaluation domain.traceEnter(e.expr) @@ -3413,14 +3420,14 @@ func (e evalEvery) eval(iter unifyIterator) error { // This would do extra work, like iterating needlessly if domain was a large array. return nil } - body := child.closure(e.body) + body := child.closure(e.Body) body.findOne = true - body.traceEnter(e.body) + body.traceEnter(e.Body) done := false err := body.eval(func(*eval) error { - body.traceExit(e.body) + body.traceExit(e.Body) done = true - body.traceRedo(e.body) + body.traceRedo(e.Body) return nil }) if !done { @@ -3446,6 +3453,15 @@ func (e evalEvery) eval(iter unifyIterator) error { return nil } +// isIterableValue returns true if the AST value is an iterable type. +func isIterableValue(x ast.Value) bool { + switch x.(type) { + case *ast.Array, ast.Object, ast.Set: + return true + } + return false +} + func (e *evalEvery) save(iter unifyIterator) error { return e.e.saveExpr(e.plug(e.expr), e.e.bindings, iter) }