Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DNS check #96

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
171 changes: 133 additions & 38 deletions acme/dns_challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ import (
"errors"
"fmt"
"log"
"net"
"strings"
"time"

"github.com/miekg/dns"
)

type preCheckDNSFunc func(domain, fqdn string) bool
type preCheckDNSFunc func(domain, fqdn, value string) error

var preCheckDNS preCheckDNSFunc = checkDNS
var preCheckDNS preCheckDNSFunc = checkDnsPropagation

var preCheckDNSFallbackCount = 5
var recursionMaxDepth = 10

// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
Expand Down Expand Up @@ -60,50 +61,141 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
}
}()

fqdn, _, _ := DNS01Record(domain, keyAuth)
fqdn, value, _ := DNS01Record(domain, keyAuth)

preCheckDNS(domain, fqdn)
logf("[INFO][%s] Checking DNS record propagation...", domain)

if err = preCheckDNS(domain, fqdn, value); err != nil {
return err
}

return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}

func checkDNS(domain, fqdn string) bool {
// check if the expected DNS entry was created. If not wait for some time and try again.
m := new(dns.Msg)
m.SetQuestion(domain+".", dns.TypeSOA)
c := new(dns.Client)
in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53")
// checkDnsPropagation checks if the expected DNS entry has been propagated to
// all authoritative nameservers. If not it waits and retries for some time.
func checkDnsPropagation(domain, fqdn, value string) error {
authoritativeNss, err := lookupNameservers(toFqdn(domain))
if err != nil {
return false
return err
}

var authorativeNS string
for _, answ := range in.Answer {
soa := answ.(*dns.SOA)
authorativeNS = soa.Ns
if err = waitFor(30, 2, func() (bool, error) {
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}); err != nil {
return err
}

fallbackCnt := 0
for fallbackCnt < preCheckDNSFallbackCount {
m.SetQuestion(fqdn, dns.TypeTXT)
in, _, err = c.Exchange(m, authorativeNS+":53")
return nil
}

// checkAuthoritativeNss checks whether a TXT record with fqdn and value exists on every given nameserver.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, net.JoinHostPort(ns, "53"))
if err != nil {
return false
return false, err
}

if len(in.Answer) > 0 {
return true
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("%s returned RCode %s", ns, dns.RcodeToString[r.Rcode])
}

fallbackCnt++
if fallbackCnt >= preCheckDNSFallbackCount {
return false
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
if strings.Join(txt.Txt, "") == value {
found = true
break
}
}
}
if !found {
return false, fmt.Errorf("%s did not return the expected TXT record", ns)
}
}

return true, nil
}

time.Sleep(time.Second * time.Duration(fallbackCnt))
// dnsQuery directly queries the given authoritative nameserver.
func dnsQuery(fqdn string, rtype uint16, nameserver string) (in *dns.Msg, err error) {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
m.RecursionDesired = false

in, err = dns.Exchange(m, nameserver)
if err == dns.ErrTruncated {
tcp := &dns.Client{Net: "tcp"}
in, _, err = tcp.Exchange(m, nameserver)
}

return false
return
}

// lookupNameservers returns the authoritative nameservers for the given domain name.
func lookupNameservers(fqdn string) ([]string, error) {
var referralNameservers []string

// We start recursion at the gTLD origin
// so we don't have to manage root hints
labels := dns.SplitDomainName(fqdn)
tld := labels[len(labels)-1]
nss, err := net.LookupNS(tld)
if err != nil {
return nil, fmt.Errorf("Could not resolve TLD %s %v", tld, err)
}

for _, ns := range nss {
referralNameservers = append(referralNameservers, ns.Host)
}

// Follow the referrals until we hit the authoritative NS
for depth := 0; depth < recursionMaxDepth; depth++ {
var r *dns.Msg
var err error

for _, ns := range referralNameservers {
r, err = dnsQuery(fqdn, dns.TypeNS, net.JoinHostPort(ns, "53"))
if err != nil {
continue
}

if r.Rcode == dns.RcodeSuccess {
break
}

if r.Rcode == dns.RcodeNameError {
return nil, fmt.Errorf("Could not resolve NXDOMAIN %s", fqdn)
}
}

if r == nil {
break
}

if r.Authoritative {
// We got an authoritative reply, which means that the
// last referral holds the authoritative nameservers.
return referralNameservers, nil
}

referralNameservers = nil

for _, rr := range r.Ns {
if ns, ok := rr.(*dns.NS); ok {
referralNameservers = append(referralNameservers, strings.ToLower(ns.Ns))
}
}

// No referrals to follow
if len(referralNameservers) == 0 {
break
}
}

return nil, fmt.Errorf("Could not determine nameservers for %s", fqdn)
}

// toFqdn converts the name into a fqdn appending a trailing dot.
Expand All @@ -124,22 +216,25 @@ func unFqdn(name string) string {
return name
}

// waitFor polls the given function 'f', once per second, up to 'timeout' seconds.
func waitFor(timeout int, f func() (bool, error)) error {
start := time.Now().Second()
// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds.
func waitFor(timeout, interval int, f func() (bool, error)) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just thinking out loud here... but maybe we should add a util.go for stuff like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was on my mind too! 😄 But maybe we can do the code re-organization later in a separate patch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree... out of scope of this PR :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any other stuff on your mind that ought to move to the util.go?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, limitReader for example. I'm also thinking about crypto.go as there are multiple functions in there which are mere utility functions.

var lastErr string
timeup := time.After(time.Duration(timeout) * time.Second)
for {
time.Sleep(1 * time.Second)

if delta := time.Now().Second() - start; delta >= timeout {
return fmt.Errorf("Time limit exceeded (%d seconds)", delta)
select {
case <-timeup:
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
default:
}

stop, err := f()
if err != nil {
return err
}
if stop {
return nil
}
if err != nil {
lastErr = err.Error()
}

time.Sleep(time.Duration(interval) * time.Second)
}
}
2 changes: 1 addition & 1 deletion acme/dns_challenge_route53.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e
return err
}

return waitFor(90, func() (bool, error) {
return waitFor(90, 5, func() (bool, error) {
status, err := r.client.GetChange(resp.ChangeInfo.ID)
if err != nil {
return false, err
Expand Down
145 changes: 142 additions & 3 deletions acme/dns_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,79 @@ import (
"net/http"
"net/http/httptest"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
)

var lookupNameserversTestsOK = []struct {
fqdn string
nss []string
}{
{"books.google.com.ng.",
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
},
{"www.google.com.",
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
},
{"physics.georgetown.edu.",
[]string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
},
}

var lookupNameserversTestsErr = []struct {
fqdn string
error string
}{
// invalid tld
{"_null.n0n0.",
"Could not resolve TLD",
},
// invalid domain
{"_null.com.",
"Could not resolve NXDOMAIN",
},
// invalid subdomain
{"_null.google.com.",
"Could not resolve NXDOMAIN",
},
}

var checkAuthoritativeNssTests = []struct {
fqdn, value string
ns []string
ok bool
}{
// TXT RR w/ expected value
{"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."},
true,
},
// No TXT RR
{"ns1.google.com.", "", []string{"ns2.google.com."},
false,
},
}

var checkAuthoritativeNssTestsErr = []struct {
fqdn, value string
ns []string
error string
}{
// TXT RR /w unexpected value
{"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."},
"did not return the expected TXT record",
},
// No TXT RR
{"ns1.google.com.", "fe01=", []string{"ns2.google.com."},
"did not return the expected TXT record",
},
}

func TestDNSValidServerResponse(t *testing.T) {
preCheckDNS = func(domain, fqdn string) bool {
return true
preCheckDNS = func(domain, fqdn, value string) error {
return nil
}
privKey, _ := generatePrivateKey(rsakey, 512)

Expand All @@ -39,7 +105,80 @@ func TestDNSValidServerResponse(t *testing.T) {
}

func TestPreCheckDNS(t *testing.T) {
if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") {
err := preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org", "fe01=")
if err != nil {
t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
}
}

func TestLookupNameserversOK(t *testing.T) {
for _, tt := range lookupNameserversTestsOK {
nss, err := lookupNameservers(tt.fqdn)
if err != nil {
t.Fatalf("#%s: got %q; want nil", tt.fqdn, err)
}

sort.Strings(nss)
sort.Strings(tt.nss)

if !reflect.DeepEqual(nss, tt.nss) {
t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss)
}
}
}

func TestLookupNameserversErr(t *testing.T) {
for _, tt := range lookupNameserversTestsErr {
_, err := lookupNameservers(tt.fqdn)
if err == nil {
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
}

if !strings.Contains(err.Error(), tt.error) {
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
continue
}
}
}

func TestCheckAuthoritativeNss(t *testing.T) {
for _, tt := range checkAuthoritativeNssTests {
ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
if ok != tt.ok {
t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok)
}
}
}

func TestCheckAuthoritativeNssErr(t *testing.T) {
for _, tt := range checkAuthoritativeNssTestsErr {
_, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
if err == nil {
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
}
if !strings.Contains(err.Error(), tt.error) {
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
continue
}
}
}

func TestWaitForTimeout(t *testing.T) {
c := make(chan error)
go func() {
err := waitFor(3, 1, func() (bool, error) {
return false, nil
})
c <- err
}()

timeout := time.After(4 * time.Second)
select {
case <-timeout:
t.Fatal("timeout exceeded")
case err := <-c:
if err == nil {
t.Errorf("expected timeout error; got <nil>", err)
}
}
}