diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index f34fcccc06..63bd08bc9f 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -6,17 +6,18 @@ import ( "errors" "fmt" "log" + "net" "strings" "time" "github.com/miekg/dns" ) -type preCheckDNSFunc func(domain, fqdn string) bool +type preCheckDNSFunc func(domain, fqdn, value string) error -var preCheckDNS preCheckDNSFunc = checkDNS +var preCheckDNS preCheckDNSFunc = checkDnsPropagation -var preCheckDNSFallbackCount = 5 +var recursionMaxDepth = 10 // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) { @@ -60,50 +61,141 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error { } }() - fqdn, _, _ := DNS01Record(domain, keyAuth) + fqdn, value, _ := DNS01Record(domain, keyAuth) - preCheckDNS(domain, fqdn) + logf("[INFO][%s] Checking DNS record propagation...", domain) + + if err = preCheckDNS(domain, fqdn, value); err != nil { + return err + } return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) } -func checkDNS(domain, fqdn string) bool { - // check if the expected DNS entry was created. If not wait for some time and try again. - m := new(dns.Msg) - m.SetQuestion(domain+".", dns.TypeSOA) - c := new(dns.Client) - in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53") +// checkDnsPropagation checks if the expected DNS entry has been propagated to +// all authoritative nameservers. If not it waits and retries for some time. +func checkDnsPropagation(domain, fqdn, value string) error { + authoritativeNss, err := lookupNameservers(toFqdn(domain)) if err != nil { - return false + return err } - var authorativeNS string - for _, answ := range in.Answer { - soa := answ.(*dns.SOA) - authorativeNS = soa.Ns + if err = waitFor(30, 2, func() (bool, error) { + return checkAuthoritativeNss(fqdn, value, authoritativeNss) + }); err != nil { + return err } - fallbackCnt := 0 - for fallbackCnt < preCheckDNSFallbackCount { - m.SetQuestion(fqdn, dns.TypeTXT) - in, _, err = c.Exchange(m, authorativeNS+":53") + return nil +} + +// checkAuthoritativeNss checks whether a TXT record with fqdn and value exists on every given nameserver. +func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) { + for _, ns := range nameservers { + r, err := dnsQuery(fqdn, dns.TypeTXT, net.JoinHostPort(ns, "53")) if err != nil { - return false + return false, err } - if len(in.Answer) > 0 { - return true + if r.Rcode != dns.RcodeSuccess { + return false, fmt.Errorf("%s returned RCode %s", ns, dns.RcodeToString[r.Rcode]) } - fallbackCnt++ - if fallbackCnt >= preCheckDNSFallbackCount { - return false + var found bool + for _, rr := range r.Answer { + if txt, ok := rr.(*dns.TXT); ok { + if strings.Join(txt.Txt, "") == value { + found = true + break + } + } + } + if !found { + return false, fmt.Errorf("%s did not return the expected TXT record", ns) } + } + + return true, nil +} - time.Sleep(time.Second * time.Duration(fallbackCnt)) +// dnsQuery directly queries the given authoritative nameserver. +func dnsQuery(fqdn string, rtype uint16, nameserver string) (in *dns.Msg, err error) { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + m.RecursionDesired = false + + in, err = dns.Exchange(m, nameserver) + if err == dns.ErrTruncated { + tcp := &dns.Client{Net: "tcp"} + in, _, err = tcp.Exchange(m, nameserver) } - return false + return +} + +// lookupNameservers returns the authoritative nameservers for the given domain name. +func lookupNameservers(fqdn string) ([]string, error) { + var referralNameservers []string + + // We start recursion at the gTLD origin + // so we don't have to manage root hints + labels := dns.SplitDomainName(fqdn) + tld := labels[len(labels)-1] + nss, err := net.LookupNS(tld) + if err != nil { + return nil, fmt.Errorf("Could not resolve TLD %s %v", tld, err) + } + + for _, ns := range nss { + referralNameservers = append(referralNameservers, ns.Host) + } + + // Follow the referrals until we hit the authoritative NS + for depth := 0; depth < recursionMaxDepth; depth++ { + var r *dns.Msg + var err error + + for _, ns := range referralNameservers { + r, err = dnsQuery(fqdn, dns.TypeNS, net.JoinHostPort(ns, "53")) + if err != nil { + continue + } + + if r.Rcode == dns.RcodeSuccess { + break + } + + if r.Rcode == dns.RcodeNameError { + return nil, fmt.Errorf("Could not resolve NXDOMAIN %s", fqdn) + } + } + + if r == nil { + break + } + + if r.Authoritative { + // We got an authoritative reply, which means that the + // last referral holds the authoritative nameservers. + return referralNameservers, nil + } + + referralNameservers = nil + + for _, rr := range r.Ns { + if ns, ok := rr.(*dns.NS); ok { + referralNameservers = append(referralNameservers, strings.ToLower(ns.Ns)) + } + } + + // No referrals to follow + if len(referralNameservers) == 0 { + break + } + } + + return nil, fmt.Errorf("Could not determine nameservers for %s", fqdn) } // toFqdn converts the name into a fqdn appending a trailing dot. @@ -124,22 +216,25 @@ func unFqdn(name string) string { return name } -// waitFor polls the given function 'f', once per second, up to 'timeout' seconds. -func waitFor(timeout int, f func() (bool, error)) error { - start := time.Now().Second() +// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds. +func waitFor(timeout, interval int, f func() (bool, error)) error { + var lastErr string + timeup := time.After(time.Duration(timeout) * time.Second) for { - time.Sleep(1 * time.Second) - - if delta := time.Now().Second() - start; delta >= timeout { - return fmt.Errorf("Time limit exceeded (%d seconds)", delta) + select { + case <-timeup: + return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr) + default: } stop, err := f() - if err != nil { - return err - } if stop { return nil } + if err != nil { + lastErr = err.Error() + } + + time.Sleep(time.Duration(interval) * time.Second) } } diff --git a/acme/dns_challenge_route53.go b/acme/dns_challenge_route53.go index 28ed0259a6..491117e2b8 100644 --- a/acme/dns_challenge_route53.go +++ b/acme/dns_challenge_route53.go @@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e return err } - return waitFor(90, func() (bool, error) { + return waitFor(90, 5, func() (bool, error) { status, err := r.client.GetChange(resp.ChangeInfo.ID) if err != nil { return false, err diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go index 0af40f71fd..ccc5890134 100644 --- a/acme/dns_challenge_test.go +++ b/acme/dns_challenge_test.go @@ -6,13 +6,79 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" + "sort" + "strings" "testing" "time" ) +var lookupNameserversTestsOK = []struct { + fqdn string + nss []string +}{ + {"books.google.com.ng.", + []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + {"www.google.com.", + []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + {"physics.georgetown.edu.", + []string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."}, + }, +} + +var lookupNameserversTestsErr = []struct { + fqdn string + error string +}{ + // invalid tld + {"_null.n0n0.", + "Could not resolve TLD", + }, + // invalid domain + {"_null.com.", + "Could not resolve NXDOMAIN", + }, + // invalid subdomain + {"_null.google.com.", + "Could not resolve NXDOMAIN", + }, +} + +var checkAuthoritativeNssTests = []struct { + fqdn, value string + ns []string + ok bool +}{ + // TXT RR w/ expected value + {"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."}, + true, + }, + // No TXT RR + {"ns1.google.com.", "", []string{"ns2.google.com."}, + false, + }, +} + +var checkAuthoritativeNssTestsErr = []struct { + fqdn, value string + ns []string + error string +}{ + // TXT RR /w unexpected value + {"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."}, + "did not return the expected TXT record", + }, + // No TXT RR + {"ns1.google.com.", "fe01=", []string{"ns2.google.com."}, + "did not return the expected TXT record", + }, +} + func TestDNSValidServerResponse(t *testing.T) { - preCheckDNS = func(domain, fqdn string) bool { - return true + preCheckDNS = func(domain, fqdn, value string) error { + return nil } privKey, _ := generatePrivateKey(rsakey, 512) @@ -39,7 +105,80 @@ func TestDNSValidServerResponse(t *testing.T) { } func TestPreCheckDNS(t *testing.T) { - if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") { + err := preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org", "fe01=") + if err != nil { t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org") } } + +func TestLookupNameserversOK(t *testing.T) { + for _, tt := range lookupNameserversTestsOK { + nss, err := lookupNameservers(tt.fqdn) + if err != nil { + t.Fatalf("#%s: got %q; want nil", tt.fqdn, err) + } + + sort.Strings(nss) + sort.Strings(tt.nss) + + if !reflect.DeepEqual(nss, tt.nss) { + t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss) + } + } +} + +func TestLookupNameserversErr(t *testing.T) { + for _, tt := range lookupNameserversTestsErr { + _, err := lookupNameservers(tt.fqdn) + if err == nil { + t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) + } + + if !strings.Contains(err.Error(), tt.error) { + t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) + continue + } + } +} + +func TestCheckAuthoritativeNss(t *testing.T) { + for _, tt := range checkAuthoritativeNssTests { + ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) + if ok != tt.ok { + t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok) + } + } +} + +func TestCheckAuthoritativeNssErr(t *testing.T) { + for _, tt := range checkAuthoritativeNssTestsErr { + _, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) + if err == nil { + t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) + } + if !strings.Contains(err.Error(), tt.error) { + t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) + continue + } + } +} + +func TestWaitForTimeout(t *testing.T) { + c := make(chan error) + go func() { + err := waitFor(3, 1, func() (bool, error) { + return false, nil + }) + c <- err + }() + + timeout := time.After(4 * time.Second) + select { + case <-timeout: + t.Fatal("timeout exceeded") + case err := <-c: + if err == nil { + t.Errorf("expected timeout error; got ", err) + } + } +}