forked from ethereum/go-ethereum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
limiter.go
398 lines (368 loc) · 11.4 KB
/
limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package utils
import (
"sync"
"github.com/drinkcoffee/l2geth/p2p/enode"
"golang.org/x/exp/slices"
)
const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group
// Limiter protects a network request serving mechanism from denial-of-service attacks.
// It limits the total amount of resources used for serving requests while ensuring that
// the most valuable connections always have a reasonable chance of being served.
type Limiter struct {
lock sync.Mutex
cond *sync.Cond
quit bool
nodes map[enode.ID]*nodeQueue
addresses map[string]*addressGroup
addressSelect, valueSelect *WeightedRandomSelect
maxValue float64
maxCost, sumCost, sumCostLimit uint
selectAddressNext bool
}
// nodeQueue represents queued requests coming from a single node ID
type nodeQueue struct {
queue []request // always nil if penaltyCost != 0
id enode.ID
address string
value float64
flatWeight, valueWeight uint64 // current selection weights in the address/value selectors
sumCost uint // summed cost of requests queued by the node
penaltyCost uint // cumulative cost of dropped requests since last processed request
groupIndex int
}
// addressGroup is a group of node IDs that have sent their last requests from the same
// network address
type addressGroup struct {
nodes []*nodeQueue
nodeSelect *WeightedRandomSelect
sumFlatWeight, groupWeight uint64
}
// request represents an incoming request scheduled for processing
type request struct {
process chan chan struct{}
cost uint
}
// flatWeight distributes weights equally between each active network address
func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight }
// add adds the node queue to the address group. It is the caller's responsibility to
// add the address group to the address map and the address selector if it wasn't
// there before.
func (ag *addressGroup) add(nq *nodeQueue) {
if nq.groupIndex != -1 {
panic("added node queue is already in an address group")
}
l := len(ag.nodes)
nq.groupIndex = l
ag.nodes = append(ag.nodes, nq)
ag.sumFlatWeight += nq.flatWeight
ag.groupWeight = ag.sumFlatWeight / uint64(l+1)
ag.nodeSelect.Update(ag.nodes[l])
}
// update updates the selection weight of the node queue inside the address group.
// It is the caller's responsibility to update the group's selection weight in the
// address selector.
func (ag *addressGroup) update(nq *nodeQueue, weight uint64) {
if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
panic("updated node queue is not in this address group")
}
ag.sumFlatWeight += weight - nq.flatWeight
nq.flatWeight = weight
ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes))
ag.nodeSelect.Update(nq)
}
// remove removes the node queue from the address group. It is the caller's responsibility
// to remove the address group from the address map if it is empty.
func (ag *addressGroup) remove(nq *nodeQueue) {
if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
panic("removed node queue is not in this address group")
}
l := len(ag.nodes) - 1
if nq.groupIndex != l {
ag.nodes[nq.groupIndex] = ag.nodes[l]
ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex
}
nq.groupIndex = -1
ag.nodes = ag.nodes[:l]
ag.sumFlatWeight -= nq.flatWeight
if l >= 1 {
ag.groupWeight = ag.sumFlatWeight / uint64(l)
} else {
ag.groupWeight = 0
}
ag.nodeSelect.Remove(nq)
}
// choose selects one of the node queues belonging to the address group
func (ag *addressGroup) choose() *nodeQueue {
return ag.nodeSelect.Choose().(*nodeQueue)
}
// NewLimiter creates a new Limiter
func NewLimiter(sumCostLimit uint) *Limiter {
l := &Limiter{
addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }),
valueSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }),
nodes: make(map[enode.ID]*nodeQueue),
addresses: make(map[string]*addressGroup),
sumCostLimit: sumCostLimit,
}
l.cond = sync.NewCond(&l.lock)
go l.processLoop()
return l
}
// selectionWeights calculates the selection weights of a node for both the address and
// the value selector. The selection weight depends on the next request cost or the
// summed cost of recently dropped requests.
func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) {
if value > l.maxValue {
l.maxValue = value
}
if value > 0 {
// normalize value to <= 1
value /= l.maxValue
}
if reqCost > l.maxCost {
l.maxCost = reqCost
}
relCost := float64(reqCost) / float64(l.maxCost)
var f float64
if relCost <= 0.001 {
f = 1
} else {
f = 0.001 / relCost
}
f *= maxSelectionWeight
flatWeight, valueWeight = uint64(f), uint64(f*value)
if flatWeight == 0 {
flatWeight = 1
}
return
}
// Add adds a new request to the node queue belonging to the given id. Value belongs
// to the requesting node. A higher value gives the request a higher chance of being
// served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate
// of the serving cost of the request. A lower cost also gives the request a
// better chance.
func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} {
l.lock.Lock()
defer l.lock.Unlock()
process := make(chan chan struct{}, 1)
if l.quit {
close(process)
return process
}
if reqCost == 0 {
reqCost = 1
}
if nq, ok := l.nodes[id]; ok {
if nq.queue != nil {
nq.queue = append(nq.queue, request{process, reqCost})
nq.sumCost += reqCost
nq.value = value
if address != nq.address {
// known id sending request from a new address, move to different address group
l.removeFromGroup(nq)
l.addToGroup(nq, address)
}
} else {
// already waiting on a penalty, just add to the penalty cost and drop the request
nq.penaltyCost += reqCost
l.update(nq)
close(process)
return process
}
} else {
nq := &nodeQueue{
queue: []request{{process, reqCost}},
id: id,
value: value,
sumCost: reqCost,
groupIndex: -1,
}
nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value)
if len(l.nodes) == 0 {
l.cond.Signal()
}
l.nodes[id] = nq
if nq.valueWeight != 0 {
l.valueSelect.Update(nq)
}
l.addToGroup(nq, address)
}
l.sumCost += reqCost
if l.sumCost > l.sumCostLimit {
l.dropRequests()
}
return process
}
// update updates the selection weights of the node queue
func (l *Limiter) update(nq *nodeQueue) {
var cost uint
if nq.queue != nil {
cost = nq.queue[0].cost
} else {
cost = nq.penaltyCost
}
flatWeight, valueWeight := l.selectionWeights(cost, nq.value)
ag := l.addresses[nq.address]
ag.update(nq, flatWeight)
l.addressSelect.Update(ag)
nq.valueWeight = valueWeight
l.valueSelect.Update(nq)
}
// addToGroup adds the node queue to the given address group. The group is created if
// it does not exist yet.
func (l *Limiter) addToGroup(nq *nodeQueue, address string) {
nq.address = address
ag := l.addresses[address]
if ag == nil {
ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)}
l.addresses[address] = ag
}
ag.add(nq)
l.addressSelect.Update(ag)
}
// removeFromGroup removes the node queue from its address group
func (l *Limiter) removeFromGroup(nq *nodeQueue) {
ag := l.addresses[nq.address]
ag.remove(nq)
if len(ag.nodes) == 0 {
delete(l.addresses, nq.address)
}
l.addressSelect.Update(ag)
}
// remove removes the node queue from its address group, the nodes map and the value
// selector
func (l *Limiter) remove(nq *nodeQueue) {
l.removeFromGroup(nq)
if nq.valueWeight != 0 {
l.valueSelect.Remove(nq)
}
delete(l.nodes, nq.id)
}
// choose selects the next node queue to process.
func (l *Limiter) choose() *nodeQueue {
if l.valueSelect.IsEmpty() || l.selectAddressNext {
if ag, ok := l.addressSelect.Choose().(*addressGroup); ok {
l.selectAddressNext = false
return ag.choose()
}
}
nq, _ := l.valueSelect.Choose().(*nodeQueue)
l.selectAddressNext = true
return nq
}
// processLoop processes requests sequentially
func (l *Limiter) processLoop() {
l.lock.Lock()
defer l.lock.Unlock()
for {
if l.quit {
for _, nq := range l.nodes {
for _, request := range nq.queue {
close(request.process)
}
}
return
}
nq := l.choose()
if nq == nil {
l.cond.Wait()
continue
}
if nq.queue != nil {
request := nq.queue[0]
nq.queue = nq.queue[1:]
nq.sumCost -= request.cost
l.sumCost -= request.cost
l.lock.Unlock()
ch := make(chan struct{})
request.process <- ch
<-ch
l.lock.Lock()
if len(nq.queue) > 0 {
l.update(nq)
} else {
l.remove(nq)
}
} else {
// penalized queue removed, next request will be added to a clean queue
l.remove(nq)
}
}
}
// Stop stops the processing loop. All queued and future requests are rejected.
func (l *Limiter) Stop() {
l.lock.Lock()
defer l.lock.Unlock()
l.quit = true
l.cond.Signal()
}
type dropListItem struct {
nq *nodeQueue
priority float64
}
// dropRequests selects the nodes with the highest queued request cost to selection
// weight ratio and drops their queued request. The empty node queues stay in the
// selectors with a low selection weight in order to penalize these nodes.
func (l *Limiter) dropRequests() {
var (
sumValue float64
list []dropListItem
)
for _, nq := range l.nodes {
sumValue += nq.value
}
for _, nq := range l.nodes {
if nq.sumCost == 0 {
continue
}
w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes))
if sumValue > 0 {
w += nq.value / sumValue
}
list = append(list, dropListItem{
nq: nq,
priority: w / float64(nq.sumCost),
})
}
slices.SortFunc(list, func(a, b dropListItem) int {
if a.priority < b.priority {
return -1
}
if a.priority < b.priority {
return 1
}
return 0
})
for _, item := range list {
for _, request := range item.nq.queue {
close(request.process)
}
// make the queue penalized; no more requests are accepted until the node is
// selected based on the penalty cost which is the cumulative cost of all dropped
// requests. This ensures that sending excess requests is always penalized
// and incentivizes the sender to stop for a while if no replies are received.
item.nq.queue = nil
item.nq.penaltyCost = item.nq.sumCost
l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost
item.nq.sumCost = 0
l.update(item.nq)
if l.sumCost <= l.sumCostLimit/2 {
return
}
}
}