/
throttling.go
214 lines (186 loc) · 6.67 KB
/
throttling.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
package handler
import (
"encoding/json"
"fmt"
"net/http"
"sync/atomic"
"time"
"unsafe"
log "github.com/cihub/seelog"
"github.com/hailo-platform/H2O/api-proxy/errors"
"github.com/hailo-platform/H2O/api-proxy/session"
)
const (
// How often the synchroniser checks-in with the API throttling service
// @TODO: Should this come from config?
synchronisationInterval = 5 * time.Second
defaultBufferSize = 5000
)
type (
throttledBucketsT map[string]bool
bucketBufferT map[string]*uint64
)
// A ThrottlingHandler is a decorator around another HTTP handler that implements our API throttling behaviour. It
// buckets inbound requests, records statistics about request volume to each bucket, and throttles full buckets. If a
// request is throttled, it will proceed no further up the handler chain.
type ThrottlingHandler struct {
Handler http.Handler
// This is seriously nasty, but we need atomic pointer operations here, so we need unsafe pointers.
// /me dies inside
throttledBuckets unsafe.Pointer // *throttledBucketsT: buckets to throttle
bucketBuffer unsafe.Pointer // *bucketBufferT: inbound per-bucket request count buffer
ingesterChan chan string // inbound requests to be added to the buffer
srv *HailoServer
}
func NewThrottlingHandler(h http.Handler, srv *HailoServer) *ThrottlingHandler {
throttled := make(throttledBucketsT)
buf := make(bucketBufferT, defaultBufferSize)
t := &ThrottlingHandler{
Handler: h,
throttledBuckets: (unsafe.Pointer)(&throttled),
bucketBuffer: (unsafe.Pointer)(&buf),
ingesterChan: make(chan string, 500000),
srv: srv,
}
srv.Tomb.Go(t.ingesterWorker)
srv.Tomb.Go(t.synchroniser)
return t
}
// ingesterWorker takes bucket names from the ingesterChan and increments the appropriate record in the bucketBuffer
func (t *ThrottlingHandler) ingesterWorker() error {
// Note that if ingesterWorker() is mid-flight incrementing a value in the buffer when synchroniser() swaps the
// buffer for a new one, the single increment will be lost. This could be mitigated by comparing the pointer again
// at the end of the work loop, but it's really not a big deal for one request to fall through the cracks.
for {
select {
case <-t.srv.Tomb.Dying():
// Die in response to tomb death
log.Tracef("[Throttler:ingesterWorker] Dying in response to server tomb death")
return nil
case bucketKey, ok := <-t.ingesterChan:
if !ok {
// Die in response to channel closure
log.Tracef("[Throttler:ingesterWorker] Dying in response to channel closure")
return nil
}
loadedP := atomic.LoadPointer(&t.bucketBuffer)
buf := *(*bucketBufferT)(loadedP)
if valPtr, ok := buf[bucketKey]; !ok {
// As adding a value to a map is non-atomic, we have no choice but to copy the map and add the new key,
// retrying if the map was switched under us while this was in-process
// @TODO: Replace with OB's concurrent hash table impl
for {
newBuf := make(bucketBufferT, len(buf)+1)
for k, v := range buf {
newBuf[k] = v
}
_one := uint64(1)
newBuf[bucketKey] = &_one
if atomic.CompareAndSwapPointer(&t.bucketBuffer, loadedP, (unsafe.Pointer)(&newBuf)) {
// CAS was uncontended; we have successfully added the value to the map
break
}
// CAS was contended; go again with a reloaded pointer
loadedP = atomic.LoadPointer(&t.bucketBuffer)
buf = *(*bucketBufferT)(loadedP)
}
} else {
atomic.AddUint64(valPtr, uint64(1))
}
}
}
}
// synchroniser periodically sends the bucketBuffer to the API throttling service, and updates throttledBuckets
// accordingly. When the bucketBuffer is sent, it is replaced (atomically) with a new buffer.
func (t *ThrottlingHandler) synchroniser() error {
tick := time.NewTicker(synchronisationInterval)
defer tick.Stop()
for {
select {
case <-t.srv.Tomb.Dying():
// Die in response to tomb death
log.Tracef("[Throttler:synchroniser] Dying in response to server tomb death")
return nil
case _, ok := <-tick.C:
if !ok {
// Die in response to channel closure
log.Tracef("[Throttler:synchroniser] Dying in response to channel closure")
return nil
}
// Swap the buffer for a shiny new one
newBuf := make(bucketBufferT, defaultBufferSize)
bufP := (*bucketBufferT)(atomic.SwapPointer(&t.bucketBuffer, (unsafe.Pointer)(&newBuf)))
// DO NOT bail here; we still need to retrieve buckets to be throttled even if no increments to report
log.Debugf("[Throttler:synchroniser] Reporting %d increments", len(*bufP))
start := time.Now()
tbP, err := reportIncrements(bufP)
var tb throttledBucketsT
if err != nil {
log.Errorf("[Throttler:synchroniser] Failed to report increments in %s: %s", time.Since(start).String(),
err.Error())
// Reset the throttled bucket list (don't throttle anything in the failure case)
tb = make(throttledBucketsT, 0)
} else {
log.Debugf("[Throttler:synchroniser] Successfully reported increments in %s",
time.Since(start).String())
tb = *tbP
log.Debugf("[Throttler:synchroniser] Got %d buckets to throttle", len(tb))
}
atomic.StorePointer(&(t.throttledBuckets), (unsafe.Pointer)(&tb))
}
}
}
// buckets returns the buckets that this request falls into
func (t *ThrottlingHandler) buckets(r *http.Request) []string {
result := make([]string, 0, 1)
// Session ID
if sessId := session.SessionId(r); sessId != "" {
result = append(result, fmt.Sprintf("sessId:%s", sessId))
}
return result
}
// anyThrottled checks if any of the passed buckets are to be throttled
func (t *ThrottlingHandler) anyThrottled(bucks []string) bool {
throttledP := (*throttledBucketsT)(atomic.LoadPointer(&(t.throttledBuckets)))
if throttledP == nil {
return false
}
throttled := *throttledP
for _, b := range bucks {
if _, ok := throttled[b]; ok {
return true
}
}
return false
}
func (t *ThrottlingHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// Record increments to the relevant buckets
bucks := t.buckets(r)
for _, b := range bucks {
select {
case t.ingesterChan <- b:
default:
log.Warn("[Throttler:ServeHTTP] Could not add bucket to ingestion buffer")
break
}
}
if t.anyThrottled(bucks) {
// The request should be throttled
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(429)
b, marshalErr := json.Marshal(errors.ErrorBody{
Status: false,
Payload: "Client error: rate limit exceeded",
Number: 429,
DottedCode: "com.hailocab.api.throttled",
Context: nil,
})
if marshalErr != nil {
log.Warn("[Throttler:ServerHTTP] Error marshalling error")
return
}
rw.Write(b)
return
}
t.Handler.ServeHTTP(rw, r)
}