Skip to content
This repository was archived by the owner on Mar 6, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 13 additions & 89 deletions internal/dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/ooni/netx/internal/dialer/dialerbase"
"github.com/ooni/netx/internal/dialer/tlsdialer"
"github.com/ooni/netx/model"
)

Expand All @@ -21,24 +22,21 @@ var nextDialID, nextConnID int64
// Dialer defines the dialer API. We implement the most basic form
// of DNS, but more advanced resolutions are possible.
type Dialer struct {
Beginning time.Time
Handler model.Handler
Resolver model.DNSResolver
StartTLSHandshakeHook func(net.Conn)
TLSConfig *tls.Config
TLSHandshakeTimeout time.Duration
Beginning time.Time
Handler model.Handler
Resolver model.DNSResolver
TLSConfig *tls.Config
}

// NewDialer creates a new Dialer.
func NewDialer(
beginning time.Time, handler model.Handler,
) (d *Dialer) {
return &Dialer{
Beginning: beginning,
Handler: handler,
Resolver: new(net.Resolver),
TLSConfig: &tls.Config{},
StartTLSHandshakeHook: func(net.Conn) {},
Beginning: beginning,
Handler: handler,
Resolver: new(net.Resolver),
TLSConfig: &tls.Config{},
}
}

Expand Down Expand Up @@ -71,26 +69,10 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
func (d *Dialer) DialTLSContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
conn, onlyhost, _, connID, err := d.DialContextEx(ctx, network, address, false)
if err != nil {
return nil, err
}
config := d.clonedTLSConfig()
if config.ServerName == "" {
config.ServerName = onlyhost
}
timeout := d.TLSHandshakeTimeout
if timeout <= 0 {
timeout = 10 * time.Second
}
tc, err := d.tlsHandshake(config, timeout, conn, connID)
if err != nil {
conn.Close()
return nil, err
}
// Note that we cannot wrap `tc` because the HTTP code assumes
// a `*tls.Conn` when implementing ALPN.
return tc, nil
dialer := tlsdialer.New(
d.Beginning, d.Handler, d, d.TLSConfig,
)
return dialer.DialTLSContext(ctx, network, address)
}

// DialContextEx is an extended DialContext where we may also
Expand Down Expand Up @@ -151,64 +133,6 @@ func (d *Dialer) DialContextEx(
return
}

func (d *Dialer) clonedTLSConfig() *tls.Config {
return d.TLSConfig.Clone()
}

func (d *Dialer) tlsHandshake(
config *tls.Config, timeout time.Duration, conn net.Conn, connID int64,
) (*tls.Conn, error) {
d.StartTLSHandshakeHook(conn)
err := conn.SetDeadline(time.Now().Add(timeout))
if err != nil {
conn.Close()
return nil, err
}
tc := tls.Client(net.Conn(conn), config)
start := time.Now()
err = tc.Handshake()
stop := time.Now()
state := tc.ConnectionState()
d.Handler.OnMeasurement(model.Measurement{
TLSHandshake: &model.TLSHandshakeEvent{
Config: model.TLSConfig{
NextProtos: config.NextProtos,
ServerName: config.ServerName,
},
ConnectionState: model.TLSConnectionState{
CipherSuite: state.CipherSuite,
NegotiatedProtocol: state.NegotiatedProtocol,
NegotiatedProtocolIsMutual: state.NegotiatedProtocolIsMutual,
PeerCertificates: simplifyCerts(state.PeerCertificates),
Version: state.Version,
},
Duration: stop.Sub(start),
Error: err,
ConnID: connID,
Time: stop.Sub(d.Beginning),
},
})
if err != nil {
tc.Close()
return nil, err
}
// The following call fails if the connection is not connected
// which should not be the case at this point. If the connection
// has just been disconnected, we'll notice when doing I/O, so
// it is fine to ignore the return value of SetDeadline.
tc.SetDeadline(time.Time{})
return tc, nil
}

func simplifyCerts(in []*x509.Certificate) (out []model.X509Certificate) {
for _, cert := range in {
out = append(out, model.X509Certificate{
Data: cert.Raw,
})
}
return
}

// SetCABundle configures the dialer to use a specific CA bundle.
func (d *Dialer) SetCABundle(path string) error {
cert, err := ioutil.ReadFile(path)
Expand Down
38 changes: 11 additions & 27 deletions internal/dialer/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"testing"
"time"

Expand All @@ -30,6 +29,17 @@ func TestIntegrationDialTLS(t *testing.T) {
}

func TestIntegrationInvalidAddress(t *testing.T) {
dialer := NewDialer(time.Now(), handlers.NoHandler)
conn, err := dialer.Dial("tcp", "www.google.com")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected a nil conn here")
}
}

func TestIntegrationInvalidAddressTLS(t *testing.T) {
dialer := NewDialer(time.Now(), handlers.NoHandler)
conn, err := dialer.DialTLS("tcp", "www.google.com")
if err == nil {
Expand Down Expand Up @@ -146,32 +156,6 @@ func TestIntegrationDialInvalidSNI(t *testing.T) {
}
}

func TestIntegrationTLSHandshakeSetDeadlineError(t *testing.T) {
dialer := NewDialer(time.Now(), handlers.NoHandler)
dialer.StartTLSHandshakeHook = func(c net.Conn) {
c.Close() // close the connection so SetDealine should fail
}
conn, err := dialer.DialTLS("tcp", "ooni.io:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected a nil conn here")
}
}

func TestIntegrationTLSHandshakeTimeout(t *testing.T) {
dialer := NewDialer(time.Now(), handlers.NoHandler)
dialer.TLSHandshakeTimeout = 1 // very small timeout
conn, err := dialer.DialTLS("tcp", "ooni.io:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected a nil conn here")
}
}

func TestSetCABundleExisting(t *testing.T) {
dialer := NewDialer(time.Now(), handlers.NoHandler)
err := dialer.SetCABundle("../../testdata/cacert.pem")
Expand Down
121 changes: 121 additions & 0 deletions internal/dialer/tlsdialer/tlsdialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Package tlsdialer contains the TLS dialer
package tlsdialer

import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"time"

"github.com/ooni/netx/internal/dialer/connx"
"github.com/ooni/netx/model"
)

// TLSDialer is the TLS dialer
type TLSDialer struct {
ConnectTimeout time.Duration // default: 30 second
TLSHandshakeTimeout time.Duration // default: 10 second
beginning time.Time
config *tls.Config
dialer model.Dialer
handler model.Handler
setDeadline func(net.Conn, time.Time) error
}

// New creates a new TLS dialer
func New(
beginning time.Time,
handler model.Handler,
dialer model.Dialer,
config *tls.Config,
) *TLSDialer {
return &TLSDialer{
ConnectTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
beginning: beginning,
config: config,
dialer: dialer,
handler: handler,
setDeadline: func(conn net.Conn, t time.Time) error {
return conn.SetDeadline(t)
},
}
}

// DialTLS dials a new TLS connection
func (d *TLSDialer) DialTLS(network, address string) (net.Conn, error) {
ctx := context.Background()
return d.DialTLSContext(ctx, network, address)
}

// DialTLSContext is like DialTLS, but with context
func (d *TLSDialer) DialTLSContext(
ctx context.Context, network, address string,
) (net.Conn, error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), d.ConnectTimeout)
defer cancel()
conn, err := d.dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
config := d.config.Clone() // avoid polluting original config
if config.ServerName == "" {
config.ServerName = host
}
err = d.setDeadline(conn, time.Now().Add(d.TLSHandshakeTimeout))
if err != nil {
conn.Close()
return nil, err
}
tlsconn := tls.Client(conn, config)
start := time.Now()
err = tlsconn.Handshake()
stop := time.Now()
var connID int64
if mconn, ok := conn.(*connx.MeasuringConn); ok {
connID = mconn.ID
}
m := model.Measurement{
TLSHandshake: &model.TLSHandshakeEvent{
ConnID: connID,
Config: model.TLSConfig{
NextProtos: config.NextProtos,
ServerName: config.ServerName,
},
ConnectionState: newConnectionState(tlsconn.ConnectionState()),
Duration: stop.Sub(start),
Error: err,
Time: stop.Sub(d.beginning),
},
}
conn.SetDeadline(time.Time{}) // clear deadline
d.handler.OnMeasurement(m)
if err != nil {
conn.Close()
return nil, err
}
return tlsconn, err
}

func newConnectionState(s tls.ConnectionState) model.TLSConnectionState {
return model.TLSConnectionState{
CipherSuite: s.CipherSuite,
NegotiatedProtocol: s.NegotiatedProtocol,
PeerCertificates: simplifyCerts(s.PeerCertificates),
Version: s.Version,
}
}

func simplifyCerts(in []*x509.Certificate) (out []model.X509Certificate) {
for _, cert := range in {
out = append(out, model.X509Certificate{
Data: cert.Raw,
})
}
return
}
Loading