Skip to content

Commit

Permalink
internal/wasm/sdk: Support for Set in input/output
Browse files Browse the repository at this point in the history
Instead of round tripping through JSON we use the newer value parse
and dump helpers along with the value stringer and ast term parser
to round preserve the rego typing.

Signed-off-by: Patrick East <east.patrick@gmail.com>
  • Loading branch information
patrick-east authored and tsandall committed Nov 6, 2020
1 parent af85ba6 commit 76f232c
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 110 deletions.
16 changes: 0 additions & 16 deletions internal/wasm/sdk/examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ func main() {

fmt.Printf("Policy 1 result: %v\n", result)

resultBool, err := opa.EvalBool(ctx, rego, entrypointID, &input)
if err != nil {
fmt.Printf("error: %v\n", err)
return
}

fmt.Printf("Policy 1 boolean result: %v\n", resultBool)

// Update the policy on the fly.

policy, err = ioutil.ReadFile(path.Join(directory, "example-2.wasm"))
Expand Down Expand Up @@ -112,12 +104,4 @@ func main() {
}

fmt.Printf("Policy 2 result: %v\n", result)

resultBool, err = opa.EvalBool(ctx, rego, entrypointID, &input)
if err != nil {
fmt.Printf("error: %v\n", err)
return
}

fmt.Printf("Policy 2 boolean result: %v\n", resultBool)
}
39 changes: 1 addition & 38 deletions internal/wasm/sdk/opa/opa.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type OPA struct {

// Result holds the evaluation result.
type Result struct {
Result interface{}
Result []byte
}

// EntrypointID is used by Eval() to determine which compiled entrypoint should
Expand Down Expand Up @@ -199,40 +199,3 @@ func (o *OPA) Entrypoints(ctx context.Context) (map[string]EntrypointID, error)

return instance.Entrypoints(), nil
}

// EvalBool evaluates the boolean policy with the given input. The
// possible error values returned are as with Eval with addition of
// ErrUndefined indicating an undefined policy decision and
// ErrNonBoolean indicating a non-boolean policy decision.
// Deprecated: Use Eval instead.
func EvalBool(ctx context.Context, o *OPA, entrypoint EntrypointID, input *interface{}) (bool, error) {
rs, err := o.Eval(ctx, EvalOpts{
Entrypoint: entrypoint,
Input: input,
})
if err != nil {
return false, err
}

r, ok := rs.Result.([]interface{})
if !ok || len(r) == 0 {
return false, ErrUndefined
}

m, ok := r[0].(map[string]interface{})
if !ok || len(m) != 1 {
return false, ErrNonBoolean
}

var b bool
for _, v := range m {
b, ok = v.(bool)
break
}

if !ok {
return false, ErrNonBoolean
}

return b, nil
}
48 changes: 25 additions & 23 deletions internal/wasm/sdk/opa/opa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package opa_test
import (
"context"
"fmt"
"reflect"
"testing"

"github.com/open-policy-agent/opa/ast"
Expand Down Expand Up @@ -37,17 +36,17 @@ func TestOPA(t *testing.T) {
Policy: `a = true`,
Query: "data.p.a = x",
Evals: []Eval{
Eval{Result: `[{"x": true}]`},
Eval{Result: `[{"x": true}]`},
Eval{Result: `{{"x": true}}`},
Eval{Result: `{{"x": true}}`},
},
},
{
Description: "Only input changing",
Policy: `a = input`,
Query: "data.p.a = x",
Evals: []Eval{
Eval{Input: "false", Result: `[{"x": false}]`},
Eval{Input: "true", Result: `[{"x": true}]`},
Eval{Input: "false", Result: `{{"x": false}}`},
Eval{Input: "true", Result: `{{"x": true}}`},
},
},
{
Expand All @@ -56,8 +55,8 @@ func TestOPA(t *testing.T) {
Query: "data.p.a = x",
Data: `{"q": false}`,
Evals: []Eval{
Eval{Result: `[{"x": false}]`},
Eval{NewData: `{"q": true}`, Result: `[{"x": true}]`},
Eval{Result: `{{"x": false}}`},
Eval{NewData: `{"q": true}`, Result: `{{"x": true}}`},
},
},
{
Expand All @@ -66,8 +65,8 @@ func TestOPA(t *testing.T) {
Query: "data.p.a = x",
Data: `{"q": false, "r": true}`,
Evals: []Eval{
Eval{Result: `[{"x": false}]`},
Eval{NewPolicy: `a = data.r`, Result: `[{"x": true}]`},
Eval{Result: `{{"x": false}}`},
Eval{NewPolicy: `a = data.r`, Result: `{{"x": true}}`},
},
},
{
Expand All @@ -76,25 +75,25 @@ func TestOPA(t *testing.T) {
Query: "data.p.a = x",
Data: `{"q": 0, "r": 1}`,
Evals: []Eval{
Eval{Result: `[{"x": 0}]`},
Eval{NewPolicy: `a = data.r`, NewData: `{"q": 2, "r": 3}`, Result: `[{"x": 3}]`},
Eval{Result: `{{"x": 0}}`},
Eval{NewPolicy: `a = data.r`, NewData: `{"q": 2, "r": 3}`, Result: `{{"x": 3}}`},
},
},
{
Description: "Builtins",
Policy: `a = count(data.q) + sum(data.q)`,
Query: "data.p.a = x",
Evals: []Eval{
Eval{NewData: `{"q": []}`, Result: `[{"x": 0}]`},
Eval{NewData: `{"q": [1, 2]}`, Result: `[{"x": 5}]`},
Eval{NewData: `{"q": []}`, Result: `{{"x": 0}}`},
Eval{NewData: `{"q": [1, 2]}`, Result: `{{"x": 5}}`},
},
},
{
Description: "Undefined decision",
Policy: `a = true`,
Query: "data.p.b = x",
Evals: []Eval{
Eval{Result: `[]`},
Eval{Result: `set()`},
},
},
}
Expand Down Expand Up @@ -140,13 +139,14 @@ func TestOPA(t *testing.T) {
}
}

result, err := instance.Eval(context.Background(), opa.EvalOpts{Input: parseJSON(eval.Input)})
r, err := instance.Eval(context.Background(), opa.EvalOpts{Input: parseJSON(eval.Input)})
if err != nil {
t.Errorf(err.Error())
t.Fatalf(err.Error())
}

if !reflect.DeepEqual(*parseJSON(eval.Result), result.Result) {
t.Errorf("Incorrect result: %v", result.Result)
expected := ast.MustParseTerm(eval.Result)
if !ast.MustParseTerm(string(r.Result)).Equal(expected) {
t.Errorf("\nExpected: %v\nGot: %v\n", expected, string(r.Result))
}
}

Expand Down Expand Up @@ -203,18 +203,20 @@ func TestNamedEntrypoint(t *testing.T) {
t.Fatalf("Unexpected error: %s", err)
}

exp := `[{"result":7}]`
if !reflect.DeepEqual(*parseJSON(exp), a.Result) {
t.Fatalf("Expected result for 'test/a' to be %s, got: %s", exp, string(util.MustMarshalJSON(a.Result)))
exp := ast.MustParseTerm(`{{"result":7}}`)
actual := ast.MustParseTerm(string(a.Result))
if !actual.Equal(exp) {
t.Fatalf("Expected result for 'test/a' to be %s, got: %s", exp, actual)
}

b, err := instance.Eval(ctx, opa.EvalOpts{Entrypoint: eps["test/b"]})
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}

if !reflect.DeepEqual(*parseJSON(exp), b.Result) {
t.Fatalf("Expected result for 'test/b' to be %s, got: %s", exp, string(util.MustMarshalJSON(b.Result)))
actual = ast.MustParseTerm(string(b.Result))
if !actual.Equal(exp) {
t.Fatalf("Expected result for 'test/b' to be %s, got: %s", exp, actual)
}
}

Expand Down
24 changes: 20 additions & 4 deletions internal/wasm/sdk/opa/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type vm struct {
heapTopSet func(...interface{}) (wasm.Value, error)
jsonDump func(...interface{}) (wasm.Value, error)
jsonParse func(...interface{}) (wasm.Value, error)
valueDump func(...interface{}) (wasm.Value, error)
valueParse func(...interface{}) (wasm.Value, error)
malloc func(...interface{}) (wasm.Value, error)
}

Expand Down Expand Up @@ -96,6 +98,8 @@ func newVM(policy []byte, data []byte, memoryMin, memoryMax uint32) (*vm, error)
heapTopSet: i.Exports["opa_heap_top_set"],
jsonDump: i.Exports["opa_json_dump"],
jsonParse: i.Exports["opa_json_parse"],
valueDump: i.Exports["opa_value_dump"],
valueParse: i.Exports["opa_value_parse"],
malloc: i.Exports["opa_malloc"],
}

Expand Down Expand Up @@ -170,7 +174,9 @@ func newVM(policy []byte, data []byte, memoryMin, memoryMax uint32) (*vm, error)
return v, nil
}

func (i *vm) Eval(ctx context.Context, entrypoint EntrypointID, input *interface{}, metrics metrics.Metrics) (interface{}, error) {
// Eval performs an evaluation of the specified entrypoint, with any provided
// input, and returns the resulting value dumped to a string.
func (i *vm) Eval(ctx context.Context, entrypoint EntrypointID, input *interface{}, metrics metrics.Metrics) ([]byte, error) {
metrics.Timer("wasm_vm_eval").Start()
defer metrics.Timer("wasm_vm_eval").Stop()

Expand Down Expand Up @@ -247,12 +253,22 @@ func (i *vm) Eval(ctx context.Context, entrypoint EntrypointID, input *interface
return nil, err
}

result, err := i.fromRegoJSON(resultAddr.ToI32(), false)
serialized, err := i.valueDump(resultAddr)
if err != nil {
return nil, err
}

data := i.memory.Data()[serialized.ToI32():]
n := bytes.IndexByte(data, 0)
if n < 0 {
n = 0
}

metrics.Timer("wasm_vm_eval_prepare_result").Stop()

// Skip free'ing input and result JSON as the heap will be reset next round anyway.

return result, err
return data[0:n], err
}

func (i *vm) SetPolicyData(policy []byte, data []byte) error {
Expand Down Expand Up @@ -432,7 +448,7 @@ func (i *vm) toRegoJSON(v interface{}, free bool) (int32, error) {
p := pos.ToI32()
copy(i.memory.Data()[p:p+n], raw)

addr, err := i.jsonParse(p, n)
addr, err := i.valueParse(p, n)
if err != nil {
return 0, err
}
Expand Down
60 changes: 31 additions & 29 deletions resolver/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,6 @@ func (r *Resolver) Close() {
// Eval performs an evaluation using the provided input and the Wasm module
// associated with this Resolver instance.
func (r *Resolver) Eval(ctx context.Context, input resolver.Input) (resolver.Result, error) {
var inp *interface{}

if input.Input != nil {
x, err := ast.JSON(input.Input.Value)
if err != nil {
return resolver.Result{}, err
}
inp = &x
}

v := r.entrypointIDs.Get(input.Ref)
if v == nil {
return resolver.Result{}, fmt.Errorf("unknown entrypoint %s", input.Ref)
Expand All @@ -105,8 +95,14 @@ func (r *Resolver) Eval(ctx context.Context, input resolver.Input) (resolver.Res
return resolver.Result{}, fmt.Errorf("internal error: invalid entrypoint id %s", numValue)
}

var in *interface{}
if input.Input != nil {
var str interface{} = []byte(input.Input.String())
in = &str
}

opts := opa.EvalOpts{
Input: inp,
Input: in,
Entrypoint: opa.EntrypointID(epID),
Metrics: input.Metrics,
}
Expand All @@ -118,43 +114,49 @@ func (r *Resolver) Eval(ctx context.Context, input resolver.Input) (resolver.Res
result, err := getResult(out)
if err != nil {
return resolver.Result{}, err
} else if result == nil {
return resolver.Result{}, nil
}

v, err = ast.InterfaceToValue(*result)
if err != nil {
return resolver.Result{}, err
}

return resolver.Result{Value: v}, nil
return resolver.Result{Value: result}, nil
}

// SetData will update the external data for the Wasm instance.
func (r *Resolver) SetData(data interface{}) error {
return r.o.SetData(data)
}

func getResult(rs *opa.Result) (*interface{}, error) {
func getResult(evalResult *opa.Result) (ast.Value, error) {

r, ok := rs.Result.([]interface{})
parsed, err := ast.ParseTerm(string(evalResult.Result))
if err != nil {
return nil, fmt.Errorf("failed to parse wasm result: %s", err)
}

resultSet, ok := parsed.Value.(ast.Set)
if !ok {
return nil, fmt.Errorf("illegal result set type")
return nil, fmt.Errorf("illegal result type")
}

if len(r) == 0 {
if resultSet.Len() == 0 {
return nil, nil
}

m, ok := r[0].(map[string]interface{})
if !ok || len(m) != 1 {
if resultSet.Len() > 1 {
return nil, fmt.Errorf("illegal result type")
}

result, ok := m["result"]
if !ok {
return nil, fmt.Errorf("missing value")
var obj ast.Object
err = resultSet.Iter(func(term *ast.Term) error {
obj, ok = term.Value.(ast.Object)
if !ok || obj.Len() != 1 {
return fmt.Errorf("illegal result type")
}
return nil
})
if err != nil {
return nil, err
}

return &result, nil
result := obj.Get(ast.StringTerm("result"))

return result.Value, nil
}

0 comments on commit 76f232c

Please sign in to comment.