Skip to content

Commit 6336456

Browse files
authored
fix: openID provider validation flow (#2186)
* fix: openID provider validation flow * remove test cleanup
1 parent 59db85d commit 6336456

File tree

7 files changed

+100
-8
lines changed

7 files changed

+100
-8
lines changed

Diff for: flask_appbuilder/security/manager.py

+8
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,14 @@ def _has_view_access(
14471447
# If it's not a builtin role check against database store roles
14481448
return self.exist_permission_on_roles(view_name, permission_name, db_role_ids)
14491449

1450+
def get_oid_identity_url(self, provider_name: str) -> Optional[str]:
1451+
"""
1452+
Returns the OIDC identity provider URL
1453+
"""
1454+
for provider in self.openid_providers:
1455+
if provider.get("name") == provider_name:
1456+
return provider.get("url")
1457+
14501458
def get_user_roles(self, user) -> List[object]:
14511459
"""
14521460
Get current user roles, if user is not authenticated returns the public role

Diff for: flask_appbuilder/security/views.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,12 @@ def login_handler(self):
565565
form = LoginForm_oid()
566566
if form.validate_on_submit():
567567
session["remember_me"] = form.remember_me.data
568+
identity_url = self.appbuilder.sm.get_oid_identity_url(form.openid.data)
569+
if identity_url is None:
570+
flash(as_unicode(self.invalid_login_message), "warning")
571+
return redirect(self.appbuilder.get_url_for_login)
568572
return self.appbuilder.sm.oid.try_login(
569-
form.openid.data,
573+
identity_url,
570574
ask_for=self.oid_ask_for,
571575
ask_for_optional=self.oid_ask_for_optional,
572576
)

Diff for: flask_appbuilder/templates/appbuilder/general/security/login_oid.html

+2-6
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,9 @@
3636
<label class="hidden control-label" id="label-username"
3737
for="username">{{ _("Enter your OpenID Username") }}:</label>
3838
{{ form.username(size = 80, class = "hidden form-control", autofocus = true) }}
39-
</div>
40-
</div>
41-
<div class="control-group">
42-
<div class="controls">
4339
<label class="checkbox" for="remember_me">
44-
{{ form.remember_me }} Remember Me
4540
</label>
41+
{{ form.remember_me }} Remember Me
4642
</div>
4743
</div>
4844
<input
@@ -133,7 +129,7 @@
133129
{% for pr in providers %}
134130
document.getElementById("btn-oid-provider-{{ pr.name }}")
135131
.addEventListener("click", function () {
136-
set_openid("{{ pr.url | safe }}", "{{ pr.name }}");
132+
set_openid("{{ pr.name | safe }}", "{{ pr.name }}");
137133
});
138134
{% endfor %}
139135
document.getElementById("btn-oid-before-submit")

Diff for: tests/config_oid.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
3+
from flask_appbuilder.security.manager import AUTH_OID
4+
5+
basedir = os.path.abspath(os.path.dirname(__file__))
6+
7+
SQLALCHEMY_DATABASE_URI = os.environ.get(
8+
"SQLALCHEMY_DATABASE_URI"
9+
) or "sqlite:///" + os.path.join(basedir, "app.db")
10+
11+
SECRET_KEY = "thisismyscretkey"
12+
13+
AUTH_TYPE = AUTH_OID
14+
15+
OPENID_PROVIDERS = [
16+
{"name": "Google", "url": "https://www.google.com/accounts/o8/id"},
17+
{"name": "Yahoo", "url": "https://me.yahoo.com"},
18+
{"name": "AOL", "url": "http://openid.aol.com/<username>"},
19+
{"name": "Flickr", "url": "http://www.flickr.com/<username>"},
20+
{"name": "OpenStack", "url": "https://openstackid.org/"},
21+
]
22+
23+
WTF_CSRF_ENABLED = False
24+
25+
# Will allow user self registration
26+
AUTH_USER_REGISTRATION = True
27+
28+
# The default user self registration role for all users
29+
AUTH_USER_REGISTRATION_ROLE = "Admin"

Diff for: tests/test_mvc_oauth.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get(self, item):
2626
return UserInfoReponseMock()
2727

2828

29-
class APICSRFTestCase(FABTestCase):
29+
class MVCOAuthTestCase(FABTestCase):
3030
def setUp(self):
3131
from flask import Flask
3232
from flask_wtf import CSRFProtect

Diff for: tests/test_mvc_oid.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from unittest.mock import MagicMock
2+
3+
from flask_appbuilder import SQLA
4+
from tests.base import FABTestCase
5+
6+
7+
class MVCOIDTestCase(FABTestCase):
8+
def setUp(self):
9+
from flask import Flask
10+
from flask_appbuilder import AppBuilder
11+
12+
self.app = Flask(__name__)
13+
self.app.config.from_object("tests.config_oid")
14+
self.db = SQLA(self.app)
15+
self.appbuilder = AppBuilder(self.app, self.db.session)
16+
17+
def test_oid_login_get(self):
18+
"""
19+
OID: Test login get
20+
"""
21+
self.appbuilder.sm.oid.try_login = MagicMock(return_value="Login ok")
22+
23+
with self.app.test_client() as client:
24+
response = client.get("/login/")
25+
self.assertEqual(response.status_code, 200)
26+
for provider in self.app.config["OPENID_PROVIDERS"]:
27+
self.assertIn(provider["name"], response.data.decode("utf-8"))
28+
29+
def test_oid_login_post(self):
30+
"""
31+
OID: Test login post with a valid provider
32+
"""
33+
self.appbuilder.sm.oid.try_login = MagicMock(return_value="Login ok")
34+
35+
with self.app.test_client() as client:
36+
response = client.post("/login/", data=dict(openid="OpenStack"))
37+
self.assertEqual(response.status_code, 200)
38+
self.assertEqual(response.data, b"Login ok")
39+
self.appbuilder.sm.oid.try_login.assert_called_with(
40+
"https://openstackid.org/", ask_for=["email"], ask_for_optional=[]
41+
)
42+
43+
def test_oid_login_post_invalid_provider(self):
44+
"""
45+
OID: Test login post with an invalid provider
46+
"""
47+
self.appbuilder.sm.oid.try_login = MagicMock(return_value="Not Ok")
48+
49+
with self.app.test_client() as client:
50+
response = client.post("/login/", data=dict(openid="DoesNotExist"))
51+
self.assertEqual(response.status_code, 302)
52+
self.assertEqual(response.location, "/login/")
53+
self.appbuilder.sm.oid.try_login.assert_not_called()

Diff for: tests/test_security_api.py

+2
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ def setUp(self):
444444
if hasattr(b, "datamodel") and b.datamodel.session is not None:
445445
b.datamodel.session = self.db.session
446446

447+
self.create_default_users(self.appbuilder)
448+
447449
def tearDown(self):
448450
self.appbuilder.session.close()
449451
engine = self.appbuilder.session.get_bind(mapper=None, clause=None)

0 commit comments

Comments
 (0)