Skip to content
This repository was archived by the owner on May 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,37 @@ def make_call(self):

def future_info(self, key):
"""Generate info string for Future."""
return "GlobalWatch.delete({})".format(key)
return "GlobalCache.watch({})".format(key)


def global_unwatch(key):
"""End optimistic transaction with global cache.

Indicates that value for the key wasn't found in the database, so there will not be
a future call to :func:`global_compare_and_swap`, and we no longer need to watch
this key.

Args:
key (bytes): The key to unwatch.

Returns:
tasklets.Future: Eventual result will be ``None``.
"""
batch = _batch.get_batch(_GlobalCacheUnwatchBatch)
return batch.add(key)


class _GlobalCacheUnwatchBatch(_GlobalCacheWatchBatch):
"""Batch for global cache unwatch requests. """

def make_call(self):
"""Call :method:`GlobalCache.unwatch`."""
cache = context_module.get_context().global_cache
return cache.unwatch(self.keys)

def future_info(self, key):
"""Generate info string for Future."""
return "GlobalCache.unwatch({})".format(key)


def global_compare_and_swap(key, value, expires=None):
Expand Down
13 changes: 9 additions & 4 deletions google/cloud/ndb/_datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,15 @@ def lookup(key, options):
entity_pb = yield batch.add(key)

# Do not cache misses
if use_global_cache and not key_locked and entity_pb is not _NOT_FOUND:
expires = context._global_cache_timeout(key, options)
serialized = entity_pb.SerializeToString()
yield _cache.global_compare_and_swap(cache_key, serialized, expires=expires)
if use_global_cache and not key_locked:
if entity_pb is not _NOT_FOUND:
expires = context._global_cache_timeout(key, options)
serialized = entity_pb.SerializeToString()
yield _cache.global_compare_and_swap(
cache_key, serialized, expires=expires
)
else:
yield _cache.global_unwatch(cache_key)

raise tasklets.Return(entity_pb)

Expand Down
32 changes: 32 additions & 0 deletions google/cloud/ndb/global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ def watch(self, keys):
"""
raise NotImplementedError

@abc.abstractmethod
def unwatch(self, keys):
"""End an optimistic transaction for the given keys.

Indicates that value for the key wasn't found in the database, so there will not
be a future call to :meth:`compare_and_swap`, and we no longer need to watch
this key.

Arguments:
keys (List[bytes]): The keys to watch.
"""
raise NotImplementedError

@abc.abstractmethod
def compare_and_swap(self, items, expires=None):
"""Like :meth:`set` but using an optimistic transaction.
Expand Down Expand Up @@ -160,6 +173,11 @@ def watch(self, keys):
for key in keys:
self._watch_keys[key] = self.cache.get(key)

def unwatch(self, keys):
"""Implements :meth:`GlobalCache.unwatch`."""
for key in keys:
self._watch_keys.pop(key, None)

def compare_and_swap(self, items, expires=None):
"""Implements :meth:`GlobalCache.compare_and_swap`."""
if expires:
Expand Down Expand Up @@ -239,6 +257,13 @@ def watch(self, keys):
for key in keys:
self.pipes[key] = holder

def unwatch(self, keys):
"""Implements :meth:`GlobalCache.watch`."""
for key in keys:
holder = self.pipes.pop(key, None)
if holder:
holder.pipe.reset()

def compare_and_swap(self, items, expires=None):
"""Implements :meth:`GlobalCache.compare_and_swap`."""
pipes = {}
Expand Down Expand Up @@ -391,6 +416,13 @@ def watch(self, keys):
for key, (value, caskey) in self.client.gets_many(keys).items():
caskeys[key] = caskey

def unwatch(self, keys):
"""Implements :meth:`GlobalCache.unwatch`."""
keys = [self._key(key) for key in keys]
caskeys = self.caskeys
for key in keys:
caskeys.pop(key, None)

def compare_and_swap(self, items, expires=None):
"""Implements :meth:`GlobalCache.compare_and_swap`."""
caskeys = self.caskeys
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test__cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,31 @@ def test_add_and_idle_and_done_callbacks(in_context):
assert future2.result() is None


@mock.patch("google.cloud.ndb._cache._batch")
def test_global_unwatch(_batch):
batch = _batch.get_batch.return_value
assert _cache.global_unwatch(b"key") is batch.add.return_value
_batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch)
batch.add.assert_called_once_with(b"key")


class Test_GlobalCacheUnwatchBatch:
@staticmethod
def test_add_and_idle_and_done_callbacks(in_context):
cache = mock.Mock()

batch = _cache._GlobalCacheUnwatchBatch({})
future1 = batch.add(b"foo")
future2 = batch.add(b"bar")

with in_context.new(global_cache=cache).use():
batch.idle_callback()

cache.unwatch.assert_called_once_with([b"foo", b"bar"])
assert future1.result() is None
assert future2.result() is None


class Test_global_compare_and_swap:
@staticmethod
@mock.patch("google.cloud.ndb._cache._batch")
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test__datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class SomeKind(model.Model):
assert future.result() is _api._NOT_FOUND

assert global_cache.get([cache_key]) == [_cache._LOCKED]
assert len(global_cache._watch_keys) == 0


class Test_LookupBatch:
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test_global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def delete(self, keys):
def watch(self, keys):
return super(MockImpl, self).watch(keys)

def unwatch(self, keys):
return super(MockImpl, self).unwatch(keys)

def compare_and_swap(self, items, expires=None):
return super(MockImpl, self).compare_and_swap(items, expires=expires)

Expand All @@ -63,6 +66,11 @@ def test_watch(self):
with pytest.raises(NotImplementedError):
cache.watch(b"foo")

def test_unwatch(self):
cache = self.make_one()
with pytest.raises(NotImplementedError):
cache.unwatch(b"foo")

def test_compare_and_swap(self):
cache = self.make_one()
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -147,6 +155,16 @@ def test_watch_compare_and_swap_with_expires(time):
result = cache.get([b"one", b"two", b"three"])
assert result == [None, b"hamburgers", None]

@staticmethod
def test_watch_unwatch():
cache = global_cache._InProcessGlobalCache()
result = cache.watch([b"one", b"two", b"three"])
assert result is None

result = cache.unwatch([b"one", b"two", b"three"])
assert result is None
assert cache._watch_keys == {}


class TestRedisCache:
@staticmethod
Expand Down Expand Up @@ -225,6 +243,23 @@ def test_watch(uuid):
"bar": global_cache._Pipeline(pipe, "abc123"),
}

@staticmethod
def test_unwatch():
redis = mock.Mock(spec=())
cache = global_cache.RedisCache(redis)
pipe1 = mock.Mock(spec=("reset",))
pipe2 = mock.Mock(spec=("reset",))
cache._pipes.pipes = {
"ay": global_cache._Pipeline(pipe1, "abc123"),
"be": global_cache._Pipeline(pipe1, "abc123"),
"see": global_cache._Pipeline(pipe2, "def456"),
"dee": global_cache._Pipeline(pipe2, "def456"),
"whatevs": global_cache._Pipeline(None, "himom!"),
}

cache.unwatch(["ay", "be", "see", "dee", "nuffin"])
assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")}

@staticmethod
def test_compare_and_swap():
redis = mock.Mock(spec=())
Expand Down Expand Up @@ -450,6 +485,17 @@ def test_watch():
key2: b"1",
}

@staticmethod
def test_unwatch():
client = mock.Mock(spec=())
cache = global_cache.MemcacheCache(client)
key2 = cache._key(b"two")
cache.caskeys[key2] = b"5"
cache.caskeys["whatevs"] = b"6"
cache.unwatch([b"one", b"two"])

assert cache.caskeys == {"whatevs": b"6"}

@staticmethod
def test_compare_and_swap():
client = mock.Mock(spec=("cas",))
Expand Down