Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: fix "unix" scheme handling for some corner cases #4021

Merged
merged 11 commits into from
Nov 30, 2020
13 changes: 7 additions & 6 deletions internal/grpcutil/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target)
ret.Scheme, ret.Endpoint, ok = split2(target, "://")
if !ok {
if strings.HasPrefix(target, "unix:") && !skipUnixColonParsing {
// Handle the "unix:[path]" case, because splitting on :// only
// handles the "unix://[/absolute/path]" case. Only handle if the
// dialer is nil, to avoid a behavior change with custom dialers.
// Handle the "unix:[local/path]" and "unix:[/absolute/path]" cases,
// because splitting on :// only handles the
// "unix://[/absolute/path]" case. Only handle if the dialer is nil,
// to avoid a behavior change with custom dialers.
return resolver.Target{Scheme: "unix", Endpoint: target[len("unix:"):]}
}
return resolver.Target{Endpoint: target}
Expand All @@ -59,9 +60,9 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target)
if !ok {
return resolver.Target{Endpoint: target}
}
if ret.Scheme == "unix" {
// Add the "/" back in the unix case, so the unix resolver receives the
// actual endpoint.
if ret.Scheme == "unix" && !skipUnixColonParsing && ret.Authority == "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this need to change? I think this can be reverted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted this.

// Add the "/" back in the "unix://[/absolute/path]" case, so the unix
// resolver receives the actual endpoint.
ret.Endpoint = "/" + ret.Endpoint
}
return ret
Expand Down
12 changes: 8 additions & 4 deletions internal/grpcutil/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,21 @@ func TestParseTargetString(t *testing.T) {
// If we can only parse part of the target.
{targetStr: "://", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "://"}},
{targetStr: "unix://domain", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix://domain"}},
{targetStr: "unix://a/b/c", want: resolver.Target{Scheme: "unix", Authority: "a", Endpoint: "b/c"}},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be Endpoint: "/b/c" since we have the 3-slashes form. It doesn't matter in practice since the authority is non-empty, so it will be an error either way, but that's how it should be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's "/b/c" now due to the revert of the +"/" change.

{targetStr: "a:b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:b"}},
{targetStr: "a/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a/b"}},
{targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}},
{targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}},
{targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}},

// Unix cases without custom dialer.
// unix:[local_path] and unix:[/absolute] have different behaviors with
// a custom dialer, to prevent behavior changes with custom dialers.
{targetStr: "unix:domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:domain"}},
{targetStr: "unix:/domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/domain"}},
// unix:[local_path], unix:[/absolute], and unix://[/absolute] have different
// behaviors with a custom dialer, to prevent behavior changes with custom dialers.
{targetStr: "unix:a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:a/b/c"}},
{targetStr: "unix:/a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/a/b/c"}},
{targetStr: "unix:///a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}},

{targetStr: "passthrough:///unix:///a/b/c", want: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}},
} {
got := ParseTarget(test.targetStr, false)
if got != test.want {
Expand Down
5 changes: 5 additions & 0 deletions internal/resolver/unix/unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package unix

import (
"fmt"

"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/resolver"
)
Expand All @@ -29,6 +31,9 @@ const scheme = "unix"
type builder struct{}

func (*builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
if target.Authority != "" {
return nil, fmt.Errorf("invalid (non-empty) authority: %v", target.Authority)
}
cc.UpdateState(resolver.State{Addresses: []resolver.Address{networktype.Set(resolver.Address{Addr: target.Endpoint}, "unix")}})
return &nopResolver{}, nil
}
Expand Down
17 changes: 12 additions & 5 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,24 @@ type http2Client struct {
}

func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
networkType := "tcp"
address := addr.Addr
n, ok := networktype.Get(addr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

networkType, ok := ... - the "tcp" default is no longer used (see below if !ok).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made this change.

if fn != nil {
return fn(ctx, addr.Addr)
if ok && n == "unix" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok is redundant - if !ok, n will be ""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got rid of redundant ok

return fn(ctx, "unix:///"+address)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deserves a comment. Something like: For backward compatibility, if the user dialed "unix:///path", the passthrough resolver would be used and the user's custom dialer would see "unix:///path". Re-add "unix://" here since we now support the "unix" scheme by default, which strips this prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this comment.

}
return fn(ctx, address)
}
networkType := "tcp"
if n, ok := networktype.Get(addr); ok {
if ok {
networkType = n
} else {
networkType, address = parseDialTarget(address)
}
if networkType == "tcp" && useProxy {
return proxyDial(ctx, addr.Addr, grpcUA)
return proxyDial(ctx, address, grpcUA)
}
return (&net.Dialer{}).DialContext(ctx, networkType, addr.Addr)
return (&net.Dialer{}).DialContext(ctx, networkType, address)
}

func isTemporary(err error) bool {
Expand Down
29 changes: 29 additions & 0 deletions internal/transport/http_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -598,3 +599,31 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f
}

// parseDialTarget returns the network and address to pass to dialer.
func parseDialTarget(target string) (string, string) {
net := "tcp"
m1 := strings.Index(target, ":")
m2 := strings.Index(target, ":/")
// handle unix:addr which will fail with url.Parse
if m1 >= 0 && m2 < 0 {
if n := target[0:m1]; n == "unix" {
return n, target[m1+1:]
}
}
if m2 >= 0 {
t, err := url.Parse(target)
if err != nil {
return net, target
}
scheme := t.Scheme
addr := t.Path
if scheme == "unix" {
if addr == "" {
addr = t.Host
}
return scheme, addr
}
}
return net, target
}
29 changes: 29 additions & 0 deletions internal/transport/http_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,32 @@ func (s) TestDecodeHeaderH2ErrCode(t *testing.T) {
})
}
}

func (s) TestParseDialTarget(t *testing.T) {
for _, test := range []struct {
target, wantNet, wantAddr string
}{
{"unix:a", "unix", "a"},
{"unix:a/b/c", "unix", "a/b/c"},
{"unix:/a", "unix", "/a"},
{"unix:/a/b/c", "unix", "/a/b/c"},
{"unix://a", "unix", "a"},
{"unix://a/b/c", "unix", "/b/c"},
{"unix:///a", "unix", "/a"},
{"unix:///a/b/c", "unix", "/a/b/c"},
{"unix:etcd:0", "unix", "etcd:0"},
{"unix:///tmp/unix-3", "unix", "/tmp/unix-3"},
{"unix://domain", "unix", "domain"},
{"unix://etcd:0", "unix", "etcd:0"},
{"unix:///etcd:0", "unix", "/etcd:0"},
{"passthrough://unix://domain", "tcp", "passthrough://unix://domain"},
{"https://google.com:443", "tcp", "https://google.com:443"},
{"dns:///google.com", "tcp", "dns:///google.com"},
{"/unix/socket/address", "tcp", "/unix/socket/address"},
} {
gotNet, gotAddr := parseDialTarget(test.target)
if gotNet != test.wantNet || gotAddr != test.wantAddr {
t.Errorf("parseDialTarget(%q) = %s, %s want %s, %s", test.target, gotNet, gotAddr, test.wantNet, test.wantAddr)
}
}
}