forked from mattermost/mattermost
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ratelimit.go
131 lines (108 loc) · 3.56 KB
/
ratelimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
// See License.txt for license information.
package app
import (
"fmt"
"math"
"net/http"
"strconv"
"strings"
"github.com/mattermost/mattermost-server/mlog"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/utils"
"github.com/pkg/errors"
throttled "gopkg.in/throttled/throttled.v2"
"gopkg.in/throttled/throttled.v2/store/memstore"
)
type RateLimiter struct {
throttledRateLimiter *throttled.GCRARateLimiter
useAuth bool
useIP bool
header string
}
func NewRateLimiter(settings *model.RateLimitSettings) (*RateLimiter, error) {
store, err := memstore.New(*settings.MemoryStoreSize)
if err != nil {
return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
}
quota := throttled.RateQuota{
MaxRate: throttled.PerSec(*settings.PerSec),
MaxBurst: *settings.MaxBurst,
}
throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
if err != nil {
return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
}
return &RateLimiter{
throttledRateLimiter: throttledRateLimiter,
useAuth: *settings.VaryByUser,
useIP: *settings.VaryByRemoteAddr,
header: settings.VaryByHeader,
}, nil
}
func (rl *RateLimiter) GenerateKey(r *http.Request) string {
key := ""
if rl.useAuth {
token, tokenLocation := ParseAuthTokenFromRequest(r)
if tokenLocation != TokenLocationNotFound {
key += token
} else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
key += utils.GetIpAddress(r)
}
} else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
key += utils.GetIpAddress(r)
}
// Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
// most common headers anyway.
if rl.header != "" {
key += strings.ToLower(r.Header.Get(rl.header))
}
return key
}
func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
if err != nil {
mlog.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
return false
}
setRateLimitHeaders(w, context)
if limited {
mlog.Error(fmt.Sprintf("Denied due to throttling settings code=429 key=%v", key))
http.Error(w, "limit exceeded", 429)
}
return limited
}
func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
if rl.useAuth {
if rl.RateLimitWriter(userId, w) {
return true
}
}
return false
}
func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := rl.GenerateKey(r)
limited := rl.RateLimitWriter(key, w)
if !limited {
wrappedHandler.ServeHTTP(w, r)
}
})
}
// Copied from https://github.com/throttled/throttled http.go
func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
if v := context.Limit; v >= 0 {
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
}
if v := context.Remaining; v >= 0 {
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
}
if v := context.ResetAfter; v >= 0 {
vi := int(math.Ceil(v.Seconds()))
w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
}
if v := context.RetryAfter; v >= 0 {
vi := int(math.Ceil(v.Seconds()))
w.Header().Add("Retry-After", strconv.Itoa(vi))
}
}