forked from golang/net
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The existing API does not allow client code to take advantage of Dialer implementations that implement DialContext receivers. This a familiar API, see net.Dialer. Fixes golang/go#27874 Fixes golang/go#19354 Fixes golang/go#17759 Fixes golang/go#13455
- Loading branch information
Showing
7 changed files
with
254 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright 2019 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package proxy | ||
|
||
import ( | ||
"context" | ||
"net" | ||
) | ||
|
||
// A ContextDialer dials using a context. | ||
type ContextDialer interface { | ||
DialContext(ctx context.Context, network, address string) (net.Conn, error) | ||
} | ||
|
||
// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment. | ||
// | ||
// The passed ctx is only used for returning the Conn, not the lifetime of the Conn. | ||
// | ||
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer | ||
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout. | ||
// | ||
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. | ||
func Dial(ctx context.Context, network, address string) (net.Conn, error) { | ||
d := FromEnvironment() | ||
if xd, ok := d.(ContextDialer); ok { | ||
return xd.DialContext(ctx, network, address) | ||
} | ||
return dialContext(ctx, d, network, address) | ||
} | ||
|
||
// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout | ||
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. | ||
func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) { | ||
var ( | ||
conn net.Conn | ||
done = make(chan struct{}, 1) | ||
err error | ||
) | ||
go func() { | ||
conn, err = d.Dial(network, address) | ||
close(done) | ||
if conn != nil && ctx.Err() != nil { | ||
conn.Close() | ||
} | ||
}() | ||
select { | ||
case <-ctx.Done(): | ||
err = ctx.Err() | ||
case <-done: | ||
} | ||
return conn, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
// Copyright 2019 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package proxy | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"golang.org/x/net/internal/sockstest" | ||
) | ||
|
||
func TestDial(t *testing.T) { | ||
ResetProxyEnv() | ||
t.Run("DirectWithCancel", func(t *testing.T) { | ||
defer ResetProxyEnv() | ||
l, err := net.Listen("tcp", "127.0.0.1:0") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer l.Close() | ||
_, port, err := net.SplitHostPort(l.Addr().String()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
c.Close() | ||
}) | ||
t.Run("DirectWithTimeout", func(t *testing.T) { | ||
defer ResetProxyEnv() | ||
l, err := net.Listen("tcp", "127.0.0.1:0") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer l.Close() | ||
_, port, err := net.SplitHostPort(l.Addr().String()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
defer cancel() | ||
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
c.Close() | ||
}) | ||
t.Run("DirectWithTimeoutExceeded", func(t *testing.T) { | ||
defer ResetProxyEnv() | ||
l, err := net.Listen("tcp", "127.0.0.1:0") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer l.Close() | ||
_, port, err := net.SplitHostPort(l.Addr().String()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) | ||
time.Sleep(time.Millisecond) | ||
defer cancel() | ||
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) | ||
if err == nil { | ||
defer c.Close() | ||
t.Fatal("failed to timeout") | ||
} | ||
}) | ||
t.Run("SOCKS5", func(t *testing.T) { | ||
defer ResetProxyEnv() | ||
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer s.Close() | ||
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { | ||
t.Fatal(err) | ||
} | ||
c, err := Dial(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
c.Close() | ||
}) | ||
t.Run("SOCKS5WithTimeout", func(t *testing.T) { | ||
defer ResetProxyEnv() | ||
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer s.Close() | ||
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { | ||
t.Fatal(err) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
defer cancel() | ||
c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
c.Close() | ||
}) | ||
t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) { | ||
defer ResetProxyEnv() | ||
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer s.Close() | ||
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { | ||
t.Fatal(err) | ||
} | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) | ||
time.Sleep(time.Millisecond) | ||
defer cancel() | ||
c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String()) | ||
if err == nil { | ||
defer c.Close() | ||
t.Fatal("failed to timeout") | ||
} | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters