Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix other test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
hughns committed Jun 5, 2023
1 parent fba9000 commit 327555d
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 13 deletions.
3 changes: 3 additions & 0 deletions tests/config/test_oauth_delegation.py
Expand Up @@ -61,6 +61,9 @@ def setUp(self) -> None:
**default_config("test"),
"public_baseurl": BASE_URL,
"enable_registration": False,
"login_via_existing_session": {
"enabled": False,
},
"experimental_features": {
"msc3861": {
"enabled": True,
Expand Down
3 changes: 3 additions & 0 deletions tests/handlers/test_oauth_delegation.py
Expand Up @@ -115,6 +115,9 @@ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
config["disable_registration"] = True
config["login_via_existing_session"] = {
"enabled": False,
}
config["experimental_features"] = {
"msc3861": {
"enabled": True,
Expand Down
44 changes: 33 additions & 11 deletions tests/handlers/test_password_providers.py
Expand Up @@ -36,15 +36,20 @@
from tests.unittest import override_config

# Login flows we expect to appear in the list after the normal ones.
ADDITIONAL_LOGIN_FLOWS = [
ADDITIONAL_LOGIN_FLOWS: List[Dict] = [
{"type": "m.login.application_service"},
{"type": "m.login.token", "get_login_token": True},
]

# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
mock_password_provider = Mock()


def sort_flows(flows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return sorted(flows, key=lambda f: f["type"])


class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""

Expand Down Expand Up @@ -184,7 +189,10 @@ def test_password_only_auth_progiver_login_legacy(self) -> None:
def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
sort_flows(flows),
sort_flows([{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS),
)

# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
Expand Down Expand Up @@ -365,7 +373,7 @@ def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(sort_flows(flows), sort_flows(ADDITIONAL_LOGIN_FLOWS))

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
Expand All @@ -386,9 +394,11 @@ def custom_auth_provider_login_test_body(self) -> None:
# (password must come first, because reasons)
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password"}, {"type": "test.login_type"}]
+ ADDITIONAL_LOGIN_FLOWS,
sort_flows(flows),
sort_flows(
[{"type": "m.login.password"}, {"type": "test.login_type"}]
+ ADDITIONAL_LOGIN_FLOWS
),
)

# login with missing param should be rejected
Expand Down Expand Up @@ -519,7 +529,10 @@ def custom_auth_password_disabled_test_body(self) -> None:
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
sort_flows(flows),
sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS),
)

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
Expand Down Expand Up @@ -554,7 +567,10 @@ def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
sort_flows(flows),
sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS),
)

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
Expand Down Expand Up @@ -585,7 +601,10 @@ def password_custom_auth_password_disabled_login_test_body(self) -> None:
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
sort_flows(flows),
sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS),
)

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
Expand Down Expand Up @@ -690,7 +709,10 @@ def custom_auth_no_local_user_fallback_test_body(self) -> None:
self.register_user("localuser", "localpass")

flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
self.assertEqual(
sort_flows(flows),
sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS),
)

# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
Expand Down Expand Up @@ -928,7 +950,7 @@ def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body

def _get_login_flows(self) -> JsonDict:
def _get_login_flows(self) -> List[JsonDict]:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]
Expand Down
2 changes: 2 additions & 0 deletions tests/rest/admin/test_jwks.py
Expand Up @@ -45,6 +45,7 @@ def test_empty_jwks(self) -> None:
@override_config(
{
"disable_registration": True,
"login_via_existing_session": {"enabled": False},
"experimental_features": {
"msc3861": {
"enabled": True,
Expand All @@ -65,6 +66,7 @@ def test_empty_jwks_for_msc3861_client_secret_post(self) -> None:
@override_config(
{
"disable_registration": True,
"login_via_existing_session": {"enabled": False},
"experimental_features": {
"msc3861": {
"enabled": True,
Expand Down
2 changes: 1 addition & 1 deletion tests/rest/client/test_capabilities.py
Expand Up @@ -187,6 +187,7 @@ def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
for room_version in details["support"]:
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))

@override_config({"login_via_existing_session": {"enabled": False}})
def test_get_get_token_login_fields_when_disabled(self) -> None:
"""By default login via an existing session is disabled."""
access_token = self.get_success(
Expand All @@ -201,7 +202,6 @@ def test_get_get_token_login_fields_when_disabled(self) -> None:
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.get_login_token"]["enabled"])

@override_config({"login_via_existing_session": {"enabled": True}})
def test_get_get_token_login_fields_when_enabled(self) -> None:
access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id(
Expand Down
2 changes: 1 addition & 1 deletion tests/rest/client/test_login.py
Expand Up @@ -446,6 +446,7 @@ def test_require_approval(self) -> None:
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)

@override_config({"login_via_existing_session": {"enabled": False}})
def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
"""GET /login should return m.login.token without get_login_token"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
Expand All @@ -454,7 +455,6 @@ def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
self.assertNotIn("m.login.token", flows)

@override_config({"login_via_existing_session": {"enabled": True}})
def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
"""GET /login should return m.login.token with get_login_token true"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
Expand Down
1 change: 1 addition & 0 deletions tests/rest/test_well_known.py
Expand Up @@ -119,6 +119,7 @@ def test_server_well_known_disabled(self) -> None:
},
},
"disable_registration": True,
"login_via_existing_session": {"enabled": False},
}
)
def test_client_well_known_msc3861_oauth_delegation(self) -> None:
Expand Down

0 comments on commit 327555d

Please sign in to comment.