Skip to content

Commit

Permalink
proxy: add Dial (with context)
Browse files Browse the repository at this point in the history
The existing API does not allow client code to take advantage of Dialer
implementations that implement DialContext receivers. This a familiar
API, see net.Dialer.

Fixes golang/go#27874
Fixes golang/go#19354
Fixes golang/go#17759
Fixes golang/go#13455
  • Loading branch information
dweomer committed May 2, 2019
1 parent 9ce7a69 commit b0a3727
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 18 deletions.
54 changes: 54 additions & 0 deletions proxy/dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package proxy

import (
"context"
"net"
)

// A ContextDialer dials using a context.
type ContextDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
//
// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
//
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
//
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
func Dial(ctx context.Context, network, address string) (net.Conn, error) {
d := FromEnvironment()
if xd, ok := d.(ContextDialer); ok {
return xd.DialContext(ctx, network, address)
}
return dialContext(ctx, d, network, address)
}

// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
var (
conn net.Conn
done = make(chan struct{}, 1)
err error
)
go func() {
conn, err = d.Dial(network, address)
close(done)
if conn != nil && ctx.Err() != nil {
conn.Close()
}
}()
select {
case <-ctx.Done():
err = ctx.Err()
case <-done:
}
return conn, err
}
131 changes: 131 additions & 0 deletions proxy/dial_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package proxy

import (
"context"
"fmt"
"net"
"os"
"testing"
"time"

"golang.org/x/net/internal/sockstest"
)

func TestDial(t *testing.T) {
ResetProxyEnv()
t.Run("DirectWithCancel", func(t *testing.T) {
defer ResetProxyEnv()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
if err != nil {
t.Fatal(err)
}
c.Close()
})
t.Run("DirectWithTimeout", func(t *testing.T) {
defer ResetProxyEnv()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
if err != nil {
t.Fatal(err)
}
c.Close()
})
t.Run("DirectWithTimeoutExceeded", func(t *testing.T) {
defer ResetProxyEnv()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
time.Sleep(time.Millisecond)
defer cancel()
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
if err == nil {
defer c.Close()
t.Fatal("failed to timeout")
}
})
t.Run("SOCKS5", func(t *testing.T) {
defer ResetProxyEnv()
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Fatal(err)
}
defer s.Close()
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
t.Fatal(err)
}
c, err := Dial(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String())
if err != nil {
t.Fatal(err)
}
c.Close()
})
t.Run("SOCKS5WithTimeout", func(t *testing.T) {
defer ResetProxyEnv()
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Fatal(err)
}
defer s.Close()
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
if err != nil {
t.Fatal(err)
}
c.Close()
})
t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) {
defer ResetProxyEnv()
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Fatal(err)
}
defer s.Close()
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
time.Sleep(time.Millisecond)
defer cancel()
c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
if err == nil {
defer c.Close()
t.Fatal("failed to timeout")
}
})
}
8 changes: 8 additions & 0 deletions proxy/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package proxy

import (
"context"
"net"
)

Expand All @@ -13,6 +14,13 @@ type direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var Direct = direct{}

// Dial directly invokes net.Dial with the supplied parameters.
func (direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}

// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
}
15 changes: 15 additions & 0 deletions proxy/per_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package proxy

import (
"context"
"net"
"strings"
)
Expand Down Expand Up @@ -41,6 +42,20 @@ func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
return p.dialerForRequest(host).Dial(network, addr)
}

// DialContext connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
d := p.dialerForRequest(host)
if x, ok := d.(ContextDialer); ok {
return x.DialContext(ctx, network, addr)
}
return dialContext(ctx, d, network, addr)
}

func (p *PerHost) dialerForRequest(host string) Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
Expand Down
53 changes: 37 additions & 16 deletions proxy/per_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package proxy

import (
"context"
"errors"
"net"
"reflect"
Expand All @@ -21,10 +22,6 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
}

func TestPerHost(t *testing.T) {
var def, bypass recordingProxy
perHost := NewPerHost(&def, &bypass)
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")

expectedDef := []string{
"example.com:123",
"1.2.3.4:123",
Expand All @@ -39,17 +36,41 @@ func TestPerHost(t *testing.T) {
"[1000::]:123",
}

for _, addr := range expectedDef {
perHost.Dial("tcp", addr)
}
for _, addr := range expectedBypass {
perHost.Dial("tcp", addr)
}
t.Run("Dial", func(t *testing.T) {
var def, bypass recordingProxy
perHost := NewPerHost(&def, &bypass)
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
for _, addr := range expectedDef {
perHost.Dial("tcp", addr)
}
for _, addr := range expectedBypass {
perHost.Dial("tcp", addr)
}

if !reflect.DeepEqual(expectedDef, def.addrs) {
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
}
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
}
if !reflect.DeepEqual(expectedDef, def.addrs) {
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
}
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
}
})

t.Run("DialContext", func(t *testing.T) {
var def, bypass recordingProxy
perHost := NewPerHost(&def, &bypass)
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
for _, addr := range expectedDef {
perHost.DialContext(context.Background(), "tcp", addr)
}
for _, addr := range expectedBypass {
perHost.DialContext(context.Background(), "tcp", addr)
}

if !reflect.DeepEqual(expectedDef, def.addrs) {
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
}
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
}
})
}
1 change: 1 addition & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
)

// A Dialer is a means to establish a connection.
// Custom dialers should also implement ContextDialer.
type Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
Expand Down
10 changes: 8 additions & 2 deletions proxy/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ import (
func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
d := socks.NewDialer(network, address)
if forward != nil {
d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) {
return forward.Dial(network, address)
if f, ok := forward.(ContextDialer); ok {
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
return f.DialContext(ctx, network, address)
}
} else {
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
return dialContext(ctx, forward, network, address)
}
}
}
if auth != nil {
Expand Down

0 comments on commit b0a3727

Please sign in to comment.