Skip to content

Commit fdc5aee

Browse files
committed
Independent oauth preflight provider check
1 parent 38b43f3 commit fdc5aee

File tree

3 files changed

+44
-39
lines changed

3 files changed

+44
-39
lines changed

plain-oauth/plain/oauth/models.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from plain.auth import get_user_model
55
from plain.exceptions import ValidationError
66
from plain.models import transaction
7-
from plain.models.db import IntegrityError, OperationalError, ProgrammingError
8-
from plain.preflight import PreflightResult
7+
from plain.models.db import IntegrityError
98
from plain.runtime import SettingsReference
109
from plain.utils import timezone
1110

@@ -15,9 +14,6 @@
1514
from .providers import OAuthToken, OAuthUser
1615

1716

18-
# TODO preflight check for deploy that ensures all provider keys in db are also in settings?
19-
20-
2117
@models.register_model
2218
class OAuthConnection(models.Model):
2319
created_at = models.DateTimeField(auto_now_add=True)
@@ -154,35 +150,3 @@ def connect(
154150
connection.save()
155151

156152
return connection
157-
158-
@classmethod
159-
def preflight(cls) -> list[PreflightResult]:
160-
"""
161-
A system check for ensuring that provider_keys in the database are also present in settings.
162-
"""
163-
errors = super().preflight()
164-
165-
from .providers import get_provider_keys
166-
167-
try:
168-
keys_in_db = set(
169-
cls.query.values_list("provider_key", flat=True).distinct()
170-
)
171-
except (OperationalError, ProgrammingError):
172-
# Check runs on plain migrate, and the table may not exist yet
173-
# or it may not be installed on the particular database intentionally
174-
return errors
175-
176-
keys_in_settings = set(get_provider_keys())
177-
178-
if keys_in_db - keys_in_settings:
179-
errors.append(
180-
PreflightResult(
181-
fix="The following OAuth providers are in the database but not in the settings: {}. Add these providers to your OAUTH_LOGIN_PROVIDERS setting or remove the corresponding OAuthConnection records.".format(
182-
", ".join(keys_in_db - keys_in_settings)
183-
),
184-
id="oauth.provider_in_db_not_in_settings",
185-
)
186-
)
187-
188-
return errors
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from plain.models import OperationalError, ProgrammingError
2+
from plain.preflight import PreflightCheck, PreflightResult, register_check
3+
4+
5+
@register_check(name="oauth.provider_keys")
6+
class CheckOAuthProviderKeys(PreflightCheck):
7+
"""
8+
Check for OAuth provider keys in the database that are not present in settings.
9+
"""
10+
11+
def run(self) -> list[PreflightResult]:
12+
from .models import OAuthConnection
13+
from .providers import get_provider_keys
14+
15+
errors = []
16+
17+
try:
18+
keys_in_db = set(
19+
OAuthConnection.query.values_list("provider_key", flat=True).distinct()
20+
)
21+
except (OperationalError, ProgrammingError):
22+
# Check runs on plain migrate, and the table may not exist yet
23+
# or it may not be installed on the particular database intentionally
24+
return errors
25+
26+
keys_in_settings = set(get_provider_keys())
27+
28+
if keys_in_db - keys_in_settings:
29+
errors.append(
30+
PreflightResult(
31+
fix="The following OAuth providers are in the database but not in the settings: {}. Add these providers to your OAUTH_LOGIN_PROVIDERS setting or remove the corresponding OAuthConnection records.".format(
32+
", ".join(keys_in_db - keys_in_settings)
33+
),
34+
id="oauth.provider_settings_missing",
35+
)
36+
)
37+
38+
return errors

plain-oauth/tests/test_checks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from plain.auth import get_user_model
22
from plain.oauth.models import OAuthConnection
3+
from plain.oauth.preflight import CheckOAuthProviderKeys
34

45

56
def test_oauth_provider_keys_check_pass(db, settings):
@@ -23,7 +24,8 @@ def test_oauth_provider_keys_check_pass(db, settings):
2324
access_token="test",
2425
)
2526

26-
errors = OAuthConnection.preflight()
27+
check = CheckOAuthProviderKeys()
28+
errors = check.run()
2729
assert len(errors) == 0
2830

2931

@@ -54,7 +56,8 @@ def test_oauth_provider_keys_check_fail(db, settings):
5456
access_token="test",
5557
)
5658

57-
errors = OAuthConnection.preflight()
59+
check = CheckOAuthProviderKeys()
60+
errors = check.run()
5861
assert len(errors) == 1
5962
assert (
6063
errors[0].fix

0 commit comments

Comments
 (0)