Skip to content

Commit 50463b0

Browse files
committed
Type check the tests, fix a few things
1 parent 4bee399 commit 50463b0

File tree

11 files changed

+184
-99
lines changed

11 files changed

+184
-99
lines changed

plain-admin/tests/test_admin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from app.users.models import User
1+
from app.users.models import User # type: ignore[import-untyped]
22

33
from plain.test import Client
44

plain-oauth/provider_examples/bitbucket.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
from __future__ import annotations
2+
13
import datetime
4+
from typing import TYPE_CHECKING, Any
25

36
import requests
47

58
from plain.oauth.providers import OAuthProvider, OAuthToken, OAuthUser
69
from plain.utils import timezone
710

11+
if TYPE_CHECKING:
12+
from plain.http import Request
13+
814

915
class BitbucketOAuthProvider(OAuthProvider):
1016
authorization_url = "https://bitbucket.org/site/oauth2/authorize"
1117

12-
def _get_token(self, request_data):
18+
def _get_token(self, request_data: dict[str, Any]) -> OAuthToken:
1319
response = requests.post(
1420
"https://bitbucket.org/site/oauth2/access_token",
1521
auth=(self.get_client_id(), self.get_client_secret()),
@@ -27,7 +33,7 @@ def _get_token(self, request_data):
2733
+ datetime.timedelta(seconds=data["expires_in"]),
2834
)
2935

30-
def get_oauth_token(self, *, code, request):
36+
def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken:
3137
return self._get_token(
3238
{
3339
"grant_type": "authorization_code",
@@ -36,15 +42,15 @@ def get_oauth_token(self, *, code, request):
3642
}
3743
)
3844

39-
def refresh_oauth_token(self, *, oauth_token):
45+
def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
4046
return self._get_token(
4147
{
4248
"grant_type": "refresh_token",
4349
"refresh_token": oauth_token.refresh_token,
4450
}
4551
)
4652

47-
def get_oauth_user(self, *, oauth_token):
53+
def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
4854
response = requests.get(
4955
"https://api.bitbucket.org/2.0/user",
5056
headers={

plain-oauth/provider_examples/github.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from __future__ import annotations
2+
13
import datetime
4+
from typing import TYPE_CHECKING, Any
25

36
import requests
47

58
from plain.oauth.exceptions import OAuthError
69
from plain.oauth.providers import OAuthProvider, OAuthToken, OAuthUser
710
from plain.utils import timezone
811

12+
if TYPE_CHECKING:
13+
from plain.http import Request
14+
915

1016
class GitHubOAuthProvider(OAuthProvider):
1117
authorization_url = "https://github.com/login/oauth/authorize"
@@ -14,7 +20,7 @@ class GitHubOAuthProvider(OAuthProvider):
1420
github_user_url = "https://api.github.com/user"
1521
github_emails_url = "https://api.github.com/user/emails"
1622

17-
def _get_token(self, request_data):
23+
def _get_token(self, request_data: dict[str, Any]) -> OAuthToken:
1824
response = requests.post(
1925
self.github_token_url,
2026
headers={
@@ -45,7 +51,7 @@ def _get_token(self, request_data):
4551

4652
return oauth_token
4753

48-
def get_oauth_token(self, *, code, request):
54+
def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken:
4955
return self._get_token(
5056
{
5157
"client_id": self.get_client_id(),
@@ -54,7 +60,7 @@ def get_oauth_token(self, *, code, request):
5460
}
5561
)
5662

57-
def refresh_oauth_token(self, *, oauth_token):
63+
def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
5864
return self._get_token(
5965
{
6066
"client_id": self.get_client_id(),
@@ -64,7 +70,7 @@ def refresh_oauth_token(self, *, oauth_token):
6470
}
6571
)
6672

67-
def get_oauth_user(self, *, oauth_token):
73+
def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
6874
response = requests.get(
6975
self.github_user_url,
7076
headers={

plain-oauth/provider_examples/gitlab.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
15
import requests
26

37
from plain.oauth.providers import OAuthProvider, OAuthToken, OAuthUser
48

9+
if TYPE_CHECKING:
10+
from plain.http import Request
11+
512

613
class GitLabOAuthProvider(OAuthProvider):
714
authorization_url = "https://gitlab.com/oauth/authorize"
815

9-
def _get_token(self, request_data):
16+
def _get_token(self, request_data: dict[str, Any]) -> OAuthToken:
1017
request_data["client_id"] = self.get_client_id()
1118
request_data["client_secret"] = self.get_client_secret()
1219
response = requests.post(
@@ -24,7 +31,7 @@ def _get_token(self, request_data):
2431
# expires_in is missing in response?
2532
)
2633

27-
def get_oauth_token(self, *, code, request):
34+
def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken:
2835
return self._get_token(
2936
{
3037
"grant_type": "authorization_code",
@@ -33,15 +40,15 @@ def get_oauth_token(self, *, code, request):
3340
}
3441
)
3542

36-
def refresh_oauth_token(self, *, oauth_token):
43+
def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
3744
return self._get_token(
3845
{
3946
"grant_type": "refresh_token",
4047
"refresh_token": oauth_token.refresh_token,
4148
}
4249
)
4350

44-
def get_oauth_user(self, *, oauth_token):
51+
def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
4552
response = requests.get(
4653
"https://gitlab.com/api/v4/user",
4754
headers={

plain-worker/plain/worker/parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _serialize_value(value: Any) -> Any:
167167
return value
168168

169169
@staticmethod
170-
def from_json(data: dict[str, Any]) -> tuple[list[Any], dict[str, Any]]:
170+
def from_json(data: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]:
171171
args = []
172172
for arg in data["args"]:
173173
deserialized = JobParameters._deserialize_value(arg)
@@ -178,7 +178,7 @@ def from_json(data: dict[str, Any]) -> tuple[list[Any], dict[str, Any]]:
178178
deserialized = JobParameters._deserialize_value(value)
179179
kwargs[key] = deserialized
180180

181-
return args, kwargs
181+
return tuple(args), kwargs
182182

183183
@staticmethod
184184
def _deserialize_value(value: Any) -> Any:

plain-worker/tests/test_parameters.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def test_datetime_parameter_serialization():
6464
# Test round trip
6565
assert DateTimeParameter.deserialize(datetime_serialized) == test_datetime
6666
assert DateTimeParameter.deserialize(datetime_tz_serialized) == test_datetime_tz
67-
assert DateTimeParameter.deserialize(datetime_tz_serialized).tzinfo == datetime.UTC
67+
deserialized_tz = DateTimeParameter.deserialize(datetime_tz_serialized)
68+
assert deserialized_tz is not None
69+
assert deserialized_tz.tzinfo == datetime.UTC
6870

6971

7072
def test_job_parameters_integration():
@@ -74,7 +76,7 @@ def test_job_parameters_integration():
7476

7577
# Test args and kwargs
7678
serialized = JobParameters.to_json(
77-
[42, "hello", test_date],
79+
(42, "hello", test_date),
7880
{"name": "test", "scheduled_at": test_datetime, "count": 5},
7981
)
8082

@@ -129,7 +131,7 @@ def test_model_parameter_formats():
129131

130132
def test_round_trip_integrity():
131133
"""Test that multiple serialization cycles preserve data."""
132-
original_args = [datetime.date(2024, 1, 15), "string", 42]
134+
original_args = (datetime.date(2024, 1, 15), "string", 42)
133135
original_kwargs = {
134136
"dt": datetime.datetime(2024, 1, 15, 10, 30, 45, 123456),
135137
"num": 100,

plain/plain/internal/handlers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
if TYPE_CHECKING:
1919
from collections.abc import Callable
2020

21-
from plain.http import Request, Response
21+
from plain.http import Request, Response, ResponseBase
2222
from plain.urls import ResolverMatch
2323

2424
logger = logging.getLogger("plain.request")
@@ -72,7 +72,7 @@ def load_middleware(self) -> None:
7272
# as a flag for initialization being complete.
7373
self._middleware_chain = handler
7474

75-
def get_response(self, request: Request) -> Response:
75+
def get_response(self, request: Request) -> ResponseBase:
7676
"""Return a Response object for the given Request."""
7777

7878
span_attributes = {
@@ -124,7 +124,7 @@ def get_response(self, request: Request) -> Response:
124124
)
125125
return response
126126

127-
def _get_response(self, request: Request) -> Response:
127+
def _get_response(self, request: Request) -> ResponseBase:
128128
"""
129129
Resolve and call the view, then apply view, exception, and
130130
template_response middleware. This method is everything that happens

plain/plain/internal/handlers/wsgi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from collections.abc import Callable, Iterable
1717
from typing import Any
1818

19+
from plain.http import ResponseBase
20+
1921
_slashes_re = _lazy_re_compile(rb"/+")
2022

2123

@@ -141,7 +143,7 @@ def __call__(
141143
self,
142144
environ: dict[str, Any],
143145
start_response: Callable[[str, list[tuple[str, str]]], Any],
144-
) -> Iterable[bytes]:
146+
) -> ResponseBase | Iterable[bytes]:
145147
signals.request_started.send(sender=self.__class__, environ=environ)
146148
request = WSGIRequest(environ)
147149
response = self.get_response(request)

0 commit comments

Comments
 (0)