diff --git a/fetch/fetcher.go b/fetch/fetcher.go index f57c2b5..4a6fc9e 100644 --- a/fetch/fetcher.go +++ b/fetch/fetcher.go @@ -18,20 +18,21 @@ import ( ) const ( - dnsRefreshInterval = 5 * time.Minute - dialTimeout = 30 * time.Second - dialKeepAlive = 30 * time.Second - httpClientTimeout = 5 * time.Minute - maxIdleConns = 100 - maxIdleConnsPerHost = 10 - idleConnTimeout = 90 * time.Second - tlsHandshakeTimeout = 10 * time.Second - defaultMaxRetries = 3 - defaultBaseDelay = 500 * time.Millisecond - backoffBase = 2 - jitterFactor = 0.1 - serverErrThreshold = 500 - maxErrBodySize = 1024 + dnsRefreshInterval = 5 * time.Minute + dialTimeout = 30 * time.Second + dialKeepAlive = 30 * time.Second + httpClientTimeout = 5 * time.Minute + responseHeaderTimeout = 60 * time.Second + maxIdleConns = 100 + maxIdleConnsPerHost = 10 + idleConnTimeout = 90 * time.Second + tlsHandshakeTimeout = 10 * time.Second + defaultMaxRetries = 3 + defaultBaseDelay = 500 * time.Millisecond + backoffBase = 2 + jitterFactor = 0.1 + serverErrThreshold = 500 + maxErrBodySize = 1024 ) var ( @@ -43,7 +44,7 @@ var ( // Artifact contains the response from fetching an upstream artifact. type Artifact struct { Body io.ReadCloser - Size int64 // -1 if unknown + Size int64 // -1 if unknown ContentType string ETag string } @@ -62,6 +63,7 @@ type Fetcher struct { maxRetries int baseDelay time.Duration authFn func(url string) (headerName, headerValue string) + stop chan struct{} } // Option configures a Fetcher. @@ -105,18 +107,23 @@ func WithAuthFunc(fn func(url string) (headerName, headerValue string)) Option { } // NewFetcher creates a new Fetcher with the given options. +// Callers should invoke Close when done to release the DNS refresh goroutine. func NewFetcher(opts ...Option) *Fetcher { - // Create DNS cache with 5 minute refresh interval resolver := &dnscache.Resolver{} + stop := make(chan struct{}) go func() { ticker := time.NewTicker(dnsRefreshInterval) defer ticker.Stop() - for range ticker.C { - resolver.Refresh(true) + for { + select { + case <-ticker.C: + resolver.Refresh(true) + case <-stop: + return + } } }() - // Create custom dialer with DNS caching dialer := &net.Dialer{ Timeout: dialTimeout, KeepAlive: dialKeepAlive, @@ -126,6 +133,7 @@ func NewFetcher(opts ...Option) *Fetcher { client: &http.Client{ Timeout: httpClientTimeout, Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { @@ -135,24 +143,31 @@ func NewFetcher(opts ...Option) *Fetcher { if err != nil { return nil, err } + var lastErr error for _, ip := range ips { conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) if err == nil { return conn, nil } + lastErr = err } - return nil, fmt.Errorf("failed to dial any resolved IP") + if lastErr == nil { + return nil, fmt.Errorf("no IPs resolved for %s", host) + } + return nil, fmt.Errorf("dialing %s: %w", host, lastErr) }, MaxIdleConns: maxIdleConns, MaxIdleConnsPerHost: maxIdleConnsPerHost, IdleConnTimeout: idleConnTimeout, TLSHandshakeTimeout: tlsHandshakeTimeout, + ResponseHeaderTimeout: responseHeaderTimeout, ExpectContinueTimeout: 1 * time.Second, }, }, userAgent: "git-pkgs-proxy/1.0", maxRetries: defaultMaxRetries, baseDelay: defaultBaseDelay, + stop: stop, } for _, opt := range opts { opt(f) @@ -160,6 +175,20 @@ func NewFetcher(opts ...Option) *Fetcher { return f } +// Close stops the Fetcher's background DNS refresh goroutine. +// It is safe to call Close more than once. +func (f *Fetcher) Close() error { + if f.stop == nil { + return nil + } + select { + case <-f.stop: + default: + close(f.stop) + } + return nil +} + // Fetch downloads an artifact from the given URL. // The caller must close the returned Artifact.Body when done. func (f *Fetcher) Fetch(ctx context.Context, url string) (*Artifact, error) { diff --git a/fetch/fetcher_test.go b/fetch/fetcher_test.go index 0d6c29f..bd3cb09 100644 --- a/fetch/fetcher_test.go +++ b/fetch/fetcher_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" "time" @@ -377,3 +378,34 @@ func TestFetchDNSCaching(t *testing.T) { t.Errorf("requestCount = %d, want 3", requestCount) } } + +func TestFetcherHonorsProxyEnvironment(t *testing.T) { + f := NewFetcher() + defer func() { _ = f.Close() }() + + transport, ok := f.client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Transport = %T, want *http.Transport", f.client.Transport) + } + if transport.Proxy == nil { + t.Fatal("Transport.Proxy is nil; HTTP_PROXY/HTTPS_PROXY/NO_PROXY env vars would be ignored") + } + want := reflect.ValueOf(http.ProxyFromEnvironment).Pointer() + got := reflect.ValueOf(transport.Proxy).Pointer() + if got != want { + t.Errorf("Transport.Proxy is not http.ProxyFromEnvironment") + } + if transport.ResponseHeaderTimeout == 0 { + t.Error("Transport.ResponseHeaderTimeout is 0; hung upstreams will only fail at the overall client timeout") + } +} + +func TestFetcherCloseStopsGoroutine(t *testing.T) { + f := NewFetcher() + if err := f.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if err := f.Close(); err != nil { + t.Errorf("second Close returned error: %v", err) + } +}