Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(netxlite): http factory that propagates close-idle-connections #465

Merged
merged 3 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 35 additions & 10 deletions internal/netxlite/http.go
Expand Up @@ -2,7 +2,6 @@ package netxlite

import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
Expand Down Expand Up @@ -67,17 +66,37 @@ func (txp *httpTransportLogger) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
}

// NewHTTPTransport creates a new HTTP transport using Go stdlib.
func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
handshaker TLSHandshaker) HTTPTransport {
// httpTransportConnectionsCloser is an HTTPTransport that
// correctly forwards CloseIdleConnections.
type httpTransportConnectionsCloser struct {
HTTPTransport
Dialer
TLSDialer
}

// CloseIdleConnections forwards the CloseIdleConnections calls.
func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
txp.Dialer.CloseIdleConnections()
txp.TLSDialer.CloseIdleConnections()
}

// NewHTTPTransport creates a new HTTP transport using the given
// dialer and TLS handshaker to create connections.
//
// We need a TLS handshaker here, as opposed to a TLSDialer, because we
// wrap the dialer we'll use to enforce timeouts for HTTP idle
// connections (see https://github.com/ooni/probe/issues/1609 for more info).
func NewHTTPTransport(dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport {
// TODO(bassosimone): here we should copy code living inside the
// websteps prototype to use the oohttp library.
txp := http.DefaultTransport.(*http.Transport).Clone()
// This wrapping ensures that we always have a timeout when we
// are using HTTP; see https://github.com/ooni/probe/issues/1609.
dialer = &httpDialerWithReadTimeout{dialer}
txp.DialContext = dialer.DialContext
txp.DialTLSContext = (&tlsDialer{
Config: tlsConfig,
Dialer: dialer,
TLSHandshaker: handshaker,
}).DialTLSContext
tlsDialer := NewTLSDialer(dialer, tlsHandshaker)
txp.DialTLSContext = tlsDialer.DialTLSContext
// Better for Cloudflare DNS and also better because we have less
// noisy events and we can better understand what happened.
txp.MaxConnsPerHost = 1
Expand All @@ -86,7 +105,13 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
// back the true headers, such as Content-Length. This change is
// functional to OONI's goal of observing the network.
txp.DisableCompression = true
return txp
txp.ForceAttemptHTTP2 = true
// Ensure we correctly forward CloseIdleConnections.
return &httpTransportConnectionsCloser{
HTTPTransport: txp,
Dialer: dialer,
TLSDialer: tlsDialer,
}
}

// httpDialerWithReadTimeout enforces a read timeout for all HTTP
Expand Down
62 changes: 51 additions & 11 deletions internal/netxlite/http_test.go
Expand Up @@ -2,7 +2,6 @@ package netxlite

import (
"context"
"crypto/tls"
"errors"
"io"
"net"
Expand Down Expand Up @@ -110,34 +109,33 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
}

func TestHTTPTransportWorks(t *testing.T) {
d := &dialerResolver{
Dialer: defaultDialer,
Resolver: NewResolverSystem(log.Log),
}
th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th)
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log))
client := &http.Client{Transport: txp}
defer client.CloseIdleConnections()
resp, err := client.Get("https://www.google.com/robots.txt")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
txp.CloseIdleConnections()
}

func TestHTTPTransportWithFailingDialer(t *testing.T) {
called := &atomicx.Int64{}
expected := errors.New("mocked error")
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context,
network, address string) (net.Conn, error) {
return nil, expected
},
MockCloseIdleConnections: func() {
called.Add(1)
},
},
Resolver: NewResolverSystem(log.Log),
}
th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th)
txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log))
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/robots.txt")
if !errors.Is(err, expected) {
Expand All @@ -146,5 +144,47 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
if resp != nil {
t.Fatal("expected non-nil response here")
}
txp.CloseIdleConnections()
client.CloseIdleConnections()
if called.Load() < 1 {
t.Fatal("did not propagate CloseIdleConnections")
}
}

func TestNewHTTPTransport(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
txp := NewHTTPTransport(d, th)
txpcc, okay := txp.(*httpTransportConnectionsCloser)
if !okay {
t.Fatal("invalid type")
}
udt, okay := txpcc.Dialer.(*httpDialerWithReadTimeout)
if !okay {
t.Fatal("invalid type")
}
if udt.Dialer != d {
t.Fatal("invalid dialer")
}
if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th {
t.Fatal("invalid tls handshaker")
}
htxp, okay := txpcc.HTTPTransport.(*http.Transport)
if !okay {
t.Fatal("invalid type")
}
if !htxp.ForceAttemptHTTP2 {
t.Fatal("invalid ForceAttemptHTTP2")
}
if !htxp.DisableCompression {
t.Fatal("invalid DisableCompression")
}
if htxp.MaxConnsPerHost != 1 {
t.Fatal("invalid MaxConnPerHost")
}
if htxp.DialTLSContext == nil {
t.Fatal("invalid DialTLSContext")
}
if htxp.DialContext == nil {
t.Fatal("invalid DialContext")
}
}
60 changes: 60 additions & 0 deletions internal/netxlite/mocks/tls.go
@@ -0,0 +1,60 @@
package mocks

import (
"context"
"crypto/tls"
"net"
)

// TLSHandshaker is a mockable TLS handshaker.
type TLSHandshaker struct {
MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}

// Handshake calls MockHandshake.
func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error) {
return th.MockHandshake(ctx, conn, config)
}

// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
// Conn is the embedded mockable Conn.
Conn

// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState

// MockHandshakeContext allows to mock the HandshakeContext method.
MockHandshakeContext func(ctx context.Context) error
}

// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}

// HandshakeContext calls MockHandshakeContext.
func (c *TLSConn) HandshakeContext(ctx context.Context) error {
return c.MockHandshakeContext(ctx)
}

// TLSDialer allows to mock netxlite.TLSDialer.
type TLSDialer struct {
// MockCloseIdleConnections allows to mock the CloseIdleConnections method.
MockCloseIdleConnections func()

// MockDialTLSContext allows to mock the DialTLSContext method.
MockDialTLSContext func(ctx context.Context, network, address string) (net.Conn, error)
}

// CloseIdleConnections calls MockCloseIdleConnections.
func (d *TLSDialer) CloseIdleConnections() {
d.MockCloseIdleConnections()
}

// DialTLSContext calls MockDialTLSContext.
func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.MockDialTLSContext(ctx, network, address)
}
89 changes: 89 additions & 0 deletions internal/netxlite/mocks/tls_test.go
@@ -0,0 +1,89 @@
package mocks

import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
)

func TestTLSHandshakerHandshake(t *testing.T) {
expected := errors.New("mocked error")
conn := &Conn{}
ctx := context.Background()
config := &tls.Config{}
th := &TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expected
},
}
tlsConn, connState, err := th.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here")
}
if tlsConn != nil {
t.Fatal("expected nil conn here")
}
}

func TestTLSConnConnectionState(t *testing.T) {
state := tls.ConnectionState{Version: tls.VersionTLS12}
c := &TLSConn{
MockConnectionState: func() tls.ConnectionState {
return state
},
}
out := c.ConnectionState()
if !reflect.DeepEqual(out, state) {
t.Fatal("not the result we expected")
}
}

func TestTLSConnHandshakeContext(t *testing.T) {
expected := errors.New("mocked error")
c := &TLSConn{
MockHandshakeContext: func(ctx context.Context) error {
return expected
},
}
err := c.HandshakeContext(context.Background())
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
}

func TestTLSDialerCloseIdleConnections(t *testing.T) {
var called bool
td := &TLSDialer{
MockCloseIdleConnections: func() {
called = true
},
}
td.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}

func TestTLSDialerDialTLSContext(t *testing.T) {
expected := errors.New("mocked error")
td := &TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
}
ctx := context.Background()
conn, err := td.DialTLSContext(ctx, "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
28 changes: 0 additions & 28 deletions internal/netxlite/mocks/tlsconn.go

This file was deleted.

35 changes: 0 additions & 35 deletions internal/netxlite/mocks/tlsconn_test.go

This file was deleted.