diff --git a/retryafter/retryafter.go b/retryafter/retryafter.go index 2b81486..57ec814 100644 --- a/retryafter/retryafter.go +++ b/retryafter/retryafter.go @@ -14,6 +14,8 @@ import ( "time" ) +var now = time.Now + // Parse parses the backoff time specified in the Retry-After header if present. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After. // @@ -26,7 +28,7 @@ func Parse(retryAfter string, fallback time.Duration) time.Duration { if retryAfter == "" { return fallback } else if t, err := time.Parse(http.TimeFormat, retryAfter); err == nil { - return time.Until(t) + return t.Sub(now()) } else if seconds, err := strconv.Atoi(retryAfter); err == nil { return time.Duration(seconds) * time.Second } diff --git a/retryafter/retryafter_test.go b/retryafter/retryafter_test.go new file mode 100644 index 0000000..3234e0c --- /dev/null +++ b/retryafter/retryafter_test.go @@ -0,0 +1,52 @@ +// Copyright (c) 2021 Dillon Dixon +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package retryafter + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBackoffFromResponse(t *testing.T) { + currentTime := time.Now().Truncate(time.Second) + now = func() time.Time { + return currentTime + } + + defaultBackoff := time.Duration(123) + + for name, tt := range map[string]struct { + headerValue string + expected time.Duration + }{ + "AsDate": { + headerValue: currentTime.In(time.UTC).Add(5 * time.Hour).Format(http.TimeFormat), + expected: time.Duration(5) * time.Hour, + }, + "AsSeconds": { + headerValue: "12345", + expected: time.Duration(12345) * time.Second, + }, + "Missing": { + headerValue: "", + expected: defaultBackoff, + }, + "Bad": { + headerValue: "invalid", + expected: defaultBackoff, + }, + } { + t.Run(name, func(t *testing.T) { + parsed := Parse(tt.headerValue, defaultBackoff) + assert.Equal(t, tt.expected, parsed) + }) + } +}