From 40402b943e64bc333228d91bae84bb6ae84477ce Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Wed, 10 Sep 2025 09:45:24 +0930 Subject: [PATCH] mito,mito/lib: implement immediate rate limit application Previously, the application of rate limits was left entirely to the host code after the CEL program evaluation has completed. This leaves the possibility for a collection of work items to deplete the rate limiter's token bucket in the middle of a CEL program evaluation. This change adds the possibility of adding a hook for application of rate limit values on completion of a call to rate_limit, allowing more fine rate limits to be respected more dynamically. --- lib/limit.go | 18 ++++++--- mito.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/lib/limit.go b/lib/limit.go index ccdd648..3c59211 100644 --- a/lib/limit.go +++ b/lib/limit.go @@ -81,13 +81,21 @@ import ( // // Non-canonical keys. // rate_limit(h, 'X-RateLimit', false, false, duration('1s'), 1) func Limit(policy map[string]LimitPolicy) cel.EnvOption { - return cel.Lib(limitLib{policies: policy}) + return LimitWithApply(policy, nil) +} + +func LimitWithApply(policy map[string]LimitPolicy, apply func(map[string]any, http.Header) map[string]any) cel.EnvOption { + if apply == nil { + apply = func(m map[string]any, _ http.Header) map[string]any { return m } + } + return cel.Lib(limitLib{policies: policy, apply: apply}) } type LimitPolicy func(header http.Header, window time.Duration) map[string]interface{} type limitLib struct { policies map[string]LimitPolicy + apply func(map[string]any, http.Header) map[string]any } func (l limitLib) CompileOptions() []cel.EnvOption { @@ -103,7 +111,7 @@ func (l limitLib) CompileOptions() []cel.EnvOption { "map_dyn_rate_limit_string_bool_bool_duration_int", []*cel.Type{mapStringDyn, cel.StringType, cel.BoolType, cel.BoolType, cel.DurationType, cel.IntType}, mapStringDyn, - cel.FunctionBinding(catch(translatePolicy)), + cel.FunctionBinding(catch(l.translateLimits)), ), ), } @@ -138,7 +146,7 @@ func (l limitLib) translatePolicy(args ...ref.Val) ref.Val { if err != nil { return types.NewErr("%s", err) } - return types.DefaultTypeAdapter.NativeToValue(translate(h, window.Duration)) + return types.DefaultTypeAdapter.NativeToValue(l.apply(translate(h, window.Duration), h)) } func mapStrings(val ref.Val) (map[string][]string, error) { @@ -394,7 +402,7 @@ func (p policy) details(q int) (window, burst int, err error) { return window, burst, nil } -func translatePolicy(args ...ref.Val) ref.Val { +func (l limitLib) translateLimits(args ...ref.Val) ref.Val { if len(args) != 6 { return types.NewErr("no such overload") } @@ -426,7 +434,7 @@ func translatePolicy(args ...ref.Val) ref.Val { if !ok { return types.ValOrErr(burst, "no such overload for burst: %s", args[4].Type()) } - p := limitPolicy(h, string(prefix), bool(canonical), bool(delta), window.Duration, int(burst)) + p := l.apply(limitPolicy(h, string(prefix), bool(canonical), bool(delta), window.Duration, int(burst)), h) return types.DefaultTypeAdapter.NativeToValue(p) } diff --git a/mito.go b/mito.go index fd2095c..a0f87a4 100644 --- a/mito.go +++ b/mito.go @@ -37,6 +37,7 @@ import ( "regexp" runtimedebug "runtime/debug" "strings" + "time" "github.com/goccy/go-yaml" "github.com/google/cel-go/cel" @@ -48,6 +49,7 @@ import ( "golang.org/x/oauth2/clientcredentials" "golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/google" + "golang.org/x/time/rate" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" @@ -92,6 +94,8 @@ func Main() int { cel.OptionalTypes(cel.OptionalTypesVersion(lib.OptionalTypesVersion)), ext.TwoVarComprehensions(ext.TwoVarComprehensionsVersion(lib.OptionalTypesVersion)), } + ctx := context.Background() + limit := rate.NewLimiter(1, 1) if *cfgPath != "" { f, err := os.Open(*cfgPath) if err != nil { @@ -140,6 +144,7 @@ func Main() int { } var client *http.Client httpOptions := lib.HTTPOptions{ + Limiter: limit, Headers: cfg.HTTPHeaders, MaxBodySize: cfg.MaxBodySize, } @@ -161,7 +166,6 @@ func Main() int { } } if client != nil || !httpOptions.IsZero() { - ctx := context.Background() libMap["http"] = lib.HTTPWithContextOpts(ctx, traceReqs(setClientInsecure(client, *insecure), *logTrace, *maxTraceBody), httpOptions) } if *maxExecutions == -1 && cfg.MaxExecutions != nil { @@ -170,8 +174,12 @@ func Main() int { } if libMap["http"] == nil { - libMap["http"] = lib.HTTP(traceReqs(setClientInsecure(nil, *insecure), *logTrace, *maxTraceBody), nil, nil) + libMap["http"] = lib.HTTPWithContextOpts(ctx, traceReqs(setClientInsecure(nil, *insecure), *logTrace, *maxTraceBody), lib.HTTPOptions{Limiter: limit}) } + libMap["limit"] = lib.LimitWithApply(limitPolicies, func(m map[string]any, h http.Header) map[string]any { + handleRateLimit(m, h, limit) + return m + }) if libMap["xml"] == nil { var err error libMap["xml"], err = lib.XML(nil, nil) @@ -263,6 +271,95 @@ func Main() int { return 0 } +func handleRateLimit(rateLimit map[string]interface{}, header http.Header, limiter *rate.Limiter) (waitUntil time.Time) { + if _, ok := rateLimit["error"]; ok { + // The error field should be a string, but we won't quibble here. + return waitUntil + } + + limit, ok := getLimit("rate", rateLimit) + if !ok { + return waitUntil + } + + var burst int + b := rateLimit["burst"] + switch b := b.(type) { + case int: + burst = b + case int64: + burst = int(b) + case float64: + burst = int(b) + default: + } + if burst < 1 { + // Make sure we can make at least one new request, even if we fail + // to get a non-zero rate.Limit. We could set to zero for the case + // that limit=rate.Inf, but that detail is not important. + burst = 1 + } + + // Process reset if we need to wait until reset to avoid a request against a zero quota. + if limit <= 0 { + w, ok := rateLimit["reset"] + if ok { + switch w := w.(type) { + case time.Time: + waitUntil = w + next, ok := getLimit("next", rateLimit) + if !ok { + return waitUntil + } + limiter.SetLimitAt(waitUntil, next) + limiter.SetBurstAt(waitUntil, burst) + case string: + t, err := time.Parse(time.RFC3339, w) + if err != nil { + return waitUntil + } + waitUntil = t + next, ok := getLimit("next", rateLimit) + if !ok { + return waitUntil + } + limiter.SetLimitAt(waitUntil, next) + limiter.SetBurstAt(waitUntil, burst) + default: + } + } + return waitUntil + } + + limiter.SetLimit(limit) + limiter.SetBurst(burst) + return waitUntil +} + +func getLimit(which string, rateLimit map[string]interface{}) (limit rate.Limit, ok bool) { + r, ok := rateLimit[which] + if !ok { + return limit, false + } + switch r := r.(type) { + case rate.Limit: + limit = r + case int: + limit = rate.Limit(r) + case int64: + limit = rate.Limit(r) + case float64: + limit = rate.Limit(r) + case string: + if !strings.EqualFold(strings.TrimPrefix(r, "+"), "inf") && !strings.EqualFold(strings.TrimPrefix(r, "+"), "infinity") { + return limit, false + } + limit = rate.Inf + default: + } + return limit, true +} + func authsCount(auth *rc.AuthConfig) int { var n int if auth.Basic != nil { @@ -359,7 +456,7 @@ var ( "file": lib.File(mimetypes), "mime": lib.MIME(mimetypes), "http": nil, // This will be populated by Main. - "limit": lib.Limit(limitPolicies), + "limit": nil, // This will be populated by Main. "strings": lib.Strings(), "printf": lib.Printf(), "xml": nil, // This will be populated by Main.