From f7c5e6a76259ec101fac0b1a6f62c05cc74c3d0c Mon Sep 17 00:00:00 2001 From: Hong Truong <41143+and1truong@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:49:11 +1000 Subject: [PATCH] DNS resolving with timeout (#6917) --- internal/resolver/dns/dns_resolver.go | 29 ++++++++----- internal/resolver/dns/dns_resolver_test.go | 43 +++++++++++++++++++ .../resolver/dns/fake_net_resolver_test.go | 5 ++- resolver/dns/dns_resolver.go | 18 ++++++++ 4 files changed, 84 insertions(+), 11 deletions(-) diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index b66dcb21327..abab35e250e 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -45,6 +45,13 @@ import ( // addresses from SRV records. Must not be changed after init time. var EnableSRVLookups = false +// ResolvingTimeout specifies the maximum duration for a DNS resolution request. +// If the timeout expires before a response is received, the request will be canceled. +// +// It is recommended to set this value at application startup. Avoid modifying this variable +// after initialization as it's not thread-safe for concurrent modification. +var ResolvingTimeout = 30 * time.Second + var logger = grpclog.Component("dns") func init() { @@ -221,18 +228,18 @@ func (d *dnsResolver) watcher() { } } -func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) { +func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) { if !EnableSRVLookups { return nil, nil } var newAddrs []resolver.Address - _, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host) + _, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host) if err != nil { err = handleDNSError(err, "SRV") // may become nil return nil, err } for _, s := range srvs { - lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target) + lbAddrs, err := d.resolver.LookupHost(ctx, s.Target) if err != nil { err = handleDNSError(err, "A") // may become nil if err == nil { @@ -269,8 +276,8 @@ func handleDNSError(err error, lookupType string) error { return err } -func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult { - ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host) +func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult { + ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host) if err != nil { if envconfig.TXTErrIgnore { return nil @@ -297,8 +304,8 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult { return d.cc.ParseServiceConfig(sc) } -func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { - addrs, err := d.resolver.LookupHost(d.ctx, d.host) +func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) { + addrs, err := d.resolver.LookupHost(ctx, d.host) if err != nil { err = handleDNSError(err, "A") return nil, err @@ -316,8 +323,10 @@ func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { } func (d *dnsResolver) lookup() (*resolver.State, error) { - srv, srvErr := d.lookupSRV() - addrs, hostErr := d.lookupHost() + ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout) + defer cancel() + srv, srvErr := d.lookupSRV(ctx) + addrs, hostErr := d.lookupHost(ctx) if hostErr != nil && (srvErr != nil || len(srv) == 0) { return nil, hostErr } @@ -327,7 +336,7 @@ func (d *dnsResolver) lookup() (*resolver.State, error) { state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv}) } if !d.disableServiceConfig { - state.ServiceConfig = d.lookupTXT() + state.ServiceConfig = d.lookupTXT(ctx) } return &state, nil } diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 1244edcb61c..498cf5b83e2 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -39,6 +39,7 @@ import ( dnsinternal "google.golang.org/grpc/internal/resolver/dns/internal" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" + dnspublic "google.golang.org/grpc/resolver/dns" "google.golang.org/grpc/serviceconfig" _ "google.golang.org/grpc" // To initialize internal.ParseServiceConfig @@ -1215,3 +1216,45 @@ func (s) TestReportError(t *testing.T) { } } } + +// Override the default dns.ResolvingTimeout with a test duration. +func overrideResolveTimeoutDuration(t *testing.T, dur time.Duration) { + t.Helper() + + origDur := dns.ResolvingTimeout + dnspublic.SetResolvingTimeout(dur) + + t.Cleanup(func() { dnspublic.SetResolvingTimeout(origDur) }) +} + +// Test verifies that the DNS resolver gets timeout error when net.Resolver +// takes too long to resolve a target. +func (s) TestResolveTimeout(t *testing.T) { + // Set DNS resolving timeout duration to 7ms + timeoutDur := 7 * time.Millisecond + overrideResolveTimeoutDuration(t, timeoutDur) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // We are trying to resolve hostname which takes infinity time to resolve. + const target = "infinity" + + // Define a testNetResolver with lookupHostCh, an unbuffered channel, + // so we can block the resolver until reaching timeout. + tr := &testNetResolver{ + lookupHostCh: testutils.NewChannelWithSize(0), + hostLookupTable: map[string][]string{target: {"1.2.3.4"}}, + } + overrideNetResolver(t, tr) + + _, _, errCh := buildResolverWithTestClientConn(t, target) + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for the DNS resolver to timeout") + case err := <-errCh: + if err == nil || !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf(`Expected to see Timeout error; got: %v`, err) + } + } +} diff --git a/internal/resolver/dns/fake_net_resolver_test.go b/internal/resolver/dns/fake_net_resolver_test.go index a3be31607b3..adc8b3d4e18 100644 --- a/internal/resolver/dns/fake_net_resolver_test.go +++ b/internal/resolver/dns/fake_net_resolver_test.go @@ -41,7 +41,9 @@ type testNetResolver struct { func (tr *testNetResolver) LookupHost(ctx context.Context, host string) ([]string, error) { if tr.lookupHostCh != nil { - tr.lookupHostCh.Send(nil) + if err := tr.lookupHostCh.SendContext(ctx, nil); err != nil { + return nil, err + } } tr.mu.Lock() @@ -50,6 +52,7 @@ func (tr *testNetResolver) LookupHost(ctx context.Context, host string) ([]strin if addrs, ok := tr.hostLookupTable[host]; ok { return addrs, nil } + return nil, &net.DNSError{ Err: "hostLookup error", Name: host, diff --git a/resolver/dns/dns_resolver.go b/resolver/dns/dns_resolver.go index 14aa6f20ae0..b54a3a3225d 100644 --- a/resolver/dns/dns_resolver.go +++ b/resolver/dns/dns_resolver.go @@ -24,10 +24,28 @@ package dns import ( + "time" + "google.golang.org/grpc/internal/resolver/dns" "google.golang.org/grpc/resolver" ) +// SetResolvingTimeout sets the maximum duration for DNS resolution requests. +// +// This function affects the global timeout used by all channels using the DNS +// name resolver scheme. +// +// It must be called only at application startup, before any gRPC calls are +// made. Modifying this value after initialization is not thread-safe. +// +// The default value is 30 seconds. Setting the timeout too low may result in +// premature timeouts during resolution, while setting it too high may lead to +// unnecessary delays in service discovery. Choose a value appropriate for your +// specific needs and network environment. +func SetResolvingTimeout(timeout time.Duration) { + dns.ResolvingTimeout = timeout +} + // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers. // // Deprecated: import grpc and use resolver.Get("dns") instead.