Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to test.util.caches #14529

Merged
merged 7 commits into from Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14529.misc
@@ -0,0 +1 @@
Add missing type hints.
11 changes: 6 additions & 5 deletions mypy.ini
Expand Up @@ -59,11 +59,6 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py
|tests/util/caches/test_response_cache.py
|tests/util/caches/test_ttlcache.py
|tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py
Expand Down Expand Up @@ -133,6 +128,12 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True

[mypy-tests.util.caches.*]
disallow_untyped_defs = True

[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False

[mypy-tests.utils]
disallow_untyped_defs = True

Expand Down
23 changes: 12 additions & 11 deletions tests/util/caches/test_cached_call.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NoReturn
from unittest.mock import Mock

from twisted.internet import defer
Expand All @@ -23,14 +24,14 @@


class CachedCallTestCase(TestCase):
def test_get(self):
def test_get(self) -> None:
"""
Happy-path test case: makes a couple of calls and makes sure they behave
correctly
"""
d = Deferred()
d: "Deferred[int]" = Deferred()

async def f():
async def f() -> int:
return await d

slow_call = Mock(side_effect=f)
Expand All @@ -43,7 +44,7 @@ async def f():
# now fire off a couple of calls
completed_results = []

async def r():
async def r() -> None:
res = await cached_call.get()
completed_results.append(res)

Expand All @@ -69,12 +70,12 @@ async def r():
self.assertEqual(r3, 123)
slow_call.assert_not_called()

def test_fast_call(self):
def test_fast_call(self) -> None:
"""
Test the behaviour when the underlying function completes immediately
"""

async def f():
async def f() -> int:
return 12

fast_call = Mock(side_effect=f)
Expand All @@ -92,12 +93,12 @@ async def f():


class RetryOnExceptionCachedCallTestCase(TestCase):
def test_get(self):
def test_get(self) -> None:
# set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while)
d = Deferred()
d: "Deferred[int]" = Deferred()

async def f1():
async def f1() -> NoReturn:
await d
raise ValueError("moo")

Expand All @@ -110,7 +111,7 @@ async def f1():
# now fire off a couple of calls
completed_results = []

async def r():
async def r() -> None:
try:
await cached_call.get()
except Exception as e1:
Expand All @@ -137,7 +138,7 @@ async def r():
# to the getter
d = Deferred()

async def f2():
async def f2() -> int:
return await d

slow_call.reset_mock()
Expand Down
61 changes: 31 additions & 30 deletions tests/util/caches/test_deferred_cache.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import partial
from typing import List, Tuple

from twisted.internet import defer

Expand All @@ -22,28 +23,28 @@


class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
def test_empty(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
with self.assertRaises(KeyError):
cache.get("foo")

def test_hit(self):
cache = DeferredCache("test")
def test_hit(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
cache.prefill("foo", 123)

self.assertEqual(self.successResultOf(cache.get("foo")), 123)

def test_hit_deferred(self):
cache = DeferredCache("test")
origin_d = defer.Deferred()
def test_hit_deferred(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d)

# get should return an incomplete deferred
get_d = cache.get("k1")
self.assertFalse(get_d.called)

# add a callback that will make sure that the set_d gets called before the get_d
def check1(r):
def check1(r: str) -> str:
self.assertTrue(set_d.called)
return r

Expand All @@ -55,16 +56,16 @@ def check1(r):
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)

def test_callbacks(self):
def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time"""
cache = DeferredCache("test")
cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()

# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))

# now replace that entry with a pending result
origin_d = defer.Deferred()
origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))

# ... and also make a get request
Expand All @@ -89,15 +90,15 @@ def test_callbacks(self):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})

def test_set_fail(self):
cache = DeferredCache("test")
def test_set_fail(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()

# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))

# now replace that entry with a pending result
origin_d = defer.Deferred()
origin_d: defer.Deferred = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))

# ... and also make a get request
Expand Down Expand Up @@ -126,9 +127,9 @@ def test_set_fail(self):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})

def test_get_immediate(self):
cache = DeferredCache("test")
d1 = defer.Deferred()
def test_get_immediate(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test")
d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1)

# get_immediate should return default
Expand All @@ -142,27 +143,27 @@ def test_get_immediate(self):
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)

def test_invalidate(self):
cache = DeferredCache("test")
def test_invalidate(self) -> None:
cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))

with self.assertRaises(KeyError):
cache.get(("foo",))

def test_invalidate_all(self):
cache = DeferredCache("testcache")
def test_invalidate_all(self) -> None:
cache: DeferredCache[str, str] = DeferredCache("testcache")

callback_record = [False, False]

def record_callback(idx):
def record_callback(idx: int) -> None:
callback_record[idx] = True

# add a couple of pending entries
d1 = defer.Deferred()
d1: "defer.Deferred[str]" = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))

d2 = defer.Deferred()
d2: "defer.Deferred[str]" = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))

# lookup should return pending deferreds
Expand Down Expand Up @@ -193,8 +194,8 @@ def record_callback(idx):
with self.assertRaises(KeyError):
cache.get("key1", None)

def test_eviction(self):
cache = DeferredCache(
def test_eviction(self) -> None:
cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)

Expand All @@ -208,8 +209,8 @@ def test_eviction(self):
cache.get(2)
cache.get(3)

def test_eviction_lru(self):
cache = DeferredCache(
def test_eviction_lru(self) -> None:
cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)

Expand All @@ -227,8 +228,8 @@ def test_eviction_lru(self):
cache.get(1)
cache.get(3)

def test_eviction_iterable(self):
cache = DeferredCache(
def test_eviction_iterable(self) -> None:
cache: DeferredCache[int, List[str]] = DeferredCache(
"test",
max_entries=3,
apply_cache_factor_from_config=False,
Expand Down
22 changes: 14 additions & 8 deletions tests/util/caches/test_descriptors.py
Expand Up @@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Set, Tuple
from typing import Iterable, Set, Tuple, cast
from unittest import mock

from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.interfaces import IReactorTime

from synapse.api.errors import SynapseError
from synapse.logging.context import (
Expand All @@ -37,8 +38,8 @@


def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
d: "Deferred[int]" = defer.Deferred()
cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d)


Expand Down Expand Up @@ -224,7 +225,8 @@ def fn(self, arg1):
callbacks: Set[str] = set()

# set off an asynchronous request
obj.result = origin_d = defer.Deferred()
origin_d: Deferred = defer.Deferred()
obj.result = origin_d

d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
Expand Down Expand Up @@ -262,7 +264,7 @@ def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""

complete_lookup = defer.Deferred()
complete_lookup: Deferred = defer.Deferred()

class Cls:
@descriptors.cached()
Expand Down Expand Up @@ -772,10 +774,14 @@ def fn(self, arg1, arg2):

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
assert current_context().name == "c1"
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
assert current_context().name == "c1"
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
return self.mock(args1, arg2)

with LoggingContext("c1") as c1:
Expand Down Expand Up @@ -834,7 +840,7 @@ def list_fn(self, args1) -> "Deferred[dict]":
return self.mock(args1)

obj = Cls()
deferred_result = Deferred()
deferred_result: "Deferred[dict]" = Deferred()
obj.mock.return_value = deferred_result

# start off several concurrent lookups of the same key
Expand Down