Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Use HTTPStatus constants in place of literals in tests. (#13297)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel committed Jul 15, 2022
1 parent 7b67e93 commit 96cf81e
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 238 deletions.
1 change: 1 addition & 0 deletions changelog.d/13297.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `HTTPStatus` constants in place of literals in tests.
5 changes: 3 additions & 2 deletions tests/federation/test_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from http import HTTPStatus
from unittest.mock import Mock

from synapse.api.errors import Codes, SynapseError
Expand Down Expand Up @@ -50,7 +51,7 @@ def test_complexity_simple(self):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(200, channel.code)
self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)

Expand All @@ -62,7 +63,7 @@ def test_complexity_simple(self):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(200, channel.code)
self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)

Expand Down
11 changes: 6 additions & 5 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus

from parameterized import parameterized

Expand Down Expand Up @@ -58,7 +59,7 @@ def test_bad_request(self, query_content):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")


Expand Down Expand Up @@ -119,7 +120,7 @@ def test_needs_to_be_in_room(self):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")


Expand Down Expand Up @@ -153,7 +154,7 @@ def _make_join(self, user_id) -> JsonDict:
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body

def test_send_join(self):
Expand All @@ -171,7 +172,7 @@ def test_send_join(self):
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# we should get complete room state back
returned_state = [
Expand Down Expand Up @@ -226,7 +227,7 @@ def test_send_join_partial_state(self):
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# expect a reduced room state
returned_state = [
Expand Down
5 changes: 3 additions & 2 deletions tests/federation/transport/test_knocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List

from synapse.api.constants import EventTypes, JoinRules, Membership
Expand Down Expand Up @@ -255,7 +256,7 @@ def test_room_state_returned_when_knocking(self):
RoomVersions.V7.identifier,
),
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)

# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
Expand Down Expand Up @@ -293,7 +294,7 @@ def test_room_state_returned_when_knocking(self):
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)

# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
Expand Down
41 changes: 21 additions & 20 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for the password_auth_provider interface"""

from http import HTTPStatus
from typing import Any, Type, Union
from unittest.mock import Mock

Expand Down Expand Up @@ -188,14 +189,14 @@ def password_only_auth_provider_login_test_body(self):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()

# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
Expand All @@ -204,7 +205,7 @@ def password_only_auth_provider_login_test_body(self):
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
Expand Down Expand Up @@ -258,10 +259,10 @@ def local_user_fallback_login_test_body(self):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)

channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])

@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
Expand Down Expand Up @@ -382,7 +383,7 @@ def password_auth_disabled_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called()

@override_config(legacy_providers_config(LegacyCustomAuthProvider))
Expand All @@ -406,14 +407,14 @@ def custom_auth_provider_login_test_body(self):

# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()

mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
Expand All @@ -427,7 +428,7 @@ def custom_auth_provider_login_test_body(self):
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
Expand Down Expand Up @@ -510,7 +511,7 @@ def custom_auth_provider_callback_test_body(self):
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
Expand Down Expand Up @@ -549,7 +550,7 @@ def custom_auth_password_disabled_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()

@override_config(
Expand Down Expand Up @@ -584,7 +585,7 @@ def custom_auth_password_disabled_localdb_enabled_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()

@override_config(
Expand Down Expand Up @@ -615,7 +616,7 @@ def password_custom_auth_password_disabled_login_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()

Expand Down Expand Up @@ -646,13 +647,13 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self):
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"]

channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)

# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
Expand Down Expand Up @@ -721,7 +722,7 @@ def custom_auth_no_local_user_fallback_test_body(self):
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)

def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
Expand Down Expand Up @@ -884,7 +885,7 @@ def _test_3pid_allowed(self, username: str, registration: bool):
},
access_token=tok,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
Expand All @@ -906,7 +907,7 @@ def _test_3pid_allowed(self, username: str, registration: bool):
},
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body)

m.assert_called_once_with("email", "bar@test.com", registration)
Expand Down Expand Up @@ -949,12 +950,12 @@ def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body

def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]

def _send_password_login(self, user: str, password: str) -> FakeChannel:
Expand Down
16 changes: 8 additions & 8 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ def test_create_server_admin(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def test_create_user(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def test_create_user_mau_limit_reached_active_admin(self) -> None:
content={"password": "abc123", "admin": False},
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])

Expand Down Expand Up @@ -1550,7 +1550,7 @@ def test_create_user_mau_limit_reached_passive_admin(self) -> None:
)

# Admin user is not blocked by mau anymore
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])

Expand Down Expand Up @@ -1585,7 +1585,7 @@ def test_create_user_email_notif_for_new_users(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
Expand Down Expand Up @@ -1626,7 +1626,7 @@ def test_create_user_email_no_notif_for_new_users(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
Expand Down Expand Up @@ -1666,7 +1666,7 @@ def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> Non
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
Expand Down Expand Up @@ -2407,7 +2407,7 @@ def test_accidental_deactivation_prevention(self) -> None:
content={"password": "abc123"},
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])

Expand Down
Loading

0 comments on commit 96cf81e

Please sign in to comment.