Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent watch stream from emitting events after close. #1471

Merged
merged 12 commits into from
Nov 14, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ private void sendBatchLocked(final BulkCommitBatch batch, final boolean flush) {
bulkWriterExecutor);
} else {
long delayMs = rateLimiter.getNextRequestDelayMs(batch.getMutationsSize());
logger.log(Level.FINE, String.format("Backing off for %d seconds", delayMs / 1000));
logger.log(Level.FINE, () -> String.format("Backing off for %d seconds", delayMs / 1000));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't directly related to PR. But since I was adding logging...

I noticed this log was building a string, even if log level is not fine. A lambda allows lazy evaluation, such that production environments don't suffer work when logging is set to a more coarse level.

bulkWriterExecutor.schedule(
() -> {
synchronized (lock) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.google.cloud.firestore;

import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.ClientStream;
import com.google.api.gax.rpc.StreamController;
import com.google.firestore.v1.ListenRequest;
import java.util.function.Function;
import java.util.logging.Logger;

public class SuppressibleBidiStream<RequestT, ResponseT> implements BidiStreamObserver<RequestT, ResponseT> {

private final ClientStream<ListenRequest> stream;
private final BidiStreamObserver<RequestT, ResponseT> delegate;
private boolean silence = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the class is now call Suppressbile.., this should be isSuppressed?

private static final Logger LOGGER = Logger.getLogger(Watch.class.getName());

SuppressibleBidiStream(
BidiStreamObserver<RequestT, ResponseT> responseObserverT,
Function<BidiStreamObserver<RequestT, ResponseT>, ClientStream<ListenRequest>> streamSupplier
) {
this.delegate = responseObserverT;
stream = streamSupplier.apply(this);
}

public void send(ListenRequest request) {
LOGGER.info(stream.toString());
stream.send(request);
}

public void close() {
LOGGER.info(stream.toString());
stream.closeSend();
}

public void closeAndSilence() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, use closeAndSuppress instead.

LOGGER.info(stream.toString());
silence = true;
stream.closeSend();
}

@Override
public void onReady(ClientStream<RequestT> stream) {
if (silence) {
LOGGER.info(String.format("Silenced: %s", stream));
} else {
delegate.onReady(stream);
}
}

@Override
public void onStart(StreamController controller) {
if (silence) {
LOGGER.info(String.format("Silenced: %s", stream));
} else {
delegate.onStart(controller);
}
}

@Override
public void onResponse(ResponseT response) {
if (silence) {
LOGGER.info(String.format("Silenced: %s", stream));
} else {
delegate.onResponse(response);
}
}

@Override
public void onError(Throwable t) {
if (silence) {
LOGGER.info(String.format("Silenced: %s", stream));
} else {
delegate.onError(t);
}
}

@Override
public void onComplete() {
if (silence) {
LOGGER.info(String.format("Silenced: %s", stream));
} else {
delegate.onComplete();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
import javax.annotation.Nullable;

/**
Expand All @@ -73,7 +74,7 @@ class Watch implements BidiStreamObserver<ListenRequest, ListenResponse> {
private final ExponentialRetryAlgorithm backoff;
private final Target target;
private TimedAttemptSettings nextAttempt;
private ClientStream<ListenRequest> stream;
private SuppressibleBidiStream<ListenRequest, ListenResponse> stream;

/** The sorted tree of DocumentSnapshots as sent in the last snapshot. */
private DocumentSet documentSet;
Expand Down Expand Up @@ -115,6 +116,8 @@ static class ChangeSet {
List<QueryDocumentSnapshot> updates = new ArrayList<>();
}

private static final Logger LOGGER = Logger.getLogger(Watch.class.getName());

/**
* @param firestore The Firestore Database client.
* @param query The query that is used to order the document snapshots returned by this watch.
Expand Down Expand Up @@ -246,7 +249,10 @@ && affectsTarget(change.getTargetIdsList(), WATCH_TARGET_ID)) {
changeMap.put(ResourcePath.create(listenResponse.getDocumentRemove().getDocument()), null);
break;
case FILTER:
if (listenResponse.getFilter().getCount() != currentSize()) {
int filterCount = listenResponse.getFilter().getCount();
int currentSize = currentSize();
if (filterCount != currentSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlining currentSize() seems better IMO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currentSize() hides a lot of computation:

  private int currentSize() {
    ChangeSet changeSet = extractChanges(Timestamp.now());
    return documentSet.size() + changeSet.adds.size() - changeSet.deletes.size();
  }
  private ChangeSet extractChanges(Timestamp readTime) {
    ChangeSet changeSet = new ChangeSet();

    for (Entry<ResourcePath, Document> change : changeMap.entrySet()) {
      if (change.getValue() == null) {
        if (documentSet.contains(change.getKey())) {
          changeSet.deletes.add(documentSet.getDocument(change.getKey()));
        }
        continue;
      }

      QueryDocumentSnapshot snapshot =
          QueryDocumentSnapshot.fromDocument(firestore, readTime, change.getValue());

      if (documentSet.contains(change.getKey())) {
        changeSet.updates.add(snapshot);
      } else {
        changeSet.adds.add(snapshot);
      }
    }

    return changeSet;
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant if (filterCount != currentSize())

LOGGER.info(() -> String.format("filter: count mismatch filter count %d != current size %d", filterCount, currentSize));
// We need to remove all the current results.
resetDocs();
// The filter didn't match, so re-issue the query.
Expand Down Expand Up @@ -297,7 +303,7 @@ ListenerRegistration runWatch(
.execute(
() -> {
synchronized (Watch.this) {
stream.closeSend();
stream.close();
stream = null;
}
});
Expand All @@ -318,7 +324,7 @@ private void resetDocs() {
resumeToken = null;

for (DocumentSnapshot snapshot : documentSet) {
// Mark each document as deleted. If documents are not deleted, they will be send again by
// Mark each document as deleted. If documents are not deleted, they will be sent again by
// the server.
changeMap.put(snapshot.getReference().getResourcePath(), null);
}
Expand All @@ -329,7 +335,7 @@ private void resetDocs() {
/** Closes the stream and calls onError() if the stream is still active. */
private void closeStream(final Throwable throwable) {
if (stream != null) {
stream.closeSend();
stream.closeAndSilence();
stream = null;
}

Expand Down Expand Up @@ -371,7 +377,7 @@ private void maybeReopenStream(Throwable throwable) {
/** Helper to restart the outgoing stream to the backend. */
private void resetStream() {
if (stream != null) {
stream.closeSend();
stream.closeAndSilence();
stream = null;
}

Expand All @@ -398,7 +404,10 @@ private void initStream() {
nextAttempt = backoff.createNextAttempt(nextAttempt);

Tracing.getTracer().getCurrentSpan().addAnnotation(TraceUtil.SPAN_NAME_LISTEN);
stream = firestore.streamRequest(Watch.this, firestore.getClient().listenCallable());
stream = new SuppressibleBidiStream<>(
Watch.this,
observer -> firestore.streamRequest(observer, firestore.getClient().listenCallable())
);

ListenRequest.Builder request = ListenRequest.newBuilder();
request.setDatabase(firestore.getDatabaseName());
Expand Down Expand Up @@ -459,6 +468,7 @@ private void pushSnapshot(final Timestamp readTime, ByteString nextResumeToken)
if (!hasPushed || !changes.isEmpty()) {
final QuerySnapshot querySnapshot =
QuerySnapshot.withChanges(query, readTime, documentSet, changes);
LOGGER.info(querySnapshot.toString());
userCallbackExecutor.execute(() -> listener.onEvent(querySnapshot, null));
hasPushed = true;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.google.cloud.firestore;

import static org.mockito.Mockito.doCallRealMethod;

import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.BidiStreamingCallable;
import com.google.firestore.v1.ListenRequest;
import com.google.firestore.v1.ListenResponse;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

public final class FirestoreSpy {

public final FirestoreImpl spy;
public final ArgumentCaptor<BidiStreamObserver<ListenRequest, ListenResponse>> streamRequestBidiStreamObserverCaptor;

public FirestoreSpy(Firestore firestore) {
spy = Mockito.spy((FirestoreImpl) firestore);
streamRequestBidiStreamObserverCaptor = ArgumentCaptor.forClass(BidiStreamObserver.class);
doCallRealMethod()
.when(spy)
.streamRequest(
streamRequestBidiStreamObserverCaptor.capture(),
ArgumentMatchers.<BidiStreamingCallable>any()
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
import javax.annotation.Nullable;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
Expand All @@ -84,6 +86,10 @@

@RunWith(MockitoJUnitRunner.class)
public class WatchTest {

@Rule
public Timeout timeout = new Timeout(1, TimeUnit.SECONDS);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While working through problem, I introduced bug that made unit tests hang. To make sure tests completed, I added timeout to fail test if they don't run within 1 second.


/** The Target ID used by the Java Firestore SDK. */
private static final int TARGET_ID = 0x1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.firestore.Firestore;
import com.google.cloud.firestore.FirestoreOptions;
import com.google.cloud.firestore.FirestoreSpy;
import com.google.common.base.Preconditions;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -30,6 +31,7 @@
public abstract class ITBaseTest {
private static final Logger logger = Logger.getLogger(ITBaseTest.class.getName());
protected Firestore firestore;
protected FirestoreSpy firestoreSpy;

@Before
public void before() {
Expand All @@ -53,5 +55,14 @@ public void after() throws Exception {
"Error instantiating Firestore. Check that the service account credentials were properly set.");
firestore.close();
firestore = null;
firestoreSpy = null;
}

public FirestoreSpy useFirestoreSpy() {
if (firestoreSpy == null) {
firestoreSpy = new FirestoreSpy(firestore);
firestore = firestoreSpy.spy;
}
return firestoreSpy;
}
}