diff --git a/google/cloud/internal/oauth2_authorized_user_credentials.cc b/google/cloud/internal/oauth2_authorized_user_credentials.cc index 49d97addda092..01783727111d1 100644 --- a/google/cloud/internal/oauth2_authorized_user_credentials.cc +++ b/google/cloud/internal/oauth2_authorized_user_credentials.cc @@ -103,8 +103,7 @@ AuthorizedUserCredentials::AuthorizedUserCredentials( StatusOr> AuthorizedUserCredentials::AuthorizationHeader() { std::unique_lock lock(mu_); - return refreshing_creds_.AuthorizationHeader(current_time_fn_(), - [this] { return Refresh(); }); + return refreshing_creds_.AuthorizationHeader([this] { return Refresh(); }); } StatusOr diff --git a/google/cloud/internal/oauth2_compute_engine_credentials.cc b/google/cloud/internal/oauth2_compute_engine_credentials.cc index d6b1f5446b73e..a1c3059a1cd31 100644 --- a/google/cloud/internal/oauth2_compute_engine_credentials.cc +++ b/google/cloud/internal/oauth2_compute_engine_credentials.cc @@ -101,8 +101,7 @@ ComputeEngineCredentials::ComputeEngineCredentials( StatusOr> ComputeEngineCredentials::AuthorizationHeader() { std::unique_lock lock(mu_); - return refreshing_creds_.AuthorizationHeader(current_time_fn_(), - [this] { return Refresh(); }); + return refreshing_creds_.AuthorizationHeader([this] { return Refresh(); }); } std::string ComputeEngineCredentials::AccountEmail() const { diff --git a/google/cloud/internal/oauth2_refreshing_credentials_wrapper.cc b/google/cloud/internal/oauth2_refreshing_credentials_wrapper.cc index 66421bc7d22af..ece3699cf8c3a 100644 --- a/google/cloud/internal/oauth2_refreshing_credentials_wrapper.cc +++ b/google/cloud/internal/oauth2_refreshing_credentials_wrapper.cc @@ -20,28 +20,28 @@ namespace cloud { namespace oauth2_internal { GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_BEGIN -bool RefreshingCredentialsWrapper::IsExpired( - std::chrono::system_clock::time_point now) const { - return now > (temporary_token_.expiration_time - - GoogleOAuthAccessTokenExpirationSlack()); +RefreshingCredentialsWrapper::RefreshingCredentialsWrapper( + CurrentTimeFn current_time_fn) + : current_time_fn_(std::move(current_time_fn)) {} + +bool RefreshingCredentialsWrapper::IsExpired() const { + return current_time_fn_() > (temporary_token_.expiration_time - + GoogleOAuthAccessTokenExpirationSlack()); } -bool RefreshingCredentialsWrapper::IsValid( - std::chrono::system_clock::time_point now) const { +bool RefreshingCredentialsWrapper::IsValid() const { return !temporary_token_.token.second.empty() && - now <= temporary_token_.expiration_time; + current_time_fn_() <= temporary_token_.expiration_time; } -bool RefreshingCredentialsWrapper::NeedsRefresh( - std::chrono::system_clock::time_point now) const { - return temporary_token_.token.second.empty() || IsExpired(now); +bool RefreshingCredentialsWrapper::NeedsRefresh() const { + return temporary_token_.token.second.empty() || IsExpired(); } StatusOr> RefreshingCredentialsWrapper::AuthorizationHeader( - std::chrono::system_clock::time_point now, RefreshFunctor refresh_fn) const { - if (!NeedsRefresh(now)) return temporary_token_.token; + if (!NeedsRefresh()) return temporary_token_.token; // If successful refreshing token, return it. Otherwise, return the current // token if it still has time left on it. If no valid token can be returned, @@ -51,7 +51,7 @@ RefreshingCredentialsWrapper::AuthorizationHeader( temporary_token_ = *std::move(new_token); return temporary_token_.token; } - if (IsValid(std::chrono::system_clock::now())) return temporary_token_.token; + if (IsValid()) return temporary_token_.token; return new_token.status(); } diff --git a/google/cloud/internal/oauth2_refreshing_credentials_wrapper.h b/google/cloud/internal/oauth2_refreshing_credentials_wrapper.h index 7c17c9885de80..793b9fb53febd 100644 --- a/google/cloud/internal/oauth2_refreshing_credentials_wrapper.h +++ b/google/cloud/internal/oauth2_refreshing_credentials_wrapper.h @@ -40,6 +40,17 @@ class RefreshingCredentialsWrapper { using RefreshFunctor = absl::FunctionRef< StatusOr()>; + using CurrentTimeFn = + std::function()>; + + /** + * Creates an instance of RefreshingCredentialsWrapper. + * + * @param current_time_fn a dependency injection point to fetch the current + * time. This should generally not be overridden except for testing. + */ + explicit RefreshingCredentialsWrapper( + CurrentTimeFn current_time_fn = std::chrono::system_clock::now); /** * Returns an Authorization header obtained by invoking `refresh_fn`. @@ -48,13 +59,12 @@ class RefreshingCredentialsWrapper { * or may not be called. */ StatusOr> AuthorizationHeader( - std::chrono::system_clock::time_point now, RefreshFunctor refresh_fn) const; /** * Returns whether the current access token should be considered valid. */ - bool IsValid(std::chrono::system_clock::time_point now) const; + bool IsValid() const; private: /** @@ -68,10 +78,11 @@ class RefreshingCredentialsWrapper { * may still return false. This helps prevent the case where an access token * expires between when it is obtained and when it is used. */ - bool IsExpired(std::chrono::system_clock::time_point now) const; - bool NeedsRefresh(std::chrono::system_clock::time_point now) const; + bool IsExpired() const; + bool NeedsRefresh() const; mutable TemporaryToken temporary_token_; + CurrentTimeFn current_time_fn_; }; GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END diff --git a/google/cloud/internal/oauth2_refreshing_credentials_wrapper_test.cc b/google/cloud/internal/oauth2_refreshing_credentials_wrapper_test.cc index 334b54b86bbff..bf425c0f9f2df 100644 --- a/google/cloud/internal/oauth2_refreshing_credentials_wrapper_test.cc +++ b/google/cloud/internal/oauth2_refreshing_credentials_wrapper_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "google/cloud/internal/oauth2_refreshing_credentials_wrapper.h" +#include "google/cloud/internal/oauth2_credential_constants.h" #include "google/cloud/testing_util/status_matchers.h" #include @@ -23,37 +24,48 @@ GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_BEGIN namespace { using ::google::cloud::testing_util::IsOk; +using ::std::chrono::minutes; +using ::std::chrono::seconds; using ::testing::Eq; +using ::testing::MockFunction; using ::testing::Not; TEST(RefreshingCredentialsWrapper, IsValid) { + auto now = std::chrono::system_clock::now(); + MockFunction + mock_current_time_fn; + EXPECT_CALL(mock_current_time_fn, Call()).WillOnce([&] { return now; }); + + RefreshingCredentialsWrapper w(mock_current_time_fn.AsStdFunction()); std::pair const auth_token = std::make_pair("Authorization", "foo"); - RefreshingCredentialsWrapper w; auto refresh_fn = [&]() -> StatusOr { RefreshingCredentialsWrapper::TemporaryToken token; token.token = auth_token; - token.expiration_time = - std::chrono::system_clock::now() + std::chrono::minutes(60); + token.expiration_time = now + minutes(60); return token; }; - auto token = - w.AuthorizationHeader(std::chrono::system_clock::now(), refresh_fn); - EXPECT_TRUE(w.IsValid(std::chrono::system_clock::now())); + auto token = w.AuthorizationHeader(refresh_fn); + EXPECT_TRUE(w.IsValid()); } TEST(RefreshingCredentialsWrapper, IsNotValid) { std::pair const auth_token = std::make_pair("Authorization", "foo"); RefreshingCredentialsWrapper w; - EXPECT_FALSE(w.IsValid(std::chrono::system_clock::now())); + EXPECT_FALSE(w.IsValid()); } TEST(RefreshingCredentialsWrapper, RefreshTokenSuccess) { + auto now = std::chrono::system_clock::now(); + MockFunction + mock_current_time_fn; + EXPECT_CALL(mock_current_time_fn, Call()).WillOnce([&] { return now; }); + + RefreshingCredentialsWrapper w(mock_current_time_fn.AsStdFunction()); std::pair const auth_token = std::make_pair("Authorization", "foo"); - RefreshingCredentialsWrapper w; // Test that we only call the refresh_fn on the first call to // AuthorizationHeader. @@ -65,79 +77,89 @@ TEST(RefreshingCredentialsWrapper, RefreshTokenSuccess) { [&]() -> StatusOr { RefreshingCredentialsWrapper::TemporaryToken token; token.token = auth_token; - token.expiration_time = - std::chrono::system_clock::now() + std::chrono::minutes(60); + token.expiration_time = now + minutes(60); return token; }); - auto token = w.AuthorizationHeader(std::chrono::system_clock::now(), - mock_refresh_fn.AsStdFunction()); + auto token = w.AuthorizationHeader(mock_refresh_fn.AsStdFunction()); ASSERT_THAT(token, IsOk()); EXPECT_THAT(*token, Eq(auth_token)); - token = w.AuthorizationHeader(std::chrono::system_clock::now(), - mock_refresh_fn.AsStdFunction()); + token = w.AuthorizationHeader(mock_refresh_fn.AsStdFunction()); ASSERT_THAT(token, IsOk()); EXPECT_THAT(*token, Eq(auth_token)); } TEST(RefreshingCredentialsWrapper, RefreshTokenFailure) { - RefreshingCredentialsWrapper w; auto refresh_fn = [&]() -> StatusOr { return Status(StatusCode::kInvalidArgument, {}, {}); }; - auto token = - w.AuthorizationHeader(std::chrono::system_clock::now(), refresh_fn); + RefreshingCredentialsWrapper w; + auto token = w.AuthorizationHeader(refresh_fn); EXPECT_THAT(token, Not(IsOk())); EXPECT_THAT(token.status().code(), Eq(StatusCode::kInvalidArgument)); } TEST(RefreshingCredentialsWrapper, RefreshTokenFailureValidToken) { + auto now = std::chrono::system_clock::now(); + auto expire_time = now + minutes(60); + + MockFunction + mock_current_time_fn; + EXPECT_CALL(mock_current_time_fn, Call()) + .WillOnce([&] { + return expire_time + GoogleOAuthAccessTokenExpirationSlack() + + seconds(10); + }) + .WillOnce([&] { return now; }); + std::pair const auth_token = std::make_pair("Authorization", "foo"); - RefreshingCredentialsWrapper w; + RefreshingCredentialsWrapper w(mock_current_time_fn.AsStdFunction()); auto refresh_fn = [&]() -> StatusOr { RefreshingCredentialsWrapper::TemporaryToken token; token.token = auth_token; - token.expiration_time = - std::chrono::system_clock::now() + std::chrono::seconds(60); + token.expiration_time = expire_time; return token; }; - auto token = - w.AuthorizationHeader(std::chrono::system_clock::now(), refresh_fn); + auto token = w.AuthorizationHeader(refresh_fn); auto failing_refresh_fn = [&]() -> StatusOr { return Status(StatusCode::kInvalidArgument, {}, {}); }; - token = w.AuthorizationHeader(std::chrono::system_clock::now(), - failing_refresh_fn); + token = w.AuthorizationHeader(failing_refresh_fn); ASSERT_THAT(token, IsOk()); EXPECT_THAT(*token, Eq(auth_token)); } TEST(RefreshingCredentialsWrapper, RefreshTokenFailureInvalidToken) { + auto now = std::chrono::system_clock::now(); + MockFunction + mock_current_time_fn; + auto expire_time = now + minutes(60); + EXPECT_CALL(mock_current_time_fn, Call()).Times(2).WillRepeatedly([&] { + return expire_time + GoogleOAuthAccessTokenExpirationSlack() + seconds(10); + }); + std::pair const auth_token = std::make_pair("Authorization", "foo"); - RefreshingCredentialsWrapper w; + RefreshingCredentialsWrapper w(mock_current_time_fn.AsStdFunction()); auto refresh_fn = [&]() -> StatusOr { RefreshingCredentialsWrapper::TemporaryToken token; token.token = auth_token; - token.expiration_time = - std::chrono::system_clock::now() - std::chrono::seconds(3600); + token.expiration_time = expire_time; return token; }; - auto token = - w.AuthorizationHeader(std::chrono::system_clock::now(), refresh_fn); + auto token = w.AuthorizationHeader(refresh_fn); auto failing_refresh_fn = [&]() -> StatusOr { return Status(StatusCode::kInvalidArgument, {}, {}); }; - token = w.AuthorizationHeader(std::chrono::system_clock::now(), - failing_refresh_fn); + token = w.AuthorizationHeader(failing_refresh_fn); EXPECT_THAT(token, Not(IsOk())); EXPECT_THAT(token.status().code(), Eq(StatusCode::kInvalidArgument)); } diff --git a/google/cloud/internal/oauth2_service_account_credentials.cc b/google/cloud/internal/oauth2_service_account_credentials.cc index 51d7aaa422a9a..63918e9721428 100644 --- a/google/cloud/internal/oauth2_service_account_credentials.cc +++ b/google/cloud/internal/oauth2_service_account_credentials.cc @@ -185,8 +185,7 @@ ServiceAccountCredentials::ServiceAccountCredentials( StatusOr> ServiceAccountCredentials::AuthorizationHeader() { std::unique_lock lock(mu_); - return refreshing_creds_.AuthorizationHeader(current_time_fn_(), - [this] { return Refresh(); }); + return refreshing_creds_.AuthorizationHeader([this] { return Refresh(); }); } StatusOr> ServiceAccountCredentials::SignBlob(