diff --git a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java index 3012990f3..114588d8b 100644 --- a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java @@ -41,6 +41,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -60,7 +61,6 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; -import java.util.concurrent.Future; import javax.annotation.Nullable; /** Base type for Credentials using OAuth2. */ @@ -77,7 +77,7 @@ public class OAuth2Credentials extends Credentials { // byte[] is serializable, so the lock variable can be final @VisibleForTesting final Object lock = new byte[0]; private volatile OAuthValue value = null; - @VisibleForTesting transient ListenableFutureTask refreshTask; + @VisibleForTesting transient RefreshTask refreshTask; // Change listeners are not serialized private transient List changeListeners; @@ -258,16 +258,7 @@ public OAuthValue call() throws Exception { } }); - task.addListener( - new Runnable() { - @Override - public void run() { - finishRefreshAsync(task); - } - }, - MoreExecutors.directExecutor()); - - refreshTask = task; + refreshTask = new RefreshTask(task, new RefreshTaskListener(task)); return new AsyncRefreshResult(refreshTask, true); } @@ -290,7 +281,7 @@ private void finishRefreshAsync(ListenableFuture finishedTask) { } catch (Exception e) { // noop } finally { - if (this.refreshTask == finishedTask) { + if (this.refreshTask != null && this.refreshTask.getTask() == finishedTask) { this.refreshTask = null; } } @@ -307,7 +298,7 @@ private void finishRefreshAsync(ListenableFuture finishedTask) { * thread of whatever executor the async call used. This doesn't affect correctness and is * extremely unlikely. */ - private static T unwrapDirectFuture(Future future) throws IOException { + private static T unwrapDirectFuture(ListenableFuture future) throws IOException { try { return future.get(); } catch (InterruptedException e) { @@ -567,10 +558,10 @@ public void onFailure(Throwable throwable) { * task is newly created, it is the caller's responsibility to execute it. */ static class AsyncRefreshResult { - private final ListenableFutureTask task; + private final RefreshTask task; private final boolean isNew; - AsyncRefreshResult(ListenableFutureTask task, boolean isNew) { + AsyncRefreshResult(RefreshTask task, boolean isNew) { this.task = task; this.isNew = isNew; } @@ -582,6 +573,57 @@ void executeIfNew(Executor executor) { } } + @VisibleForTesting + class RefreshTaskListener implements Runnable { + private ListenableFutureTask task; + + RefreshTaskListener(ListenableFutureTask task) { + this.task = task; + } + + @Override + public void run() { + finishRefreshAsync(task); + } + } + + class RefreshTask extends AbstractFuture implements Runnable { + private final ListenableFutureTask task; + private final RefreshTaskListener listener; + + RefreshTask(ListenableFutureTask task, RefreshTaskListener listener) { + this.task = task; + this.listener = listener; + + // Update Credential state first + task.addListener(listener, MoreExecutors.directExecutor()); + + // Then notify the world + Futures.addCallback( + task, + new FutureCallback() { + @Override + public void onSuccess(OAuthValue result) { + RefreshTask.this.set(result); + } + + @Override + public void onFailure(Throwable t) { + RefreshTask.this.setException(t); + } + }, + MoreExecutors.directExecutor()); + } + + public ListenableFutureTask getTask() { + return this.task; + } + + public void run() { + task.run(); + } + } + public static class Builder { private AccessToken accessToken; diff --git a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java index 10a7ee9d4..bc6046ea9 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java @@ -47,6 +47,8 @@ import com.google.auth.http.AuthHttpConstants; import com.google.auth.oauth2.GoogleCredentialsTest.MockTokenServerTransportFactory; import com.google.auth.oauth2.OAuth2Credentials.OAuthValue; +import com.google.auth.oauth2.OAuth2Credentials.RefreshTask; +import com.google.auth.oauth2.OAuth2Credentials.RefreshTaskListener; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFutureTask; @@ -58,6 +60,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Date; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -590,7 +593,7 @@ public AccessToken refreshAccessToken() { creds.getRequestMetadata(CALL_URI, realExecutor, callback); TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN); assertNotNull(creds.refreshTask); - ListenableFutureTask refreshTask = creds.refreshTask; + RefreshTask refreshTask = creds.refreshTask; // Fast forward to expiration, which will hang cause the callback to hang testClock.setCurrentTime(clientExpired.toEpochMilli()); @@ -873,6 +876,91 @@ public void serialize() throws IOException, ClassNotFoundException { assertSame(deserializedCredentials.clock, Clock.SYSTEM); } + @Test + public void updateTokenValueBeforeWake() throws IOException, InterruptedException { + final SettableFuture refreshedTokenFuture = SettableFuture.create(); + AccessToken refreshedToken = new AccessToken("2/MkSJoj1xsli0AccessToken_NKPY2", null); + refreshedTokenFuture.set(refreshedToken); + + final ListenableFutureTask task = + ListenableFutureTask.create( + new Callable() { + @Override + public OAuthValue call() throws Exception { + return OAuthValue.create(refreshedToken, new HashMap<>()); + } + }); + + OAuth2Credentials creds = + new OAuth2Credentials() { + @Override + public AccessToken refreshAccessToken() { + synchronized (this) { + // Wake up the main thread. This is done now because the child thread (t) is known to + // have the refresh task. Now we want the main thread to wake up and create a future + // in order to wait for the refresh to complete. + this.notify(); + } + RefreshTaskListener listener = + new RefreshTaskListener(task) { + @Override + public void run() { + try { + // Sleep before setting accessToken to new accessToken. Refresh should not + // complete before this, and the accessToken is `null` until it is. + Thread.sleep(300); + super.run(); + } catch (Exception e) { + fail("Unexpected error. Exception: " + e); + } + } + }; + + this.refreshTask = new RefreshTask(task, listener); + + try { + // Sleep for 100 milliseconds to give parent thread time to create a refresh future. + Thread.sleep(100); + return refreshedTokenFuture.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + + Thread t = + new Thread( + new Runnable() { + @Override + public void run() { + try { + creds.refresh(); + assertNotNull(creds.getAccessToken()); + } catch (Exception e) { + fail("Unexpected error. Exception: " + e); + } + } + }); + t.start(); + + synchronized (creds) { + // Grab a lock on creds object. This thread (the main thread) will wait here until the child + // thread (t) calls `notify` on the creds object. + creds.wait(); + } + + AccessToken token = creds.getAccessToken(); + assertNull(token); + + creds.refresh(); + token = creds.getAccessToken(); + // Token should never be NULL after a refresh that succeeded. + // Previously the token could be NULL due to an internal race condition between the future + // completing and the task listener updating the value of the access token. + assertNotNull(token); + t.join(); + } + private void waitForRefreshTaskCompletion(OAuth2Credentials credentials) throws TimeoutException, InterruptedException { for (int i = 0; i < 100; i++) {