Skip to content

Commit

Permalink
fix(netxlite): http factory that propagates close-idle-connections (#465
Browse files Browse the repository at this point in the history
)

While there reorganize mocks' tls implementation to use a single file
called tls.go (and tls_test.go) just like netxlite does.

While there write tests ensuring we always add timeouts when we are
making TCP connections (be them TLS or cleartext).

See ooni/probe#1591
  • Loading branch information
bassosimone committed Sep 6, 2021
1 parent 2572376 commit 6df27d9
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 136 deletions.
45 changes: 35 additions & 10 deletions internal/netxlite/http.go
Expand Up @@ -2,7 +2,6 @@ package netxlite

import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
Expand Down Expand Up @@ -67,17 +66,37 @@ func (txp *httpTransportLogger) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
}

// NewHTTPTransport creates a new HTTP transport using Go stdlib.
func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
handshaker TLSHandshaker) HTTPTransport {
// httpTransportConnectionsCloser is an HTTPTransport that
// correctly forwards CloseIdleConnections.
type httpTransportConnectionsCloser struct {
HTTPTransport
Dialer
TLSDialer
}

// CloseIdleConnections forwards the CloseIdleConnections calls.
func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
txp.Dialer.CloseIdleConnections()
txp.TLSDialer.CloseIdleConnections()
}

// NewHTTPTransport creates a new HTTP transport using the given
// dialer and TLS handshaker to create connections.
//
// We need a TLS handshaker here, as opposed to a TLSDialer, because we
// wrap the dialer we'll use to enforce timeouts for HTTP idle
// connections (see https://github.com/ooni/probe/issues/1609 for more info).
func NewHTTPTransport(dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport {
// TODO(bassosimone): here we should copy code living inside the
// websteps prototype to use the oohttp library.
txp := http.DefaultTransport.(*http.Transport).Clone()
// This wrapping ensures that we always have a timeout when we
// are using HTTP; see https://github.com/ooni/probe/issues/1609.
dialer = &httpDialerWithReadTimeout{dialer}
txp.DialContext = dialer.DialContext
txp.DialTLSContext = (&tlsDialer{
Config: tlsConfig,
Dialer: dialer,
TLSHandshaker: handshaker,
}).DialTLSContext
tlsDialer := NewTLSDialer(dialer, tlsHandshaker)
txp.DialTLSContext = tlsDialer.DialTLSContext
// Better for Cloudflare DNS and also better because we have less
// noisy events and we can better understand what happened.
txp.MaxConnsPerHost = 1
Expand All @@ -86,7 +105,13 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
// back the true headers, such as Content-Length. This change is
// functional to OONI's goal of observing the network.
txp.DisableCompression = true
return txp
txp.ForceAttemptHTTP2 = true
// Ensure we correctly forward CloseIdleConnections.
return &httpTransportConnectionsCloser{
HTTPTransport: txp,
Dialer: dialer,
TLSDialer: tlsDialer,
}
}

// httpDialerWithReadTimeout enforces a read timeout for all HTTP
Expand Down
62 changes: 51 additions & 11 deletions internal/netxlite/http_test.go
Expand Up @@ -2,7 +2,6 @@ package netxlite

import (
"context"
"crypto/tls"
"errors"
"io"
"net"
Expand Down Expand Up @@ -110,34 +109,33 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
}

func TestHTTPTransportWorks(t *testing.T) {
d := &dialerResolver{
Dialer: defaultDialer,
Resolver: NewResolverSystem(log.Log),
}
th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th)
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log))
client := &http.Client{Transport: txp}
defer client.CloseIdleConnections()
resp, err := client.Get("https://www.google.com/robots.txt")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
txp.CloseIdleConnections()
}

func TestHTTPTransportWithFailingDialer(t *testing.T) {
called := &atomicx.Int64{}
expected := errors.New("mocked error")
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context,
network, address string) (net.Conn, error) {
return nil, expected
},
MockCloseIdleConnections: func() {
called.Add(1)
},
},
Resolver: NewResolverSystem(log.Log),
}
th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th)
txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log))
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/robots.txt")
if !errors.Is(err, expected) {
Expand All @@ -146,5 +144,47 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
if resp != nil {
t.Fatal("expected non-nil response here")
}
txp.CloseIdleConnections()
client.CloseIdleConnections()
if called.Load() < 1 {
t.Fatal("did not propagate CloseIdleConnections")
}
}

func TestNewHTTPTransport(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
txp := NewHTTPTransport(d, th)
txpcc, okay := txp.(*httpTransportConnectionsCloser)
if !okay {
t.Fatal("invalid type")
}
udt, okay := txpcc.Dialer.(*httpDialerWithReadTimeout)
if !okay {
t.Fatal("invalid type")
}
if udt.Dialer != d {
t.Fatal("invalid dialer")
}
if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th {
t.Fatal("invalid tls handshaker")
}
htxp, okay := txpcc.HTTPTransport.(*http.Transport)
if !okay {
t.Fatal("invalid type")
}
if !htxp.ForceAttemptHTTP2 {
t.Fatal("invalid ForceAttemptHTTP2")
}
if !htxp.DisableCompression {
t.Fatal("invalid DisableCompression")
}
if htxp.MaxConnsPerHost != 1 {
t.Fatal("invalid MaxConnPerHost")
}
if htxp.DialTLSContext == nil {
t.Fatal("invalid DialTLSContext")
}
if htxp.DialContext == nil {
t.Fatal("invalid DialContext")
}
}
60 changes: 60 additions & 0 deletions internal/netxlite/mocks/tls.go
@@ -0,0 +1,60 @@
package mocks

import (
"context"
"crypto/tls"
"net"
)

// TLSHandshaker is a mockable TLS handshaker.
type TLSHandshaker struct {
MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}

// Handshake calls MockHandshake.
func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error) {
return th.MockHandshake(ctx, conn, config)
}

// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
// Conn is the embedded mockable Conn.
Conn

// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState

// MockHandshakeContext allows to mock the HandshakeContext method.
MockHandshakeContext func(ctx context.Context) error
}

// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}

// HandshakeContext calls MockHandshakeContext.
func (c *TLSConn) HandshakeContext(ctx context.Context) error {
return c.MockHandshakeContext(ctx)
}

// TLSDialer allows to mock netxlite.TLSDialer.
type TLSDialer struct {
// MockCloseIdleConnections allows to mock the CloseIdleConnections method.
MockCloseIdleConnections func()

// MockDialTLSContext allows to mock the DialTLSContext method.
MockDialTLSContext func(ctx context.Context, network, address string) (net.Conn, error)
}

// CloseIdleConnections calls MockCloseIdleConnections.
func (d *TLSDialer) CloseIdleConnections() {
d.MockCloseIdleConnections()
}

// DialTLSContext calls MockDialTLSContext.
func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.MockDialTLSContext(ctx, network, address)
}
89 changes: 89 additions & 0 deletions internal/netxlite/mocks/tls_test.go
@@ -0,0 +1,89 @@
package mocks

import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
)

func TestTLSHandshakerHandshake(t *testing.T) {
expected := errors.New("mocked error")
conn := &Conn{}
ctx := context.Background()
config := &tls.Config{}
th := &TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expected
},
}
tlsConn, connState, err := th.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here")
}
if tlsConn != nil {
t.Fatal("expected nil conn here")
}
}

func TestTLSConnConnectionState(t *testing.T) {
state := tls.ConnectionState{Version: tls.VersionTLS12}
c := &TLSConn{
MockConnectionState: func() tls.ConnectionState {
return state
},
}
out := c.ConnectionState()
if !reflect.DeepEqual(out, state) {
t.Fatal("not the result we expected")
}
}

func TestTLSConnHandshakeContext(t *testing.T) {
expected := errors.New("mocked error")
c := &TLSConn{
MockHandshakeContext: func(ctx context.Context) error {
return expected
},
}
err := c.HandshakeContext(context.Background())
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
}

func TestTLSDialerCloseIdleConnections(t *testing.T) {
var called bool
td := &TLSDialer{
MockCloseIdleConnections: func() {
called = true
},
}
td.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}

func TestTLSDialerDialTLSContext(t *testing.T) {
expected := errors.New("mocked error")
td := &TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
}
ctx := context.Background()
conn, err := td.DialTLSContext(ctx, "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
28 changes: 0 additions & 28 deletions internal/netxlite/mocks/tlsconn.go

This file was deleted.

35 changes: 0 additions & 35 deletions internal/netxlite/mocks/tlsconn_test.go

This file was deleted.

0 comments on commit 6df27d9

Please sign in to comment.