From a42b9d840a9acb13f1921a6379d5fda8d536bdee Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 24 Oct 2019 14:27:53 +0200 Subject: [PATCH] Let DNS transports take interfaces in input --- internal/dialerapi/dialerapi.go | 7 ++ internal/dnsconf/dnsconf.go | 8 +- .../dnstransport/dnsovertcp/dnsovertcp.go | 34 ++++++-- .../dnsovertcp/dnsovertcp_test.go | 78 +++++++++++++++---- .../dnstransport/dnsoverudp/dnsoverudp.go | 14 ++-- .../dnsoverudp/dnsoverudp_test.go | 61 +++++++++------ internal/oodns/oodns_test.go | 8 +- model/model.go | 18 +++++ 8 files changed, 164 insertions(+), 64 deletions(-) diff --git a/internal/dialerapi/dialerapi.go b/internal/dialerapi/dialerapi.go index e25a08f..87585bd 100644 --- a/internal/dialerapi/dialerapi.go +++ b/internal/dialerapi/dialerapi.go @@ -89,6 +89,13 @@ func (d *Dialer) DialContext( // DialTLS is like Dial, but creates TLS connections. func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { ctx := context.Background() + return d.DialTLSContext(ctx, network, address) +} + +// DialTLSContext is like DialTLS, but with context +func (d *Dialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { conn, onlyhost, _, err := d.DialContextEx(ctx, network, address, false) if err != nil { return nil, err diff --git a/internal/dnsconf/dnsconf.go b/internal/dnsconf/dnsconf.go index c4d7a0f..078321a 100644 --- a/internal/dnsconf/dnsconf.go +++ b/internal/dnsconf/dnsconf.go @@ -74,19 +74,21 @@ func NewResolver( // We need a child dialer here to avoid an endless loop where the // dialer will ask us to resolve, we'll tell the dialer to dial, it // will ask us to resolve, ... - dialerapi.NewDialer(dialer.Beginning, dialer.Handler).DialTLS, + dnsovertcp.NewTLSDialerAdapter( + dialerapi.NewDialer(dialer.Beginning, dialer.Handler), + ), withPort(address, "853"), ) } else if network == "tcp" { transport = dnsovertcp.NewTransport( // Same rationale as above: avoid possible endless loop - dialerapi.NewDialer(dialer.Beginning, dialer.Handler).Dial, + dialerapi.NewDialer(dialer.Beginning, dialer.Handler), withPort(address, "53"), ) } else if network == "udp" { transport = dnsoverudp.NewTransport( // Same rationale as above: avoid possible endless loop - dialerapi.NewDialer(dialer.Beginning, dialer.Handler).Dial, + dialerapi.NewDialer(dialer.Beginning, dialer.Handler), withPort(address, "53"), ) } diff --git a/internal/dnstransport/dnsovertcp/dnsovertcp.go b/internal/dnstransport/dnsovertcp/dnsovertcp.go index d05f142..a7d588d 100644 --- a/internal/dnstransport/dnsovertcp/dnsovertcp.go +++ b/internal/dnstransport/dnsovertcp/dnsovertcp.go @@ -10,6 +10,7 @@ import ( "time" "github.com/m-lab/go/rtx" + "github.com/ooni/netx/model" ) // Transport is a DNS over TCP/TLS model.DNSRoundTripper. @@ -17,24 +18,21 @@ import ( // As a known bug, this implementation always creates a new connection // for each incoming query, thus increasing the response delay. type Transport struct { - dial func(network, address string) (net.Conn, error) + dialer model.Dialer address string } // NewTransport creates a new Transport -func NewTransport( - dial func(network, address string) (net.Conn, error), - address string, -) *Transport { +func NewTransport(dialer model.Dialer, address string) *Transport { return &Transport{ - dial: dial, + dialer: dialer, address: address, } } // RoundTrip sends a request and receives a response. func (t *Transport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { - conn, err := t.dial("tcp", t.address) + conn, err := t.dialer.DialContext(ctx, "tcp", t.address) if err != nil { return nil, err } @@ -70,3 +68,25 @@ func (t *Transport) doWithConn(conn net.Conn, query []byte) (reply []byte, err e rtx.PanicOnError(err, "io.ReadFull failed") return reply, nil } + +// TLSDialerAdapter makes a TLSDialer look like a Dialer +type TLSDialerAdapter struct { + dialer model.TLSDialer +} + +// NewTLSDialerAdapter creates a new TLSDialerAdapter +func NewTLSDialerAdapter(dialer model.TLSDialer) *TLSDialerAdapter { + return &TLSDialerAdapter{dialer: dialer} +} + +// Dial dials a new connection +func (d *TLSDialerAdapter) Dial(network, address string) (net.Conn, error) { + return d.dialer.DialTLS(network, address) +} + +// DialContext is like Dial but with context +func (d *TLSDialerAdapter) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return d.dialer.DialTLSContext(ctx, network, address) +} diff --git a/internal/dnstransport/dnsovertcp/dnsovertcp_test.go b/internal/dnstransport/dnsovertcp/dnsovertcp_test.go index 8c4e7fd..1050a26 100644 --- a/internal/dnstransport/dnsovertcp/dnsovertcp_test.go +++ b/internal/dnstransport/dnsovertcp/dnsovertcp_test.go @@ -11,44 +11,50 @@ import ( "github.com/miekg/dns" ) -func dialTLS(config *tls.Config) func(network, address string) (net.Conn, error) { - return func(network, address string) (net.Conn, error) { - return tls.Dial(network, address, config) - } +type tlsdialer struct { + config *tls.Config +} + +func (d *tlsdialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) } -func dialTCP(network, address string) (net.Conn, error) { - return net.Dial(network, address) +func (d *tlsdialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return tls.Dial(network, address, d.config) } func TestIntegrationSuccessTLS(t *testing.T) { // "Dial interprets a nil configuration as equivalent to // the zero configuration; see the documentation of Config // for the defaults." - transport := NewTransport(dialTLS(nil), "dns.quad9.net:853") + transport := NewTransport(&tlsdialer{}, "dns.quad9.net:853") if err := threeRounds(transport); err != nil { t.Fatal(err) } } func TestIntegrationSuccessTCP(t *testing.T) { - transport := NewTransport(dialTCP, "9.9.9.9:53") + transport := NewTransport(&net.Dialer{}, "9.9.9.9:53") if err := threeRounds(transport); err != nil { t.Fatal(err) } } func TestIntegrationLookupHostError(t *testing.T) { - transport := NewTransport(dialTCP, "antani.local") + transport := NewTransport(&net.Dialer{}, "antani.local") if err := roundTrip(transport, "ooni.io."); err == nil { t.Fatal("expected an error here") } } func TestIntegrationCustomTLSConfig(t *testing.T) { - transport := NewTransport(dialTLS(&tls.Config{ - MinVersion: tls.VersionTLS10, - }), "dns.quad9.net:853") + transport := NewTransport(&tlsdialer{ + config: &tls.Config{ + MinVersion: tls.VersionTLS10, + }, + }, "dns.quad9.net:853") if err := roundTrip(transport, "ooni.io."); err != nil { t.Fatal(err) } @@ -57,9 +63,7 @@ func TestIntegrationCustomTLSConfig(t *testing.T) { func TestUnitRoundTripWithConnFailure(t *testing.T) { // fakeconn will fail in the SetDeadline, therefore we will have // an immediate error and we expect all errors the be alike - transport := NewTransport(func(network, address string) (net.Conn, error) { - return &fakeconn{}, nil - }, "8.8.8.8:53") + transport := NewTransport(&fakeconnDialer{}, "8.8.8.8:53") query := make([]byte, 1<<10) reply, err := transport.doWithConn(&fakeconn{}, query) if err == nil { @@ -100,6 +104,20 @@ func roundTrip(transport *Transport, domain string) error { return query.Unpack(data) } +type fakeconnDialer struct { + fakeconn fakeconn +} + +func (d *fakeconnDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *fakeconnDialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return &d.fakeconn, nil +} + type fakeconn struct{} func (fakeconn) Read(b []byte) (n int, err error) { @@ -128,3 +146,33 @@ func (fakeconn) SetReadDeadline(t time.Time) (err error) { func (fakeconn) SetWriteDeadline(t time.Time) (err error) { return } + +func TestTLSDialerAdapter(t *testing.T) { + fake := &fakeTLSDialer{} + adapter := NewTLSDialerAdapter(fake) + adapter.Dial("tcp", "www.google.com:443") + if !fake.calledDialTLS { + t.Fatal("redirection to DialTLS not working") + } + adapter.DialContext(context.Background(), "tcp", "www.google.com:443") + if !fake.calledDialTLSContext { + t.Fatal("redirection to DialTLSContext not working") + } +} + +type fakeTLSDialer struct { + calledDialTLS bool + calledDialTLSContext bool +} + +func (d *fakeTLSDialer) DialTLS(network, address string) (net.Conn, error) { + d.calledDialTLS = true + return nil, errors.New("mocked error") +} + +func (d *fakeTLSDialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + d.calledDialTLSContext = true + return nil, errors.New("mocked error") +} diff --git a/internal/dnstransport/dnsoverudp/dnsoverudp.go b/internal/dnstransport/dnsoverudp/dnsoverudp.go index 014b2f4..af1a1b1 100644 --- a/internal/dnstransport/dnsoverudp/dnsoverudp.go +++ b/internal/dnstransport/dnsoverudp/dnsoverudp.go @@ -3,30 +3,28 @@ package dnsoverudp import ( "context" - "net" "time" + + "github.com/ooni/netx/model" ) // Transport is a DNS over UDP model.DNSRoundTripper. type Transport struct { - dial func(network, address string) (net.Conn, error) + dialer model.Dialer address string } // NewTransport creates a new Transport -func NewTransport( - dial func(network, address string) (net.Conn, error), - address string, -) *Transport { +func NewTransport(dialer model.Dialer, address string) *Transport { return &Transport{ - dial: dial, + dialer: dialer, address: address, } } // RoundTrip sends a request and receives a response. func (t *Transport) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) { - conn, err := t.dial("udp", t.address) + conn, err := t.dialer.DialContext(ctx, "udp", t.address) if err != nil { return } diff --git a/internal/dnstransport/dnsoverudp/dnsoverudp_test.go b/internal/dnstransport/dnsoverudp/dnsoverudp_test.go index fed1e6c..c763091 100644 --- a/internal/dnstransport/dnsoverudp/dnsoverudp_test.go +++ b/internal/dnstransport/dnsoverudp/dnsoverudp_test.go @@ -8,13 +8,11 @@ import ( "time" "github.com/miekg/dns" - "github.com/ooni/netx/handlers" - "github.com/ooni/netx/internal/connx" ) func TestIntegrationSuccessWithAddress(t *testing.T) { transport := NewTransport( - net.Dial, "9.9.9.9:53", + &net.Dialer{}, "9.9.9.9:53", ) err := threeRounds(transport) if err != nil { @@ -24,7 +22,7 @@ func TestIntegrationSuccessWithAddress(t *testing.T) { func TestIntegrationSuccessWithDomain(t *testing.T) { transport := NewTransport( - net.Dial, "dns.quad9.net:53", + &net.Dialer{}, "dns.quad9.net:53", ) err := threeRounds(transport) if err != nil { @@ -34,11 +32,8 @@ func TestIntegrationSuccessWithDomain(t *testing.T) { func TestIntegrationDialFailure(t *testing.T) { transport := NewTransport( - net.Dial, "9.9.9.9:53", + &failingDialer{}, "9.9.9.9:53", ) - transport.dial = func(network, address string) (net.Conn, error) { - return nil, errors.New("mocked error") - } err := threeRounds(transport) if err == nil { t.Fatal("expected an error here") @@ -47,16 +42,12 @@ func TestIntegrationDialFailure(t *testing.T) { func TestIntegrationSetDeadlineError(t *testing.T) { transport := NewTransport( - net.Dial, "9.9.9.9:53", - ) - transport.dial = func(network, address string) (net.Conn, error) { - return &connx.MeasuringConn{ - Conn: fakeconn{ + &fakeconnDialer{ + fakeconn: fakeconn{ setDeadlineError: errors.New("mocked error"), }, - Handler: handlers.NoHandler, - }, nil - } + }, "9.9.9.9:53", + ) err := threeRounds(transport) if err == nil { t.Fatal("expected an error here") @@ -65,16 +56,12 @@ func TestIntegrationSetDeadlineError(t *testing.T) { func TestIntegrationWriteError(t *testing.T) { transport := NewTransport( - net.Dial, "9.9.9.9:53", - ) - transport.dial = func(network, address string) (net.Conn, error) { - return &connx.MeasuringConn{ - Conn: fakeconn{ + &fakeconnDialer{ + fakeconn: fakeconn{ writeError: errors.New("mocked error"), }, - Handler: handlers.NoHandler, - }, nil - } + }, "9.9.9.9:53", + ) err := threeRounds(transport) if err == nil { t.Fatal("expected an error here") @@ -111,6 +98,32 @@ func roundTrip(transport *Transport, domain string) error { return query.Unpack(data) } +type failingDialer struct{} + +func (d *failingDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *failingDialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return nil, errors.New("mocked error") +} + +type fakeconnDialer struct { + fakeconn fakeconn +} + +func (d *fakeconnDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *fakeconnDialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return &d.fakeconn, nil +} + type fakeconn struct { setDeadlineError error writeError error diff --git a/internal/oodns/oodns_test.go b/internal/oodns/oodns_test.go index 8b02b87..08176d2 100644 --- a/internal/oodns/oodns_test.go +++ b/internal/oodns/oodns_test.go @@ -2,7 +2,6 @@ package oodns import ( "context" - "crypto/tls" "errors" "net" "testing" @@ -15,12 +14,7 @@ import ( ) func newtransport() model.DNSRoundTripper { - return dnsovertcp.NewTransport( - func(network, address string) (net.Conn, error) { - return tls.Dial(network, address, nil) - }, - "dns.quad9.net:853", - ) + return dnsovertcp.NewTransport(&net.Dialer{}, "dns.quad9.net:53") } func TestLookupAddr(t *testing.T) { diff --git a/model/model.go b/model/model.go index e41302c..dda60a0 100644 --- a/model/model.go +++ b/model/model.go @@ -243,3 +243,21 @@ type DNSRoundTripper interface { // RoundTrip sends a DNS query and receives the reply. RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) } + +// Dialer is a dialer for network connections. +type Dialer interface { + // Dial dials a new connection + Dial(network, address string) (net.Conn, error) + + // DialContext is like Dial but with context + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// TLSDialer is a dialer for TLS connections. +type TLSDialer interface { + // DialTLS dials a new TLS connection + DialTLS(network, address string) (net.Conn, error) + + // DialTLSContext is like DialTLS but with context + DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) +}