Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DNS check #116

Merged
merged 2 commits into from
Feb 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
155 changes: 117 additions & 38 deletions acme/dns_challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ import (
"errors"
"fmt"
"log"
"net"
"strings"
"time"

"github.com/miekg/dns"
)

type preCheckDNSFunc func(domain, fqdn string) bool
type preCheckDNSFunc func(fqdn, value string) (bool, error)

var preCheckDNS preCheckDNSFunc = checkDNS
var preCheckDNS preCheckDNSFunc = checkDnsPropagation

var preCheckDNSFallbackCount = 5
var recursiveNameserver = "google-public-dns-a.google.com"

// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
Expand Down Expand Up @@ -60,50 +61,125 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
}
}()

fqdn, _, _ := DNS01Record(domain, keyAuth)
fqdn, value, _ := DNS01Record(domain, keyAuth)

preCheckDNS(domain, fqdn)
logf("[INFO][%s] Checking DNS record propagation...", domain)

err = waitFor(30, 2, func() (bool, error) {
return preCheckDNS(fqdn, value)
})
if err != nil {
return err
}

return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}

func checkDNS(domain, fqdn string) bool {
// check if the expected DNS entry was created. If not wait for some time and try again.
m := new(dns.Msg)
m.SetQuestion(domain+".", dns.TypeSOA)
c := new(dns.Client)
in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53")
// checkDnsPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func checkDnsPropagation(fqdn, value string) (bool, error) {
// Initial attempt to resolve at the recursive NS
r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameserver, true)
if err != nil {
return false
return false, err
}
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("Could not resolve %s -> %s", fqdn, dns.RcodeToString[r.Rcode])
}

var authorativeNS string
for _, answ := range in.Answer {
soa := answ.(*dns.SOA)
authorativeNS = soa.Ns
// If we see a CNAME here then use the alias
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if cn.Hdr.Name == fqdn {
fqdn = cn.Target
break
}
}
}

fallbackCnt := 0
for fallbackCnt < preCheckDNSFallbackCount {
m.SetQuestion(fqdn, dns.TypeTXT)
in, _, err = c.Exchange(m, authorativeNS+":53")
authoritativeNss, err := lookupNameservers(fqdn)
if err != nil {
return false, err
}

return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}

// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, ns, false)
if err != nil {
return false
return false, err
}

if len(in.Answer) > 0 {
return true
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}

fallbackCnt++
if fallbackCnt >= preCheckDNSFallbackCount {
return false
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
if strings.Join(txt.Txt, "") == value {
found = true
break
}
}
}

time.Sleep(time.Second * time.Duration(fallbackCnt))
if !found {
return false, fmt.Errorf("NS %s did not return the expected TXT record", ns)
}
}

return false
return true, nil
}

// dnsQuery sends a DNS query to the given nameserver.
func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)

if !recursive {
m.RecursionDesired = false
}

in, err = dns.Exchange(m, net.JoinHostPort(nameserver, "53"))
if err == dns.ErrTruncated {
tcp := &dns.Client{Net: "tcp"}
in, _, err = tcp.Exchange(m, nameserver)
}

return
}

// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string) ([]string, error) {
var authoritativeNss []string

r, err := dnsQuery(fqdn, dns.TypeNS, recursiveNameserver, true)
if err != nil {
return nil, err
}

for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}

if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}

// Strip of the left most label to get the parent domain.
offset, _ := dns.NextLabel(fqdn, 0)
next := fqdn[offset:]
if dns.CountLabel(next) < 2 {
return nil, fmt.Errorf("Could not determine authoritative nameservers")
}

return lookupNameservers(next)
}

// toFqdn converts the name into a fqdn appending a trailing dot.
Expand All @@ -124,22 +200,25 @@ func unFqdn(name string) string {
return name
}

// waitFor polls the given function 'f', once per second, up to 'timeout' seconds.
func waitFor(timeout int, f func() (bool, error)) error {
start := time.Now().Second()
// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds.
func waitFor(timeout, interval int, f func() (bool, error)) error {
var lastErr string
timeup := time.After(time.Duration(timeout) * time.Second)
for {
time.Sleep(1 * time.Second)

if delta := time.Now().Second() - start; delta >= timeout {
return fmt.Errorf("Time limit exceeded (%d seconds)", delta)
select {
case <-timeup:
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
default:
}

stop, err := f()
if err != nil {
return err
}
if stop {
return nil
}
if err != nil {
lastErr = err.Error()
}

time.Sleep(time.Duration(interval) * time.Second)
}
}
2 changes: 1 addition & 1 deletion acme/dns_challenge_route53.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e
return err
}

return waitFor(90, func() (bool, error) {
return waitFor(90, 5, func() (bool, error) {
status, err := r.client.GetChange(resp.ChangeInfo.ID)
if err != nil {
return false, err
Expand Down
141 changes: 138 additions & 3 deletions acme/dns_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,75 @@ import (
"net/http"
"net/http/httptest"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
)

var lookupNameserversTestsOK = []struct {
fqdn string
nss []string
}{
{"books.google.com.ng.",
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
},
{"www.google.com.",
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
},
{"physics.georgetown.edu.",
[]string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
},
}

var lookupNameserversTestsErr = []struct {
fqdn string
error string
}{
// invalid tld
{"_null.n0n0.",
"Could not determine authoritative nameservers",
},
// invalid domain
{"_null.com.",
"Could not determine authoritative nameservers",
},
}

var checkAuthoritativeNssTests = []struct {
fqdn, value string
ns []string
ok bool
}{
// TXT RR w/ expected value
{"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."},
true,
},
// No TXT RR
{"ns1.google.com.", "", []string{"ns2.google.com."},
false,
},
}

var checkAuthoritativeNssTestsErr = []struct {
fqdn, value string
ns []string
error string
}{
// TXT RR /w unexpected value
{"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."},
"did not return the expected TXT record",
},
// No TXT RR
{"ns1.google.com.", "fe01=", []string{"ns2.google.com."},
"did not return the expected TXT record",
},
}

func TestDNSValidServerResponse(t *testing.T) {
preCheckDNS = func(domain, fqdn string) bool {
return true
preCheckDNS = func(fqdn, value string) (bool, error) {
return true, nil
}
privKey, _ := generatePrivateKey(rsakey, 512)

Expand All @@ -39,7 +101,80 @@ func TestDNSValidServerResponse(t *testing.T) {
}

func TestPreCheckDNS(t *testing.T) {
if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") {
ok, err := preCheckDNS("acme-staging.api.letsencrypt.org", "fe01=")
if err != nil || !ok {
t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
}
}

func TestLookupNameserversOK(t *testing.T) {
for _, tt := range lookupNameserversTestsOK {
nss, err := lookupNameservers(tt.fqdn)
if err != nil {
t.Fatalf("#%s: got %q; want nil", tt.fqdn, err)
}

sort.Strings(nss)
sort.Strings(tt.nss)

if !reflect.DeepEqual(nss, tt.nss) {
t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss)
}
}
}

func TestLookupNameserversErr(t *testing.T) {
for _, tt := range lookupNameserversTestsErr {
_, err := lookupNameservers(tt.fqdn)
if err == nil {
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
}

if !strings.Contains(err.Error(), tt.error) {
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
continue
}
}
}

func TestCheckAuthoritativeNss(t *testing.T) {
for _, tt := range checkAuthoritativeNssTests {
ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
if ok != tt.ok {
t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok)
}
}
}

func TestCheckAuthoritativeNssErr(t *testing.T) {
for _, tt := range checkAuthoritativeNssTestsErr {
_, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
if err == nil {
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
}
if !strings.Contains(err.Error(), tt.error) {
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
continue
}
}
}

func TestWaitForTimeout(t *testing.T) {
c := make(chan error)
go func() {
err := waitFor(3, 1, func() (bool, error) {
return false, nil
})
c <- err
}()

timeout := time.After(4 * time.Second)
select {
case <-timeout:
t.Fatal("timeout exceeded")
case err := <-c:
if err == nil {
t.Errorf("expected timeout error; got <nil>", err)
}
}
}