Skip to content

Commit

Permalink
fix: stop / start stream after filter mismatch (#502)
Browse files Browse the repository at this point in the history
~~Based on branch for PR #500 -- I will rebase after that PR merges.~~

Closes #367.
    
Supersedes PR #497.
  • Loading branch information
tseaver committed Jan 11, 2022
1 parent 74d8171 commit a256752
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 65 deletions.
8 changes: 2 additions & 6 deletions google/cloud/firestore_v1/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,5 @@ def on_snapshot(collection_snapshot, changes, read_time):
# Terminate this watch
collection_watch.unsubscribe()
"""
return Watch.for_query(
self._query(),
callback,
document.DocumentSnapshot,
document.DocumentReference,
)
query = self._query()
return Watch.for_query(query, callback, document.DocumentSnapshot)
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,4 +489,4 @@ def on_snapshot(document_snapshot, changes, read_time):
# Terminate this watch
doc_watch.unsubscribe()
"""
return Watch.for_document(self, callback, DocumentSnapshot, DocumentReference)
return Watch.for_document(self, callback, DocumentSnapshot)
4 changes: 1 addition & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def on_snapshot(docs, changes, read_time):
# Terminate this watch
query_watch.unsubscribe()
"""
return Watch.for_query(
self, callback, document.DocumentSnapshot, document.DocumentReference
)
return Watch.for_query(self, callback, document.DocumentSnapshot)

@staticmethod
def _get_collection_reference_class() -> Type[
Expand Down
67 changes: 34 additions & 33 deletions google/cloud/firestore_v1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def __init__(
comparator,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
):
"""
Args:
Expand All @@ -192,35 +191,21 @@ def __init__(
read_time (string): The ISO 8601 time at which this
snapshot was obtained.
document_snapshot_cls: instance of DocumentSnapshot
document_reference_cls: instance of DocumentReference
document_snapshot_cls: factory for instances of DocumentSnapshot
"""
self._document_reference = document_reference
self._firestore = firestore
self._api = firestore._firestore_api
self._targets = target
self._comparator = comparator
self.DocumentSnapshot = document_snapshot_cls
self.DocumentReference = document_reference_cls
self._document_snapshot_cls = document_snapshot_cls
self._snapshot_callback = snapshot_callback
self._api = firestore._firestore_api
self._closing = threading.Lock()
self._closed = False
self._set_documents_pfx(firestore._database_string)

self.resume_token = None

rpc_request = self._get_rpc_request

self._rpc = ResumableBidiRpc(
start_rpc=self._api._transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
initial_request=rpc_request,
metadata=self._firestore._rpc_metadata,
)

self._rpc.add_done_callback(self._on_rpc_done)

# Initialize state for on_snapshot
# The sorted tree of QueryDocumentSnapshots as sent in the last
# snapshot. We only look at the keys.
Expand All @@ -242,17 +227,29 @@ def __init__(
# aren't docs.
self.has_pushed = False

self._init_stream()

def _init_stream(self):

rpc_request = self._get_rpc_request

self._rpc = ResumableBidiRpc(
start_rpc=self._api._transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
initial_request=rpc_request,
metadata=self._firestore._rpc_metadata,
)

self._rpc.add_done_callback(self._on_rpc_done)

# The server assigns and updates the resume token.
self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot)
self._consumer.start()

@classmethod
def for_document(
cls,
document_ref,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
cls, document_ref, snapshot_callback, document_snapshot_cls,
):
"""
Creates a watch snapshot listener for a document. snapshot_callback
Expand All @@ -276,13 +273,10 @@ def for_document(
document_watch_comparator,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
)

@classmethod
def for_query(
cls, query, snapshot_callback, document_snapshot_cls, document_reference_cls,
):
def for_query(cls, query, snapshot_callback, document_snapshot_cls):
parent_path, _ = query._parent._parent_info()
query_target = Target.QueryTarget(
parent=parent_path, structured_query=query._to_protobuf()
Expand All @@ -295,12 +289,13 @@ def for_query(
query._comparator,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
)

def _get_rpc_request(self):
if self.resume_token is not None:
self._targets["resume_token"] = self.resume_token
else:
self._targets.pop("resume_token", None)

return ListenRequest(
database=self._firestore._database_string, add_target=self._targets
Expand Down Expand Up @@ -490,7 +485,7 @@ def on_snapshot(self, proto):
document_name = self._strip_document_pfx(document.name)
document_ref = self._firestore.document(document_name)

snapshot = self.DocumentSnapshot(
snapshot = self._document_snapshot_cls(
reference=document_ref,
data=data,
exists=True,
Expand Down Expand Up @@ -520,11 +515,17 @@ def on_snapshot(self, proto):
elif which == "filter":
_LOGGER.debug("on_snapshot: filter update")
if pb.filter.count != self._current_size():
# We need to remove all the current results.
# First, shut down current stream
_LOGGER.info("Filter mismatch -- restarting stream.")
thread = threading.Thread(
name=_RPC_ERROR_THREAD_NAME, target=self.close,
)
thread.start()
thread.join() # wait for shutdown to complete
# Then, remove all the current results.
self._reset_docs()
# The filter didn't match, so re-issue the query.
# TODO: reset stream method?
# self._reset_stream();
# Finally, restart stream.
self._init_stream()

else:
_LOGGER.debug("UNKNOWN TYPE. UHOH")
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/v1/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER
# 'docs' (list of 'google.firestore_v1.Document'),
# 'changes' (list lof local 'DocChange', and 'read_time' timestamp.
from google.cloud.firestore_v1 import Client
from google.cloud.firestore_v1 import DocumentReference
from google.cloud.firestore_v1 import DocumentSnapshot
from google.cloud.firestore_v1 import Watch
import google.auth.credentials
Expand All @@ -226,6 +225,9 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER

credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(project="project", credentials=credentials)
# conformance data has db string as this
db_str = "projects/projectID/databases/(default)"
client._database_string_internal = db_str
with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"):
with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"):
# conformance data sets WATCH_TARGET_ID to 1
Expand All @@ -237,12 +239,7 @@ def callback(keys, applied_changes, read_time):

collection = DummyCollection(client=client)
query = DummyQuery(parent=collection)
watch = Watch.for_query(
query, callback, DocumentSnapshot, DocumentReference
)
# conformance data has db string as this
db_str = "projects/projectID/databases/(default)"
watch._firestore._database_string_internal = db_str
watch = Watch.for_query(query, callback, DocumentSnapshot)

wrapped_responses = [
firestore.ListenResponse.wrap(proto) for proto in testcase.responses
Expand Down
18 changes: 3 additions & 15 deletions tests/unit/v1/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def snapshot_callback(*args):
comparator=comparator,
snapshot_callback=snapshot_callback,
document_snapshot_cls=DummyDocumentSnapshot,
document_reference_cls=DummyDocumentReference,
)


Expand Down Expand Up @@ -224,16 +223,11 @@ def snapshot_callback(*args): # pragma: NO COVER
snapshots.append(args)

docref = DummyDocumentReference()
snapshot_class_instance = DummyDocumentSnapshot
document_reference_class_instance = DummyDocumentReference

with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"):
with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"):
inst = Watch.for_document(
docref,
snapshot_callback,
snapshot_class_instance,
document_reference_class_instance,
docref, snapshot_callback, document_snapshot_cls=DummyDocumentSnapshot,
)

inst._consumer.start.assert_called_once_with()
Expand All @@ -246,8 +240,6 @@ def test_watch_for_query(snapshots):
def snapshot_callback(*args): # pragma: NO COVER
snapshots.append(args)

snapshot_class_instance = DummyDocumentSnapshot
document_reference_class_instance = DummyDocumentReference
client = DummyFirestore()
parent = DummyCollection(client)
query = DummyQuery(parent=parent)
Expand All @@ -258,8 +250,7 @@ def snapshot_callback(*args): # pragma: NO COVER
inst = Watch.for_query(
query,
snapshot_callback,
snapshot_class_instance,
document_reference_class_instance,
document_snapshot_cls=DummyDocumentSnapshot,
)

inst._consumer.start.assert_called_once_with()
Expand All @@ -278,8 +269,6 @@ def test_watch_for_query_nested(snapshots):
def snapshot_callback(*args): # pragma: NO COVER
snapshots.append(args)

snapshot_class_instance = DummyDocumentSnapshot
document_reference_class_instance = DummyDocumentReference
client = DummyFirestore()
root = DummyCollection(client)
grandparent = DummyDocument("document", parent=root)
Expand All @@ -292,8 +281,7 @@ def snapshot_callback(*args): # pragma: NO COVER
inst = Watch.for_query(
query,
snapshot_callback,
snapshot_class_instance,
document_reference_class_instance,
document_snapshot_cls=DummyDocumentSnapshot,
)

inst._consumer.start.assert_called_once_with()
Expand Down

0 comments on commit a256752

Please sign in to comment.