diff --git a/xfr.go b/xfr.go index 05b3c5add..91080109e 100644 --- a/xfr.go +++ b/xfr.go @@ -1,6 +1,7 @@ package dns import ( + "crypto/tls" "fmt" "time" ) @@ -20,6 +21,7 @@ type Transfer struct { TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) tsigTimersOnly bool + TLS *tls.Config // TLS config. If Xfr over TLS will be attempted } func (t *Transfer) tsigProvider() TsigProvider { @@ -57,7 +59,11 @@ func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { } if t.Conn == nil { - t.Conn, err = DialTimeout("tcp", a, timeout) + if t.TLS != nil { + t.Conn, err = DialTimeoutWithTLS("tcp-tls", a, t.TLS, timeout) + } else { + t.Conn, err = DialTimeout("tcp", a, timeout) + } if err != nil { return nil, err } diff --git a/xfr_test.go b/xfr_test.go index f6c5e98cc..04801a2ec 100644 --- a/xfr_test.go +++ b/xfr_test.go @@ -1,6 +1,7 @@ package dns import ( + "crypto/tls" "testing" "time" ) @@ -87,6 +88,27 @@ func TestSingleEnvelopeXfr(t *testing.T) { axfrTestingSuite(t, addrstr) } +func TestSingleEnvelopeXfrTLS(t *testing.T) { + HandleFunc("miek.nl.", SingleEnvelopeXfrServer) + defer HandleRemove("miek.nl.") + + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + tlsConfig := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + s, addrstr, _, err := RunLocalTLSServer(":0", &tlsConfig) + if err != nil { + t.Fatalf("unable to run test server: %s", err) + } + defer s.Shutdown() + + axfrTestingSuiteTLS(t, addrstr) +} + func TestMultiEnvelopeXfr(t *testing.T) { HandleFunc("miek.nl.", MultipleEnvelopeXfrServer) defer HandleRemove("miek.nl.") @@ -131,6 +153,38 @@ func axfrTestingSuite(t *testing.T, addrstr string) { } } +func axfrTestingSuiteTLS(t *testing.T, addrstr string) { + tr := new(Transfer) + m := new(Msg) + m.SetAxfr("miek.nl.") + + tr.TLS = &tls.Config{ + InsecureSkipVerify: true, + } + c, err := tr.In(m, addrstr) + if err != nil { + t.Fatal("failed to zone transfer in", err) + } + + var records []RR + for msg := range c { + if msg.Error != nil { + t.Fatal(msg.Error) + } + records = append(records, msg.RR...) + } + + if len(records) != len(xfrTestData) { + t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) + } + + for i, rr := range records { + if !IsDuplicate(rr, xfrTestData[i]) { + t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) + } + } +} + func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) { tr := new(Transfer) m := new(Msg)