diff --git a/compile.go b/compile.go index a4f489c..1dc1be4 100644 --- a/compile.go +++ b/compile.go @@ -53,6 +53,9 @@ func CopyCompileConfig(origin *CompileConfig) *CompileConfig { for k, v := range origin.CostsMap { conf.CostsMap[k] = v } + for _, op := range origin.StatelessOperators { + conf.StatelessOperators = append(conf.StatelessOperators, op) + } return conf } @@ -103,11 +106,12 @@ var ( func NewCompileConfig(opts ...CompileOption) *CompileConfig { conf := &CompileConfig{ - ConstantMap: make(map[string]Value), - SelectorMap: make(map[string]SelectorKey), - OperatorMap: make(map[string]Operator), - CompileOptions: make(map[Option]bool), - CostsMap: make(map[string]int), + ConstantMap: make(map[string]Value), + SelectorMap: make(map[string]SelectorKey), + OperatorMap: make(map[string]Operator), + CompileOptions: make(map[Option]bool), + CostsMap: make(map[string]int), + StatelessOperators: []string{}, } for _, opt := range opts { opt(conf) @@ -125,6 +129,8 @@ type CompileConfig struct { // compile options CompileOptions map[Option]bool + + StatelessOperators []string } func (cc *CompileConfig) getCosts(nodeType uint8, nodeName string) int { @@ -383,21 +389,27 @@ func isStatelessOp(c *CompileConfig, n *node) (bool, Operator) { return false, nil } - s, ok := n.value.(string) + op, ok := n.value.(string) if !ok { return false, nil } - // by default, we only do constant folding on builtin operators - if _, exist := c.OperatorMap[s]; exist { - return false, nil + // builtinOperators are all stateless functions + fn, exist := builtinOperators[op] + if exist { + return true, fn } - fn, exist := builtinOperators[s] // should be stateless function - if !exist { - return false, nil + for _, so := range c.StatelessOperators { + if so == op { + if fn = c.OperatorMap[op]; fn != nil { + return true, fn + } + break + } } - return true, fn + + return false, fn } func optimizeFastEvaluation(cc *CompileConfig, root *astNode) { diff --git a/compile_test.go b/compile_test.go index f375a00..aa2ac36 100644 --- a/compile_test.go +++ b/compile_test.go @@ -17,6 +17,7 @@ func TestCopyCompileConfig(t *testing.T) { assertNotNil(t, res.ConstantMap) assertNotNil(t, res.SelectorMap) assertNotNil(t, res.CompileOptions) + assertNotNil(t, res.StatelessOperators) res = CopyCompileConfig(&CompileConfig{}) assertNotNil(t, res) @@ -24,6 +25,7 @@ func TestCopyCompileConfig(t *testing.T) { assertNotNil(t, res.ConstantMap) assertNotNil(t, res.SelectorMap) assertNotNil(t, res.CompileOptions) + assertNotNil(t, res.StatelessOperators) cc := &CompileConfig{ ConstantMap: map[string]Value{ @@ -60,6 +62,49 @@ func TestCopyCompileConfig(t *testing.T) { age := int64(time.Now().Sub(birthTime) / timeYear) return age < 18, nil }, + "max": func(_ *Ctx, param []Value) (Value, error) { + const op = "max" + if len(param) < 2 { + return nil, ParamsCountError(op, 2, len(param)) + } + + var m int64 + for i, p := range param { + v, ok := p.(int64) + if !ok { + return nil, ParamTypeError(op, typeInt, p) + } + if i == 0 { + m = v + } else { + if v > m { + m = v + } + } + } + return m, nil + }, + "to_set": func(_ *Ctx, params []Value) (Value, error) { + if len(params) != 1 { + return nil, ParamsCountError("to_set", 1, len(params)) + } + switch list := params[0].(type) { + case []int64: + set := make(map[int64]struct{}, len(list)) + for _, i := range list { + set[i] = empty + } + return set, nil + case []string: + set := make(map[string]struct{}, len(list)) + for _, s := range list { + set[s] = empty + } + return set, nil + default: + return nil, ParamTypeError("to_set", "slice", list) + } + }, }, CostsMap: map[string]int{ "selector": 10, @@ -69,6 +114,9 @@ func TestCopyCompileConfig(t *testing.T) { Reordering: true, ConstantFolding: false, }, + // max & to_set are both stateless operators + // but is_child is not, because it varies with time + StatelessOperators: []string{"max", "to_set"}, } res = CopyCompileConfig(cc) @@ -76,6 +124,7 @@ func TestCopyCompileConfig(t *testing.T) { assertEquals(t, res.SelectorMap, cc.SelectorMap) assertEquals(t, res.CompileOptions, cc.CompileOptions) assertEquals(t, res.CostsMap, cc.CostsMap) + assertEquals(t, res.StatelessOperators, cc.StatelessOperators) assertEquals(t, len(res.OperatorMap), len(cc.OperatorMap)) for s := range cc.OperatorMap { diff --git a/engine.go b/engine.go index c1a1f74..faeaa04 100644 --- a/engine.go +++ b/engine.go @@ -61,7 +61,7 @@ const ( type Event struct { CurtIdx int16 EventType EventType - NodeValue interface{} + NodeValue Value Stack []Value Data interface{} } diff --git a/engine_test.go b/engine_test.go index e60a6d9..a8a6e4e 100644 --- a/engine_test.go +++ b/engine_test.go @@ -780,6 +780,18 @@ func TestExpr_TryEval(t *testing.T) { "F": false, }, }, + { + want: false, + optimizeLevel: disable, + s: ` +(and F + (= 0 0) + (!= 0 0))`, + valMap: map[string]interface{}{ + "F": false, + "T": true, + }, + }, { want: true, optimizeLevel: disable, @@ -1292,6 +1304,72 @@ func TestReportEvent(t *testing.T) { assertEquals(t, events, []Value{int64(1), int64(2), "+"}) } +func TestStatelessOperators(t *testing.T) { + cc := &CompileConfig{ + OperatorMap: map[string]Operator{ + "to_set": func(_ *Ctx, params []Value) (Value, error) { + if len(params) != 1 { + return nil, ParamsCountError("to_set", 1, len(params)) + } + switch list := params[0].(type) { + case []int64: + set := make(map[int64]struct{}, len(list)) + for _, i := range list { + set[i] = empty + } + return set, nil + case []string: + set := make(map[string]struct{}, len(list)) + for _, s := range list { + set[s] = empty + } + return set, nil + default: + return nil, ParamTypeError("to_set", "list", list) + } + }, + }, + SelectorMap: map[string]SelectorKey{ + "num": SelectorKey(1), + }, + StatelessOperators: []string{"to_set"}, + } + + s := ` + (in + num + (to_set + (2 3 5 7 11 13 17 19 23 29 31 37 41 + 43 47 53 59 61 67 71 73 79 83 89 97 + 101 103 107 109 113 127 131 137 139 + 149 151 157 163 167 173 179 181 191 + 193 197 199 211 223 227 229 233 239 + 241 251 257 263 269 271 277 281 283 + 293 307 311 313 317 331 337 347 349 + 353 359 367 373 379 383 389 397 401 + 409 419 421 431 433 439 443 449 457 + 461 463 467 479 487 491 499 503 509 + 521 523 541 547 557 563 569 571 577 + 587 593 599 601 607 613 617 619 631 + 641 643 647 653 659 661 673 677 683 + 691 701 709 719 727 733 739 743 751 + 757 761 769 773 787 797 809 811 821 + 823 827 829 839 853 857 859 863 877 + 881 883 887 907 911 919 929 937 941 + 947 953 967 971 977 983 991 997))) +` + + expr, err := Compile(cc, s) + assertNil(t, err) + res, err := expr.EvalBool(NewCtxWithMap(cc, map[string]interface{}{ + "num": 499, + })) + + assertNil(t, err) + assertEquals(t, res, true) + assertEquals(t, len(expr.nodes), 3) +} + func assertEquals(t *testing.T, got, want any, msg ...any) { if !reflect.DeepEqual(got, want) { t.Fatalf("assertEquals failed, got: %+v, want: %+v, msg: %+v", got, want, msg) diff --git a/example_test.go b/example_test.go deleted file mode 100644 index 5853f53..0000000 --- a/example_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package eval - -import "fmt" - -func ExampleEval() { - output, err := Eval(`(+ 1 v1)`, map[string]interface{}{ - "v1": 1, - }) - if err != nil { - fmt.Printf("err: %v", err) - return - } - - fmt.Printf("%v", output) - - // Output: 2 -} - -func ExampleEval_infix() { - - expr := `1 + v2 * (v3 + v5) / v4 + abs(-6 - v1) - max(1, 3, 2, abs(-8))` - - vals := map[string]interface{}{ - "abs": Operator(func(_ *Ctx, params []Value) (Value, error) { - if len(params) != 1 { - return nil, ParamsCountError("abs", 2, len(params)) - } - i, ok := params[0].(int64) - if !ok { - return nil, ParamTypeError("abs", "int", params[0]) - } - if i < 0 { - return -i, nil - } - return i, nil - }), - "max": func(_ *Ctx, params []Value) (Value, error) { - if len(params) == 0 { - return nil, ParamsCountError("max", 1, len(params)) - } - - var res int64 - for i, v := range params { - i64, ok := v.(int64) - if !ok { - return nil, ParamTypeError("max", "int64", v) - } - if i == 0 { - res = i64 - continue - } - if res < i64 { - res = i64 - } - } - return res, nil - }, - "v1": 1, - "v2": 2, - "v3": 3, - "v4": 4, - "v5": 5, - } - cc := NewCompileConfig(EnableInfixNotation, RegisterVals(vals)) - - output, err := Eval(expr, vals, cc) - if err != nil { - fmt.Printf("err: %v", err) - return - } - - fmt.Printf("%v", output) - - // Output: 4 -} diff --git a/operator.go b/operator.go index 9a3fa59..9832f0c 100644 --- a/operator.go +++ b/operator.go @@ -348,29 +348,36 @@ func listIn(_ *Ctx, params []Value) (Value, error) { } switch v := params[0].(type) { case string: - list, ok := params[1].([]string) - if !ok { - return nil, ParamTypeError(op, typeStrList, params[1]) - } - for _, s := range list { - if s == v { - return true, nil + switch coll := params[1].(type) { + case []string: + for _, i := range coll { + if i == v { + return true, nil + } } + return false, nil + case map[string]struct{}: + _, exist := coll[v] + return exist, nil + default: + return nil, ParamTypeError(op, typeStrList, params[1]) } - return false, nil case int64: - switch list := params[1].(type) { + switch coll := params[1].(type) { case []int64: - for _, i := range list { + for _, i := range coll { if i == v { return true, nil } } return false, nil case []string: // the empty list is parsed to a string list - if len(list) == 0 { + if len(coll) == 0 { return false, nil } + case map[int64]struct{}: + _, exist := coll[v] + return exist, nil } return nil, ParamTypeError(op, typeIntList, params[1]) } diff --git a/parse.go b/parse.go index 4025993..549da79 100644 --- a/parse.go +++ b/parse.go @@ -266,8 +266,7 @@ func (p *parser) setLeafNodeParsers() { // For infix expressions only lists with brackets are supported fns = append(fns, p.parseList(lBracket, rBracket)) } else { - // For prefix expressions, lists with brackets or parentheses both are supported - fns = append(fns, p.parseList(lBracket, rBracket), p.parseList(lParen, rParen)) + fns = append(fns, p.parseList(lParen, rParen)) } p.leafNodeParser = fns @@ -281,19 +280,24 @@ func (p *parser) check() error { (p.tokens[0].typ != lParen || p.tokens[last].typ != rParen) { return p.parenUnmatchedErr(0) } + // check parentheses + var parenCnt int + var inBracket bool - var parenCnt int // check parentheses for i, t := range p.tokens { switch t.typ { case lParen: parenCnt++ case rParen: parenCnt-- - case comma: - if prefixNotation { // commas can be used in infix expressions only + case comma, lBracket, rBracket: + if prefixNotation { // commas, brackets can be used in infix expressions only return p.unknownTokenError(t) } + default: + continue } + if parenCnt < 0 { return p.parenUnmatchedErr(t.pos) } @@ -301,6 +305,16 @@ func (p *parser) check() error { if prefixNotation && parenCnt == 0 && i != last { return p.parenUnmatchedErr(t.pos) } + + if inBracket && t.typ != rBracket { + return p.invalidExprErr(t.pos) + } + + if t.typ == lBracket { + inBracket = true + } else if t.typ == rBracket { + inBracket = false + } } if parenCnt != 0 { diff --git a/parse_test.go b/parse_test.go index 014cbcd..74ed03f 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1076,6 +1076,13 @@ func TestParseAstTree(t *testing.T) { }, }, }, + { + expr: `(1 2)`, + ast: verifyNode{ + tpy: constant, + data: []int64{1, 2}, + }, + }, // return an error when expr use operator at selector position {