-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LimitReader to drop excess packets
This includes code from #1052 and add a tests See for further discussion: #997 Signed-off-by: Miek Gieben <miek@miek.nl>
- Loading branch information
Showing
2 changed files
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package dns | ||
|
||
import ( | ||
"net" | ||
"runtime" | ||
"time" | ||
) | ||
|
||
// LimitReader can be used to limit the intake of new packets. Currently it works by checking the number of | ||
// active goroutines. If we have too many it will refuse any new messages. Note that this check is performed every 100th packet. | ||
// | ||
// LimitReader can be used as a DecorateReader in a server as: | ||
// | ||
// server := &Server{/* various options */} | ||
// server.DecorateReader = func(r Reader) Reader { return &LimitReader{Reader: r, MaxGoroutines: 10000} } | ||
// | ||
type LimitReader struct { | ||
Reader | ||
|
||
// MaxGoroutines is the maxium number of goroutines we're willing to tolerate. | ||
MaxGoroutines int | ||
|
||
upkts int | ||
tpkts int | ||
} | ||
|
||
// ReadUDP implements Reader. | ||
func (r *LimitReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { | ||
m, s, err := r.Reader.ReadUDP(conn, timeout) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
r.upkts++ | ||
if r.upkts%thisManyPackets != 0 { | ||
return m, s, nil | ||
} | ||
|
||
// well below. | ||
numgo := runtime.NumGoroutine() | ||
if numgo <= r.MaxGoroutines { | ||
return m, s, nil | ||
} | ||
|
||
err = refusePacketUDP(conn, m, s) | ||
return m, s, err | ||
} | ||
|
||
// ReadTCP implements Reader. | ||
func (r *LimitReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { | ||
m, err := r.Reader.ReadTCP(conn, timeout) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
r.tpkts++ | ||
if r.tpkts%100 != 0 { | ||
return m, nil | ||
} | ||
|
||
numgo := runtime.NumGoroutine() | ||
if numgo <= r.MaxGoroutines/2 { | ||
return m, nil | ||
} | ||
|
||
err = refusePacketTCP(conn, m) | ||
return m, err | ||
} | ||
|
||
// ErrPacketRefuse is an error that is returned when a packet is refused by the server. | ||
var ErrPacketRefuse = refuseError{} | ||
|
||
type refuseError struct{} | ||
|
||
// These implement the net.Error interface. | ||
func (refuseError) Error() string { return "dns: refusing packet" } | ||
func (refuseError) Timeout() bool { return false } | ||
func (refuseError) Temporary() bool { return true } | ||
|
||
func refusePacketUDP(conn *net.UDPConn, m []byte, s *SessionUDP) error { | ||
dh, _, err := unpackMsgHdr(m, 0) | ||
if err != nil { | ||
return nil | ||
} | ||
|
||
msg := new(Msg) | ||
msg.setHdr(dh) | ||
msg.Rcode = RcodeRefused | ||
|
||
m, err = msg.Pack() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
_, err = WriteToSessionUDP(conn, m, s) | ||
return err | ||
} | ||
|
||
func refusePacketTCP(conn net.Conn, m []byte) error { | ||
dh, _, err := unpackMsgHdr(m, 0) | ||
if err != nil { | ||
return nil | ||
} | ||
|
||
msg := new(Msg) | ||
msg.setHdr(dh) | ||
msg.Rcode = RcodeRefused | ||
|
||
m, err = msg.Pack() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
_, err = conn.Write(m) | ||
return err | ||
} | ||
|
||
const thisManyPackets = 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package dns | ||
|
||
import ( | ||
"net" | ||
"sync" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestLimitReader(t *testing.T) { | ||
pc, err := net.ListenPacket("udp", ":0") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
server := &Server{PacketConn: pc, ReadTimeout: time.Second * 2, WriteTimeout: time.Second * 2} | ||
server.DecorateReader = func(r Reader) Reader { | ||
return &LimitReader{Reader: r, MaxGoroutines: 0} | ||
} | ||
HandleFunc("example.org.", HelloServer) | ||
defer server.Shutdown() | ||
|
||
waitLock := sync.Mutex{} | ||
waitLock.Lock() | ||
server.NotifyStartedFunc = waitLock.Unlock | ||
|
||
fin := make(chan error, 1) | ||
go func() { | ||
fin <- server.ActivateAndServe() | ||
pc.Close() | ||
}() | ||
|
||
waitLock.Lock() | ||
|
||
c := new(Client) | ||
m := new(Msg).SetQuestion("example.org.", TypeTXT) | ||
|
||
for i := 0; i < thisManyPackets; i++ { | ||
r, _, _ := c.Exchange(m, pc.LocalAddr().String()) | ||
if i == thisManyPackets && r.Rcode != RcodeRefused { | ||
t.Errorf("expected rcode %d, got %d", RcodeRefused, r.Rcode) | ||
} | ||
} | ||
} |