diff --git a/builtin/builtin.go b/builtin/builtin.go index c23daf468..4fe356fbf 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -3,6 +3,7 @@ package builtin import ( "encoding/base64" "encoding/json" + "errors" "fmt" "reflect" "sort" @@ -16,6 +17,10 @@ import ( var ( Index map[string]int Names []string + + // MaxDepth limits the recursion depth for nested structures. + MaxDepth = 10000 + ErrorMaxDepth = errors.New("recursion depth exceeded") ) func init() { @@ -377,7 +382,7 @@ var Builtins = []*Function{ { Name: "max", Func: func(args ...any) (any, error) { - return minMax("max", runtime.Less, args...) + return minMax("max", runtime.Less, 0, args...) }, Validate: func(args []reflect.Type) (reflect.Type, error) { return validateAggregateFunc("max", args) @@ -386,7 +391,7 @@ var Builtins = []*Function{ { Name: "min", Func: func(args ...any) (any, error) { - return minMax("min", runtime.More, args...) + return minMax("min", runtime.More, 0, args...) }, Validate: func(args []reflect.Type) (reflect.Type, error) { return validateAggregateFunc("min", args) @@ -395,7 +400,7 @@ var Builtins = []*Function{ { Name: "mean", Func: func(args ...any) (any, error) { - count, sum, err := mean(args...) + count, sum, err := mean(0, args...) if err != nil { return nil, err } @@ -411,7 +416,7 @@ var Builtins = []*Function{ { Name: "median", Func: func(args ...any) (any, error) { - values, err := median(args...) + values, err := median(0, args...) if err != nil { return nil, err } @@ -940,7 +945,10 @@ var Builtins = []*Function{ if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { return nil, size, fmt.Errorf("cannot flatten %s", v.Kind()) } - ret := flatten(v) + ret, err := flatten(v, 0) + if err != nil { + return nil, 0, err + } size = uint(len(ret)) return ret, size, nil }, diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 6ca1e8fdd..a5dabbbbe 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) { }) } } + +func TestBuiltin_flatten_recursion(t *testing.T) { + var s []any + s = append(s, &s) // s contains a pointer to itself + + env := map[string]any{ + "arr": s, + } + + program, err := expr.Compile("flatten(arr)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) +} + +func TestBuiltin_flatten_recursion_slice(t *testing.T) { + s := make([]any, 1) + s[0] = s + + env := map[string]any{ + "arr": s, + } + + program, err := expr.Compile("flatten(arr)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) +} + +func TestBuiltin_numerical_recursion(t *testing.T) { + s := make([]any, 1) + s[0] = s + + env := map[string]any{ + "arr": s, + } + + tests := []string{ + "max(arr)", + "min(arr)", + "mean(arr)", + "median(arr)", + } + + for _, input := range tests { + t.Run(input, func(t *testing.T) { + program, err := expr.Compile(input, expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) + }) + } +} + +func TestBuiltin_recursion_custom_max_depth(t *testing.T) { + originalMaxDepth := builtin.MaxDepth + defer func() { + builtin.MaxDepth = originalMaxDepth + }() + + // Set a small depth limit + builtin.MaxDepth = 2 + + // Create a deeply nested array (depth 5) + // [1, [2, [3, [4, [5]]]]] + arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}} + + env := map[string]any{ + "arr": arr, + } + + t.Run("flatten exceeds max depth", func(t *testing.T) { + program, err := expr.Compile("flatten(arr)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) + }) + + t.Run("flatten within max depth", func(t *testing.T) { + // Depth 2: [1, [2]] + shallowArr := []any{1, []any{2}} + envShallow := map[string]any{"arr": shallowArr} + program, err := expr.Compile("flatten(arr)", expr.Env(envShallow)) + require.NoError(t, err) + + _, err = expr.Run(program, envShallow) + require.NoError(t, err) + }) +} diff --git a/builtin/lib.go b/builtin/lib.go index 6f6a3b6cd..07a029b2f 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -253,7 +253,10 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { +func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) { + if depth > MaxDepth { + return nil, ErrorMaxDepth + } var val any for _, arg := range args { rv := reflect.ValueOf(arg) @@ -261,7 +264,7 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { case reflect.Array, reflect.Slice: size := rv.Len() for i := 0; i < size; i++ { - elemVal, err := minMax(name, fn, rv.Index(i).Interface()) + elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface()) if err != nil { return nil, err } @@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { return val, nil } -func mean(args ...any) (int, float64, error) { +func mean(depth int, args ...any) (int, float64, error) { + if depth > MaxDepth { + return 0, 0, ErrorMaxDepth + } var total float64 var count int @@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) { case reflect.Array, reflect.Slice: size := rv.Len() for i := 0; i < size; i++ { - elemCount, elemSum, err := mean(rv.Index(i).Interface()) + elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface()) if err != nil { return 0, 0, err } @@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) { return count, total, nil } -func median(args ...any) ([]float64, error) { +func median(depth int, args ...any) ([]float64, error) { + if depth > MaxDepth { + return nil, ErrorMaxDepth + } var values []float64 for _, arg := range args { @@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) { case reflect.Array, reflect.Slice: size := rv.Len() for i := 0; i < size; i++ { - elems, err := median(rv.Index(i).Interface()) + elems, err := median(depth+1, rv.Index(i).Interface()) if err != nil { return nil, err } @@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) { return values, nil } -func flatten(arg reflect.Value) []any { +func flatten(arg reflect.Value, depth int) ([]any, error) { + if depth > MaxDepth { + return nil, ErrorMaxDepth + } ret := []any{} for i := 0; i < arg.Len(); i++ { v := deref.Value(arg.Index(i)) if v.Kind() == reflect.Array || v.Kind() == reflect.Slice { - x := flatten(v) + x, err := flatten(v, depth+1) + if err != nil { + return nil, err + } ret = append(ret, x...) } else { ret = append(ret, v.Interface()) } } - return ret + return ret, nil } func get(params ...any) (out any, err error) {