Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ func CopyCompileConfig(origin *CompileConfig) *CompileConfig {
for k, v := range origin.CostsMap {
conf.CostsMap[k] = v
}
for _, op := range origin.StatelessOperators {
conf.StatelessOperators = append(conf.StatelessOperators, op)
}
return conf
}

Expand Down Expand Up @@ -103,11 +106,12 @@ var (

func NewCompileConfig(opts ...CompileOption) *CompileConfig {
conf := &CompileConfig{
ConstantMap: make(map[string]Value),
SelectorMap: make(map[string]SelectorKey),
OperatorMap: make(map[string]Operator),
CompileOptions: make(map[Option]bool),
CostsMap: make(map[string]int),
ConstantMap: make(map[string]Value),
SelectorMap: make(map[string]SelectorKey),
OperatorMap: make(map[string]Operator),
CompileOptions: make(map[Option]bool),
CostsMap: make(map[string]int),
StatelessOperators: []string{},
}
for _, opt := range opts {
opt(conf)
Expand All @@ -125,6 +129,8 @@ type CompileConfig struct {

// compile options
CompileOptions map[Option]bool

StatelessOperators []string
}

func (cc *CompileConfig) getCosts(nodeType uint8, nodeName string) int {
Expand Down Expand Up @@ -383,21 +389,27 @@ func isStatelessOp(c *CompileConfig, n *node) (bool, Operator) {
return false, nil
}

s, ok := n.value.(string)
op, ok := n.value.(string)
if !ok {
return false, nil
}

// by default, we only do constant folding on builtin operators
if _, exist := c.OperatorMap[s]; exist {
return false, nil
// builtinOperators are all stateless functions
fn, exist := builtinOperators[op]
if exist {
return true, fn
}

fn, exist := builtinOperators[s] // should be stateless function
if !exist {
return false, nil
for _, so := range c.StatelessOperators {
if so == op {
if fn = c.OperatorMap[op]; fn != nil {
return true, fn
}
break
}
}
return true, fn

return false, fn
}

func optimizeFastEvaluation(cc *CompileConfig, root *astNode) {
Expand Down
49 changes: 49 additions & 0 deletions compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ func TestCopyCompileConfig(t *testing.T) {
assertNotNil(t, res.ConstantMap)
assertNotNil(t, res.SelectorMap)
assertNotNil(t, res.CompileOptions)
assertNotNil(t, res.StatelessOperators)

res = CopyCompileConfig(&CompileConfig{})
assertNotNil(t, res)
assertNotNil(t, res.OperatorMap)
assertNotNil(t, res.ConstantMap)
assertNotNil(t, res.SelectorMap)
assertNotNil(t, res.CompileOptions)
assertNotNil(t, res.StatelessOperators)

cc := &CompileConfig{
ConstantMap: map[string]Value{
Expand Down Expand Up @@ -60,6 +62,49 @@ func TestCopyCompileConfig(t *testing.T) {
age := int64(time.Now().Sub(birthTime) / timeYear)
return age < 18, nil
},
"max": func(_ *Ctx, param []Value) (Value, error) {
const op = "max"
if len(param) < 2 {
return nil, ParamsCountError(op, 2, len(param))
}

var m int64
for i, p := range param {
v, ok := p.(int64)
if !ok {
return nil, ParamTypeError(op, typeInt, p)
}
if i == 0 {
m = v
} else {
if v > m {
m = v
}
}
}
return m, nil
},
"to_set": func(_ *Ctx, params []Value) (Value, error) {
if len(params) != 1 {
return nil, ParamsCountError("to_set", 1, len(params))
}
switch list := params[0].(type) {
case []int64:
set := make(map[int64]struct{}, len(list))
for _, i := range list {
set[i] = empty
}
return set, nil
case []string:
set := make(map[string]struct{}, len(list))
for _, s := range list {
set[s] = empty
}
return set, nil
default:
return nil, ParamTypeError("to_set", "slice", list)
}
},
},
CostsMap: map[string]int{
"selector": 10,
Expand All @@ -69,13 +114,17 @@ func TestCopyCompileConfig(t *testing.T) {
Reordering: true,
ConstantFolding: false,
},
// max & to_set are both stateless operators
// but is_child is not, because it varies with time
StatelessOperators: []string{"max", "to_set"},
}

res = CopyCompileConfig(cc)
assertEquals(t, res.ConstantMap, cc.ConstantMap)
assertEquals(t, res.SelectorMap, cc.SelectorMap)
assertEquals(t, res.CompileOptions, cc.CompileOptions)
assertEquals(t, res.CostsMap, cc.CostsMap)
assertEquals(t, res.StatelessOperators, cc.StatelessOperators)

assertEquals(t, len(res.OperatorMap), len(cc.OperatorMap))
for s := range cc.OperatorMap {
Expand Down
2 changes: 1 addition & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ const (
type Event struct {
CurtIdx int16
EventType EventType
NodeValue interface{}
NodeValue Value
Stack []Value
Data interface{}
}
Expand Down
78 changes: 78 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,18 @@ func TestExpr_TryEval(t *testing.T) {
"F": false,
},
},
{
want: false,
optimizeLevel: disable,
s: `
(and F
(= 0 0)
(!= 0 0))`,
valMap: map[string]interface{}{
"F": false,
"T": true,
},
},
{
want: true,
optimizeLevel: disable,
Expand Down Expand Up @@ -1292,6 +1304,72 @@ func TestReportEvent(t *testing.T) {
assertEquals(t, events, []Value{int64(1), int64(2), "+"})
}

func TestStatelessOperators(t *testing.T) {
cc := &CompileConfig{
OperatorMap: map[string]Operator{
"to_set": func(_ *Ctx, params []Value) (Value, error) {
if len(params) != 1 {
return nil, ParamsCountError("to_set", 1, len(params))
}
switch list := params[0].(type) {
case []int64:
set := make(map[int64]struct{}, len(list))
for _, i := range list {
set[i] = empty
}
return set, nil
case []string:
set := make(map[string]struct{}, len(list))
for _, s := range list {
set[s] = empty
}
return set, nil
default:
return nil, ParamTypeError("to_set", "list", list)
}
},
},
SelectorMap: map[string]SelectorKey{
"num": SelectorKey(1),
},
StatelessOperators: []string{"to_set"},
}

s := `
(in
num
(to_set
(2 3 5 7 11 13 17 19 23 29 31 37 41
43 47 53 59 61 67 71 73 79 83 89 97
101 103 107 109 113 127 131 137 139
149 151 157 163 167 173 179 181 191
193 197 199 211 223 227 229 233 239
241 251 257 263 269 271 277 281 283
293 307 311 313 317 331 337 347 349
353 359 367 373 379 383 389 397 401
409 419 421 431 433 439 443 449 457
461 463 467 479 487 491 499 503 509
521 523 541 547 557 563 569 571 577
587 593 599 601 607 613 617 619 631
641 643 647 653 659 661 673 677 683
691 701 709 719 727 733 739 743 751
757 761 769 773 787 797 809 811 821
823 827 829 839 853 857 859 863 877
881 883 887 907 911 919 929 937 941
947 953 967 971 977 983 991 997)))
`

expr, err := Compile(cc, s)
assertNil(t, err)
res, err := expr.EvalBool(NewCtxWithMap(cc, map[string]interface{}{
"num": 499,
}))

assertNil(t, err)
assertEquals(t, res, true)
assertEquals(t, len(expr.nodes), 3)
}

func assertEquals(t *testing.T, got, want any, msg ...any) {
if !reflect.DeepEqual(got, want) {
t.Fatalf("assertEquals failed, got: %+v, want: %+v, msg: %+v", got, want, msg)
Expand Down
75 changes: 0 additions & 75 deletions example_test.go

This file was deleted.

29 changes: 18 additions & 11 deletions operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,29 +348,36 @@ func listIn(_ *Ctx, params []Value) (Value, error) {
}
switch v := params[0].(type) {
case string:
list, ok := params[1].([]string)
if !ok {
return nil, ParamTypeError(op, typeStrList, params[1])
}
for _, s := range list {
if s == v {
return true, nil
switch coll := params[1].(type) {
case []string:
for _, i := range coll {
if i == v {
return true, nil
}
}
return false, nil
case map[string]struct{}:
_, exist := coll[v]
return exist, nil
default:
return nil, ParamTypeError(op, typeStrList, params[1])
}
return false, nil
case int64:
switch list := params[1].(type) {
switch coll := params[1].(type) {
case []int64:
for _, i := range list {
for _, i := range coll {
if i == v {
return true, nil
}
}
return false, nil
case []string: // the empty list is parsed to a string list
if len(list) == 0 {
if len(coll) == 0 {
return false, nil
}
case map[int64]struct{}:
_, exist := coll[v]
return exist, nil
}
return nil, ParamTypeError(op, typeIntList, params[1])
}
Expand Down
Loading