From 3929f465403aa6742a05f2d1c61e918383d8c120 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 24 Oct 2019 17:27:57 +0200 Subject: [PATCH] dialer: refactor TLS dialing as model.TLSDialer --- internal/dialer/dialer.go | 102 +++-------------- internal/dialer/dialer_test.go | 38 ++---- internal/dialer/tlsdialer/tlsdialer.go | 121 ++++++++++++++++++++ internal/dialer/tlsdialer/tlsdialer_test.go | 98 ++++++++++++++++ 4 files changed, 243 insertions(+), 116 deletions(-) create mode 100644 internal/dialer/tlsdialer/tlsdialer.go create mode 100644 internal/dialer/tlsdialer/tlsdialer_test.go diff --git a/internal/dialer/dialer.go b/internal/dialer/dialer.go index 6cfdb12..1c738ff 100644 --- a/internal/dialer/dialer.go +++ b/internal/dialer/dialer.go @@ -13,6 +13,7 @@ import ( "time" "github.com/ooni/netx/internal/dialer/dialerbase" + "github.com/ooni/netx/internal/dialer/tlsdialer" "github.com/ooni/netx/model" ) @@ -21,12 +22,10 @@ var nextDialID, nextConnID int64 // Dialer defines the dialer API. We implement the most basic form // of DNS, but more advanced resolutions are possible. type Dialer struct { - Beginning time.Time - Handler model.Handler - Resolver model.DNSResolver - StartTLSHandshakeHook func(net.Conn) - TLSConfig *tls.Config - TLSHandshakeTimeout time.Duration + Beginning time.Time + Handler model.Handler + Resolver model.DNSResolver + TLSConfig *tls.Config } // NewDialer creates a new Dialer. @@ -34,11 +33,10 @@ func NewDialer( beginning time.Time, handler model.Handler, ) (d *Dialer) { return &Dialer{ - Beginning: beginning, - Handler: handler, - Resolver: new(net.Resolver), - TLSConfig: &tls.Config{}, - StartTLSHandshakeHook: func(net.Conn) {}, + Beginning: beginning, + Handler: handler, + Resolver: new(net.Resolver), + TLSConfig: &tls.Config{}, } } @@ -71,26 +69,10 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { func (d *Dialer) DialTLSContext( ctx context.Context, network, address string, ) (net.Conn, error) { - conn, onlyhost, _, connID, err := d.DialContextEx(ctx, network, address, false) - if err != nil { - return nil, err - } - config := d.clonedTLSConfig() - if config.ServerName == "" { - config.ServerName = onlyhost - } - timeout := d.TLSHandshakeTimeout - if timeout <= 0 { - timeout = 10 * time.Second - } - tc, err := d.tlsHandshake(config, timeout, conn, connID) - if err != nil { - conn.Close() - return nil, err - } - // Note that we cannot wrap `tc` because the HTTP code assumes - // a `*tls.Conn` when implementing ALPN. - return tc, nil + dialer := tlsdialer.New( + d.Beginning, d.Handler, d, d.TLSConfig, + ) + return dialer.DialTLSContext(ctx, network, address) } // DialContextEx is an extended DialContext where we may also @@ -151,64 +133,6 @@ func (d *Dialer) DialContextEx( return } -func (d *Dialer) clonedTLSConfig() *tls.Config { - return d.TLSConfig.Clone() -} - -func (d *Dialer) tlsHandshake( - config *tls.Config, timeout time.Duration, conn net.Conn, connID int64, -) (*tls.Conn, error) { - d.StartTLSHandshakeHook(conn) - err := conn.SetDeadline(time.Now().Add(timeout)) - if err != nil { - conn.Close() - return nil, err - } - tc := tls.Client(net.Conn(conn), config) - start := time.Now() - err = tc.Handshake() - stop := time.Now() - state := tc.ConnectionState() - d.Handler.OnMeasurement(model.Measurement{ - TLSHandshake: &model.TLSHandshakeEvent{ - Config: model.TLSConfig{ - NextProtos: config.NextProtos, - ServerName: config.ServerName, - }, - ConnectionState: model.TLSConnectionState{ - CipherSuite: state.CipherSuite, - NegotiatedProtocol: state.NegotiatedProtocol, - NegotiatedProtocolIsMutual: state.NegotiatedProtocolIsMutual, - PeerCertificates: simplifyCerts(state.PeerCertificates), - Version: state.Version, - }, - Duration: stop.Sub(start), - Error: err, - ConnID: connID, - Time: stop.Sub(d.Beginning), - }, - }) - if err != nil { - tc.Close() - return nil, err - } - // The following call fails if the connection is not connected - // which should not be the case at this point. If the connection - // has just been disconnected, we'll notice when doing I/O, so - // it is fine to ignore the return value of SetDeadline. - tc.SetDeadline(time.Time{}) - return tc, nil -} - -func simplifyCerts(in []*x509.Certificate) (out []model.X509Certificate) { - for _, cert := range in { - out = append(out, model.X509Certificate{ - Data: cert.Raw, - }) - } - return -} - // SetCABundle configures the dialer to use a specific CA bundle. func (d *Dialer) SetCABundle(path string) error { cert, err := ioutil.ReadFile(path) diff --git a/internal/dialer/dialer_test.go b/internal/dialer/dialer_test.go index 8186c3a..74361a0 100644 --- a/internal/dialer/dialer_test.go +++ b/internal/dialer/dialer_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "net" "testing" "time" @@ -30,6 +29,17 @@ func TestIntegrationDialTLS(t *testing.T) { } func TestIntegrationInvalidAddress(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + conn, err := dialer.Dial("tcp", "www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestIntegrationInvalidAddressTLS(t *testing.T) { dialer := NewDialer(time.Now(), handlers.NoHandler) conn, err := dialer.DialTLS("tcp", "www.google.com") if err == nil { @@ -146,32 +156,6 @@ func TestIntegrationDialInvalidSNI(t *testing.T) { } } -func TestIntegrationTLSHandshakeSetDeadlineError(t *testing.T) { - dialer := NewDialer(time.Now(), handlers.NoHandler) - dialer.StartTLSHandshakeHook = func(c net.Conn) { - c.Close() // close the connection so SetDealine should fail - } - conn, err := dialer.DialTLS("tcp", "ooni.io:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected a nil conn here") - } -} - -func TestIntegrationTLSHandshakeTimeout(t *testing.T) { - dialer := NewDialer(time.Now(), handlers.NoHandler) - dialer.TLSHandshakeTimeout = 1 // very small timeout - conn, err := dialer.DialTLS("tcp", "ooni.io:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected a nil conn here") - } -} - func TestSetCABundleExisting(t *testing.T) { dialer := NewDialer(time.Now(), handlers.NoHandler) err := dialer.SetCABundle("../../testdata/cacert.pem") diff --git a/internal/dialer/tlsdialer/tlsdialer.go b/internal/dialer/tlsdialer/tlsdialer.go new file mode 100644 index 0000000..3fd3604 --- /dev/null +++ b/internal/dialer/tlsdialer/tlsdialer.go @@ -0,0 +1,121 @@ +// Package tlsdialer contains the TLS dialer +package tlsdialer + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "time" + + "github.com/ooni/netx/internal/dialer/connx" + "github.com/ooni/netx/model" +) + +// TLSDialer is the TLS dialer +type TLSDialer struct { + ConnectTimeout time.Duration // default: 30 second + TLSHandshakeTimeout time.Duration // default: 10 second + beginning time.Time + config *tls.Config + dialer model.Dialer + handler model.Handler + setDeadline func(net.Conn, time.Time) error +} + +// New creates a new TLS dialer +func New( + beginning time.Time, + handler model.Handler, + dialer model.Dialer, + config *tls.Config, +) *TLSDialer { + return &TLSDialer{ + ConnectTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + beginning: beginning, + config: config, + dialer: dialer, + handler: handler, + setDeadline: func(conn net.Conn, t time.Time) error { + return conn.SetDeadline(t) + }, + } +} + +// DialTLS dials a new TLS connection +func (d *TLSDialer) DialTLS(network, address string) (net.Conn, error) { + ctx := context.Background() + return d.DialTLSContext(ctx, network, address) +} + +// DialTLSContext is like DialTLS, but with context +func (d *TLSDialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(context.Background(), d.ConnectTimeout) + defer cancel() + conn, err := d.dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + config := d.config.Clone() // avoid polluting original config + if config.ServerName == "" { + config.ServerName = host + } + err = d.setDeadline(conn, time.Now().Add(d.TLSHandshakeTimeout)) + if err != nil { + conn.Close() + return nil, err + } + tlsconn := tls.Client(conn, config) + start := time.Now() + err = tlsconn.Handshake() + stop := time.Now() + var connID int64 + if mconn, ok := conn.(*connx.MeasuringConn); ok { + connID = mconn.ID + } + m := model.Measurement{ + TLSHandshake: &model.TLSHandshakeEvent{ + ConnID: connID, + Config: model.TLSConfig{ + NextProtos: config.NextProtos, + ServerName: config.ServerName, + }, + ConnectionState: newConnectionState(tlsconn.ConnectionState()), + Duration: stop.Sub(start), + Error: err, + Time: stop.Sub(d.beginning), + }, + } + conn.SetDeadline(time.Time{}) // clear deadline + d.handler.OnMeasurement(m) + if err != nil { + conn.Close() + return nil, err + } + return tlsconn, err +} + +func newConnectionState(s tls.ConnectionState) model.TLSConnectionState { + return model.TLSConnectionState{ + CipherSuite: s.CipherSuite, + NegotiatedProtocol: s.NegotiatedProtocol, + PeerCertificates: simplifyCerts(s.PeerCertificates), + Version: s.Version, + } +} + +func simplifyCerts(in []*x509.Certificate) (out []model.X509Certificate) { + for _, cert := range in { + out = append(out, model.X509Certificate{ + Data: cert.Raw, + }) + } + return +} diff --git a/internal/dialer/tlsdialer/tlsdialer_test.go b/internal/dialer/tlsdialer/tlsdialer_test.go new file mode 100644 index 0000000..3430653 --- /dev/null +++ b/internal/dialer/tlsdialer/tlsdialer_test.go @@ -0,0 +1,98 @@ +package tlsdialer + +import ( + "crypto/tls" + "errors" + "net" + "testing" + "time" + + "github.com/ooni/netx/handlers" + "github.com/ooni/netx/internal/dialer/dialerbase" + "github.com/ooni/netx/model" +) + +func TestIntegrationSuccess(t *testing.T) { + dialer := newdialer() + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("connection is nil") + } + conn.Close() +} + +func TestIntegrationSuccessWithMeasuringConn(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).dialer = dialerbase.New( + time.Now(), handlers.NoHandler, new(net.Dialer), 17, 17, + ) + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("connection is nil") + } + conn.Close() +} + +func TestIntegrationFailureSplitHostPort(t *testing.T) { + dialer := newdialer() + conn, err := dialer.DialTLS("tcp", "www.google.com") // missing port + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestIntegrationFailureConnectTimeout(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).ConnectTimeout = 10 * time.Microsecond + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestIntegrationFailureTLSHandshakeTimeout(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).TLSHandshakeTimeout = 10 * time.Microsecond + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestIntegrationFailureSetDeadline(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).setDeadline = func(conn net.Conn, t time.Time) error { + return errors.New("mocked error") + } + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func newdialer() model.TLSDialer { + return New( + time.Now(), + handlers.NoHandler, + new(net.Dialer), + new(tls.Config), + ) +}