Skip to content

Commit

Permalink
Refactor: Fix race condition on dns message.
Browse files Browse the repository at this point in the history
related to #14
  • Loading branch information
cherrot committed Jul 2, 2019
1 parent f2f292a commit 69ecd76
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
9 changes: 8 additions & 1 deletion dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (s *Server) Serve(w dns.ResponseWriter, req *dns.Msg) {
cancel()
}()

req.RecursionDesired = true
s.normalizeRequest(req)

trusted := make(chan *dns.Msg, 1)
untrusted := make(chan *dns.Msg, 1)
Expand Down Expand Up @@ -72,6 +72,13 @@ func (s *Server) Serve(w dns.ResponseWriter, req *dns.Msg) {
logger.Debug("SERVING RTT: ", time.Since(start))
}

func (s *Server) normalizeRequest(req *dns.Msg) {
req.RecursionDesired = true
if !s.TCPOnly {
setUDPSize(req, uint16(s.UDPMaxSize))
}
}

func (s *Server) processReply(
ctx context.Context, logger *logrus.Entry, rep *dns.Msg, other <-chan *dns.Msg,
process func(context.Context, *logrus.Entry, *dns.Msg, net.IP, <-chan *dns.Msg) *dns.Msg,
Expand Down
51 changes: 21 additions & 30 deletions lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func lookupInServers(
if len(servers) == 0 {
return
}
var errChain error
logger := logrus.WithField("question", questionString(&req.Question[0]))

ticker := time.NewTicker(waitInterval)
Expand All @@ -30,13 +29,12 @@ func lookupInServers(
queryNext <- struct{}{}
var wg sync.WaitGroup

doLookup := func(idx int, server string) {
doLookup := func(server string) {
defer wg.Done()
logger := logger.WithField("server", server)

reply, rtt, err := lookup(req, server)
reply, rtt, err := lookup(req.Copy(), server)
if err != nil {
errChain = errors.Wrapf(err, "%d", idx)
queryNext <- struct{}{}
return
}
Expand All @@ -50,23 +48,20 @@ func lookupInServers(
}

LOOP:
for idx, server := range servers {
for _, server := range servers {
select {
case <-ctx.Done():
break LOOP
case <-queryNext:
wg.Add(1)
go doLookup(idx, server)
go doLookup(server)
case <-ticker.C:
wg.Add(1)
go doLookup(idx, server)
go doLookup(server)
}
}

wg.Wait()
if errChain != nil {
logger.WithError(errChain).Error("Error hanppens.")
}
}

// Lookup send a DNS request to the specific server and get its corresponding reply.
Expand All @@ -80,8 +75,6 @@ func (s *Server) Lookup(req *dns.Msg, server string) (reply *dns.Msg, rtt time.D
})

if !s.TCPOnly {
req := req.Copy()
setUDPSize(req, s.UDPMaxSize)
reply, rtt, err = s.UDPCli.Exchange(req, server)
if err != nil {
logger.WithError(err).Error("Fail to send UDP query. Will retry in TCP.")
Expand Down Expand Up @@ -110,19 +103,9 @@ func (s *Server) LookupMutation(req *dns.Msg, server string) (reply *dns.Msg, rt
"question": questionString(&req.Question[0]),
"server": server,
})
// cleanEdns0(req)

var (
udpSize int
buffer []byte
)
if !s.TCPOnly {
req := req.Copy()
udpSize = setUDPSize(req, s.UDPMaxSize)
buffer, err = req.Pack()
} else {
buffer, err = req.Pack()
}
var buffer []byte
buffer, err = req.Pack()
if err != nil {
return nil, 0, errors.Wrap(err, "fail to pack request")
}
Expand All @@ -131,7 +114,8 @@ func (s *Server) LookupMutation(req *dns.Msg, server string) (reply *dns.Msg, rt
t := time.Now()
if !s.TCPOnly {
ddl := t.Add(s.UDPCli.Timeout)
reply, err = rawLookup(s.UDPCli, req.Id, buffer, server, ddl, uint16(udpSize))
udpSize := getUDPSize(req)
reply, err = rawLookup(s.UDPCli, req.Id, buffer, server, ddl, udpSize)
if err != nil {
logger.WithError(err).Error("Fail to send UDP mutation query. Will retry in TCP.")
}
Expand Down Expand Up @@ -176,22 +160,29 @@ func rawLookup(cli *dns.Client, id uint16, req []byte, server string, ddl time.T
return reply, err
}

func setUDPSize(req *dns.Msg, size int) int {
func setUDPSize(req *dns.Msg, size uint16) uint16 {
if size <= dns.MinMsgSize {
return dns.MinMsgSize
}
// https://tools.ietf.org/html/rfc6891#section-6.2.5
if e := req.IsEdns0(); e != nil {
if e.UDPSize() >= uint16(size) {
return int(e.UDPSize())
if e.UDPSize() >= size {
return e.UDPSize()
}
e.SetUDPSize(uint16(size))
e.SetUDPSize(size)
return size
}
req.SetEdns0(uint16(size), false)
req.SetEdns0(size, false)
return size
}

func getUDPSize(req *dns.Msg) uint16 {
if e := req.IsEdns0(); e != nil && e.UDPSize() > dns.MinMsgSize {
return e.UDPSize()
}
return dns.MinMsgSize
}

func cleanEdns0(req *dns.Msg) {
for {
if req.IsEdns0() == nil {
Expand Down

0 comments on commit 69ecd76

Please sign in to comment.