Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix type hints in typing edu unit tests #14886

Merged
merged 11 commits into from
Jan 26, 2023
14 changes: 12 additions & 2 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
import abc
import functools
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
TypeVar,
Union,
cast,
)

from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.tcp import Port
Expand Down Expand Up @@ -479,7 +489,7 @@ def get_presence_router(self) -> PresenceRouter:
return PresenceRouter(self)

@cache_in_self
def get_typing_handler(self) -> FollowerTypingHandler:
def get_typing_handler(self) -> Union[TypingWriterHandler, FollowerTypingHandler]:
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
if self.get_instance_name() in self.config.worker.writers.typing:
# Use get_typing_writer_handler to ensure that we use the same
# cached version.
Expand Down
32 changes: 22 additions & 10 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


import json
from typing import Dict
from typing import Dict, List, Set, cast
clokep marked this conversation as resolved.
Show resolved Hide resolved
from unittest.mock import ANY, Mock, call

from twisted.internet import defer
Expand All @@ -24,11 +24,13 @@
from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock

from tests import unittest
from tests.server import ThreadedMemoryReactorClock
from tests.test_utils import make_awaitable
from tests.unittest import override_config

Expand Down Expand Up @@ -62,7 +64,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:


class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def make_homeserver(
self,
reactor: ThreadedMemoryReactorClock,
clokep marked this conversation as resolved.
Show resolved Hide resolved
clock: Clock,
) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
Expand Down Expand Up @@ -93,7 +99,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event

self.handler = hs.get_typing_handler()
self.handler = cast(TypingWriterHandler, hs.get_typing_handler())

# hs.get_typing_handler will return a TypingWriterHandler when calling it
# from the main process, and a FollowerTypingHandler on workers.
# We rely on methods only available on the former, so assert we have the
# correct type here.
self.assertIsInstance(self.handler, TypingWriterHandler)
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved

self.event_source = hs.get_event_sources().sources.typing

Expand Down Expand Up @@ -186,7 +198,7 @@ def test_started_typing_local(self) -> None:
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
self.event_source.get_new_events(
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
)
)
self.assertEqual(
Expand Down Expand Up @@ -257,7 +269,7 @@ def test_started_typing_remote_recv(self) -> None:
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
self.event_source.get_new_events(
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
)
)
self.assertEqual(
Expand Down Expand Up @@ -298,7 +310,7 @@ def test_started_typing_remote_recv_not_in_room(self) -> None:
self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
limit=None,
limit=0,
room_ids=[OTHER_ROOM_ID],
is_guest=False,
)
Expand Down Expand Up @@ -351,7 +363,7 @@ def test_stopped_typing(self) -> None:
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
self.event_source.get_new_events(
user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
)
)
self.assertEqual(
Expand Down Expand Up @@ -387,7 +399,7 @@ def test_typing_timeout(self) -> None:
self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
limit=None,
limit=0,
room_ids=[ROOM_ID],
is_guest=False,
)
Expand All @@ -412,7 +424,7 @@ def test_typing_timeout(self) -> None:
self.event_source.get_new_events(
user=U_APPLE,
from_key=1,
limit=None,
limit=0,
room_ids=[ROOM_ID],
is_guest=False,
)
Expand Down Expand Up @@ -447,7 +459,7 @@ def test_typing_timeout(self) -> None:
self.event_source.get_new_events(
user=U_APPLE,
from_key=0,
limit=None,
limit=0,
room_ids=[ROOM_ID],
is_guest=False,
)
Expand Down
5 changes: 4 additions & 1 deletion tests/storage/test_user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock

from tests.server import ThreadedMemoryReactorClock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config

Expand Down Expand Up @@ -138,7 +139,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
register.register_servlets,
]

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
self.appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
Expand Down
3 changes: 2 additions & 1 deletion tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from tests.server import (
CustomHeaderType,
FakeChannel,
ThreadedMemoryReactorClock,
get_clock,
make_request,
setup_test_homeserver,
Expand Down Expand Up @@ -360,7 +361,7 @@ def wait_for_background_updates(self) -> None:
store.db_pool.updates.do_next_background_update(False), by=0.1
)

def make_homeserver(self, reactor: MemoryReactor, clock: Clock):
def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
"""
Make and return a homeserver.

Expand Down