Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions lib/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@ import (
// // Non-canonical keys.
// rate_limit(h, 'X-RateLimit', false, false, duration('1s'), 1)
func Limit(policy map[string]LimitPolicy) cel.EnvOption {
return cel.Lib(limitLib{policies: policy})
return LimitWithApply(policy, nil)
}

func LimitWithApply(policy map[string]LimitPolicy, apply func(map[string]any, http.Header) map[string]any) cel.EnvOption {
if apply == nil {
apply = func(m map[string]any, _ http.Header) map[string]any { return m }
}
return cel.Lib(limitLib{policies: policy, apply: apply})
}

type LimitPolicy func(header http.Header, window time.Duration) map[string]interface{}

type limitLib struct {
policies map[string]LimitPolicy
apply func(map[string]any, http.Header) map[string]any
}

func (l limitLib) CompileOptions() []cel.EnvOption {
Expand All @@ -103,7 +111,7 @@ func (l limitLib) CompileOptions() []cel.EnvOption {
"map_dyn_rate_limit_string_bool_bool_duration_int",
[]*cel.Type{mapStringDyn, cel.StringType, cel.BoolType, cel.BoolType, cel.DurationType, cel.IntType},
mapStringDyn,
cel.FunctionBinding(catch(translatePolicy)),
cel.FunctionBinding(catch(l.translateLimits)),
),
),
}
Expand Down Expand Up @@ -138,7 +146,7 @@ func (l limitLib) translatePolicy(args ...ref.Val) ref.Val {
if err != nil {
return types.NewErr("%s", err)
}
return types.DefaultTypeAdapter.NativeToValue(translate(h, window.Duration))
return types.DefaultTypeAdapter.NativeToValue(l.apply(translate(h, window.Duration), h))
}

func mapStrings(val ref.Val) (map[string][]string, error) {
Expand Down Expand Up @@ -394,7 +402,7 @@ func (p policy) details(q int) (window, burst int, err error) {
return window, burst, nil
}

func translatePolicy(args ...ref.Val) ref.Val {
func (l limitLib) translateLimits(args ...ref.Val) ref.Val {
if len(args) != 6 {
return types.NewErr("no such overload")
}
Expand Down Expand Up @@ -426,7 +434,7 @@ func translatePolicy(args ...ref.Val) ref.Val {
if !ok {
return types.ValOrErr(burst, "no such overload for burst: %s", args[4].Type())
}
p := limitPolicy(h, string(prefix), bool(canonical), bool(delta), window.Duration, int(burst))
p := l.apply(limitPolicy(h, string(prefix), bool(canonical), bool(delta), window.Duration, int(burst)), h)
return types.DefaultTypeAdapter.NativeToValue(p)
}

Expand Down
103 changes: 100 additions & 3 deletions mito.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"regexp"
runtimedebug "runtime/debug"
"strings"
"time"

"github.com/goccy/go-yaml"
"github.com/google/cel-go/cel"
Expand All @@ -48,6 +49,7 @@ import (
"golang.org/x/oauth2/clientcredentials"
"golang.org/x/oauth2/endpoints"
"golang.org/x/oauth2/google"
"golang.org/x/time/rate"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
Expand Down Expand Up @@ -92,6 +94,8 @@ func Main() int {
cel.OptionalTypes(cel.OptionalTypesVersion(lib.OptionalTypesVersion)),
ext.TwoVarComprehensions(ext.TwoVarComprehensionsVersion(lib.OptionalTypesVersion)),
}
ctx := context.Background()
limit := rate.NewLimiter(1, 1)
if *cfgPath != "" {
f, err := os.Open(*cfgPath)
if err != nil {
Expand Down Expand Up @@ -140,6 +144,7 @@ func Main() int {
}
var client *http.Client
httpOptions := lib.HTTPOptions{
Limiter: limit,
Headers: cfg.HTTPHeaders,
MaxBodySize: cfg.MaxBodySize,
}
Expand All @@ -161,7 +166,6 @@ func Main() int {
}
}
if client != nil || !httpOptions.IsZero() {
ctx := context.Background()
libMap["http"] = lib.HTTPWithContextOpts(ctx, traceReqs(setClientInsecure(client, *insecure), *logTrace, *maxTraceBody), httpOptions)
}
if *maxExecutions == -1 && cfg.MaxExecutions != nil {
Expand All @@ -170,8 +174,12 @@ func Main() int {

}
if libMap["http"] == nil {
libMap["http"] = lib.HTTP(traceReqs(setClientInsecure(nil, *insecure), *logTrace, *maxTraceBody), nil, nil)
libMap["http"] = lib.HTTPWithContextOpts(ctx, traceReqs(setClientInsecure(nil, *insecure), *logTrace, *maxTraceBody), lib.HTTPOptions{Limiter: limit})
}
libMap["limit"] = lib.LimitWithApply(limitPolicies, func(m map[string]any, h http.Header) map[string]any {
handleRateLimit(m, h, limit)
return m
})
if libMap["xml"] == nil {
var err error
libMap["xml"], err = lib.XML(nil, nil)
Expand Down Expand Up @@ -263,6 +271,95 @@ func Main() int {
return 0
}

func handleRateLimit(rateLimit map[string]interface{}, header http.Header, limiter *rate.Limiter) (waitUntil time.Time) {
if _, ok := rateLimit["error"]; ok {
// The error field should be a string, but we won't quibble here.
return waitUntil
}

limit, ok := getLimit("rate", rateLimit)
if !ok {
return waitUntil
}

var burst int
b := rateLimit["burst"]
switch b := b.(type) {
case int:
burst = b
case int64:
burst = int(b)
case float64:
burst = int(b)
default:
}
if burst < 1 {
// Make sure we can make at least one new request, even if we fail
// to get a non-zero rate.Limit. We could set to zero for the case
// that limit=rate.Inf, but that detail is not important.
burst = 1
}

// Process reset if we need to wait until reset to avoid a request against a zero quota.
if limit <= 0 {
w, ok := rateLimit["reset"]
if ok {
switch w := w.(type) {
case time.Time:
waitUntil = w
next, ok := getLimit("next", rateLimit)
if !ok {
return waitUntil
}
limiter.SetLimitAt(waitUntil, next)
limiter.SetBurstAt(waitUntil, burst)
case string:
t, err := time.Parse(time.RFC3339, w)
if err != nil {
return waitUntil
}
waitUntil = t
next, ok := getLimit("next", rateLimit)
if !ok {
return waitUntil
}
limiter.SetLimitAt(waitUntil, next)
limiter.SetBurstAt(waitUntil, burst)
default:
}
}
return waitUntil
}

limiter.SetLimit(limit)
limiter.SetBurst(burst)
return waitUntil
}

func getLimit(which string, rateLimit map[string]interface{}) (limit rate.Limit, ok bool) {
r, ok := rateLimit[which]
if !ok {
return limit, false
}
switch r := r.(type) {
case rate.Limit:
limit = r
case int:
limit = rate.Limit(r)
case int64:
limit = rate.Limit(r)
case float64:
limit = rate.Limit(r)
case string:
if !strings.EqualFold(strings.TrimPrefix(r, "+"), "inf") && !strings.EqualFold(strings.TrimPrefix(r, "+"), "infinity") {
return limit, false
}
limit = rate.Inf
default:
}
return limit, true
}

func authsCount(auth *rc.AuthConfig) int {
var n int
if auth.Basic != nil {
Expand Down Expand Up @@ -359,7 +456,7 @@ var (
"file": lib.File(mimetypes),
"mime": lib.MIME(mimetypes),
"http": nil, // This will be populated by Main.
"limit": lib.Limit(limitPolicies),
"limit": nil, // This will be populated by Main.
"strings": lib.Strings(),
"printf": lib.Printf(),
"xml": nil, // This will be populated by Main.
Expand Down
Loading