Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: fix "unix" scheme handling for some corner cases #4021

Merged
merged 11 commits into from
Nov 30, 2020
9 changes: 5 additions & 4 deletions internal/grpcutil/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target)
ret.Scheme, ret.Endpoint, ok = split2(target, "://")
if !ok {
if strings.HasPrefix(target, "unix:") && !skipUnixColonParsing {
// Handle the "unix:[path]" case, because splitting on :// only
// handles the "unix://[/absolute/path]" case. Only handle if the
// dialer is nil, to avoid a behavior change with custom dialers.
// Handle the "unix:[local/path]" and "unix:[/absolute/path]" cases,
// because splitting on :// only handles the
// "unix://[/absolute/path]" case. Only handle if the dialer is nil,
// to avoid a behavior change with custom dialers.
return resolver.Target{Scheme: "unix", Endpoint: target[len("unix:"):]}
}
return resolver.Target{Endpoint: target}
Expand All @@ -61,7 +62,7 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target)
}
if ret.Scheme == "unix" {
// Add the "/" back in the unix case, so the unix resolver receives the
// actual endpoint.
// actual endpoint in the "unix://[/absolute/path]" case.
ret.Endpoint = "/" + ret.Endpoint
}
return ret
Expand Down
12 changes: 8 additions & 4 deletions internal/grpcutil/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,21 @@ func TestParseTargetString(t *testing.T) {
// If we can only parse part of the target.
{targetStr: "://", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "://"}},
{targetStr: "unix://domain", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix://domain"}},
{targetStr: "unix://a/b/c", want: resolver.Target{Scheme: "unix", Authority: "a", Endpoint: "/b/c"}},
{targetStr: "a:b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:b"}},
{targetStr: "a/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a/b"}},
{targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}},
{targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}},
{targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}},

// Unix cases without custom dialer.
// unix:[local_path] and unix:[/absolute] have different behaviors with
// a custom dialer, to prevent behavior changes with custom dialers.
{targetStr: "unix:domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:domain"}},
{targetStr: "unix:/domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/domain"}},
// unix:[local_path], unix:[/absolute], and unix://[/absolute] have different
// behaviors with a custom dialer, to prevent behavior changes with custom dialers.
{targetStr: "unix:a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:a/b/c"}},
{targetStr: "unix:/a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/a/b/c"}},
{targetStr: "unix:///a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}},

{targetStr: "passthrough:///unix:///a/b/c", want: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}},
} {
got := ParseTarget(test.targetStr, false)
if got != test.want {
Expand Down
5 changes: 5 additions & 0 deletions internal/resolver/unix/unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package unix

import (
"fmt"

"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/resolver"
)
Expand All @@ -29,6 +31,9 @@ const scheme = "unix"
type builder struct{}

func (*builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
if target.Authority != "" {
return nil, fmt.Errorf("invalid (non-empty) authority: %v", target.Authority)
}
cc.UpdateState(resolver.State{Addresses: []resolver.Address{networktype.Set(resolver.Address{Addr: target.Endpoint}, "unix")}})
return &nopResolver{}, nil
}
Expand Down
21 changes: 15 additions & 6 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,26 @@ type http2Client struct {
}

func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
address := addr.Addr
networkType, ok := networktype.Get(addr)
if fn != nil {
return fn(ctx, addr.Addr)
if networkType == "unix" {
// For backward compatibility, if the user dialed "unix:///path",
// the passthrough resolver would be used and the user's custom
// dialer would see "unix:///path". Since the unix resolver is used
// and the address is now "/path", prepend "unix://" so the user's
// custom dialer sees the same address.
return fn(ctx, "unix://"+address)
}
return fn(ctx, address)
}
networkType := "tcp"
if n, ok := networktype.Get(addr); ok {
networkType = n
if !ok {
networkType, address = parseDialTarget(address)
}
if networkType == "tcp" && useProxy {
return proxyDial(ctx, addr.Addr, grpcUA)
return proxyDial(ctx, address, grpcUA)
}
return (&net.Dialer{}).DialContext(ctx, networkType, addr.Addr)
return (&net.Dialer{}).DialContext(ctx, networkType, address)
}

func isTemporary(err error) bool {
Expand Down
29 changes: 29 additions & 0 deletions internal/transport/http_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -598,3 +599,31 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f
}

// parseDialTarget returns the network and address to pass to dialer.
func parseDialTarget(target string) (string, string) {
net := "tcp"
m1 := strings.Index(target, ":")
m2 := strings.Index(target, ":/")
// handle unix:addr which will fail with url.Parse
if m1 >= 0 && m2 < 0 {
if n := target[0:m1]; n == "unix" {
return n, target[m1+1:]
}
}
if m2 >= 0 {
t, err := url.Parse(target)
if err != nil {
return net, target
}
scheme := t.Scheme
addr := t.Path
if scheme == "unix" {
if addr == "" {
addr = t.Host
}
return scheme, addr
}
}
return net, target
}
29 changes: 29 additions & 0 deletions internal/transport/http_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,32 @@ func (s) TestDecodeHeaderH2ErrCode(t *testing.T) {
})
}
}

func (s) TestParseDialTarget(t *testing.T) {
for _, test := range []struct {
target, wantNet, wantAddr string
}{
{"unix:a", "unix", "a"},
{"unix:a/b/c", "unix", "a/b/c"},
{"unix:/a", "unix", "/a"},
{"unix:/a/b/c", "unix", "/a/b/c"},
{"unix://a", "unix", "a"},
{"unix://a/b/c", "unix", "/b/c"},
{"unix:///a", "unix", "/a"},
{"unix:///a/b/c", "unix", "/a/b/c"},
{"unix:etcd:0", "unix", "etcd:0"},
{"unix:///tmp/unix-3", "unix", "/tmp/unix-3"},
{"unix://domain", "unix", "domain"},
{"unix://etcd:0", "unix", "etcd:0"},
{"unix:///etcd:0", "unix", "/etcd:0"},
{"passthrough://unix://domain", "tcp", "passthrough://unix://domain"},
{"https://google.com:443", "tcp", "https://google.com:443"},
{"dns:///google.com", "tcp", "dns:///google.com"},
{"/unix/socket/address", "tcp", "/unix/socket/address"},
} {
gotNet, gotAddr := parseDialTarget(test.target)
if gotNet != test.wantNet || gotAddr != test.wantAddr {
t.Errorf("parseDialTarget(%q) = %s, %s want %s, %s", test.target, gotNet, gotAddr, test.wantNet, test.wantAddr)
}
}
}
97 changes: 47 additions & 50 deletions test/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,35 +80,46 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer
}
}

type authorityTest struct {
name string
address string
target string
authority string
dialTargetWant string
}

var authorityTests = []authorityTest{
{
name: "UnixRelative",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "UnixAbsolute",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
{
name: "UnixAbsoluteAlternate",
address: "/tmp/sock.sock",
target: "unix:///tmp/sock.sock",
authority: "localhost",
},
{
name: "UnixPassthrough",
address: "/tmp/sock.sock",
target: "passthrough:///unix:///tmp/sock.sock",
authority: "unix:///tmp/sock.sock",
dialTargetWant: "unix:///tmp/sock.sock",
},
}

// TestUnix does end to end tests with the various supported unix target
// formats, ensuring that the authority is set to localhost in every case.
// formats, ensuring that the authority is set as expected.
func (s) TestUnix(t *testing.T) {
tests := []struct {
name string
address string
target string
authority string
}{
{
name: "UnixRelative",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "UnixAbsolute",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
{
name: "UnixAbsoluteAlternate",
address: "/tmp/sock.sock",
target: "unix:///tmp/sock.sock",
authority: "localhost",
},
}
for _, test := range tests {
for _, test := range authorityTests {
t.Run(test.name, func(t *testing.T) {
runUnixTest(t, test.address, test.target, test.authority, nil)
})
Expand All @@ -119,30 +130,14 @@ func (s) TestUnix(t *testing.T) {
// formats, ensuring that the target sent to the dialer does NOT have the
// "unix:" prefix stripped.
func (s) TestUnixCustomDialer(t *testing.T) {
tests := []struct {
name string
address string
target string
authority string
}{
{
name: "UnixRelative",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "UnixAbsolute",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, test := range authorityTests {
t.Run(test.name+"WithDialer", func(t *testing.T) {
if test.dialTargetWant == "" {
test.dialTargetWant = test.target
}
dialer := func(ctx context.Context, address string) (net.Conn, error) {
if address != test.target {
return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.target, address)
if address != test.dialTargetWant {
return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address)
}
address = address[len("unix:"):]
return (&net.Dialer{}).DialContext(ctx, "unix", address)
Expand All @@ -152,6 +147,8 @@ func (s) TestUnixCustomDialer(t *testing.T) {
}
}

// TestColonPortAuthority does an end to end test with the target for grpc.Dial
// being ":[port]". Ensures authority is "localhost:[port]".
func (s) TestColonPortAuthority(t *testing.T) {
expectedAuthority := ""
var authorityMu sync.Mutex
Expand Down