Skip to content

Commit

Permalink
feat: Add retry logic to COUNT queries (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
dconeybe authored and cherylEnkidu committed Dec 11, 2023
1 parent b4bf79f commit 50b7478
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StatusCode;
import com.google.api.gax.rpc.StreamController;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.v1.FirestoreSettings;
import com.google.firestore.v1.RunAggregationQueryRequest;
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.StructuredAggregationQuery;
import com.google.firestore.v1.Value;
import com.google.protobuf.ByteString;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -68,25 +71,68 @@ public ApiFuture<AggregateQuerySnapshot> get() {

@Nonnull
ApiFuture<AggregateQuerySnapshot> get(@Nullable final ByteString transactionId) {
RunAggregationQueryRequest request = toProto(transactionId);
AggregateQueryResponseObserver responseObserver = new AggregateQueryResponseObserver();
AggregateQueryResponseDeliverer responseDeliverer =
new AggregateQueryResponseDeliverer(
transactionId, /* startTimeNanos= */ query.rpcContext.getClock().nanoTime());
runQuery(responseDeliverer);
return responseDeliverer.getFuture();
}

private void runQuery(AggregateQueryResponseDeliverer responseDeliverer) {
RunAggregationQueryRequest request = toProto(responseDeliverer.getTransactionId());
AggregateQueryResponseObserver responseObserver =
new AggregateQueryResponseObserver(responseDeliverer);
ServerStreamingCallable<RunAggregationQueryRequest, RunAggregationQueryResponse> callable =
query.rpcContext.getClient().runAggregationQueryCallable();

query.rpcContext.streamRequest(request, responseObserver, callable);
}

private final class AggregateQueryResponseDeliverer {

@Nullable private final ByteString transactionId;
private final long startTimeNanos;
private final SettableApiFuture<AggregateQuerySnapshot> future = SettableApiFuture.create();
private final AtomicBoolean isFutureCompleted = new AtomicBoolean(false);

AggregateQueryResponseDeliverer(@Nullable ByteString transactionId, long startTimeNanos) {
this.transactionId = transactionId;
this.startTimeNanos = startTimeNanos;
}

ApiFuture<AggregateQuerySnapshot> getFuture() {
return future;
}

@Nullable
ByteString getTransactionId() {
return transactionId;
}

long getStartTimeNanos() {
return startTimeNanos;
}

void deliverResult(long count, Timestamp readTime) {
if (isFutureCompleted.compareAndSet(false, true)) {
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
}
}

return responseObserver.getFuture();
void deliverError(Throwable throwable) {
if (isFutureCompleted.compareAndSet(false, true)) {
future.setException(throwable);
}
}
}

private final class AggregateQueryResponseObserver
implements ResponseObserver<RunAggregationQueryResponse> {

private final SettableApiFuture<AggregateQuerySnapshot> future = SettableApiFuture.create();
private final AtomicBoolean isFutureNotified = new AtomicBoolean(false);
private final AggregateQueryResponseDeliverer responseDeliverer;
private StreamController streamController;

SettableApiFuture<AggregateQuerySnapshot> getFuture() {
return future;
AggregateQueryResponseObserver(AggregateQueryResponseDeliverer responseDeliverer) {
this.responseDeliverer = responseDeliverer;
}

@Override
Expand All @@ -96,14 +142,10 @@ public void onStart(StreamController streamController) {

@Override
public void onResponse(RunAggregationQueryResponse response) {
// Ignore subsequent response messages. The RunAggregationQuery RPC returns a stream of
// responses (rather than just a single response); however, only the first response of the
// stream is actually used. Any more responses are technically errors, but since the Future
// will have already been notified, we just drop any unexpected responses.
if (!isFutureNotified.compareAndSet(false, true)) {
return;
}
// Close the stream to avoid it dangling, since we're not expecting any more responses.
streamController.cancel();

// Extract the count and read time from the RunAggregationQueryResponse.
Timestamp readTime = Timestamp.fromProto(response.getReadTime());
Value value = response.getResult().getAggregateFieldsMap().get(ALIAS_COUNT);
if (value == null) {
Expand All @@ -118,19 +160,30 @@ public void onResponse(RunAggregationQueryResponse response) {
}
long count = value.getIntegerValue();

future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));

// Close the stream to avoid it dangling, since we're not expecting any more responses.
streamController.cancel();
// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
// that `onResponse()` can be called multiple times, it _should_ only be called once for count
// queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
// results.
responseDeliverer.deliverResult(count, readTime);
}

@Override
public void onError(Throwable throwable) {
if (!isFutureNotified.compareAndSet(false, true)) {
return;
if (shouldRetry(throwable)) {
runQuery(responseDeliverer);
} else {
responseDeliverer.deliverError(throwable);
}
}

future.setException(throwable);
private boolean shouldRetry(Throwable throwable) {
Set<StatusCode.Code> retryableCodes =
FirestoreSettings.newBuilder().runAggregationQuerySettings().getRetryableCodes();
return query.shouldRetryQuery(
throwable,
responseDeliverer.getTransactionId(),
responseDeliverer.getStartTimeNanos(),
retryableCodes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1673,27 +1673,15 @@ public void onComplete() {
}

boolean shouldRetry(DocumentSnapshot lastDocument, Throwable t) {
if (transactionId != null) {
// Transactional queries are retried via the transaction runner.
return false;
}

if (lastDocument == null) {
// Only retry if we have received a single result. Retries for RPCs with initial
// failure are handled by Google Gax, which also implements backoff.
return false;
}

if (!isRetryableError(t)) {
return false;
}

if (rpcContext.getTotalRequestTimeout().isZero()) {
return true;
}

Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
Set<StatusCode.Code> retryableCodes =
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
return shouldRetryQuery(t, transactionId, startTimeNanos, retryableCodes);
}
};

Expand Down Expand Up @@ -1832,21 +1820,42 @@ private <T> ImmutableList<T> append(ImmutableList<T> existingList, T newElement)
}

/** Verifies whether the given exception is retryable based on the RunQuery configuration. */
private boolean isRetryableError(Throwable throwable) {
private boolean isRetryableError(Throwable throwable, Set<StatusCode.Code> retryableCodes) {
if (!(throwable instanceof FirestoreException)) {
return false;
}
Set<StatusCode.Code> codes =
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
Status status = ((FirestoreException) throwable).getStatus();
for (StatusCode.Code code : codes) {
for (StatusCode.Code code : retryableCodes) {
if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) {
return true;
}
}
return false;
}

/** Returns whether a query that failed in the given scenario should be retried. */
boolean shouldRetryQuery(
Throwable throwable,
@Nullable ByteString transactionId,
long startTimeNanos,
Set<StatusCode.Code> retryableCodes) {
if (transactionId != null) {
// Transactional queries are retried via the transaction runner.
return false;
}

if (!isRetryableError(throwable, retryableCodes)) {
return false;
}

if (rpcContext.getTotalRequestTimeout().isZero()) {
return true;
}

Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
}

/**
* Returns a query that counts the documents in the result set of this query.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

import com.google.api.core.ApiClock;
import com.google.api.core.ApiFuture;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
Expand All @@ -39,14 +41,15 @@
import com.google.firestore.v1.RunAggregationQueryRequest;
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.StructuredQuery;
import io.grpc.Status;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.runners.MockitoJUnitRunner;
import org.threeten.bp.Duration;
Expand All @@ -58,7 +61,7 @@ public class QueryCountTest {
private final FirestoreImpl firestoreMock =
new FirestoreImpl(
FirestoreOptions.newBuilder().setProjectId("test-project").build(),
Mockito.mock(FirestoreRpc.class));
mock(FirestoreRpc.class));

@Captor private ArgumentCaptor<RunAggregationQueryRequest> runAggregationQuery;

Expand Down Expand Up @@ -211,6 +214,117 @@ public void aggregateQuerySnapshotGetQueryShouldReturnCorrectValue() throws Exce

AggregateQuery aggregateQuery = query.count();
AggregateQuerySnapshot snapshot = aggregateQuery.get().get();

assertThat(snapshot.getQuery()).isSameInstanceAs(aggregateQuery);
}

@Test
public void shouldNotRetryIfExceptionIsNotFirestoreException() {
doAnswer(aggregationQueryResponse(new NotFirestoreException()))
.doAnswer(aggregationQueryResponse())
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();

assertThrows(ExecutionException.class, future::get);
}

@Test
public void shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatus() throws Exception {
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();
AggregateQuerySnapshot snapshot = future.get();

assertThat(snapshot.getCount()).isEqualTo(42);
}

@Test
public void shouldNotRetryIfExceptionIsFirestoreExceptionWithNonRetryableStatus() {
doReturn(Duration.ZERO).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INVALID_ARGUMENT)))
.doAnswer(aggregationQueryResponse())
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();

assertThrows(ExecutionException.class, future::get);
}

@Test
public void
shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatusWithInfiniteTimeoutWindow()
throws Exception {
doReturn(Duration.ZERO).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();
AggregateQuerySnapshot snapshot = future.get();

assertThat(snapshot.getCount()).isEqualTo(42);
}

@Test
public void shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatusWithinTimeoutWindow()
throws Exception {
doReturn(Duration.ofDays(999)).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();
AggregateQuerySnapshot snapshot = future.get();

assertThat(snapshot.getCount()).isEqualTo(42);
}

@Test
public void
shouldNotRetryIfExceptionIsFirestoreExceptionWithRetryableStatusBeyondTimeoutWindow() {
ApiClock clockMock = mock(ApiClock.class);
doReturn(clockMock).when(firestoreMock).getClock();
doReturn(TimeUnit.SECONDS.toNanos(10))
.doReturn(TimeUnit.SECONDS.toNanos(20))
.doReturn(TimeUnit.SECONDS.toNanos(30))
.when(clockMock)
.nanoTime();
doReturn(Duration.ofSeconds(5)).when(firestoreMock).getTotalRequestTimeout();
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
.doAnswer(aggregationQueryResponse(42))
.when(firestoreMock)
.streamRequest(
runAggregationQuery.capture(),
streamObserverCapture.capture(),
Matchers.<ServerStreamingCallable>any());

ApiFuture<AggregateQuerySnapshot> future = query.count().get();

assertThrows(ExecutionException.class, future::get);
}

private static final class NotFirestoreException extends Exception {}
}

0 comments on commit 50b7478

Please sign in to comment.