From 2a07330648d2f827f92a95b4c9976d0c4f971ee6 Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Wed, 3 Dec 2025 14:29:38 +0200 Subject: [PATCH] fix(builtin): limit recursion depth Add builtin.MaxDepth (default 10k) to prevent stack overflows when processing deeply nested or cyclic structures in builtin functions. The functions flatten, min, max, mean, and median now return a "recursion depth exceeded" error instead of crashing the runtime. Signed-off-by: Ville Vesilehto --- builtin/builtin.go | 18 +++++--- builtin/builtin_test.go | 97 +++++++++++++++++++++++++++++++++++++++++ builtin/lib.go | 33 ++++++++++---- 3 files changed, 134 insertions(+), 14 deletions(-) diff --git a/builtin/builtin.go b/builtin/builtin.go index c23daf468..4fe356fbf 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -3,6 +3,7 @@ package builtin import ( "encoding/base64" "encoding/json" + "errors" "fmt" "reflect" "sort" @@ -16,6 +17,10 @@ import ( var ( Index map[string]int Names []string + + // MaxDepth limits the recursion depth for nested structures. + MaxDepth = 10000 + ErrorMaxDepth = errors.New("recursion depth exceeded") ) func init() { @@ -377,7 +382,7 @@ var Builtins = []*Function{ { Name: "max", Func: func(args ...any) (any, error) { - return minMax("max", runtime.Less, args...) + return minMax("max", runtime.Less, 0, args...) }, Validate: func(args []reflect.Type) (reflect.Type, error) { return validateAggregateFunc("max", args) @@ -386,7 +391,7 @@ var Builtins = []*Function{ { Name: "min", Func: func(args ...any) (any, error) { - return minMax("min", runtime.More, args...) + return minMax("min", runtime.More, 0, args...) }, Validate: func(args []reflect.Type) (reflect.Type, error) { return validateAggregateFunc("min", args) @@ -395,7 +400,7 @@ var Builtins = []*Function{ { Name: "mean", Func: func(args ...any) (any, error) { - count, sum, err := mean(args...) + count, sum, err := mean(0, args...) if err != nil { return nil, err } @@ -411,7 +416,7 @@ var Builtins = []*Function{ { Name: "median", Func: func(args ...any) (any, error) { - values, err := median(args...) + values, err := median(0, args...) if err != nil { return nil, err } @@ -940,7 +945,10 @@ var Builtins = []*Function{ if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { return nil, size, fmt.Errorf("cannot flatten %s", v.Kind()) } - ret := flatten(v) + ret, err := flatten(v, 0) + if err != nil { + return nil, 0, err + } size = uint(len(ret)) return ret, size, nil }, diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 6ca1e8fdd..a5dabbbbe 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) { }) } } + +func TestBuiltin_flatten_recursion(t *testing.T) { + var s []any + s = append(s, &s) // s contains a pointer to itself + + env := map[string]any{ + "arr": s, + } + + program, err := expr.Compile("flatten(arr)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) +} + +func TestBuiltin_flatten_recursion_slice(t *testing.T) { + s := make([]any, 1) + s[0] = s + + env := map[string]any{ + "arr": s, + } + + program, err := expr.Compile("flatten(arr)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) +} + +func TestBuiltin_numerical_recursion(t *testing.T) { + s := make([]any, 1) + s[0] = s + + env := map[string]any{ + "arr": s, + } + + tests := []string{ + "max(arr)", + "min(arr)", + "mean(arr)", + "median(arr)", + } + + for _, input := range tests { + t.Run(input, func(t *testing.T) { + program, err := expr.Compile(input, expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) + }) + } +} + +func TestBuiltin_recursion_custom_max_depth(t *testing.T) { + originalMaxDepth := builtin.MaxDepth + defer func() { + builtin.MaxDepth = originalMaxDepth + }() + + // Set a small depth limit + builtin.MaxDepth = 2 + + // Create a deeply nested array (depth 5) + // [1, [2, [3, [4, [5]]]]] + arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}} + + env := map[string]any{ + "arr": arr, + } + + t.Run("flatten exceeds max depth", func(t *testing.T) { + program, err := expr.Compile("flatten(arr)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.Error(t, err) + assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error()) + }) + + t.Run("flatten within max depth", func(t *testing.T) { + // Depth 2: [1, [2]] + shallowArr := []any{1, []any{2}} + envShallow := map[string]any{"arr": shallowArr} + program, err := expr.Compile("flatten(arr)", expr.Env(envShallow)) + require.NoError(t, err) + + _, err = expr.Run(program, envShallow) + require.NoError(t, err) + }) +} diff --git a/builtin/lib.go b/builtin/lib.go index 6f6a3b6cd..07a029b2f 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -253,7 +253,10 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { +func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) { + if depth > MaxDepth { + return nil, ErrorMaxDepth + } var val any for _, arg := range args { rv := reflect.ValueOf(arg) @@ -261,7 +264,7 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { case reflect.Array, reflect.Slice: size := rv.Len() for i := 0; i < size; i++ { - elemVal, err := minMax(name, fn, rv.Index(i).Interface()) + elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface()) if err != nil { return nil, err } @@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { return val, nil } -func mean(args ...any) (int, float64, error) { +func mean(depth int, args ...any) (int, float64, error) { + if depth > MaxDepth { + return 0, 0, ErrorMaxDepth + } var total float64 var count int @@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) { case reflect.Array, reflect.Slice: size := rv.Len() for i := 0; i < size; i++ { - elemCount, elemSum, err := mean(rv.Index(i).Interface()) + elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface()) if err != nil { return 0, 0, err } @@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) { return count, total, nil } -func median(args ...any) ([]float64, error) { +func median(depth int, args ...any) ([]float64, error) { + if depth > MaxDepth { + return nil, ErrorMaxDepth + } var values []float64 for _, arg := range args { @@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) { case reflect.Array, reflect.Slice: size := rv.Len() for i := 0; i < size; i++ { - elems, err := median(rv.Index(i).Interface()) + elems, err := median(depth+1, rv.Index(i).Interface()) if err != nil { return nil, err } @@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) { return values, nil } -func flatten(arg reflect.Value) []any { +func flatten(arg reflect.Value, depth int) ([]any, error) { + if depth > MaxDepth { + return nil, ErrorMaxDepth + } ret := []any{} for i := 0; i < arg.Len(); i++ { v := deref.Value(arg.Index(i)) if v.Kind() == reflect.Array || v.Kind() == reflect.Slice { - x := flatten(v) + x, err := flatten(v, depth+1) + if err != nil { + return nil, err + } ret = append(ret, x...) } else { ret = append(ret, v.Interface()) } } - return ret + return ret, nil } func get(params ...any) (out any, err error) {