Skip to content

Commit

Permalink
client: fix "unix" scheme handling for some corner cases (#4021)
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettGutierrez1 authored and menghanl committed Jan 11, 2021
1 parent 553e4ef commit 39242ac
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 64 deletions.
9 changes: 5 additions & 4 deletions internal/grpcutil/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target)
ret.Scheme, ret.Endpoint, ok = split2(target, "://")
if !ok {
if strings.HasPrefix(target, "unix:") && !skipUnixColonParsing {
// Handle the "unix:[path]" case, because splitting on :// only
// handles the "unix://[/absolute/path]" case. Only handle if the
// dialer is nil, to avoid a behavior change with custom dialers.
// Handle the "unix:[local/path]" and "unix:[/absolute/path]" cases,
// because splitting on :// only handles the
// "unix://[/absolute/path]" case. Only handle if the dialer is nil,
// to avoid a behavior change with custom dialers.
return resolver.Target{Scheme: "unix", Endpoint: target[len("unix:"):]}
}
return resolver.Target{Endpoint: target}
Expand All @@ -61,7 +62,7 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target)
}
if ret.Scheme == "unix" {
// Add the "/" back in the unix case, so the unix resolver receives the
// actual endpoint.
// actual endpoint in the "unix://[/absolute/path]" case.
ret.Endpoint = "/" + ret.Endpoint
}
return ret
Expand Down
12 changes: 8 additions & 4 deletions internal/grpcutil/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,21 @@ func TestParseTargetString(t *testing.T) {
// If we can only parse part of the target.
{targetStr: "://", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "://"}},
{targetStr: "unix://domain", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix://domain"}},
{targetStr: "unix://a/b/c", want: resolver.Target{Scheme: "unix", Authority: "a", Endpoint: "/b/c"}},
{targetStr: "a:b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:b"}},
{targetStr: "a/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a/b"}},
{targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}},
{targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}},
{targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}},

// Unix cases without custom dialer.
// unix:[local_path] and unix:[/absolute] have different behaviors with
// a custom dialer, to prevent behavior changes with custom dialers.
{targetStr: "unix:domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:domain"}},
{targetStr: "unix:/domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/domain"}},
// unix:[local_path], unix:[/absolute], and unix://[/absolute] have different
// behaviors with a custom dialer, to prevent behavior changes with custom dialers.
{targetStr: "unix:a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:a/b/c"}},
{targetStr: "unix:/a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/a/b/c"}},
{targetStr: "unix:///a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}},

{targetStr: "passthrough:///unix:///a/b/c", want: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}},
} {
got := ParseTarget(test.targetStr, false)
if got != test.want {
Expand Down
5 changes: 5 additions & 0 deletions internal/resolver/unix/unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package unix

import (
"fmt"

"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/resolver"
)
Expand All @@ -29,6 +31,9 @@ const scheme = "unix"
type builder struct{}

func (*builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
if target.Authority != "" {
return nil, fmt.Errorf("invalid (non-empty) authority: %v", target.Authority)
}
cc.UpdateState(resolver.State{Addresses: []resolver.Address{networktype.Set(resolver.Address{Addr: target.Endpoint}, "unix")}})
return &nopResolver{}, nil
}
Expand Down
21 changes: 15 additions & 6 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,26 @@ type http2Client struct {
}

func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
address := addr.Addr
networkType, ok := networktype.Get(addr)
if fn != nil {
return fn(ctx, addr.Addr)
if networkType == "unix" {
// For backward compatibility, if the user dialed "unix:///path",
// the passthrough resolver would be used and the user's custom
// dialer would see "unix:///path". Since the unix resolver is used
// and the address is now "/path", prepend "unix://" so the user's
// custom dialer sees the same address.
return fn(ctx, "unix://"+address)
}
return fn(ctx, address)
}
networkType := "tcp"
if n, ok := networktype.Get(addr); ok {
networkType = n
if !ok {
networkType, address = parseDialTarget(address)
}
if networkType == "tcp" && useProxy {
return proxyDial(ctx, addr.Addr, grpcUA)
return proxyDial(ctx, address, grpcUA)
}
return (&net.Dialer{}).DialContext(ctx, networkType, addr.Addr)
return (&net.Dialer{}).DialContext(ctx, networkType, address)
}

func isTemporary(err error) bool {
Expand Down
29 changes: 29 additions & 0 deletions internal/transport/http_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -598,3 +599,31 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f
}

// parseDialTarget returns the network and address to pass to dialer.
func parseDialTarget(target string) (string, string) {
net := "tcp"
m1 := strings.Index(target, ":")
m2 := strings.Index(target, ":/")
// handle unix:addr which will fail with url.Parse
if m1 >= 0 && m2 < 0 {
if n := target[0:m1]; n == "unix" {
return n, target[m1+1:]
}
}
if m2 >= 0 {
t, err := url.Parse(target)
if err != nil {
return net, target
}
scheme := t.Scheme
addr := t.Path
if scheme == "unix" {
if addr == "" {
addr = t.Host
}
return scheme, addr
}
}
return net, target
}
29 changes: 29 additions & 0 deletions internal/transport/http_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,32 @@ func (s) TestDecodeHeaderH2ErrCode(t *testing.T) {
})
}
}

func (s) TestParseDialTarget(t *testing.T) {
for _, test := range []struct {
target, wantNet, wantAddr string
}{
{"unix:a", "unix", "a"},
{"unix:a/b/c", "unix", "a/b/c"},
{"unix:/a", "unix", "/a"},
{"unix:/a/b/c", "unix", "/a/b/c"},
{"unix://a", "unix", "a"},
{"unix://a/b/c", "unix", "/b/c"},
{"unix:///a", "unix", "/a"},
{"unix:///a/b/c", "unix", "/a/b/c"},
{"unix:etcd:0", "unix", "etcd:0"},
{"unix:///tmp/unix-3", "unix", "/tmp/unix-3"},
{"unix://domain", "unix", "domain"},
{"unix://etcd:0", "unix", "etcd:0"},
{"unix:///etcd:0", "unix", "/etcd:0"},
{"passthrough://unix://domain", "tcp", "passthrough://unix://domain"},
{"https://google.com:443", "tcp", "https://google.com:443"},
{"dns:///google.com", "tcp", "dns:///google.com"},
{"/unix/socket/address", "tcp", "/unix/socket/address"},
} {
gotNet, gotAddr := parseDialTarget(test.target)
if gotNet != test.wantNet || gotAddr != test.wantAddr {
t.Errorf("parseDialTarget(%q) = %s, %s want %s, %s", test.target, gotNet, gotAddr, test.wantNet, test.wantAddr)
}
}
}
97 changes: 47 additions & 50 deletions test/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,35 +80,46 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer
}
}

type authorityTest struct {
name string
address string
target string
authority string
dialTargetWant string
}

var authorityTests = []authorityTest{
{
name: "UnixRelative",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "UnixAbsolute",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
{
name: "UnixAbsoluteAlternate",
address: "/tmp/sock.sock",
target: "unix:///tmp/sock.sock",
authority: "localhost",
},
{
name: "UnixPassthrough",
address: "/tmp/sock.sock",
target: "passthrough:///unix:///tmp/sock.sock",
authority: "unix:///tmp/sock.sock",
dialTargetWant: "unix:///tmp/sock.sock",
},
}

// TestUnix does end to end tests with the various supported unix target
// formats, ensuring that the authority is set to localhost in every case.
// formats, ensuring that the authority is set as expected.
func (s) TestUnix(t *testing.T) {
tests := []struct {
name string
address string
target string
authority string
}{
{
name: "UnixRelative",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "UnixAbsolute",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
{
name: "UnixAbsoluteAlternate",
address: "/tmp/sock.sock",
target: "unix:///tmp/sock.sock",
authority: "localhost",
},
}
for _, test := range tests {
for _, test := range authorityTests {
t.Run(test.name, func(t *testing.T) {
runUnixTest(t, test.address, test.target, test.authority, nil)
})
Expand All @@ -119,30 +130,14 @@ func (s) TestUnix(t *testing.T) {
// formats, ensuring that the target sent to the dialer does NOT have the
// "unix:" prefix stripped.
func (s) TestUnixCustomDialer(t *testing.T) {
tests := []struct {
name string
address string
target string
authority string
}{
{
name: "UnixRelative",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "UnixAbsolute",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, test := range authorityTests {
t.Run(test.name+"WithDialer", func(t *testing.T) {
if test.dialTargetWant == "" {
test.dialTargetWant = test.target
}
dialer := func(ctx context.Context, address string) (net.Conn, error) {
if address != test.target {
return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.target, address)
if address != test.dialTargetWant {
return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address)
}
address = address[len("unix:"):]
return (&net.Dialer{}).DialContext(ctx, "unix", address)
Expand All @@ -152,6 +147,8 @@ func (s) TestUnixCustomDialer(t *testing.T) {
}
}

// TestColonPortAuthority does an end to end test with the target for grpc.Dial
// being ":[port]". Ensures authority is "localhost:[port]".
func (s) TestColonPortAuthority(t *testing.T) {
expectedAuthority := ""
var authorityMu sync.Mutex
Expand Down

0 comments on commit 39242ac

Please sign in to comment.