Skip to content

Commit

Permalink
Add LimitReader to drop excess packets
Browse files Browse the repository at this point in the history
This includes code from #1052 and add a tests

See for further discussion: #997

Signed-off-by: Miek Gieben <miek@miek.nl>
  • Loading branch information
miekg committed Jan 9, 2020
1 parent 6c0c4e6 commit 2d40eb3
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
118 changes: 118 additions & 0 deletions limit_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package dns

import (
"net"
"runtime"
"time"
)

// LimitReader can be used to limit the intake of new packets. Currently it works by checking the number of
// active goroutines. If we have too many it will refuse any new messages. Note that this check is performed every 100th packet.
//
// LimitReader can be used as a DecorateReader in a server as:
//
// server := &Server{/* various options */}
// server.DecorateReader = func(r Reader) Reader { return &LimitReader{Reader: r, MaxGoroutines: 10000} }
//
type LimitReader struct {
Reader

// MaxGoroutines is the maxium number of goroutines we're willing to tolerate.
MaxGoroutines int

upkts int
tpkts int
}

// ReadUDP implements Reader.
func (r *LimitReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
m, s, err := r.Reader.ReadUDP(conn, timeout)
if err != nil {
return nil, nil, err
}

r.upkts++
if r.upkts%thisManyPackets != 0 {
return m, s, nil
}

// well below.
numgo := runtime.NumGoroutine()
if numgo <= r.MaxGoroutines {
return m, s, nil
}

err = refusePacketUDP(conn, m, s)
return m, s, err
}

// ReadTCP implements Reader.
func (r *LimitReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
m, err := r.Reader.ReadTCP(conn, timeout)
if err != nil {
return nil, err
}

r.tpkts++
if r.tpkts%100 != 0 {
return m, nil
}

numgo := runtime.NumGoroutine()
if numgo <= r.MaxGoroutines/2 {
return m, nil
}

err = refusePacketTCP(conn, m)
return m, err
}

// ErrPacketRefuse is an error that is returned when a packet is refused by the server.
var ErrPacketRefuse = refuseError{}

type refuseError struct{}

// These implement the net.Error interface.
func (refuseError) Error() string { return "dns: refusing packet" }
func (refuseError) Timeout() bool { return false }
func (refuseError) Temporary() bool { return true }

func refusePacketUDP(conn *net.UDPConn, m []byte, s *SessionUDP) error {
dh, _, err := unpackMsgHdr(m, 0)
if err != nil {
return nil
}

msg := new(Msg)
msg.setHdr(dh)
msg.Rcode = RcodeRefused

m, err = msg.Pack()
if err != nil {
return err
}

_, err = WriteToSessionUDP(conn, m, s)
return err
}

func refusePacketTCP(conn net.Conn, m []byte) error {
dh, _, err := unpackMsgHdr(m, 0)
if err != nil {
return nil
}

msg := new(Msg)
msg.setHdr(dh)
msg.Rcode = RcodeRefused

m, err = msg.Pack()
if err != nil {
return err
}

_, err = conn.Write(m)
return err
}

const thisManyPackets = 100
43 changes: 43 additions & 0 deletions limit_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package dns

import (
"net"
"sync"
"testing"
"time"
)

func TestLimitReader(t *testing.T) {
pc, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatal(err)
}
server := &Server{PacketConn: pc, ReadTimeout: time.Second * 2, WriteTimeout: time.Second * 2}
server.DecorateReader = func(r Reader) Reader {
return &LimitReader{Reader: r, MaxGoroutines: 0}
}
HandleFunc("example.org.", HelloServer)
defer server.Shutdown()

waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock

fin := make(chan error, 1)
go func() {
fin <- server.ActivateAndServe()
pc.Close()
}()

waitLock.Lock()

c := new(Client)
m := new(Msg).SetQuestion("example.org.", TypeTXT)

for i := 0; i < thisManyPackets; i++ {
r, _, _ := c.Exchange(m, pc.LocalAddr().String())
if i == thisManyPackets && r.Rcode != RcodeRefused {
t.Errorf("expected rcode %d, got %d", RcodeRefused, r.Rcode)
}
}
}

0 comments on commit 2d40eb3

Please sign in to comment.