-
Notifications
You must be signed in to change notification settings - Fork 4
/
limit.go
132 lines (98 loc) · 3.06 KB
/
limit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package ratelimit
import (
"context"
cache "github.com/Code-Hex/go-generics-cache"
"github.com/ansel1/merry/v2"
"github.com/bcc-code/bcc-media-platform/backend/remotecache"
"github.com/bcc-code/bcc-media-platform/backend/user"
"github.com/bcc-code/bcc-media-platform/backend/utils"
"github.com/gin-gonic/gin"
"time"
)
var limitCache = cache.New[string, RateLimit]()
// RateLimit contains basic data for limit options
type RateLimit struct {
Increment int
}
// Middleware protects the API globally from anonymous requests
func Middleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
u := user.GetFromCtx(ctx)
if !u.Anonymous {
return
}
const rateLimit = 10000
forwardedFor := ctx.Request.Header.Get("X-Forwarded-For")
limit, _ := limitCache.Get(forwardedFor)
if limit.Increment >= rateLimit {
ctx.JSON(429, map[string]string{
"error": "Too many requests",
})
ctx.Abort()
return
}
limit.Increment++
limitCache.Set(forwardedFor, limit, cache.WithExpiration(time.Minute*5))
}
}
func getUniqueKeyForCtx(ginCtx *gin.Context) string {
u := user.GetFromCtx(ginCtx)
var key string
if u.Anonymous {
key = ginCtx.Request.Header.Get("X-Forwarded-For")
if key == "" {
key = ginCtx.ClientIP()
}
} else {
p := user.GetProfileFromCtx(ginCtx)
key = p.ID.String()
}
return key
}
// Endpoint protects a specific endpoint from public clients
func Endpoint(ctx context.Context, endpoint string, rateLimit int, anonymousOnly bool) error {
ginCtx, _ := utils.GinCtx(ctx)
u := user.GetFromCtx(ginCtx)
if anonymousOnly && !u.Anonymous {
return nil
}
key := getUniqueKeyForCtx(ginCtx)
limit, _ := limitCache.Get(endpoint + ":" + key)
if limit.Increment >= rateLimit {
return merry.New("Rate limit exceeded", merry.WithUserMessage("Too many requests"), merry.WithHTTPCode(429))
}
limit.Increment++
limitCache.Set(endpoint+":"+key, limit, cache.WithExpiration(time.Minute*5))
return nil
}
// Remote protects a specific endpoint from public clients, with remote client
func Remote(ctx context.Context, remoteClient *remotecache.Client, endpoint string, rateLimit int, anonymousOnly bool) error {
ginCtx, _ := utils.GinCtx(ctx)
u := user.GetFromCtx(ginCtx)
if anonymousOnly && !u.Anonymous {
return nil
}
key := getUniqueKeyForCtx(ginCtx)
cacheKey := "ratelimit:" + endpoint + ":" + key
limit, err := remoteClient.Client().Get(ctx, cacheKey).Int()
if err != nil && err != remotecache.Nil {
return err
}
if limit >= rateLimit {
return merry.New("Rate limit exceeded", merry.WithUserMessage("Too many requests"), merry.WithHTTPCode(429))
}
limit++
_, err = remoteClient.Client().Set(ctx, cacheKey, limit, time.Minute*1).Result()
return err
}
// Clear the specified remote entry
func Clear(ctx context.Context, remoteClient *remotecache.Client, endpoint string) error {
ginCtx, _ := utils.GinCtx(ctx)
key := getUniqueKeyForCtx(ginCtx)
cacheKey := "ratelimit:" + endpoint + ":" + key
_, err := remoteClient.Client().Del(ctx, cacheKey).Result()
if err != nil && err != remotecache.Nil {
return err
}
return nil
}