Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve static type checking #333

Merged
merged 5 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/333.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve static type checking.
34 changes: 34 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[mypy]
plugins = mypy_zope:plugin
check_untyped_defs = True
disallow_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
Expand Down Expand Up @@ -43,3 +44,36 @@ ignore_missing_imports = True

[mypy-pywebpush]
ignore_missing_imports = True

[mypy-sygnal.helper.*]
disallow_untyped_defs = False

[mypy-sygnal.notifications]
disallow_untyped_defs = False

[mypy-sygnal.http]
disallow_untyped_defs = False

[mypy-sygnal.sygnal]
disallow_untyped_defs = False

[mypy-tests.asyncio_test_helpers]
disallow_untyped_defs = False

[mypy-tests.test_http]
disallow_untyped_defs = False

[mypy-tests.test_httpproxy_asyncio]
disallow_untyped_defs = False

[mypy-tests.test_httpproxy_twisted]
disallow_untyped_defs = False

[mypy-tests.test_pushgateway_api_v1]
disallow_untyped_defs = False

[mypy-tests.testutils]
disallow_untyped_defs = False

[mypy-tests.twisted_test_helpers]
disallow_untyped_defs = False
4 changes: 2 additions & 2 deletions sygnal/apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
]


def json_encode(payload) -> bytes:
def json_encode(payload: Dict[str, Any]) -> bytes:
return json.dumps(payload, ensure_ascii=False).encode()


Expand Down Expand Up @@ -115,7 +115,7 @@ def _choppables_for_aps(aps: Dict[str, Any]) -> List[Choppable]:
def _choppable_get(
aps: Dict[str, Any],
choppable: Choppable,
):
) -> str:
if choppable[0] == "alert":
return aps["alert"]
elif choppable[0] == "alert.body":
Expand Down
4 changes: 2 additions & 2 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def create(
return cls(name, sygnal, config)

async def _perform_http_request(
self, body: Dict, headers: Dict[AnyStr, List[AnyStr]]
self, body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]]
) -> Tuple[IResponse, str]:
"""
Perform an HTTP request to the FCM server with the body and headers
Expand Down Expand Up @@ -208,7 +208,7 @@ async def _request_dispatch(
self,
n: Notification,
log: NotificationLoggerAdapter,
body: dict,
body: Dict[str, Any],
headers: Dict[AnyStr, List[AnyStr]],
pushkeys: List[str],
span: Span,
Expand Down
27 changes: 14 additions & 13 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import MagicMock, patch

from aioapns.common import NotificationResult, PushType
Expand Down Expand Up @@ -56,7 +57,7 @@


class ApnsTestCase(testutils.TestCase):
def setUp(self):
def setUp(self) -> None:
self.apns_mock_class = patch("sygnal.apnspushkin.APNs").start()
self.apns_mock = MagicMock()
self.apns_mock_class.return_value = self.apns_mock
Expand All @@ -82,7 +83,7 @@ def get_test_pushkin(self, name: str) -> ApnsPushkin:
assert isinstance(test_pushkin, ApnsPushkin)
return test_pushkin

def config_setup(self, config):
def config_setup(self, config: Dict[str, Any]) -> None:
super().config_setup(config)
config["apps"][PUSHKIN_ID] = {"type": "apns", "certfile": TEST_CERTFILE_PATH}
config["apps"][PUSHKIN_ID_WITH_PUSH_TYPE] = {
Expand All @@ -91,7 +92,7 @@ def config_setup(self, config):
"push_type": "alert",
}

def test_payload_truncation(self):
def test_payload_truncation(self) -> None:
"""
Tests that APNS message bodies will be truncated to fit the limits of
APNS.
Expand All @@ -114,7 +115,7 @@ def test_payload_truncation(self):

self.assertLessEqual(len(apnstruncate.json_encode(payload)), 240)

def test_payload_truncation_test_validity(self):
def test_payload_truncation_test_validity(self) -> None:
"""
This tests that L{test_payload_truncation_success} is a valid test
by showing that not limiting the truncation size would result in a
Expand All @@ -138,7 +139,7 @@ def test_payload_truncation_test_validity(self):

self.assertGreater(len(apnstruncate.json_encode(payload)), 200)

def test_expected(self):
def test_expected(self) -> None:
"""
Tests the expected case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -177,7 +178,7 @@ def test_expected(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_event_id_only_with_default_payload(self):
def test_expected_event_id_only_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_expected_event_id_only_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_badge_only_with_default_payload(self):
def test_expected_badge_only_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -243,7 +244,7 @@ def test_expected_badge_only_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_full_with_default_payload(self):
def test_expected_full_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -285,7 +286,7 @@ def test_expected_full_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_misconfigured_payload_is_rejected(self):
def test_misconfigured_payload_is_rejected(self) -> None:
"""Test that a malformed default_payload causes pushkey to be rejected"""

resp = self._request(
Expand All @@ -294,7 +295,7 @@ def test_misconfigured_payload_is_rejected(self):

self.assertEqual({"rejected": ["badpayload"]}, resp)

def test_rejection(self):
def test_rejection(self) -> None:
"""
Tests the rejection case: a rejection response from APNS leads to us
passing on a rejection to the homeserver.
Expand All @@ -312,7 +313,7 @@ def test_rejection(self):
self.assertEqual(1, method.call_count)
self.assertEqual({"rejected": ["spqr"]}, resp)

def test_no_retry_on_4xx(self):
def test_no_retry_on_4xx(self) -> None:
"""
Test that we don't retry when we get a 4xx error but do not mark as
rejected.
Expand All @@ -330,7 +331,7 @@ def test_no_retry_on_4xx(self):
self.assertEqual(1, method.call_count)
self.assertEqual(502, resp)

def test_retry_on_5xx(self):
def test_retry_on_5xx(self) -> None:
"""
Test that we DO retry when we get a 5xx error and do not mark as
rejected.
Expand All @@ -348,7 +349,7 @@ def test_retry_on_5xx(self):
self.assertGreater(method.call_count, 1)
self.assertEqual(502, resp)

def test_expected_with_push_type(self):
def test_expected_with_push_type(self) -> None:
"""
Tests the expected case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down
23 changes: 12 additions & 11 deletions tests/test_apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import string
import unittest
from typing import Any, Dict

from sygnal.apnstruncate import json_encode, truncate


def simplestring(length, offset=0):
def simplestring(length: int, offset: int = 0) -> str:
"""
Deterministically generates a string.
Args:
Expand All @@ -41,7 +42,7 @@ def simplestring(length, offset=0):
)


def sillystring(length, offset=0):
def sillystring(length: int, offset: int = 0) -> str:
"""
Deterministically generates a string
Args:
Expand All @@ -55,15 +56,15 @@ def sillystring(length, offset=0):
return "".join([chars[(i + offset) % len(chars)] for i in range(length)])


def payload_for_aps(aps):
def payload_for_aps(aps: Dict[str, Any]) -> Dict[str, Any]:
"""
Returns the APNS payload for an 'aps' dictionary.
"""
return {"aps": aps}


class TruncateTestCase(unittest.TestCase):
def test_dont_truncate(self):
def test_dont_truncate(self) -> None:
"""
Tests that truncation is not performed if unnecessary.
"""
Expand All @@ -72,7 +73,7 @@ def test_dont_truncate(self):
aps = {"alert": txt}
self.assertEqual(txt, truncate(payload_for_aps(aps), 256)["aps"]["alert"])

def test_truncate_alert(self):
def test_truncate_alert(self) -> None:
"""
Tests that the 'alert' string field will be truncated when needed.
"""
Expand All @@ -83,7 +84,7 @@ def test_truncate_alert(self):
txt[:5], truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]
)

def test_truncate_alert_body(self):
def test_truncate_alert_body(self) -> None:
"""
Tests that the 'alert' 'body' field will be truncated when needed.
"""
Expand All @@ -95,7 +96,7 @@ def test_truncate_alert_body(self):
truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["body"],
)

def test_truncate_loc_arg(self):
def test_truncate_loc_arg(self) -> None:
"""
Tests that the 'alert' 'loc-args' field will be truncated when needed.
(Tests with one loc arg)
Expand All @@ -108,7 +109,7 @@ def test_truncate_loc_arg(self):
truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["loc-args"][0],
)

def test_truncate_loc_args(self):
def test_truncate_loc_args(self) -> None:
"""
Tests that the 'alert' 'loc-args' field will be truncated when needed.
(Tests with two loc args)
Expand All @@ -130,7 +131,7 @@ def test_truncate_loc_args(self):
],
)

def test_python_unicode_support(self):
def test_python_unicode_support(self) -> None:
"""
Tests Python's unicode support :-
a one character unicode string should have a length of one, even if it's one
Expand All @@ -146,7 +147,7 @@ def test_python_unicode_support(self):
)
self.fail(msg)

def test_truncate_string_with_multibyte(self):
def test_truncate_string_with_multibyte(self) -> None:
"""
Tests that truncation works as expected on strings containing one
multibyte character.
Expand All @@ -160,7 +161,7 @@ def test_truncate_string_with_multibyte(self):
txt[:17], truncate(payload_for_aps(aps), overhead + 20)["aps"]["alert"]
)

def test_truncate_multibyte(self):
def test_truncate_multibyte(self) -> None:
"""
Tests that truncation works as expected on strings containing only
multibyte characters.
Expand Down
19 changes: 13 additions & 6 deletions tests/test_concurrency_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from sygnal.notifications import ConcurrencyLimitedPushkin
from typing import TYPE_CHECKING, Any, Dict, List

from sygnal.notifications import ConcurrencyLimitedPushkin, Device, Notification
from sygnal.utils import twisted_sleep

from tests.testutils import TestCase

if TYPE_CHECKING:
from sygnal.notifications import NotificationContext

DEVICE_GCM1_EXAMPLE = {
"app_id": "com.example.gcm",
"pushkey": "spqrg",
Expand All @@ -36,7 +41,9 @@


class SlowConcurrencyLimitedDummyPushkin(ConcurrencyLimitedPushkin):
async def _dispatch_notification_unlimited(self, n, device, context):
async def dispatch_notification(
self, n: Notification, device: Device, context: "NotificationContext"
) -> List[str]:
"""
We will deliver the notification to the mighty nobody
and we will take one second to do it, because we are slow!
Expand All @@ -46,7 +53,7 @@ async def _dispatch_notification_unlimited(self, n, device, context):


class ConcurrencyLimitTestCase(TestCase):
def config_setup(self, config):
def config_setup(self, config: Dict[str, Any]) -> None:
super().config_setup(config)
config["apps"]["com.example.gcm"] = {
"type": "tests.test_concurrency_limit.SlowConcurrencyLimitedDummyPushkin",
Expand All @@ -57,15 +64,15 @@ def config_setup(self, config):
"inflight_request_limit": 1,
}

def test_passes_under_limit_one(self):
def test_passes_under_limit_one(self) -> None:
"""
Tests that a push notification succeeds if it is under the limit.
"""
resp = self._request(self._make_dummy_notification([DEVICE_GCM1_EXAMPLE]))

self.assertEqual(resp, {"rejected": []})

def test_passes_under_limit_multiple_no_interfere(self):
def test_passes_under_limit_multiple_no_interfere(self) -> None:
"""
Tests that 2 push notifications succeed if they are to different
pushkins (so do not hit a per-pushkin limit).
Expand All @@ -76,7 +83,7 @@ def test_passes_under_limit_multiple_no_interfere(self):

self.assertEqual(resp, {"rejected": []})

def test_fails_when_limit_hit(self):
def test_fails_when_limit_hit(self) -> None:
"""
Tests that 1 of 2 push notifications fail if they are to the same pushkins
(so do hit the per-pushkin limit of 1).
Expand Down
Loading