Skip to content

Commit

Permalink
do rate limit before parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed May 22, 2024
1 parent 884b68d commit 612b476
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 36 deletions.
87 changes: 62 additions & 25 deletions p2p/protocol/autonatv2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,29 @@ func (as *server) handleDialRequest(s network.Stream) {
defer s.Close()

p := s.Conn().RemotePeer()
r := pbio.NewDelimitedReader(s, maxMsgSize)

var msg pb.Message
w := pbio.NewDelimitedWriter(s)
// Check for rate limit before parsing the request
if !as.limiter.Accept(p) {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_E_REQUEST_REJECTED,
},
},
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debugf("failed to write request rejected response to %s: %s", p, err)
return
}
log.Debugf("rejected request from %s: rate limit exceeded", p)
return
}
defer as.limiter.CompleteRequest(p)

r := pbio.NewDelimitedReader(s, maxMsgSize)
if err := r.ReadMsg(&msg); err != nil {
s.Reset()
log.Debugf("failed to read request from %s: %s", p, err)
Expand Down Expand Up @@ -129,7 +150,6 @@ func (as *server) handleDialRequest(s network.Stream) {
break
}

w := pbio.NewDelimitedWriter(s)
// No dialable address
if dialAddr == nil {
msg = pb.Message{
Expand All @@ -148,7 +168,7 @@ func (as *server) handleDialRequest(s network.Stream) {
}

isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr)
if !as.limiter.Accept(p, isDialDataRequired) {
if !as.limiter.AcceptDialDataRequest(p) {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Expand All @@ -164,7 +184,6 @@ func (as *server) handleDialRequest(s network.Stream) {
log.Debugf("rejected request from %s: rate limit exceeded", p)
return
}
defer as.limiter.CompleteRequest(p)

if isDialDataRequired {
if err := getDialData(w, r, &msg, addrIdx); err != nil {
Expand Down Expand Up @@ -271,7 +290,7 @@ type rateLimiter struct {
DialDataRPM int

mu sync.Mutex
reqs []time.Time
reqs []entry
peerReqs map[peer.ID][]time.Time
dialDataReqs []time.Time
// ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the
Expand All @@ -282,7 +301,12 @@ type rateLimiter struct {
now func() time.Time // for tests
}

func (r *rateLimiter) Accept(p peer.ID, requiresData bool) bool {
type entry struct {
PeerID peer.ID
Time time.Time
}

func (r *rateLimiter) Accept(p peer.ID) bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.peerReqs == nil {
Expand All @@ -291,35 +315,57 @@ func (r *rateLimiter) Accept(p peer.ID, requiresData bool) bool {
}

nw := r.now()
r.cleanup(p, nw)
r.cleanup(nw)

if _, ok := r.ongoingReqs[p]; ok {
return false
}
if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM {
return false
}
if requiresData && len(r.dialDataReqs) >= r.DialDataRPM {
return false
}

r.ongoingReqs[p] = struct{}{}
r.reqs = append(r.reqs, nw)
r.reqs = append(r.reqs, entry{PeerID: p, Time: nw})
r.peerReqs[p] = append(r.peerReqs[p], nw)
if requiresData {
r.dialDataReqs = append(r.dialDataReqs, nw)
return true
}

func (r *rateLimiter) AcceptDialDataRequest(p peer.ID) bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.peerReqs == nil {
r.peerReqs = make(map[peer.ID][]time.Time)
r.ongoingReqs = make(map[peer.ID]struct{})
}
nw := r.now()
r.cleanup(nw)
if len(r.dialDataReqs) >= r.DialDataRPM {
return false
}
r.dialDataReqs = append(r.dialDataReqs, nw)
return true
}

// cleanup removes stale requests.
//
// This is fast enough in rate limited cases and the state is small enough to
// clean up quickly when blocking requests.
func (r *rateLimiter) cleanup(p peer.ID, now time.Time) {
func (r *rateLimiter) cleanup(now time.Time) {
idx := len(r.reqs)
for i, t := range r.reqs {
if now.Sub(t) < time.Minute {
for i, e := range r.reqs {
if now.Sub(e.Time) >= time.Minute {
pi := len(r.peerReqs[e.PeerID])
for j, t := range r.peerReqs[e.PeerID] {
if now.Sub(t) < time.Minute {
pi = j
break
}
}
r.peerReqs[e.PeerID] = r.peerReqs[e.PeerID][pi:]
if len(r.peerReqs[e.PeerID]) == 0 {
delete(r.peerReqs, e.PeerID)
}
} else {
idx = i
break
}
Expand All @@ -334,15 +380,6 @@ func (r *rateLimiter) cleanup(p peer.ID, now time.Time) {
}
}
r.dialDataReqs = r.dialDataReqs[idx:]

idx = len(r.peerReqs[p])
for i, t := range r.peerReqs[p] {
if now.Sub(t) < time.Minute {
idx = i
break
}
}
r.peerReqs[p] = r.peerReqs[p][idx:]
}

func (r *rateLimiter) CompleteRequest(p peer.ID) {
Expand Down
26 changes: 15 additions & 11 deletions p2p/protocol/autonatv2/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,36 +152,40 @@ func TestRateLimiter(t *testing.T) {
cl := test.NewMockClock()
r := rateLimiter{RPM: 3, PerPeerRPM: 2, DialDataRPM: 1, now: cl.Now}

require.True(t, r.Accept("peer1", false))
require.True(t, r.Accept("peer1"))

cl.AdvanceBy(10 * time.Second)
require.False(t, r.Accept("peer1", false)) // first request is still active
require.False(t, r.Accept("peer1")) // first request is still active
r.CompleteRequest("peer1")

require.True(t, r.Accept("peer1", false))
require.True(t, r.Accept("peer1"))
r.CompleteRequest("peer1")

cl.AdvanceBy(10 * time.Second)
require.False(t, r.Accept("peer1", false))
require.False(t, r.Accept("peer1"))

cl.AdvanceBy(10 * time.Second)
require.True(t, r.Accept("peer2", false))
require.True(t, r.Accept("peer2"))
r.CompleteRequest("peer2")

cl.AdvanceBy(10 * time.Second)
require.False(t, r.Accept("peer3", false))
require.False(t, r.Accept("peer3"))

cl.AdvanceBy(21 * time.Second) // first request expired
require.True(t, r.Accept("peer1", false))
require.True(t, r.Accept("peer1"))
r.CompleteRequest("peer1")

cl.AdvanceBy(10 * time.Second)
require.True(t, r.Accept("peer3", true))
require.True(t, r.Accept("peer3"))
r.CompleteRequest("peer3")

cl.AdvanceBy(50 * time.Second)
require.False(t, r.Accept("peer3", true))
require.True(t, r.Accept("peer3"))
r.CompleteRequest("peer3")

cl.AdvanceBy(1 * time.Second)
require.False(t, r.Accept("peer3"))

cl.AdvanceBy(11 * time.Second)
require.True(t, r.Accept("peer3", true))
cl.AdvanceBy(10 * time.Second)
require.True(t, r.Accept("peer3"))
}

0 comments on commit 612b476

Please sign in to comment.