Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3882 revision 1
Browse files Browse the repository at this point in the history
  • Loading branch information
hughns committed Apr 4, 2023
1 parent 6d10337 commit 308ce68
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 13 deletions.
5 changes: 5 additions & 0 deletions synapse/rest/client/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"enabled": self.config.experimental.msc3664_enabled,
}

if self.config.experimental.msc3882_enabled:
response["capabilities"]["org.matrix.msc3882.get_login_token"] = {
"enabled": True,
}

return HTTPStatus.OK, response


Expand Down
16 changes: 14 additions & 2 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def __init__(self, hs: "HomeServer"):
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)

# Whether MSC3882 get login token is enabled.
self._get_login_token_enabled = hs.config.experimental.msc3882_enabled

self.auth = hs.get_auth()

self.clock = hs.get_clock()
Expand Down Expand Up @@ -145,7 +148,12 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# to SSO.
flows.append({"type": LoginRestServlet.CAS_TYPE})

if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
if (
self.cas_enabled
or self.saml2_enabled
or self.oidc_enabled
or self._get_login_token_enabled
):
flows.append(
{
"type": LoginRestServlet.SSO_TYPE,
Expand All @@ -163,7 +171,11 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# don't know how to implement, since they (currently) will always
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
tokenTypeFlow: Dict[str, Any] = {"type": LoginRestServlet.TOKEN_TYPE}
# If MSC3882 is enabled we advertise the get_login_token flag.
if self._get_login_token_enabled:
tokenTypeFlow["org.matrix.msc3882.get_login_token"] = True
flows.append(tokenTypeFlow)

flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())

Expand Down
10 changes: 5 additions & 5 deletions synapse/rest/client/login_token_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LoginTokenRequestServlet(RestServlet):
Request:
POST /login/token HTTP/1.1
POST /login/get_token HTTP/1.1
Content-Type: application/json
{}
Expand All @@ -43,12 +43,12 @@ class LoginTokenRequestServlet(RestServlet):
HTTP/1.1 200 OK
{
"login_token": "ABDEFGH",
"expires_in": 3600,
"expires_in_ms": 3600000,
}
"""

PATTERNS = client_patterns(
"/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
"/org.matrix.msc3882/login/get_token$", releases=[], v1=False, unstable=True
)

def __init__(self, hs: "HomeServer"):
Expand Down Expand Up @@ -77,15 +77,15 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

login_token = await self.auth_handler.create_login_token_for_user_id(
user_id=requester.user.to_string(),
auth_provider_id="org.matrix.msc3882.login_token_request",
auth_provider_id="org.matrix.msc3882.get_login_token",
duration_ms=self.token_timeout,
)

return (
200,
{
"login_token": login_token,
"expires_in": self.token_timeout // 1000,
"expires_in_ms": self.token_timeout,
},
)

Expand Down
2 changes: 0 additions & 2 deletions synapse/rest/client/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds a ping endpoint for appservices to check HS->AS connection
"fi.mau.msc2659": self.config.experimental.msc2659_enabled,
# Adds support for login token requests as per MSC3882
"org.matrix.msc3882": self.config.experimental.msc3882_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
# Adds support for filtering /messages by event relation.
Expand Down
30 changes: 30 additions & 0 deletions tests/rest/client/test_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,33 @@ def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
self.assertGreater(len(details["support"]), 0)
for room_version in details["support"]:
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))

def test_get_does_not_include_msc3882_fields_when_disabled(self) -> None:
access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)

channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]

self.assertEqual(channel.code, HTTPStatus.OK)
self.assertTrue(
"org.matrix.msc3882.get_login_token" not in capabilities
or not capabilities["org.matrix.msc3882.get_login_token"]["enabled"]
)

@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_get_does_include_msc3882_fields_when_enabled(self) -> None:
access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)

channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]

self.assertEqual(channel.code, HTTPStatus.OK)
self.assertTrue(capabilities["org.matrix.msc3882.get_login_token"]["enabled"])
23 changes: 23 additions & 0 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,29 @@ def test_require_approval(self) -> None:
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)

def test_get_login_flows_with_msc3882_disabled(self) -> None:
"""GET /login should return m.login.token without get_login_token true"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)

flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
self.assertTrue(
"m.login.token" not in flows
or "org.matrix.msc3882.get_login_token" not in flows["m.login.token"]
or not flows["m.login.token"]["org.matrix.msc3882.get_login_token"]
)

@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_get_login_flows_with_msc3882_enabled(self) -> None:
"""GET /login should return m.login.token without get_login_token true"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)

print(channel.json_body)

flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
self.assertTrue(flows["m.login.token"]["org.matrix.msc3882.get_login_token"])


@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
Expand Down
8 changes: 4 additions & 4 deletions tests/rest/client/test_login_token_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tests import unittest
from tests.unittest import override_config

endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/get_token"


class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_uia_on(self) -> None:

channel = self.make_request("POST", endpoint, uia, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 300)
self.assertEqual(channel.json_body["expires_in_ms"], 300000)

login_token = channel.json_body["login_token"]

Expand All @@ -103,7 +103,7 @@ def test_uia_off(self) -> None:

channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 300)
self.assertEqual(channel.json_body["expires_in_ms"], 300000)

login_token = channel.json_body["login_token"]

Expand All @@ -130,4 +130,4 @@ def test_expires_in(self) -> None:

channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 15)
self.assertEqual(channel.json_body["expires_in_ms"], 15000)

0 comments on commit 308ce68

Please sign in to comment.