Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add retry logic to COUNT queries #1062

Merged
merged 1 commit into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StatusCode;
import com.google.api.gax.rpc.StreamController;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.v1.FirestoreSettings;
import com.google.firestore.v1.RunAggregationQueryRequest;
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.StructuredAggregationQuery;
import com.google.firestore.v1.Value;
import com.google.protobuf.ByteString;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -68,25 +71,68 @@ public ApiFuture<AggregateQuerySnapshot> get() {

@Nonnull
ApiFuture<AggregateQuerySnapshot> get(@Nullable final ByteString transactionId) {
RunAggregationQueryRequest request = toProto(transactionId);
AggregateQueryResponseObserver responseObserver = new AggregateQueryResponseObserver();
AggregateQueryResponseDeliverer responseDeliverer =
new AggregateQueryResponseDeliverer(
transactionId, /* startTimeNanos= */ query.rpcContext.getClock().nanoTime());
runQuery(responseDeliverer);
return responseDeliverer.getFuture();
}

private void runQuery(AggregateQueryResponseDeliverer responseDeliverer) {
RunAggregationQueryRequest request = toProto(responseDeliverer.getTransactionId());
AggregateQueryResponseObserver responseObserver =
new AggregateQueryResponseObserver(responseDeliverer);
ServerStreamingCallable<RunAggregationQueryRequest, RunAggregationQueryResponse> callable =
query.rpcContext.getClient().runAggregationQueryCallable();

query.rpcContext.streamRequest(request, responseObserver, callable);
}

private final class AggregateQueryResponseDeliverer {

@Nullable private final ByteString transactionId;
private final long startTimeNanos;
private final SettableApiFuture<AggregateQuerySnapshot> future = SettableApiFuture.create();
private final AtomicBoolean isFutureCompleted = new AtomicBoolean(false);

AggregateQueryResponseDeliverer(@Nullable ByteString transactionId, long startTimeNanos) {
this.transactionId = transactionId;
this.startTimeNanos = startTimeNanos;
}

ApiFuture<AggregateQuerySnapshot> getFuture() {
return future;
}

@Nullable
ByteString getTransactionId() {
return transactionId;
}

long getStartTimeNanos() {
return startTimeNanos;
}

void deliverResult(long count, Timestamp readTime) {
if (isFutureCompleted.compareAndSet(false, true)) {
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
}
}

return responseObserver.getFuture();
void deliverError(Throwable throwable) {
if (isFutureCompleted.compareAndSet(false, true)) {
future.setException(throwable);
}
}
}

private final class AggregateQueryResponseObserver
implements ResponseObserver<RunAggregationQueryResponse> {

private final SettableApiFuture<AggregateQuerySnapshot> future = SettableApiFuture.create();
private final AtomicBoolean isFutureNotified = new AtomicBoolean(false);
private final AggregateQueryResponseDeliverer responseDeliverer;
private StreamController streamController;

SettableApiFuture<AggregateQuerySnapshot> getFuture() {
return future;
AggregateQueryResponseObserver(AggregateQueryResponseDeliverer responseDeliverer) {
this.responseDeliverer = responseDeliverer;
}

@Override
Expand All @@ -96,14 +142,10 @@ public void onStart(StreamController streamController) {

@Override
public void onResponse(RunAggregationQueryResponse response) {
// Ignore subsequent response messages. The RunAggregationQuery RPC returns a stream of
// responses (rather than just a single response); however, only the first response of the
// stream is actually used. Any more responses are technically errors, but since the Future
// will have already been notified, we just drop any unexpected responses.
if (!isFutureNotified.compareAndSet(false, true)) {
return;
}
// Close the stream to avoid it dangling, since we're not expecting any more responses.
streamController.cancel();

// Extract the count and read time from the RunAggregationQueryResponse.
Timestamp readTime = Timestamp.fromProto(response.getReadTime());
Value value = response.getResult().getAggregateFieldsMap().get(ALIAS_COUNT);
if (value == null) {
Expand All @@ -118,19 +160,30 @@ public void onResponse(RunAggregationQueryResponse response) {
}
long count = value.getIntegerValue();

future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));

// Close the stream to avoid it dangling, since we're not expecting any more responses.
streamController.cancel();
// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
// that `onResponse()` can be called multiple times, it _should_ only be called once for count
// queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
// results.
responseDeliverer.deliverResult(count, readTime);
}

@Override
public void onError(Throwable throwable) {
if (!isFutureNotified.compareAndSet(false, true)) {
return;
if (shouldRetry(throwable)) {
runQuery(responseDeliverer);
} else {
responseDeliverer.deliverError(throwable);
}
}

future.setException(throwable);
private boolean shouldRetry(Throwable throwable) {
Set<StatusCode.Code> retryableCodes =
FirestoreSettings.newBuilder().runAggregationQuerySettings().getRetryableCodes();
return query.shouldRetryQuery(
throwable,
responseDeliverer.getTransactionId(),
responseDeliverer.getStartTimeNanos(),
retryableCodes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1672,27 +1672,15 @@ public void onComplete() {
}

boolean shouldRetry(DocumentSnapshot lastDocument, Throwable t) {
if (transactionId != null) {
// Transactional queries are retried via the transaction runner.
return false;
}

if (lastDocument == null) {
// Only retry if we have received a single result. Retries for RPCs with initial
// failure are handled by Google Gax, which also implements backoff.
return false;
}

if (!isRetryableError(t)) {
return false;
}

if (rpcContext.getTotalRequestTimeout().isZero()) {
return true;
}

Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
Set<StatusCode.Code> retryableCodes =
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
return shouldRetryQuery(t, transactionId, startTimeNanos, retryableCodes);
}
};

Expand Down Expand Up @@ -1831,21 +1819,42 @@ private <T> ImmutableList<T> append(ImmutableList<T> existingList, T newElement)
}

/** Verifies whether the given exception is retryable based on the RunQuery configuration. */
private boolean isRetryableError(Throwable throwable) {
private boolean isRetryableError(Throwable throwable, Set<StatusCode.Code> retryableCodes) {
if (!(throwable instanceof FirestoreException)) {
return false;
}
Set<StatusCode.Code> codes =
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
Status status = ((FirestoreException) throwable).getStatus();
for (StatusCode.Code code : codes) {
for (StatusCode.Code code : retryableCodes) {
if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) {
return true;
}
}
return false;
}

/** Returns whether a query that failed in the given scenario should be retried. */
boolean shouldRetryQuery(
Throwable throwable,
@Nullable ByteString transactionId,
long startTimeNanos,
Set<StatusCode.Code> retryableCodes) {
if (transactionId != null) {
// Transactional queries are retried via the transaction runner.
return false;
}

if (!isRetryableError(throwable, retryableCodes)) {
return false;
}

if (rpcContext.getTotalRequestTimeout().isZero()) {
return true;
}

Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
}

/**
* Returns a query that counts the documents in the result set of this query.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

import com.google.api.core.ApiClock;
import com.google.api.core.ApiFuture;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
Expand All @@ -39,14 +41,15 @@
import com.google.firestore.v1.RunAggregationQueryRequest;
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.StructuredQuery;
import io.grpc.Status;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.runners.MockitoJUnitRunner;
import org.threeten.bp.Duration;
Expand All @@ -58,7 +61,7 @@ public class QueryCountTest {
private final FirestoreImpl firestoreMock =
new FirestoreImpl(
FirestoreOptions.newBuilder().setProjectId("test-project").build(),
Mockito.mock(FirestoreRpc.class));
mock(FirestoreRpc.class));

@Captor private ArgumentCaptor<RunAggregationQueryRequest> runAggregationQuery;

Expand Down Expand Up @@ -211,6 +214,117 @@ public void aggregateQuerySnapshotGetQueryShouldReturnCorrectValue() throws Exce

AggregateQuery aggregateQuery = query.count();
AggregateQuerySnapshot snapshot = aggregateQuery.get().get();

assertThat(snapshot.getQuery()).isSameInstanceAs(aggregateQuery);
}

@Test
public void shouldNotRetryIfExceptionIsNotFirestoreException() {
doAnswer(aggregationQueryResponse(new NotFirestoreException()))
.doAnswer(aggregationQueryResponse())
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();

assertThrows(ExecutionException.class, future::get);
}

@Test
public void shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatus() throws Exception {
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();
AggregateQuerySnapshot snapshot = future.get();

assertThat(snapshot.getCount()).isEqualTo(42);
}

@Test
public void shouldNotRetryIfExceptionIsFirestoreExceptionWithNonRetryableStatus() {
doReturn(Duration.ZERO).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INVALID_ARGUMENT)))
.doAnswer(aggregationQueryResponse())
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();

assertThrows(ExecutionException.class, future::get);
}

@Test
public void
shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatusWithInfiniteTimeoutWindow()
throws Exception {
doReturn(Duration.ZERO).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();
AggregateQuerySnapshot snapshot = future.get();

assertThat(snapshot.getCount()).isEqualTo(42);
}

@Test
public void shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatusWithinTimeoutWindow()
throws Exception {
doReturn(Duration.ofDays(999)).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();
AggregateQuerySnapshot snapshot = future.get();

assertThat(snapshot.getCount()).isEqualTo(42);
}

@Test
public void
shouldNotRetryIfExceptionIsFirestoreExceptionWithRetryableStatusBeyondTimeoutWindow() {
ApiClock clockMock = mock(ApiClock.class);
doReturn(clockMock).when(firestoreMock).getClock();
doReturn(TimeUnit.SECONDS.toNanos(10))
.doReturn(TimeUnit.SECONDS.toNanos(20))
.doReturn(TimeUnit.SECONDS.toNanos(30))
.when(clockMock)
.nanoTime();
doReturn(Duration.ofSeconds(5)).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();

assertThrows(ExecutionException.class, future::get);
}

private static final class NotFirestoreException extends Exception {}
}