Skip to content

Commit

Permalink
only use one host string for QUIC DialContext()
Browse files Browse the repository at this point in the history
  • Loading branch information
kelmenhorst committed Jan 11, 2021
1 parent ca2e86c commit 3b21732
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 51 deletions.
2 changes: 1 addition & 1 deletion netx/httptransport/http3transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type QUICWrapperDialer struct {

// Dial implements QUICDialer.Dial
func (d QUICWrapperDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
return d.Dialer.DialContext(context.Background(), network, "", host, tlsCfg, cfg)
return d.Dialer.DialContext(context.Background(), network, host, tlsCfg, cfg)
}

// HTTP3Transport is a httptransport.RoundTripper using the http3 protocol.
Expand Down
11 changes: 7 additions & 4 deletions netx/quicdialer/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ type DNSDialer struct {
Resolver Resolver
}

// TODO(bassosimone): figure out what `addr` is used for?

// DialContext implements ContextDialer.DialContext
func (d DNSDialer) DialContext(
ctx context.Context, network, addr string, host string,
ctx context.Context, network, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
onlyhost, onlyport, err := net.SplitHostPort(host)
if err != nil {
return nil, err
}
// TODO(kelmenhorst): Should this be somewhere else?
// failure if tlsCfg is nil but that should not happen
if tlsCfg.ServerName == "" {
tlsCfg.ServerName = onlyhost
}
ctx = dialid.WithDialID(ctx)
var addrs []string
addrs, err = d.LookupHost(ctx, onlyhost)
Expand All @@ -37,7 +40,7 @@ func (d DNSDialer) DialContext(
for _, addr := range addrs {
target := net.JoinHostPort(addr, onlyport)
sess, err := d.Dialer.DialContext(
ctx, network, target, host, tlsCfg, cfg)
ctx, network, target, tlsCfg, cfg)
if err == nil {
return sess, nil
}
Expand Down
43 changes: 16 additions & 27 deletions netx/quicdialer/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,40 @@ func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string
return r.Addresses, r.Err
}
func TestDNSDialerSuccess(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "www.google.com:443",
context.Background(), "udp", "www.google.com:443",
tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error")
t.Fatal("unexpected error", err)
}
if sess == nil {
t.Fatal("non nil sess expected")
}
}

func TestDNSDialerNoPort(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "antani.ooni.nu",
context.Background(), "udp", "www.google.com",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected a nil sess here")
}
if err.Error() != "address antani.ooni.nu: missing port in address" {
if err.Error() != "address www.google.com: missing port in address" {
t.Fatal("not the error we expected")
}
}
Expand All @@ -74,7 +80,7 @@ func TestDNSDialerLookupHostFailure(t *testing.T) {
Err: expected,
}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "dns.google.com:853",
context.Background(), "udp", "dns.google.com:853",
tlsConf, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
Expand All @@ -89,7 +95,7 @@ func TestDNSDialerInvalidPort(t *testing.T) {
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "www.google.com:0",
context.Background(), "udp", "www.google.com:0",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
Expand All @@ -108,7 +114,7 @@ func TestDNSDialerInvalidPortSyntax(t *testing.T) {
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "www.google.com:port",
context.Background(), "udp", "www.google.com:port",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
Expand All @@ -121,31 +127,14 @@ func TestDNSDialerInvalidPortSyntax(t *testing.T) {
}
}

func TestDNSDialerNilTLSConf(t *testing.T) {
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "www.google.com:443",
nil /* should cause failure */, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil sess")
}
if err.Error() != "quic: tls.Config not set" {
t.Fatal("not the error we expected")
}
}

func TestDNSDialerDialEarlyFails(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
expected := errors.New("mocked DialEarly error")

dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: MockDialer{Err: expected}}
sess, err := dialer.DialContext(
context.Background(), "udp", "", "www.google.com:443",
context.Background(), "udp", "www.google.com:443",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
Expand Down
4 changes: 2 additions & 2 deletions netx/quicdialer/errorwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type ErrorWrapperDialer struct {

// DialContext implements ContextDialer.DialContext
func (d ErrorWrapperDialer) DialContext(
ctx context.Context, network string, addr string, host string,
ctx context.Context, network string, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialID := dialid.ContextDialID(ctx)
sess, err := d.Dialer.DialContext(ctx, network, addr, host, tlsCfg, cfg)
sess, err := d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
err = errorx.SafeErrWrapperBuilder{
// ConnID does not make any sense if we've failed and the error
// does not make any sense (and is nil) if we succeded.
Expand Down
10 changes: 6 additions & 4 deletions netx/quicdialer/errorwrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestErrorWrapperFailure(t *testing.T) {
d := quicdialer.ErrorWrapperDialer{
Dialer: MockDialer{Sess: nil, Err: io.EOF}}
sess, err := d.DialContext(
ctx, "udp", "", "www.google.com:443", &tls.Config{}, &quic.Config{})
ctx, "udp", "www.google.com:443", &tls.Config{}, &quic.Config{})
if sess != nil {
t.Fatal("expected a nil sess here")
}
Expand Down Expand Up @@ -46,10 +46,12 @@ func errorWrapperCheckErr(t *testing.T, err error, op string) {

func TestErrorWrapperSuccess(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
d := quicdialer.ErrorWrapperDialer{Dialer: quicdialer.SystemDialer{}}
sess, err := d.DialContext(ctx, "udp", "216.58.212.164:443",
"www.google.com:443", tlsConf, &quic.Config{})
sess, err := d.DialContext(ctx, "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion netx/quicdialer/quicdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// ContextDialer is a dialer for QUIC using Context.
type ContextDialer interface {
DialContext(ctx context.Context, network, addr string, host string,
DialContext(ctx context.Context, network, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
}

Expand Down
4 changes: 2 additions & 2 deletions netx/quicdialer/saver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type HandshakeSaver struct {
}

// DialContext implements ContextDialer.DialContext
func (h HandshakeSaver) DialContext(ctx context.Context, network string, addr string, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
func (h HandshakeSaver) DialContext(ctx context.Context, network string, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
start := time.Now()
// TODO(bassosimone): in the future we probably want to also save
// information about what versions we're willing to accept.
Expand All @@ -33,7 +33,7 @@ func (h HandshakeSaver) DialContext(ctx context.Context, network string, addr st
TLSServerName: tlsCfg.ServerName,
Time: start,
})
sess, err := h.Dialer.DialContext(ctx, network, addr, host, tlsCfg, cfg)
sess, err := h.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
stop := time.Now()
if err != nil {
h.Saver.Write(trace.Event{
Expand Down
12 changes: 6 additions & 6 deletions netx/quicdialer/saver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ type MockDialer struct {
Err error
}

func (d MockDialer) DialContext(ctx context.Context, network, addr string, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
func (d MockDialer) DialContext(ctx context.Context, network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
if d.Dialer != nil {
return d.Dialer.DialContext(ctx, network, addr, host, tlsCfg, cfg)
return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
}
return d.Sess, d.Err
}

func TestSaverConnDialSuccess(t *testing.T) {
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
saver := &trace.Saver{}
systemdialer := quicdialer.SystemDialer{Saver: saver}

sess, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", "www.google.com:443", tlsConf, &quic.Config{})
sess, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error", err)
}
Expand Down Expand Up @@ -86,7 +86,7 @@ func TestHandshakeSaverSuccess(t *testing.T) {
Saver: saver,
}

sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", "www.google.com:443", tlsConf, &quic.Config{})
sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error", err)
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func TestHandshakeSaverHostNameError(t *testing.T) {
Saver: saver,
}

sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", "www.google.com:443", tlsConf, &quic.Config{})
sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
Expand Down
4 changes: 2 additions & 2 deletions netx/quicdialer/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ type SystemDialer struct {
}

// DialContext implements ContextDialer.DialContext
func (d SystemDialer) DialContext(ctx context.Context, network string, addr string,
func (d SystemDialer) DialContext(ctx context.Context, network string,
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
onlyhost, onlyport, err := net.SplitHostPort(addr)
onlyhost, onlyport, err := net.SplitHostPort(host)
port, err := strconv.Atoi(onlyport)
if err != nil {
return nil, err
Expand Down
6 changes: 4 additions & 2 deletions netx/quicdialer/system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import (
func TestSystemDialerSuccess(t *testing.T) {
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
var systemdialer quicdialer.SystemDialer

sess, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", "www.google.com:443", tlsConf, &quic.Config{})
sess, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error", err)
}
Expand All @@ -33,12 +34,13 @@ func TestSystemDialerSuccessWithReadWrite(t *testing.T) {
}
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
saver := &trace.Saver{}
systemdialer := quicdialer.SystemDialer{
Saver: saver,
}
_, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", "www.google.com:443", tlsConf, &quic.Config{})
_, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 3b21732

Please sign in to comment.