Skip to content
This repository was archived by the owner on Mar 6, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions internal/dialerapi/dialerapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ func (d *Dialer) DialContext(
// DialTLS is like Dial, but creates TLS connections.
func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
ctx := context.Background()
return d.DialTLSContext(ctx, network, address)
}

// DialTLSContext is like DialTLS, but with context
func (d *Dialer) DialTLSContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
conn, onlyhost, _, err := d.DialContextEx(ctx, network, address, false)
if err != nil {
return nil, err
Expand Down
8 changes: 5 additions & 3 deletions internal/dnsconf/dnsconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,21 @@ func NewResolver(
// We need a child dialer here to avoid an endless loop where the
// dialer will ask us to resolve, we'll tell the dialer to dial, it
// will ask us to resolve, ...
dialerapi.NewDialer(dialer.Beginning, dialer.Handler).DialTLS,
dnsovertcp.NewTLSDialerAdapter(
dialerapi.NewDialer(dialer.Beginning, dialer.Handler),
),
withPort(address, "853"),
)
} else if network == "tcp" {
transport = dnsovertcp.NewTransport(
// Same rationale as above: avoid possible endless loop
dialerapi.NewDialer(dialer.Beginning, dialer.Handler).Dial,
dialerapi.NewDialer(dialer.Beginning, dialer.Handler),
withPort(address, "53"),
)
} else if network == "udp" {
transport = dnsoverudp.NewTransport(
// Same rationale as above: avoid possible endless loop
dialerapi.NewDialer(dialer.Beginning, dialer.Handler).Dial,
dialerapi.NewDialer(dialer.Beginning, dialer.Handler),
withPort(address, "53"),
)
}
Expand Down
34 changes: 27 additions & 7 deletions internal/dnstransport/dnsovertcp/dnsovertcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,29 @@ import (
"time"

"github.com/m-lab/go/rtx"
"github.com/ooni/netx/model"
)

// Transport is a DNS over TCP/TLS model.DNSRoundTripper.
//
// As a known bug, this implementation always creates a new connection
// for each incoming query, thus increasing the response delay.
type Transport struct {
dial func(network, address string) (net.Conn, error)
dialer model.Dialer
address string
}

// NewTransport creates a new Transport
func NewTransport(
dial func(network, address string) (net.Conn, error),
address string,
) *Transport {
func NewTransport(dialer model.Dialer, address string) *Transport {
return &Transport{
dial: dial,
dialer: dialer,
address: address,
}
}

// RoundTrip sends a request and receives a response.
func (t *Transport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
conn, err := t.dial("tcp", t.address)
conn, err := t.dialer.DialContext(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -70,3 +68,25 @@ func (t *Transport) doWithConn(conn net.Conn, query []byte) (reply []byte, err e
rtx.PanicOnError(err, "io.ReadFull failed")
return reply, nil
}

// TLSDialerAdapter makes a TLSDialer look like a Dialer
type TLSDialerAdapter struct {
dialer model.TLSDialer
}

// NewTLSDialerAdapter creates a new TLSDialerAdapter
func NewTLSDialerAdapter(dialer model.TLSDialer) *TLSDialerAdapter {
return &TLSDialerAdapter{dialer: dialer}
}

// Dial dials a new connection
func (d *TLSDialerAdapter) Dial(network, address string) (net.Conn, error) {
return d.dialer.DialTLS(network, address)
}

// DialContext is like Dial but with context
func (d *TLSDialerAdapter) DialContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
return d.dialer.DialTLSContext(ctx, network, address)
}
78 changes: 63 additions & 15 deletions internal/dnstransport/dnsovertcp/dnsovertcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,50 @@ import (
"github.com/miekg/dns"
)

func dialTLS(config *tls.Config) func(network, address string) (net.Conn, error) {
return func(network, address string) (net.Conn, error) {
return tls.Dial(network, address, config)
}
type tlsdialer struct {
config *tls.Config
}

func (d *tlsdialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}

func dialTCP(network, address string) (net.Conn, error) {
return net.Dial(network, address)
func (d *tlsdialer) DialContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
return tls.Dial(network, address, d.config)
}

func TestIntegrationSuccessTLS(t *testing.T) {
// "Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults."
transport := NewTransport(dialTLS(nil), "dns.quad9.net:853")
transport := NewTransport(&tlsdialer{}, "dns.quad9.net:853")
if err := threeRounds(transport); err != nil {
t.Fatal(err)
}
}

func TestIntegrationSuccessTCP(t *testing.T) {
transport := NewTransport(dialTCP, "9.9.9.9:53")
transport := NewTransport(&net.Dialer{}, "9.9.9.9:53")
if err := threeRounds(transport); err != nil {
t.Fatal(err)
}
}

func TestIntegrationLookupHostError(t *testing.T) {
transport := NewTransport(dialTCP, "antani.local")
transport := NewTransport(&net.Dialer{}, "antani.local")
if err := roundTrip(transport, "ooni.io."); err == nil {
t.Fatal("expected an error here")
}
}

func TestIntegrationCustomTLSConfig(t *testing.T) {
transport := NewTransport(dialTLS(&tls.Config{
MinVersion: tls.VersionTLS10,
}), "dns.quad9.net:853")
transport := NewTransport(&tlsdialer{
config: &tls.Config{
MinVersion: tls.VersionTLS10,
},
}, "dns.quad9.net:853")
if err := roundTrip(transport, "ooni.io."); err != nil {
t.Fatal(err)
}
Expand All @@ -57,9 +63,7 @@ func TestIntegrationCustomTLSConfig(t *testing.T) {
func TestUnitRoundTripWithConnFailure(t *testing.T) {
// fakeconn will fail in the SetDeadline, therefore we will have
// an immediate error and we expect all errors the be alike
transport := NewTransport(func(network, address string) (net.Conn, error) {
return &fakeconn{}, nil
}, "8.8.8.8:53")
transport := NewTransport(&fakeconnDialer{}, "8.8.8.8:53")
query := make([]byte, 1<<10)
reply, err := transport.doWithConn(&fakeconn{}, query)
if err == nil {
Expand Down Expand Up @@ -100,6 +104,20 @@ func roundTrip(transport *Transport, domain string) error {
return query.Unpack(data)
}

type fakeconnDialer struct {
fakeconn fakeconn
}

func (d *fakeconnDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}

func (d *fakeconnDialer) DialContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
return &d.fakeconn, nil
}

type fakeconn struct{}

func (fakeconn) Read(b []byte) (n int, err error) {
Expand Down Expand Up @@ -128,3 +146,33 @@ func (fakeconn) SetReadDeadline(t time.Time) (err error) {
func (fakeconn) SetWriteDeadline(t time.Time) (err error) {
return
}

func TestTLSDialerAdapter(t *testing.T) {
fake := &fakeTLSDialer{}
adapter := NewTLSDialerAdapter(fake)
adapter.Dial("tcp", "www.google.com:443")
if !fake.calledDialTLS {
t.Fatal("redirection to DialTLS not working")
}
adapter.DialContext(context.Background(), "tcp", "www.google.com:443")
if !fake.calledDialTLSContext {
t.Fatal("redirection to DialTLSContext not working")
}
}

type fakeTLSDialer struct {
calledDialTLS bool
calledDialTLSContext bool
}

func (d *fakeTLSDialer) DialTLS(network, address string) (net.Conn, error) {
d.calledDialTLS = true
return nil, errors.New("mocked error")
}

func (d *fakeTLSDialer) DialTLSContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
d.calledDialTLSContext = true
return nil, errors.New("mocked error")
}
14 changes: 6 additions & 8 deletions internal/dnstransport/dnsoverudp/dnsoverudp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,28 @@ package dnsoverudp

import (
"context"
"net"
"time"

"github.com/ooni/netx/model"
)

// Transport is a DNS over UDP model.DNSRoundTripper.
type Transport struct {
dial func(network, address string) (net.Conn, error)
dialer model.Dialer
address string
}

// NewTransport creates a new Transport
func NewTransport(
dial func(network, address string) (net.Conn, error),
address string,
) *Transport {
func NewTransport(dialer model.Dialer, address string) *Transport {
return &Transport{
dial: dial,
dialer: dialer,
address: address,
}
}

// RoundTrip sends a request and receives a response.
func (t *Transport) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) {
conn, err := t.dial("udp", t.address)
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil {
return
}
Expand Down
61 changes: 37 additions & 24 deletions internal/dnstransport/dnsoverudp/dnsoverudp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ import (
"time"

"github.com/miekg/dns"
"github.com/ooni/netx/handlers"
"github.com/ooni/netx/internal/connx"
)

func TestIntegrationSuccessWithAddress(t *testing.T) {
transport := NewTransport(
net.Dial, "9.9.9.9:53",
&net.Dialer{}, "9.9.9.9:53",
)
err := threeRounds(transport)
if err != nil {
Expand All @@ -24,7 +22,7 @@ func TestIntegrationSuccessWithAddress(t *testing.T) {

func TestIntegrationSuccessWithDomain(t *testing.T) {
transport := NewTransport(
net.Dial, "dns.quad9.net:53",
&net.Dialer{}, "dns.quad9.net:53",
)
err := threeRounds(transport)
if err != nil {
Expand All @@ -34,11 +32,8 @@ func TestIntegrationSuccessWithDomain(t *testing.T) {

func TestIntegrationDialFailure(t *testing.T) {
transport := NewTransport(
net.Dial, "9.9.9.9:53",
&failingDialer{}, "9.9.9.9:53",
)
transport.dial = func(network, address string) (net.Conn, error) {
return nil, errors.New("mocked error")
}
err := threeRounds(transport)
if err == nil {
t.Fatal("expected an error here")
Expand All @@ -47,16 +42,12 @@ func TestIntegrationDialFailure(t *testing.T) {

func TestIntegrationSetDeadlineError(t *testing.T) {
transport := NewTransport(
net.Dial, "9.9.9.9:53",
)
transport.dial = func(network, address string) (net.Conn, error) {
return &connx.MeasuringConn{
Conn: fakeconn{
&fakeconnDialer{
fakeconn: fakeconn{
setDeadlineError: errors.New("mocked error"),
},
Handler: handlers.NoHandler,
}, nil
}
}, "9.9.9.9:53",
)
err := threeRounds(transport)
if err == nil {
t.Fatal("expected an error here")
Expand All @@ -65,16 +56,12 @@ func TestIntegrationSetDeadlineError(t *testing.T) {

func TestIntegrationWriteError(t *testing.T) {
transport := NewTransport(
net.Dial, "9.9.9.9:53",
)
transport.dial = func(network, address string) (net.Conn, error) {
return &connx.MeasuringConn{
Conn: fakeconn{
&fakeconnDialer{
fakeconn: fakeconn{
writeError: errors.New("mocked error"),
},
Handler: handlers.NoHandler,
}, nil
}
}, "9.9.9.9:53",
)
err := threeRounds(transport)
if err == nil {
t.Fatal("expected an error here")
Expand Down Expand Up @@ -111,6 +98,32 @@ func roundTrip(transport *Transport, domain string) error {
return query.Unpack(data)
}

type failingDialer struct{}

func (d *failingDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}

func (d *failingDialer) DialContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
return nil, errors.New("mocked error")
}

type fakeconnDialer struct {
fakeconn fakeconn
}

func (d *fakeconnDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}

func (d *fakeconnDialer) DialContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
return &d.fakeconn, nil
}

type fakeconn struct {
setDeadlineError error
writeError error
Expand Down
Loading