-
Notifications
You must be signed in to change notification settings - Fork 2k
/
wait.go
441 lines (381 loc) · 13 KB
/
wait.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
// +skip_license_check
/*
This file contains portions of code directly taken from the 'xenolf/lego' project.
A copy of the license for this code can be found in the file named LICENSE in
this directory.
*/
package util
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
logf "github.com/jetstack/cert-manager/pkg/logs"
)
type preCheckDNSFunc func(fqdn, value string, nameservers []string,
useAuthoritative bool) (bool, error)
type dnsQueryFunc func(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error)
var (
// PreCheckDNS checks DNS propagation before notifying ACME that
// the DNS challenge is ready.
PreCheckDNS preCheckDNSFunc = checkDNSPropagation
// dnsQuery is used to be able to mock DNSQuery
dnsQuery dnsQueryFunc = DNSQuery
fqdnToZoneLock sync.RWMutex
fqdnToZone = map[string]string{}
)
const defaultResolvConf = "/etc/resolv.conf"
const issueTag = "issue"
const issuewildTag = "issuewild"
var defaultNameservers = []string{
"8.8.8.8:53",
"8.8.4.4:53",
}
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
// DNSTimeout is used to override the default DNS timeout of 10 seconds.
var DNSTimeout = 10 * time.Second
// getNameservers attempts to get systems nameservers before falling back to the defaults
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
if err != nil || len(config.Servers) == 0 {
return defaults
}
systemNameservers := []string{}
for _, server := range config.Servers {
// ensure all servers have a port number
if _, _, err := net.SplitHostPort(server); err != nil {
systemNameservers = append(systemNameservers, net.JoinHostPort(server, "53"))
} else {
systemNameservers = append(systemNameservers, server)
}
}
return systemNameservers
}
// Follows the CNAME records and returns the last non-CNAME fully qualified domain name
// that it finds. Returns an error when a loop is found in the CNAME chain. The
// argument fqdnChain is used by the function itself to keep track of which fqdns it
// already encountered and detect loops.
func followCNAMEs(fqdn string, nameservers []string, fqdnChain ...string) (string, error) {
r, err := dnsQuery(fqdn, dns.TypeCNAME, nameservers, true)
if err != nil {
return "", err
}
if r.Rcode != dns.RcodeSuccess {
return fqdn, err
}
for _, rr := range r.Answer {
cn, ok := rr.(*dns.CNAME)
if !ok || cn.Hdr.Name != fqdn {
continue
}
logf.V(logf.DebugLevel).Infof("Updating FQDN: %s with its CNAME: %s", fqdn, cn.Target)
// Check if we were here before to prevent loops in the chain of CNAME records.
for _, fqdnInChain := range fqdnChain {
if cn.Target != fqdnInChain {
continue
}
return "", fmt.Errorf("Found recursive CNAME record to %q when looking up %q", cn.Target, fqdn)
}
return followCNAMEs(cn.Target, nameservers, append(fqdnChain, fqdn)...)
}
return fqdn, nil
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func checkDNSPropagation(fqdn, value string, nameservers []string,
useAuthoritative bool) (bool, error) {
var err error
fqdn, err = followCNAMEs(fqdn, nameservers)
if err != nil {
return false, err
}
if !useAuthoritative {
return checkAuthoritativeNss(fqdn, value, nameservers)
}
authoritativeNss, err := lookupNameservers(fqdn, nameservers)
if err != nil {
return false, err
}
for i, ans := range authoritativeNss {
authoritativeNss[i] = net.JoinHostPort(ans, "53")
}
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := DNSQuery(fqdn, dns.TypeTXT, []string{ns}, true)
if err != nil {
return false, err
}
// NXDomain response is not really an error, just waiting for propagation to happen
if !(r.Rcode == dns.RcodeSuccess || r.Rcode == dns.RcodeNameError) {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}
logf.V(logf.DebugLevel).Infof("Looking up TXT records for %q", fqdn)
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
if strings.Join(txt.Txt, "") == value {
found = true
break
}
}
}
if !found {
return false, nil
}
}
return true, nil
}
// DNSQuery will query a nameserver, iterating through the supplied servers as it retries
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func DNSQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
// Will retry the request based on the number of servers (n+1)
for i := 1; i <= len(nameservers)+1; i++ {
ns := nameservers[i%len(nameservers)]
udp := &dns.Client{Net: "udp", Timeout: DNSTimeout}
in, _, err = udp.Exchange(m, ns)
if (in != nil && in.Truncated) ||
(err != nil && strings.HasPrefix(err.Error(), "read udp") && strings.HasSuffix(err.Error(), "i/o timeout")) {
logf.V(logf.DebugLevel).Infof("UDP dns lookup failed, retrying with TCP: %v", err)
tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
if err == nil {
break
}
}
return
}
func ValidateCAA(domain string, issuerID []string, iswildcard bool, nameservers []string) error {
// see https://tools.ietf.org/html/rfc6844#section-4
// for more information about how CAA lookup is performed
fqdn := ToFqdn(domain)
issuerSet := make(map[string]bool)
for _, s := range issuerID {
issuerSet[s] = true
}
var caas []*dns.CAA
for {
// follow at most 8 cnames per label
queryDomain := fqdn
var msg *dns.Msg
var err error
for i := 0; i < 8; i++ {
// usually, we should be able to just ask the local recursive
// nameserver for CAA records, but some setups will return SERVFAIL
// on unknown types like CAA. Instead, ask the authoritative server
var authNS []string
authNS, err = lookupNameservers(queryDomain, nameservers)
if err != nil {
return fmt.Errorf("Could not validate CAA record: %s", err)
}
for i, ans := range authNS {
authNS[i] = net.JoinHostPort(ans, "53")
}
msg, err = DNSQuery(queryDomain, dns.TypeCAA, authNS, false)
if err != nil {
return fmt.Errorf("Could not validate CAA record: %s", err)
}
// domain may not exist, which is fine. It will fail HTTP01 checks
// but DNS01 checks will create a proper domain
if msg.Rcode == dns.RcodeNameError {
break
}
if msg.Rcode != dns.RcodeSuccess {
return fmt.Errorf("Could not validate CAA: Unexpected response code '%s' for %s",
dns.RcodeToString[msg.Rcode], domain)
}
oldQuery := queryDomain
queryDomain, err := followCNAMEs(queryDomain, nameservers)
if err != nil {
return fmt.Errorf("while trying to follow CNAMEs for domain %s using nameservers %v: %w", queryDomain, nameservers, err)
}
if queryDomain == oldQuery {
break
}
}
// we have a response that's not a CNAME. It might be empty.
// if it is, go up a label and ask again
for _, rr := range msg.Answer {
caa, ok := rr.(*dns.CAA)
if !ok {
continue
}
caas = append(caas, caa)
}
// once we've found any CAA records, we use these CAAs
if len(caas) != 0 {
break
}
index := strings.Index(fqdn, ".")
if index == -1 {
panic("should never happen")
}
fqdn = fqdn[index+1:]
if len(fqdn) == 0 {
// we reached the root with no CAA, don't bother asking
return nil
}
}
if !matchCAA(caas, issuerSet, iswildcard) {
// TODO(dmo): better error message
return fmt.Errorf("CAA record does not match issuer")
}
return nil
}
func matchCAA(caas []*dns.CAA, issuerIDs map[string]bool, iswildcard bool) bool {
matches := false
for _, caa := range caas {
// if we require a wildcard certificate, we must prioritize any issuewild
// tags - only if it matches (regardless of any other entries) can we
// issue a wildcard certificate
if iswildcard && caa.Tag == issuewildTag {
return issuerIDs[caa.Value]
}
// issue tags allow any certificate, we perform a check which will only
// be returned if we do not need a wildcard certificate, or if we need
// a wildcard certificate and no issuewild entries are present
if caa.Tag == issueTag {
matches = matches || issuerIDs[caa.Value]
}
}
return matches
}
// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string, nameservers []string) ([]string, error) {
var authoritativeNss []string
logf.V(logf.DebugLevel).Infof("Searching fqdn %q using seed nameservers [%s]", fqdn, strings.Join(nameservers, ", "))
zone, err := FindZoneByFqdn(fqdn, nameservers)
if err != nil {
return nil, fmt.Errorf("Could not determine the zone for %q: %v", fqdn, err)
}
r, err := DNSQuery(zone, dns.TypeNS, nameservers, true)
if err != nil {
return nil, err
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}
if len(authoritativeNss) > 0 {
logf.V(logf.DebugLevel).Infof("Returning authoritative nameservers [%s]", strings.Join(authoritativeNss, ", "))
return authoritativeNss, nil
}
return nil, fmt.Errorf("Could not determine authoritative nameservers for %q", fqdn)
}
// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the
// domain labels until the nameserver returns a SOA record in the answer section.
func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
fqdnToZoneLock.RLock()
// Do we have it cached?
if zone, ok := fqdnToZone[fqdn]; ok {
fqdnToZoneLock.RUnlock()
logf.V(logf.DebugLevel).Infof("Returning cached zone record %q for fqdn %q", zone, fqdn)
return zone, nil
}
fqdnToZoneLock.RUnlock()
labelIndexes := dns.Split(fqdn)
// We are climbing up the domain tree, looking for the SOA record on
// one of them. For example, imagine that the DNS tree looks like this:
//
// example.com. ← SOA is here.
// └── foo.example.com.
// └── _acme-challenge.foo.example.com. ← Starting point.
//
// We start at the bottom of the tree and climb up. The NXDOMAIN error
// lets us know that we should climb higher:
//
// _acme-challenge.foo.example.com. returns NXDOMAIN
// foo.example.com. returns NXDOMAIN
// example.com. returns NOERROR along with the SOA
for _, index := range labelIndexes {
domain := fqdn[index:]
in, err := DNSQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
return "", err
}
// NXDOMAIN tells us that we did not climb far enough up the DNS tree. We
// thus continue climbing to find the SOA record.
if in.Rcode == dns.RcodeNameError {
continue
}
// Any non-successful response code, other than NXDOMAIN, is treated as an error
// and interrupts the search.
if in.Rcode != dns.RcodeSuccess {
return "", fmt.Errorf("When querying the SOA record for the domain '%s' using nameservers %v, rcode was expected to be 'NOERROR' or 'NXDOMAIN', but got '%s'",
domain, nameservers, dns.RcodeToString[in.Rcode])
}
// As per RFC 2181, CNAME records cannot not exist at the root of a zone,
// which means we won't be finding any SOA record for this domain.
if dnsMsgContainsCNAME(in) {
continue
}
for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok {
fqdnToZoneLock.Lock()
defer fqdnToZoneLock.Unlock()
zone := soa.Hdr.Name
fqdnToZone[fqdn] = zone
logf.V(logf.DebugLevel).Infof("Returning discovered zone record %q for fqdn %q", zone, fqdn)
return zone, nil
}
}
}
return "", fmt.Errorf("Could not find the SOA record in the DNS tree for the domain '%s' using nameservers %v", fqdn, nameservers)
}
// dnsMsgContainsCNAME checks for a CNAME answer in msg
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
for _, ans := range msg.Answer {
if _, ok := ans.(*dns.CNAME); ok {
return true
}
}
return false
}
// ToFqdn converts the name into a fqdn appending a trailing dot.
func ToFqdn(name string) string {
n := len(name)
if n == 0 || name[n-1] == '.' {
return name
}
return name + "."
}
// UnFqdn converts the fqdn into a name removing the trailing dot.
func UnFqdn(name string) string {
n := len(name)
if n != 0 && name[n-1] == '.' {
return name[:n-1]
}
return name
}
// WaitFor polls the given function 'f', once every 'interval', up to 'timeout'.
func WaitFor(timeout, interval time.Duration, f func() (bool, error)) error {
var lastErr string
timeup := time.After(timeout)
for {
select {
case <-timeup:
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
default:
}
stop, err := f()
if stop {
return nil
}
if err != nil {
lastErr = err.Error()
}
time.Sleep(interval)
}
}