From a3f60f0f756680fe58b39a227e69250aee804edc Mon Sep 17 00:00:00 2001 From: Jan Broer Date: Wed, 3 Feb 2016 05:03:03 +0100 Subject: [PATCH] Refactor DNS check * Implements a (semi-)recursive query mechanism to find all authoritative nameservers for the zone enclosing the domain. This is necessary because Google Public DNS does sometimes not return any NS at all or those of a different zone (e.g. when the domain is a CNAME pointing elsewhere). * All authoritative nameservers are then queried for the TXT record to make sure that it has been completely propagated. * The waitFor utility function now takes an interval parameter and returns the last seen error on timeout. --- acme/dns_challenge.go | 171 ++++++++++++++++++++++++++-------- acme/dns_challenge_route53.go | 2 +- acme/dns_challenge_test.go | 145 +++++++++++++++++++++++++++- 3 files changed, 276 insertions(+), 42 deletions(-) diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index f34fcccc06..63bd08bc9f 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -6,17 +6,18 @@ import ( "errors" "fmt" "log" + "net" "strings" "time" "github.com/miekg/dns" ) -type preCheckDNSFunc func(domain, fqdn string) bool +type preCheckDNSFunc func(domain, fqdn, value string) error -var preCheckDNS preCheckDNSFunc = checkDNS +var preCheckDNS preCheckDNSFunc = checkDnsPropagation -var preCheckDNSFallbackCount = 5 +var recursionMaxDepth = 10 // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) { @@ -60,50 +61,141 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error { } }() - fqdn, _, _ := DNS01Record(domain, keyAuth) + fqdn, value, _ := DNS01Record(domain, keyAuth) - preCheckDNS(domain, fqdn) + logf("[INFO][%s] Checking DNS record propagation...", domain) + + if err = preCheckDNS(domain, fqdn, value); err != nil { + return err + } return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) } -func checkDNS(domain, fqdn string) bool { - // check if the expected DNS entry was created. If not wait for some time and try again. - m := new(dns.Msg) - m.SetQuestion(domain+".", dns.TypeSOA) - c := new(dns.Client) - in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53") +// checkDnsPropagation checks if the expected DNS entry has been propagated to +// all authoritative nameservers. If not it waits and retries for some time. +func checkDnsPropagation(domain, fqdn, value string) error { + authoritativeNss, err := lookupNameservers(toFqdn(domain)) if err != nil { - return false + return err } - var authorativeNS string - for _, answ := range in.Answer { - soa := answ.(*dns.SOA) - authorativeNS = soa.Ns + if err = waitFor(30, 2, func() (bool, error) { + return checkAuthoritativeNss(fqdn, value, authoritativeNss) + }); err != nil { + return err } - fallbackCnt := 0 - for fallbackCnt < preCheckDNSFallbackCount { - m.SetQuestion(fqdn, dns.TypeTXT) - in, _, err = c.Exchange(m, authorativeNS+":53") + return nil +} + +// checkAuthoritativeNss checks whether a TXT record with fqdn and value exists on every given nameserver. +func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) { + for _, ns := range nameservers { + r, err := dnsQuery(fqdn, dns.TypeTXT, net.JoinHostPort(ns, "53")) if err != nil { - return false + return false, err } - if len(in.Answer) > 0 { - return true + if r.Rcode != dns.RcodeSuccess { + return false, fmt.Errorf("%s returned RCode %s", ns, dns.RcodeToString[r.Rcode]) } - fallbackCnt++ - if fallbackCnt >= preCheckDNSFallbackCount { - return false + var found bool + for _, rr := range r.Answer { + if txt, ok := rr.(*dns.TXT); ok { + if strings.Join(txt.Txt, "") == value { + found = true + break + } + } + } + if !found { + return false, fmt.Errorf("%s did not return the expected TXT record", ns) } + } + + return true, nil +} - time.Sleep(time.Second * time.Duration(fallbackCnt)) +// dnsQuery directly queries the given authoritative nameserver. +func dnsQuery(fqdn string, rtype uint16, nameserver string) (in *dns.Msg, err error) { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + m.RecursionDesired = false + + in, err = dns.Exchange(m, nameserver) + if err == dns.ErrTruncated { + tcp := &dns.Client{Net: "tcp"} + in, _, err = tcp.Exchange(m, nameserver) } - return false + return +} + +// lookupNameservers returns the authoritative nameservers for the given domain name. +func lookupNameservers(fqdn string) ([]string, error) { + var referralNameservers []string + + // We start recursion at the gTLD origin + // so we don't have to manage root hints + labels := dns.SplitDomainName(fqdn) + tld := labels[len(labels)-1] + nss, err := net.LookupNS(tld) + if err != nil { + return nil, fmt.Errorf("Could not resolve TLD %s %v", tld, err) + } + + for _, ns := range nss { + referralNameservers = append(referralNameservers, ns.Host) + } + + // Follow the referrals until we hit the authoritative NS + for depth := 0; depth < recursionMaxDepth; depth++ { + var r *dns.Msg + var err error + + for _, ns := range referralNameservers { + r, err = dnsQuery(fqdn, dns.TypeNS, net.JoinHostPort(ns, "53")) + if err != nil { + continue + } + + if r.Rcode == dns.RcodeSuccess { + break + } + + if r.Rcode == dns.RcodeNameError { + return nil, fmt.Errorf("Could not resolve NXDOMAIN %s", fqdn) + } + } + + if r == nil { + break + } + + if r.Authoritative { + // We got an authoritative reply, which means that the + // last referral holds the authoritative nameservers. + return referralNameservers, nil + } + + referralNameservers = nil + + for _, rr := range r.Ns { + if ns, ok := rr.(*dns.NS); ok { + referralNameservers = append(referralNameservers, strings.ToLower(ns.Ns)) + } + } + + // No referrals to follow + if len(referralNameservers) == 0 { + break + } + } + + return nil, fmt.Errorf("Could not determine nameservers for %s", fqdn) } // toFqdn converts the name into a fqdn appending a trailing dot. @@ -124,22 +216,25 @@ func unFqdn(name string) string { return name } -// waitFor polls the given function 'f', once per second, up to 'timeout' seconds. -func waitFor(timeout int, f func() (bool, error)) error { - start := time.Now().Second() +// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds. +func waitFor(timeout, interval int, f func() (bool, error)) error { + var lastErr string + timeup := time.After(time.Duration(timeout) * time.Second) for { - time.Sleep(1 * time.Second) - - if delta := time.Now().Second() - start; delta >= timeout { - return fmt.Errorf("Time limit exceeded (%d seconds)", delta) + select { + case <-timeup: + return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr) + default: } stop, err := f() - if err != nil { - return err - } if stop { return nil } + if err != nil { + lastErr = err.Error() + } + + time.Sleep(time.Duration(interval) * time.Second) } } diff --git a/acme/dns_challenge_route53.go b/acme/dns_challenge_route53.go index 28ed0259a6..491117e2b8 100644 --- a/acme/dns_challenge_route53.go +++ b/acme/dns_challenge_route53.go @@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e return err } - return waitFor(90, func() (bool, error) { + return waitFor(90, 5, func() (bool, error) { status, err := r.client.GetChange(resp.ChangeInfo.ID) if err != nil { return false, err diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go index 0af40f71fd..ccc5890134 100644 --- a/acme/dns_challenge_test.go +++ b/acme/dns_challenge_test.go @@ -6,13 +6,79 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" + "sort" + "strings" "testing" "time" ) +var lookupNameserversTestsOK = []struct { + fqdn string + nss []string +}{ + {"books.google.com.ng.", + []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + {"www.google.com.", + []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + {"physics.georgetown.edu.", + []string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."}, + }, +} + +var lookupNameserversTestsErr = []struct { + fqdn string + error string +}{ + // invalid tld + {"_null.n0n0.", + "Could not resolve TLD", + }, + // invalid domain + {"_null.com.", + "Could not resolve NXDOMAIN", + }, + // invalid subdomain + {"_null.google.com.", + "Could not resolve NXDOMAIN", + }, +} + +var checkAuthoritativeNssTests = []struct { + fqdn, value string + ns []string + ok bool +}{ + // TXT RR w/ expected value + {"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."}, + true, + }, + // No TXT RR + {"ns1.google.com.", "", []string{"ns2.google.com."}, + false, + }, +} + +var checkAuthoritativeNssTestsErr = []struct { + fqdn, value string + ns []string + error string +}{ + // TXT RR /w unexpected value + {"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."}, + "did not return the expected TXT record", + }, + // No TXT RR + {"ns1.google.com.", "fe01=", []string{"ns2.google.com."}, + "did not return the expected TXT record", + }, +} + func TestDNSValidServerResponse(t *testing.T) { - preCheckDNS = func(domain, fqdn string) bool { - return true + preCheckDNS = func(domain, fqdn, value string) error { + return nil } privKey, _ := generatePrivateKey(rsakey, 512) @@ -39,7 +105,80 @@ func TestDNSValidServerResponse(t *testing.T) { } func TestPreCheckDNS(t *testing.T) { - if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") { + err := preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org", "fe01=") + if err != nil { t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org") } } + +func TestLookupNameserversOK(t *testing.T) { + for _, tt := range lookupNameserversTestsOK { + nss, err := lookupNameservers(tt.fqdn) + if err != nil { + t.Fatalf("#%s: got %q; want nil", tt.fqdn, err) + } + + sort.Strings(nss) + sort.Strings(tt.nss) + + if !reflect.DeepEqual(nss, tt.nss) { + t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss) + } + } +} + +func TestLookupNameserversErr(t *testing.T) { + for _, tt := range lookupNameserversTestsErr { + _, err := lookupNameservers(tt.fqdn) + if err == nil { + t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) + } + + if !strings.Contains(err.Error(), tt.error) { + t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) + continue + } + } +} + +func TestCheckAuthoritativeNss(t *testing.T) { + for _, tt := range checkAuthoritativeNssTests { + ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) + if ok != tt.ok { + t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok) + } + } +} + +func TestCheckAuthoritativeNssErr(t *testing.T) { + for _, tt := range checkAuthoritativeNssTestsErr { + _, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) + if err == nil { + t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) + } + if !strings.Contains(err.Error(), tt.error) { + t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) + continue + } + } +} + +func TestWaitForTimeout(t *testing.T) { + c := make(chan error) + go func() { + err := waitFor(3, 1, func() (bool, error) { + return false, nil + }) + c <- err + }() + + timeout := time.After(4 * time.Second) + select { + case <-timeout: + t.Fatal("timeout exceeded") + case err := <-c: + if err == nil { + t.Errorf("expected timeout error; got ", err) + } + } +}