Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent watch stream from emitting events after close. #1471

Merged
merged 12 commits into from
Nov 14, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ private void sendBatchLocked(final BulkCommitBatch batch, final boolean flush) {
bulkWriterExecutor);
} else {
long delayMs = rateLimiter.getNextRequestDelayMs(batch.getMutationsSize());
logger.log(Level.FINE, String.format("Backing off for %d seconds", delayMs / 1000));
logger.log(Level.FINE, () -> String.format("Backing off for %d seconds", delayMs / 1000));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't directly related to PR. But since I was adding logging...

I noticed this log was building a string, even if log level is not fine. A lambda allows lazy evaluation, such that production environments don't suffer work when logging is set to a more coarse level.

bulkWriterExecutor.schedule(
() -> {
synchronized (lock) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.ClientStream;
import com.google.api.gax.rpc.StreamController;
import com.google.firestore.v1.ListenRequest;
import java.util.function.Function;
import java.util.logging.Logger;

final class SuppressibleBidiStream<RequestT, ResponseT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a class level comment to explain what this class does and why it is neccessary.

implements BidiStreamObserver<RequestT, ResponseT> {

private final ClientStream<ListenRequest> stream;
private final BidiStreamObserver<RequestT, ResponseT> delegate;
private boolean silence = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the class is now call Suppressbile.., this should be isSuppressed?

private static final Logger LOGGER = Logger.getLogger(Watch.class.getName());

SuppressibleBidiStream(
BidiStreamObserver<RequestT, ResponseT> responseObserverT,
Function<BidiStreamObserver<RequestT, ResponseT>, ClientStream<ListenRequest>>
streamSupplier) {
this.delegate = responseObserverT;
stream = streamSupplier.apply(this);
}

public void send(ListenRequest request) {
LOGGER.info(stream.toString());
stream.send(request);
}

public void close() {
LOGGER.info(stream::toString);
stream.closeSend();
}

public void closeAndSilence() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, use closeAndSuppress instead.

LOGGER.info(stream::toString);
silence = true;
stream.closeSend();
}

@Override
public void onReady(ClientStream<RequestT> stream) {
if (silence) {
LOGGER.info(() -> String.format("Silenced: %s", stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the log message.

} else {
delegate.onReady(stream);
}
}

@Override
public void onStart(StreamController controller) {
if (silence) {
LOGGER.info(() -> String.format("Silenced: %s", stream));
} else {
delegate.onStart(controller);
}
}

@Override
public void onResponse(ResponseT response) {
if (silence) {
LOGGER.info(() -> String.format("Silenced: %s", stream));
} else {
delegate.onResponse(response);
}
}

@Override
public void onError(Throwable t) {
if (silence) {
LOGGER.info(() -> String.format("Silenced: %s", stream));
} else {
delegate.onError(t);
}
}

@Override
public void onComplete() {
if (silence) {
LOGGER.info(() -> String.format("Silenced: %s", stream));
} else {
delegate.onComplete();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
import javax.annotation.Nullable;

/**
Expand All @@ -59,7 +60,7 @@
* It synchronizes on its own instance so it is advisable not to use this class for external
* synchronization.
*/
class Watch implements BidiStreamObserver<ListenRequest, ListenResponse> {
final class Watch implements BidiStreamObserver<ListenRequest, ListenResponse> {
/**
* Target ID used by watch. Watch uses a fixed target id since we only support one target per
* stream. The actual target ID we use is arbitrary.
Expand All @@ -73,7 +74,7 @@ class Watch implements BidiStreamObserver<ListenRequest, ListenResponse> {
private final ExponentialRetryAlgorithm backoff;
private final Target target;
private TimedAttemptSettings nextAttempt;
private ClientStream<ListenRequest> stream;
private SuppressibleBidiStream<ListenRequest, ListenResponse> stream;

/** The sorted tree of DocumentSnapshots as sent in the last snapshot. */
private DocumentSet documentSet;
Expand Down Expand Up @@ -115,6 +116,8 @@ static class ChangeSet {
List<QueryDocumentSnapshot> updates = new ArrayList<>();
}

private static final Logger LOGGER = Logger.getLogger(Watch.class.getName());

/**
* @param firestore The Firestore Database client.
* @param query The query that is used to order the document snapshots returned by this watch.
Expand Down Expand Up @@ -246,7 +249,14 @@ && affectsTarget(change.getTargetIdsList(), WATCH_TARGET_ID)) {
changeMap.put(ResourcePath.create(listenResponse.getDocumentRemove().getDocument()), null);
break;
case FILTER:
if (listenResponse.getFilter().getCount() != currentSize()) {
int filterCount = listenResponse.getFilter().getCount();
int currentSize = currentSize();
if (filterCount != currentSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlining currentSize() seems better IMO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currentSize() hides a lot of computation:

  private int currentSize() {
    ChangeSet changeSet = extractChanges(Timestamp.now());
    return documentSet.size() + changeSet.adds.size() - changeSet.deletes.size();
  }
  private ChangeSet extractChanges(Timestamp readTime) {
    ChangeSet changeSet = new ChangeSet();

    for (Entry<ResourcePath, Document> change : changeMap.entrySet()) {
      if (change.getValue() == null) {
        if (documentSet.contains(change.getKey())) {
          changeSet.deletes.add(documentSet.getDocument(change.getKey()));
        }
        continue;
      }

      QueryDocumentSnapshot snapshot =
          QueryDocumentSnapshot.fromDocument(firestore, readTime, change.getValue());

      if (documentSet.contains(change.getKey())) {
        changeSet.updates.add(snapshot);
      } else {
        changeSet.adds.add(snapshot);
      }
    }

    return changeSet;
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant if (filterCount != currentSize())

LOGGER.info(
() ->
String.format(
"filter: count mismatch filter count %d != current size %d",
filterCount, currentSize));
// We need to remove all the current results.
resetDocs();
// The filter didn't match, so re-issue the query.
Expand Down Expand Up @@ -297,7 +307,7 @@ ListenerRegistration runWatch(
.execute(
() -> {
synchronized (Watch.this) {
stream.closeSend();
stream.close();
stream = null;
}
});
Expand All @@ -318,7 +328,7 @@ private void resetDocs() {
resumeToken = null;

for (DocumentSnapshot snapshot : documentSet) {
// Mark each document as deleted. If documents are not deleted, they will be send again by
// Mark each document as deleted. If documents are not deleted, they will be sent again by
// the server.
changeMap.put(snapshot.getReference().getResourcePath(), null);
}
Expand All @@ -329,7 +339,7 @@ private void resetDocs() {
/** Closes the stream and calls onError() if the stream is still active. */
private void closeStream(final Throwable throwable) {
if (stream != null) {
stream.closeSend();
stream.closeAndSilence();
stream = null;
}

Expand Down Expand Up @@ -371,7 +381,7 @@ private void maybeReopenStream(Throwable throwable) {
/** Helper to restart the outgoing stream to the backend. */
private void resetStream() {
if (stream != null) {
stream.closeSend();
stream.closeAndSilence();
stream = null;
}

Expand All @@ -398,7 +408,12 @@ private void initStream() {
nextAttempt = backoff.createNextAttempt(nextAttempt);

Tracing.getTracer().getCurrentSpan().addAnnotation(TraceUtil.SPAN_NAME_LISTEN);
stream = firestore.streamRequest(Watch.this, firestore.getClient().listenCallable());
stream =
new SuppressibleBidiStream<>(
Watch.this,
observer ->
firestore.streamRequest(
observer, firestore.getClient().listenCallable()));

ListenRequest.Builder request = ListenRequest.newBuilder();
request.setDatabase(firestore.getDatabaseName());
Expand Down Expand Up @@ -459,6 +474,7 @@ private void pushSnapshot(final Timestamp readTime, ByteString nextResumeToken)
if (!hasPushed || !changes.isEmpty()) {
final QuerySnapshot querySnapshot =
QuerySnapshot.withChanges(query, readTime, documentSet, changes);
LOGGER.info(querySnapshot.toString());
userCallbackExecutor.execute(() -> listener.onEvent(querySnapshot, null));
hasPushed = true;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.BidiStreamingCallable;
import com.google.api.gax.rpc.ClientStream;

public final class FirestoreSpy {

public final FirestoreImpl spy;
public BidiStreamObserver streamRequestBidiStreamObserver;

public FirestoreSpy(FirestoreOptions firestoreOptions) {
spy =
new FirestoreImpl(firestoreOptions) {
@Override
public <RequestT, ResponseT> ClientStream<RequestT> streamRequest(
BidiStreamObserver<RequestT, ResponseT> responseObserverT,
BidiStreamingCallable<RequestT, ResponseT> callable) {
streamRequestBidiStreamObserver = responseObserverT;
return super.streamRequest(responseObserverT, callable);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
import javax.annotation.Nullable;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
Expand All @@ -84,6 +86,9 @@

@RunWith(MockitoJUnitRunner.class)
public class WatchTest {

@Rule public Timeout timeout = new Timeout(1, TimeUnit.SECONDS);

/** The Target ID used by the Java Firestore SDK. */
private static final int TARGET_ID = 0x1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.cloud.firestore.Firestore;
import com.google.cloud.firestore.FirestoreOptions;
import com.google.cloud.firestore.FirestoreSpy;
import com.google.common.base.Preconditions;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -30,6 +31,8 @@
public abstract class ITBaseTest {
private static final Logger logger = Logger.getLogger(ITBaseTest.class.getName());
protected Firestore firestore;
protected FirestoreSpy firestoreSpy;
private FirestoreOptions firestoreOptions;

@Before
public void before() {
Expand All @@ -43,7 +46,8 @@ public void before() {
logger.log(Level.INFO, "Integration test using default database.");
}

firestore = optionsBuilder.build().getService();
firestoreOptions = optionsBuilder.build();
firestore = firestoreOptions.getService();
}

@After
Expand All @@ -53,5 +57,15 @@ public void after() throws Exception {
"Error instantiating Firestore. Check that the service account credentials were properly set.");
firestore.close();
firestore = null;
firestoreOptions = null;
firestoreSpy = null;
}

public FirestoreSpy useFirestoreSpy() {
if (firestoreSpy == null) {
firestoreSpy = new FirestoreSpy(firestoreOptions);
firestore = firestoreSpy.spy;
}
return firestoreSpy;
}
}