Skip to content

Commit

Permalink
fix: Resolve race condition reported in #692 (#1031)
Browse files Browse the repository at this point in the history
Co-authored-by: Johan Blumenberg <johan.blumenberg@gmail.com>
Co-authored-by: Igor Berntein <igorbernstein@google.com>
  • Loading branch information
3 people committed Oct 18, 2022
1 parent 43874fc commit 87a6606
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 17 deletions.
74 changes: 58 additions & 16 deletions oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java
Expand Up @@ -41,6 +41,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
Expand All @@ -60,7 +61,6 @@
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import javax.annotation.Nullable;

/** Base type for Credentials using OAuth2. */
Expand All @@ -77,7 +77,7 @@ public class OAuth2Credentials extends Credentials {
// byte[] is serializable, so the lock variable can be final
@VisibleForTesting final Object lock = new byte[0];
private volatile OAuthValue value = null;
@VisibleForTesting transient ListenableFutureTask<OAuthValue> refreshTask;
@VisibleForTesting transient RefreshTask refreshTask;

// Change listeners are not serialized
private transient List<CredentialsChangedListener> changeListeners;
Expand Down Expand Up @@ -258,16 +258,7 @@ public OAuthValue call() throws Exception {
}
});

task.addListener(
new Runnable() {
@Override
public void run() {
finishRefreshAsync(task);
}
},
MoreExecutors.directExecutor());

refreshTask = task;
refreshTask = new RefreshTask(task, new RefreshTaskListener(task));

return new AsyncRefreshResult(refreshTask, true);
}
Expand All @@ -290,7 +281,7 @@ private void finishRefreshAsync(ListenableFuture<OAuthValue> finishedTask) {
} catch (Exception e) {
// noop
} finally {
if (this.refreshTask == finishedTask) {
if (this.refreshTask != null && this.refreshTask.getTask() == finishedTask) {
this.refreshTask = null;
}
}
Expand All @@ -307,7 +298,7 @@ private void finishRefreshAsync(ListenableFuture<OAuthValue> finishedTask) {
* thread of whatever executor the async call used. This doesn't affect correctness and is
* extremely unlikely.
*/
private static <T> T unwrapDirectFuture(Future<T> future) throws IOException {
private static <T> T unwrapDirectFuture(ListenableFuture<T> future) throws IOException {
try {
return future.get();
} catch (InterruptedException e) {
Expand Down Expand Up @@ -567,10 +558,10 @@ public void onFailure(Throwable throwable) {
* task is newly created, it is the caller's responsibility to execute it.
*/
static class AsyncRefreshResult {
private final ListenableFutureTask<OAuthValue> task;
private final RefreshTask task;
private final boolean isNew;

AsyncRefreshResult(ListenableFutureTask<OAuthValue> task, boolean isNew) {
AsyncRefreshResult(RefreshTask task, boolean isNew) {
this.task = task;
this.isNew = isNew;
}
Expand All @@ -582,6 +573,57 @@ void executeIfNew(Executor executor) {
}
}

@VisibleForTesting
class RefreshTaskListener implements Runnable {
private ListenableFutureTask<OAuthValue> task;

RefreshTaskListener(ListenableFutureTask<OAuthValue> task) {
this.task = task;
}

@Override
public void run() {
finishRefreshAsync(task);
}
}

class RefreshTask extends AbstractFuture<OAuthValue> implements Runnable {
private final ListenableFutureTask<OAuthValue> task;
private final RefreshTaskListener listener;

RefreshTask(ListenableFutureTask<OAuthValue> task, RefreshTaskListener listener) {
this.task = task;
this.listener = listener;

// Update Credential state first
task.addListener(listener, MoreExecutors.directExecutor());

// Then notify the world
Futures.addCallback(
task,
new FutureCallback<OAuthValue>() {
@Override
public void onSuccess(OAuthValue result) {
RefreshTask.this.set(result);
}

@Override
public void onFailure(Throwable t) {
RefreshTask.this.setException(t);
}
},
MoreExecutors.directExecutor());
}

public ListenableFutureTask<OAuthValue> getTask() {
return this.task;
}

public void run() {
task.run();
}
}

public static class Builder {

private AccessToken accessToken;
Expand Down
Expand Up @@ -47,6 +47,8 @@
import com.google.auth.http.AuthHttpConstants;
import com.google.auth.oauth2.GoogleCredentialsTest.MockTokenServerTransportFactory;
import com.google.auth.oauth2.OAuth2Credentials.OAuthValue;
import com.google.auth.oauth2.OAuth2Credentials.RefreshTask;
import com.google.auth.oauth2.OAuth2Credentials.RefreshTaskListener;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListenableFutureTask;
Expand All @@ -58,6 +60,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
Expand Down Expand Up @@ -590,7 +593,7 @@ public AccessToken refreshAccessToken() {
creds.getRequestMetadata(CALL_URI, realExecutor, callback);
TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN);
assertNotNull(creds.refreshTask);
ListenableFutureTask<OAuthValue> refreshTask = creds.refreshTask;
RefreshTask refreshTask = creds.refreshTask;

// Fast forward to expiration, which will hang cause the callback to hang
testClock.setCurrentTime(clientExpired.toEpochMilli());
Expand Down Expand Up @@ -873,6 +876,91 @@ public void serialize() throws IOException, ClassNotFoundException {
assertSame(deserializedCredentials.clock, Clock.SYSTEM);
}

@Test
public void updateTokenValueBeforeWake() throws IOException, InterruptedException {
final SettableFuture<AccessToken> refreshedTokenFuture = SettableFuture.create();
AccessToken refreshedToken = new AccessToken("2/MkSJoj1xsli0AccessToken_NKPY2", null);
refreshedTokenFuture.set(refreshedToken);

final ListenableFutureTask<OAuthValue> task =
ListenableFutureTask.create(
new Callable<OAuthValue>() {
@Override
public OAuthValue call() throws Exception {
return OAuthValue.create(refreshedToken, new HashMap<>());
}
});

OAuth2Credentials creds =
new OAuth2Credentials() {
@Override
public AccessToken refreshAccessToken() {
synchronized (this) {
// Wake up the main thread. This is done now because the child thread (t) is known to
// have the refresh task. Now we want the main thread to wake up and create a future
// in order to wait for the refresh to complete.
this.notify();
}
RefreshTaskListener listener =
new RefreshTaskListener(task) {
@Override
public void run() {
try {
// Sleep before setting accessToken to new accessToken. Refresh should not
// complete before this, and the accessToken is `null` until it is.
Thread.sleep(300);
super.run();
} catch (Exception e) {
fail("Unexpected error. Exception: " + e);
}
}
};

this.refreshTask = new RefreshTask(task, listener);

try {
// Sleep for 100 milliseconds to give parent thread time to create a refresh future.
Thread.sleep(100);
return refreshedTokenFuture.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
};

Thread t =
new Thread(
new Runnable() {
@Override
public void run() {
try {
creds.refresh();
assertNotNull(creds.getAccessToken());
} catch (Exception e) {
fail("Unexpected error. Exception: " + e);
}
}
});
t.start();

synchronized (creds) {
// Grab a lock on creds object. This thread (the main thread) will wait here until the child
// thread (t) calls `notify` on the creds object.
creds.wait();
}

AccessToken token = creds.getAccessToken();
assertNull(token);

creds.refresh();
token = creds.getAccessToken();
// Token should never be NULL after a refresh that succeeded.
// Previously the token could be NULL due to an internal race condition between the future
// completing and the task listener updating the value of the access token.
assertNotNull(token);
t.join();
}

private void waitForRefreshTaskCompletion(OAuth2Credentials credentials)
throws TimeoutException, InterruptedException {
for (int i = 0; i < 100; i++) {
Expand Down

0 comments on commit 87a6606

Please sign in to comment.