forked from rs/dnstrace
-
Notifications
You must be signed in to change notification settings - Fork 7
/
client.go
260 lines (239 loc) · 5.51 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
package client
import (
"errors"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
// Client is a DNS client capable of performing parallel requests.
type Client struct {
dns.Client
DCache DelegationCache
LCache LookupCache
}
type ResponseType int
const (
ResponseTypeUnknown ResponseType = iota
ResponseTypeDelegation
ResponseTypeCNAME
ResponseTypeFinal
)
// Response stores a DNS response.
type Response struct {
Server Server
Addr string
Msg *dns.Msg
RTT time.Duration
Err error
}
type Responses []Response
// Fastest returns the fastest success response or nil.
func (rs Responses) Fastest() *Response {
var fr Response
for _, r := range rs {
if r.Err != nil {
continue
}
if fr.Msg == nil || ((r.RTT + r.Server.LookupRTT) < (fr.RTT + fr.Server.LookupRTT)) {
fr = r
}
}
return &fr
}
type Tracer struct {
GotIntermediaryResponse func(i int, m *dns.Msg, rs Responses, rtype ResponseType)
FollowingCNAME func(domain, target string)
}
// New creates a new Client.
func New() Client {
return Client{
DCache: DelegationCache{},
LCache: LookupCache{},
}
}
// ParallelQuery perform an exchange using m with all servers in parallel and
// return all responses.
func (c *Client) ParallelQuery(m *dns.Msg, servers []Server) Responses {
rc := make(chan Response)
cnt := 0
for _, s := range servers {
for _, addr := range s.Addrs {
cnt++
go func(s Server, addr string) {
r := Response{
Server: s,
Addr: addr,
}
r.Msg, r.RTT, r.Err = c.Exchange(m, net.JoinHostPort(addr, "53"))
rc <- r
}(s, addr)
}
}
rs := make([]Response, 0, cnt)
for ; cnt > 0; cnt-- {
rs = append(rs, <-rc)
}
return rs
}
func domainEqual(d1, d2 string) bool {
return strings.ToLower(dns.Fqdn(d1)) == strings.ToLower(dns.Fqdn(d2))
}
// RecursiveQuery performs a recursive query by querying all the available name
// servers to gather statistics.
func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time.Duration, err error) {
// TODO: check m got a single question
m = m.Copy()
qname := m.Question[0].Name
qtype := m.Question[0].Qtype
zone := "."
for i := 1; i < 100; i++ {
_, servers := c.DCache.Get(qname)
// Resolve servers name if needed.
wg := &sync.WaitGroup{}
for i, s := range servers {
if len(s.Addrs) == 0 {
wg.Add(1)
go func(s *Server) {
var err error
lm := m.Copy()
lm.SetQuestion(s.Name, 0) // qtypes are set by lookup host
s.Addrs, s.LookupRTT, err = c.lookupHost(lm)
if err != nil {
s.LookupErr = err
}
wg.Done()
}(&servers[i])
}
}
wg.Wait()
m.Question[0].Name = qname
rs := c.ParallelQuery(m, servers)
var r *dns.Msg
fr := rs.Fastest()
if fr != nil {
r = fr.Msg
}
if r == nil {
if len(rs) > 0 {
return rs[0].Msg, rtt + rs[0].RTT, rs[0].Err
}
return nil, rtt, errors.New("no response")
}
rtt += fr.Server.LookupRTT + fr.RTT
var rtype ResponseType
var cname string
for _, rr := range r.Answer {
if domainEqual(rr.Header().Name, qname) && rr.Header().Rrtype == qtype {
rtype = ResponseTypeFinal
break
} else if rr.Header().Rrtype == dns.TypeCNAME {
cname = rr.Header().Name
qname = rr.(*dns.CNAME).Target
zone = "."
rtype = ResponseTypeCNAME
}
}
if rtype == ResponseTypeUnknown {
for _, ns := range r.Ns {
if ns, ok := ns.(*dns.NS); ok && len(ns.Header().Name) > len(zone) {
rtype = ResponseTypeDelegation
zone = ns.Header().Name
break
}
}
if rtype == ResponseTypeUnknown {
// NOERROR / empty
rtype = ResponseTypeFinal
}
}
if rtype == ResponseTypeDelegation {
for _, ns := range r.Ns {
ns, ok := ns.(*dns.NS)
if !ok {
continue // skip DS records
}
name := ns.Header().Name
var addrs []string
for _, rr := range r.Extra {
if domainEqual(rr.Header().Name, ns.Ns) {
switch a := rr.(type) {
case *dns.A:
addrs = append(addrs, a.A.String())
case *dns.AAAA:
addrs = append(addrs, a.AAAA.String())
}
}
}
s := Server{
Name: ns.Ns,
HasGlue: len(addrs) > 0,
TTL: ns.Header().Ttl,
Addrs: addrs,
}
c.DCache.Add(name, s)
c.LCache.Set(s.Name, s.Addrs)
if tracer.GotIntermediaryResponse == nil {
// If not traced, only take first NS.
break
}
}
}
if tracer.GotIntermediaryResponse != nil {
tracer.GotIntermediaryResponse(i, m.Copy(), rs, rtype)
}
switch rtype {
case ResponseTypeCNAME:
if tracer.FollowingCNAME != nil {
tracer.FollowingCNAME(cname, qname)
}
case ResponseTypeFinal:
return r, rtt, nil
}
}
return nil, rtt, nil
}
func (c *Client) lookupHost(m *dns.Msg) (addrs []string, rtt time.Duration, err error) {
qname := m.Question[0].Name
addrs = c.LCache.Get(qname)
if len(addrs) > 0 {
return addrs, 0, nil
}
qtypes := []uint16{dns.TypeA, dns.TypeAAAA}
rs := make(chan Response)
for _, qtype := range qtypes {
m := m.Copy()
m.Question[0].Qtype = qtype
go func() {
r, rtt, err := c.RecursiveQuery(m, Tracer{})
rs <- Response{
Msg: r,
Err: err,
RTT: rtt,
}
}()
}
for range qtypes {
r := <-rs
if r.Err != nil {
return nil, 0, err
}
if r.RTT > rtt {
rtt = r.RTT // get the longest of the two // queries
}
if r.Msg == nil {
continue
}
for _, rr := range r.Msg.Answer {
switch rr := rr.(type) {
case *dns.A:
addrs = append(addrs, rr.A.String())
case *dns.AAAA:
addrs = append(addrs, rr.AAAA.String())
}
}
}
c.LCache.Set(qname, addrs)
return
}