diff --git a/internal/bytecounter/conn.go b/internal/bytecounter/conn.go index 954009ddb3..3f37dffdbd 100644 --- a/internal/bytecounter/conn.go +++ b/internal/bytecounter/conn.go @@ -1,5 +1,9 @@ package bytecounter +// +// Code to wrap a net.Conn +// + import "net" // Conn wraps a network connection and counts bytes. diff --git a/internal/bytecounter/context.go b/internal/bytecounter/context.go index 4521e97743..64b9d2a824 100644 --- a/internal/bytecounter/context.go +++ b/internal/bytecounter/context.go @@ -1,5 +1,9 @@ package bytecounter +// +// Implicit byte counting based on context +// + import ( "context" "net" diff --git a/internal/bytecounter/bytecounter.go b/internal/bytecounter/counter.go similarity index 94% rename from internal/bytecounter/bytecounter.go rename to internal/bytecounter/counter.go index fbc8f0d9a3..6518bbbafc 100644 --- a/internal/bytecounter/bytecounter.go +++ b/internal/bytecounter/counter.go @@ -1,7 +1,9 @@ -// Package bytecounter contains code to track the number of -// bytes sent and received by a probe. package bytecounter +// +// Implementation of Counter +// + import "github.com/ooni/probe-cli/v3/internal/atomicx" // Counter counts bytes sent and received. diff --git a/internal/bytecounter/bytecounter_test.go b/internal/bytecounter/counter_test.go similarity index 100% rename from internal/bytecounter/bytecounter_test.go rename to internal/bytecounter/counter_test.go diff --git a/internal/bytecounter/doc.go b/internal/bytecounter/doc.go new file mode 100644 index 0000000000..df78813d36 --- /dev/null +++ b/internal/bytecounter/doc.go @@ -0,0 +1,3 @@ +// Package bytecounter contains code to track the number of +// bytes sent and received by a probe. +package bytecounter diff --git a/internal/bytecounter/http.go b/internal/bytecounter/http.go new file mode 100644 index 0000000000..0032ae019a --- /dev/null +++ b/internal/bytecounter/http.go @@ -0,0 +1,101 @@ +package bytecounter + +import ( + "io" + "net/http" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// HTTPTransport is a model.HTTPTransport that counts bytes. +type HTTPTransport struct { + HTTPTransport model.HTTPTransport + Counter *Counter +} + +// NewHTTPTransport creates a new byte-counting-aware HTTP transport. +func NewHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport { + return &HTTPTransport{ + HTTPTransport: txp, + Counter: counter, + } +} + +var _ model.HTTPTransport = &HTTPTransport{} + +// CloseIdleConnections implements model.HTTPTransport.CloseIdleConnections. +func (txp *HTTPTransport) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() +} + +// RoundTrip implements model.HTTPTRansport.RoundTrip +func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Body != nil { + req.Body = &httpBodyWrapper{ + account: txp.Counter.CountBytesSent, + rc: req.Body, + } + } + txp.estimateRequestMetadata(req) + resp, err := txp.HTTPTransport.RoundTrip(req) + if err != nil { + return nil, err + } + txp.estimateResponseMetadata(resp) + resp.Body = &httpBodyWrapper{ + account: txp.Counter.CountBytesReceived, + rc: resp.Body, + } + return resp, nil +} + +// Network implements model.HTTPTransport.Network. +func (txp *HTTPTransport) Network() string { + return txp.HTTPTransport.Network() +} + +func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) { + txp.Counter.CountBytesSent(len(req.Method)) + txp.Counter.CountBytesSent(len(req.URL.String())) + for key, values := range req.Header { + for _, value := range values { + txp.Counter.CountBytesSent(len(key)) + txp.Counter.CountBytesSent(len(": ")) + txp.Counter.CountBytesSent(len(value)) + txp.Counter.CountBytesSent(len("\r\n")) + } + } + txp.Counter.CountBytesSent(len("\r\n")) +} + +func (txp *HTTPTransport) estimateResponseMetadata(resp *http.Response) { + txp.Counter.CountBytesReceived(len(resp.Status)) + for key, values := range resp.Header { + for _, value := range values { + txp.Counter.CountBytesReceived(len(key)) + txp.Counter.CountBytesReceived(len(": ")) + txp.Counter.CountBytesReceived(len(value)) + txp.Counter.CountBytesReceived(len("\r\n")) + } + } + txp.Counter.CountBytesReceived(len("\r\n")) +} + +type httpBodyWrapper struct { + account func(int) + rc io.ReadCloser +} + +var _ io.ReadCloser = &httpBodyWrapper{} + +func (r *httpBodyWrapper) Read(p []byte) (int, error) { + count, err := r.rc.Read(p) + if count > 0 { + r.account(count) + } + return count, err +} + +func (r *httpBodyWrapper) Close() error { + return r.rc.Close() +} diff --git a/internal/bytecounter/http_test.go b/internal/bytecounter/http_test.go new file mode 100644 index 0000000000..d83eaa04b8 --- /dev/null +++ b/internal/bytecounter/http_test.go @@ -0,0 +1,162 @@ +package bytecounter + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +func TestHTTPTransport(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("failure", func(t *testing.T) { + counter := New() + txp := &HTTPTransport{ + Counter: counter, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + } + req, err := http.NewRequest( + "POST", "https://www.google.com", strings.NewReader("AAAAAA")) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "antani-browser/1.0.0") + resp, err := txp.RoundTrip(req) + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response here") + } + if counter.Sent.Load() != 62 { + t.Fatal("expected 62 bytes sent", counter.Sent.Load()) + } + if counter.Received.Load() != 0 { + t.Fatal("expected zero bytes received", counter.Received.Load()) + } + }) + + t.Run("success", func(t *testing.T) { + counter := New() + txp := &HTTPTransport{ + Counter: counter, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader("1234567")), + Header: http.Header{ + "Server": []string{"antani/0.1.0"}, + }, + Status: "200 OK", + StatusCode: http.StatusOK, + } + return resp, nil + }, + }, + } + req, err := http.NewRequest( + "POST", "https://www.google.com", strings.NewReader("AAAAAA")) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "antani-browser/1.0.0") + resp, err := txp.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + data, err := netxlite.ReadAllContext(context.Background(), resp.Body) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if string(data) != "1234567" { + t.Fatal("expected a different body here") + } + if counter.Sent.Load() != 62 { + t.Fatal("expected 62 bytes sent", counter.Sent.Load()) + } + if counter.Received.Load() != 37 { + t.Fatal("expected 37 bytes received", counter.Received.Load()) + } + }) + + t.Run("success with EOF", func(t *testing.T) { + counter := New() + txp := &HTTPTransport{ + Counter: counter, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + if len(b) < 1 { + panic("should not happen") + } + b[0] = 'A' + return 1, io.EOF // we want code to be robust to this + }, + }), + Header: http.Header{ + "Server": []string{"antani/0.1.0"}, + }, + Status: "200 OK", + StatusCode: http.StatusOK, + } + return resp, nil + }, + }, + } + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + data, err := netxlite.ReadAllContext(context.Background(), resp.Body) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if string(data) != "A" { + t.Fatal("expected a different body here") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + called = true + }, + } + counter := New() + txp := NewHTTPTransport(child, counter) + txp.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("Network", func(t *testing.T) { + expected := "antani" + child := &mocks.HTTPTransport{ + MockNetwork: func() string { + return expected + }, + } + counter := New() + txp := NewHTTPTransport(child, counter) + if network := txp.Network(); network != expected { + t.Fatal("unexpected network", network) + } + }) +} diff --git a/internal/engine/netx/dialer/proxy.go b/internal/engine/netx/dialer/proxy.go index 384a9030db..54d29424f8 100644 --- a/internal/engine/netx/dialer/proxy.go +++ b/internal/engine/netx/dialer/proxy.go @@ -1,56 +1,5 @@ package dialer -import ( - "context" - "errors" - "net" - "net/url" +import "github.com/ooni/probe-cli/v3/internal/netxlite" - "github.com/ooni/probe-cli/v3/internal/model" - "golang.org/x/net/proxy" -) - -// proxyDialer is a dialer that uses a proxy. If the ProxyURL is not configured, this -// dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally -// create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer. -type proxyDialer struct { - model.Dialer - ProxyURL *url.URL -} - -// ErrProxyUnsupportedScheme indicates we don't support a protocol scheme. -var ErrProxyUnsupportedScheme = errors.New("proxy: unsupported scheme") - -// DialContext implements Dialer.DialContext -func (d *proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - url := d.ProxyURL - if url == nil { - return d.Dialer.DialContext(ctx, network, address) - } - if url.Scheme != "socks5" { - return nil, ErrProxyUnsupportedScheme - } - // the code at proxy/socks5.go never fails; see https://git.io/JfJ4g - child, _ := proxy.SOCKS5( - network, url.Host, nil, &proxyDialerWrapper{d.Dialer}) - return d.dial(ctx, child, network, address) -} - -func (d *proxyDialer) dial( - ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) { - cd := child.(proxy.ContextDialer) // will work - return cd.DialContext(ctx, network, address) -} - -// proxyDialerWrapper is required because SOCKS5 expects a Dialer.Dial type but internally -// it checks whether DialContext is available and prefers that. So, we need to use this -// structure to cast our inner Dialer the way in which SOCKS5 likes it. -// -// See https://git.io/JfJ4g. -type proxyDialerWrapper struct { - model.Dialer -} - -func (d *proxyDialerWrapper) Dial(network, address string) (net.Conn, error) { - panic(errors.New("proxyDialerWrapper.Dial should not be called directly")) -} +type proxyDialer = netxlite.MaybeProxyDialer diff --git a/internal/engine/netx/dialer/proxy_test.go b/internal/engine/netx/dialer/proxy_test.go deleted file mode 100644 index d4c82e517a..0000000000 --- a/internal/engine/netx/dialer/proxy_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package dialer - -import ( - "context" - "errors" - "io" - "net" - "net/url" - "testing" - - "github.com/ooni/probe-cli/v3/internal/model/mocks" -) - -func TestProxyDialerDialContextNoProxyURL(t *testing.T) { - expected := errors.New("mocked error") - d := &proxyDialer{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }, - }, - } - conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, expected) { - t.Fatal(err) - } - if conn != nil { - t.Fatal("conn is not nil") - } -} - -func TestProxyDialerDialContextInvalidScheme(t *testing.T) { - d := &proxyDialer{ - ProxyURL: &url.URL{Scheme: "antani"}, - } - conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, ErrProxyUnsupportedScheme) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("conn is not nil") - } -} - -func TestProxyDialerDialContextWithEOF(t *testing.T) { - const expect = "10.0.0.1:9050" - d := &proxyDialer{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - if address != expect { - return nil, errors.New("unexpected address") - } - return nil, io.EOF - }, - }, - ProxyURL: &url.URL{Scheme: "socks5", Host: expect}, - } - conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("conn is not nil") - } -} - -func TestProxyDialWrapperPanics(t *testing.T) { - d := &proxyDialerWrapper{} - err := func() (rv error) { - defer func() { - if r := recover(); r != nil { - rv = r.(error) - } - }() - d.Dial("tcp", "10.0.0.1:1234") - return - }() - if err.Error() != "proxyDialerWrapper.Dial should not be called directly" { - t.Fatal("unexpected result", err) - } -} diff --git a/internal/engine/netx/httptransport/bytecounter.go b/internal/engine/netx/httptransport/bytecounter.go index d950d2854d..99b1f5e9f2 100644 --- a/internal/engine/netx/httptransport/bytecounter.go +++ b/internal/engine/netx/httptransport/bytecounter.go @@ -1,74 +1,5 @@ package httptransport -import ( - "io" - "net/http" +import "github.com/ooni/probe-cli/v3/internal/bytecounter" - "github.com/ooni/probe-cli/v3/internal/bytecounter" - "github.com/ooni/probe-cli/v3/internal/model" -) - -// ByteCountingTransport is a RoundTripper that counts bytes. -type ByteCountingTransport struct { - model.HTTPTransport - Counter *bytecounter.Counter -} - -// RoundTrip implements RoundTripper.RoundTrip -func (txp ByteCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if req.Body != nil { - req.Body = byteCountingBody{ - ReadCloser: req.Body, Account: txp.Counter.CountBytesSent} - } - txp.estimateRequestMetadata(req) - resp, err := txp.HTTPTransport.RoundTrip(req) - if err != nil { - return nil, err - } - txp.estimateResponseMetadata(resp) - resp.Body = byteCountingBody{ - ReadCloser: resp.Body, Account: txp.Counter.CountBytesReceived} - return resp, nil -} - -func (txp ByteCountingTransport) estimateRequestMetadata(req *http.Request) { - txp.Counter.CountBytesSent(len(req.Method)) - txp.Counter.CountBytesSent(len(req.URL.String())) - for key, values := range req.Header { - for _, value := range values { - txp.Counter.CountBytesSent(len(key)) - txp.Counter.CountBytesSent(len(": ")) - txp.Counter.CountBytesSent(len(value)) - txp.Counter.CountBytesSent(len("\r\n")) - } - } - txp.Counter.CountBytesSent(len("\r\n")) -} - -func (txp ByteCountingTransport) estimateResponseMetadata(resp *http.Response) { - txp.Counter.CountBytesReceived(len(resp.Status)) - for key, values := range resp.Header { - for _, value := range values { - txp.Counter.CountBytesReceived(len(key)) - txp.Counter.CountBytesReceived(len(": ")) - txp.Counter.CountBytesReceived(len(value)) - txp.Counter.CountBytesReceived(len("\r\n")) - } - } - txp.Counter.CountBytesReceived(len("\r\n")) -} - -type byteCountingBody struct { - io.ReadCloser - Account func(int) -} - -func (r byteCountingBody) Read(p []byte) (int, error) { - count, err := r.ReadCloser.Read(p) - if count > 0 { - r.Account(count) - } - return count, err -} - -var _ model.HTTPTransport = ByteCountingTransport{} +type ByteCountingTransport = bytecounter.HTTPTransport diff --git a/internal/engine/netx/httptransport/bytecounter_test.go b/internal/engine/netx/httptransport/bytecounter_test.go deleted file mode 100644 index 2c0c72ffba..0000000000 --- a/internal/engine/netx/httptransport/bytecounter_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package httptransport_test - -import ( - "context" - "errors" - "io" - "net/http" - "strings" - "testing" - - "github.com/ooni/probe-cli/v3/internal/bytecounter" - "github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport" - "github.com/ooni/probe-cli/v3/internal/netxlite" -) - -func TestByteCounterFailure(t *testing.T) { - counter := bytecounter.New() - txp := httptransport.ByteCountingTransport{ - Counter: counter, - HTTPTransport: httptransport.FakeTransport{ - Err: io.EOF, - }, - } - client := &http.Client{Transport: txp} - req, err := http.NewRequest( - "POST", "https://www.google.com", strings.NewReader("AAAAAA")) - if err != nil { - t.Fatal(err) - } - req.Header.Set("User-Agent", "antani-browser/1.0.0") - resp, err := client.Do(req) - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("expected nil response here") - } - if counter.Sent.Load() != 68 { - t.Fatal("expected around 68 bytes sent") - } - if counter.Received.Load() != 0 { - t.Fatal("expected zero bytes received") - } -} - -func TestByteCounterSuccess(t *testing.T) { - counter := bytecounter.New() - txp := httptransport.ByteCountingTransport{ - Counter: counter, - HTTPTransport: httptransport.FakeTransport{ - Resp: &http.Response{ - Body: io.NopCloser(strings.NewReader("1234567")), - Header: http.Header{ - "Server": []string{"antani/0.1.0"}, - }, - Status: "200 OK", - StatusCode: http.StatusOK, - }, - }, - } - client := &http.Client{Transport: txp} - req, err := http.NewRequest( - "POST", "https://www.google.com", strings.NewReader("AAAAAA")) - if err != nil { - t.Fatal(err) - } - req.Header.Set("User-Agent", "antani-browser/1.0.0") - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - data, err := netxlite.ReadAllContext(context.Background(), resp.Body) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if string(data) != "1234567" { - t.Fatal("expected a different body here") - } - if counter.Sent.Load() != 68 { - t.Fatal("expected around 68 bytes sent") - } - if counter.Received.Load() != 37 { - t.Fatal("expected zero around 37 bytes received") - } -} - -func TestByteCounterSuccessWithEOF(t *testing.T) { - counter := bytecounter.New() - txp := httptransport.ByteCountingTransport{ - Counter: counter, - HTTPTransport: httptransport.FakeTransport{ - Resp: &http.Response{ - Body: bodyReaderWithEOF{}, - Header: http.Header{ - "Server": []string{"antani/0.1.0"}, - }, - Status: "200 OK", - StatusCode: http.StatusOK, - }, - }, - } - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if err != nil { - t.Fatal(err) - } - data, err := netxlite.ReadAllContext(context.Background(), resp.Body) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if string(data) != "A" { - t.Fatal("expected a different body here") - } -} - -type bodyReaderWithEOF struct{} - -func (bodyReaderWithEOF) Read(p []byte) (int, error) { - if len(p) < 1 { - panic("should not happen") - } - p[0] = 'A' - return 1, io.EOF // we want code to be robust to this -} -func (bodyReaderWithEOF) Close() error { - return nil -} diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 2efdb137a2..5153b74573 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -202,7 +202,7 @@ func NewHTTPTransport(config Config) model.HTTPTransport { TLSConfig: config.TLSConfig}) if config.ByteCounter != nil { - txp = httptransport.ByteCountingTransport{ + txp = &httptransport.ByteCountingTransport{ Counter: config.ByteCounter, HTTPTransport: txp} } if config.Logger != nil { diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index e8b54a1f23..9ff50d2c59 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -473,7 +473,7 @@ func TestNewWithByteCounter(t *testing.T) { txp := netx.NewHTTPTransport(netx.Config{ ByteCounter: counter, }) - bctxp, ok := txp.(httptransport.ByteCountingTransport) + bctxp, ok := txp.(*httptransport.ByteCountingTransport) if !ok { t.Fatal("not the transport we expected") } diff --git a/internal/engine/session.go b/internal/engine/session.go index b0e3bdbc43..9a4cc9559e 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -14,10 +14,10 @@ import ( "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/geolocate" "github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver" - "github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/probeservices" "github.com/ooni/probe-cli/v3/internal/kvstore" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/platform" "github.com/ooni/probe-cli/v3/internal/tunnel" "github.com/ooni/probe-cli/v3/internal/version" @@ -191,20 +191,19 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) { } } sess.proxyURL = proxyURL - httpConfig := netx.Config{ - ByteCounter: sess.byteCounter, - BogonIsError: true, - Logger: sess.logger, - ProxyURL: proxyURL, - } sess.resolver = &sessionresolver.Resolver{ ByteCounter: sess.byteCounter, KVStore: config.KVStore, Logger: sess.logger, ProxyURL: proxyURL, } - httpConfig.FullResolver = sess.resolver - sess.httpDefaultTransport = netx.NewHTTPTransport(httpConfig) + dialer := netxlite.NewDialerWithResolver(sess.logger, sess.resolver) + dialer = netxlite.NewMaybeProxyDialer(dialer, proxyURL) + handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger) + tlsDialer := netxlite.NewTLSDialer(dialer, handshaker) + txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer) + txp = bytecounter.NewHTTPTransport(txp, sess.byteCounter) + sess.httpDefaultTransport = txp return sess, nil } diff --git a/internal/netxlite/maybeproxy.go b/internal/netxlite/maybeproxy.go new file mode 100644 index 0000000000..6aaf09cf37 --- /dev/null +++ b/internal/netxlite/maybeproxy.go @@ -0,0 +1,70 @@ +package netxlite + +import ( + "context" + "errors" + "net" + "net/url" + + "github.com/ooni/probe-cli/v3/internal/model" + "golang.org/x/net/proxy" +) + +// MaybeProxyDialer is a dialer that may use a proxy. If the ProxyURL is not configured, +// this dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally +// create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer. +type MaybeProxyDialer struct { + Dialer model.Dialer + ProxyURL *url.URL +} + +// NewMaybeProxyDialer creates a new NewMaybeProxyDialer. +func NewMaybeProxyDialer(dialer model.Dialer, proxyURL *url.URL) *MaybeProxyDialer { + return &MaybeProxyDialer{ + Dialer: dialer, + ProxyURL: proxyURL, + } +} + +var _ model.Dialer = &MaybeProxyDialer{} + +// CloseIdleConnections implements Dialer.CloseIdleConnections. +func (d *MaybeProxyDialer) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} + +// ErrProxyUnsupportedScheme indicates we don't support a protocol scheme. +var ErrProxyUnsupportedScheme = errors.New("proxy: unsupported scheme") + +// DialContext implements Dialer.DialContext. +func (d *MaybeProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + url := d.ProxyURL + if url == nil { + return d.Dialer.DialContext(ctx, network, address) + } + if url.Scheme != "socks5" { + return nil, ErrProxyUnsupportedScheme + } + // the code at proxy/socks5.go never fails; see https://git.io/JfJ4g + child, _ := proxy.SOCKS5(network, url.Host, nil, &proxyDialerWrapper{d.Dialer}) + return d.dial(ctx, child, network, address) +} + +func (d *MaybeProxyDialer) dial( + ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) { + cd := child.(proxy.ContextDialer) // will work + return cd.DialContext(ctx, network, address) +} + +// proxyDialerWrapper is required because SOCKS5 expects a Dialer.Dial type but internally +// it checks whether DialContext is available and prefers that. So, we need to use this +// structure to cast our inner Dialer the way in which SOCKS5 likes it. +// +// See https://git.io/JfJ4g. +type proxyDialerWrapper struct { + model.Dialer +} + +func (d *proxyDialerWrapper) Dial(network, address string) (net.Conn, error) { + panic(errors.New("proxyDialerWrapper.Dial should not be called directly")) +} diff --git a/internal/netxlite/maybeproxy_test.go b/internal/netxlite/maybeproxy_test.go new file mode 100644 index 0000000000..af3a502b51 --- /dev/null +++ b/internal/netxlite/maybeproxy_test.go @@ -0,0 +1,104 @@ +package netxlite + +import ( + "context" + "errors" + "io" + "net" + "net/url" + "testing" + + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestMaybeProxyDialer(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("missing proxy URL", func(t *testing.T) { + expected := errors.New("mocked error") + d := &MaybeProxyDialer{ + Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }}, + ProxyURL: nil, + } + conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, expected) { + t.Fatal(err) + } + if conn != nil { + t.Fatal("conn is not nil") + } + }) + + t.Run("invalid scheme", func(t *testing.T) { + child := &mocks.Dialer{} + URL := &url.URL{Scheme: "antani"} + d := NewMaybeProxyDialer(child, URL) + conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, ErrProxyUnsupportedScheme) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("conn is not nil") + } + }) + + t.Run("underlying dial fails with EOF", func(t *testing.T) { + const expect = "10.0.0.1:9050" + d := &MaybeProxyDialer{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + if address != expect { + return nil, errors.New("unexpected address") + } + return nil, io.EOF + }, + }, + ProxyURL: &url.URL{ + Scheme: "socks5", + Host: expect, + }, + } + conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("conn is not nil") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + URL := &url.URL{} + dialer := NewMaybeProxyDialer(child, URL) + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("proxyDialerWrapper", func(t *testing.T) { + t.Run("Dial panics", func(t *testing.T) { + d := &proxyDialerWrapper{} + err := func() (rv error) { + defer func() { + if r := recover(); r != nil { + rv = r.(error) + } + }() + d.Dial("tcp", "10.0.0.1:1234") + return + }() + if err.Error() != "proxyDialerWrapper.Dial should not be called directly" { + t.Fatal("unexpected result", err) + } + }) + }) +}