Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add login spam checker API #15838

Merged
merged 4 commits into from Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15838.feature
@@ -0,0 +1 @@
Add spam checker module API for logins.
36 changes: 36 additions & 0 deletions docs/modules/spam_checker_callbacks.md
Expand Up @@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.


### `check_login_for_spam`

_First introduced in Synapse v1.87.0_

```python
async def check_login_for_spam(
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
auth_provider_id: Optional[str] = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```

Called when a user logs in.

The arguments passed to this callback are:

* `user_id`: The user ID the user is logging in with
* `device_id`: The device ID the user is re-logging into.
* `initial_display_name`: The device display name, if any.
* `request_info`: A collection of tuples, which first item is a user agent, and which
second item is an IP address. These user agents and IP addresses are the ones that were
used during the login process.
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.

If multiple modules implement this callback, they will be considered in order. If a
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
be used. If this happens, Synapse will not call any of the subsequent implementations of
this callback.

*Note:* This will not be called when a user registers.


## Example

The example below is a module that implements the spam checker callback
Expand Down
11 changes: 11 additions & 0 deletions synapse/http/site.py
Expand Up @@ -521,6 +521,11 @@ def get_client_ip_if_available(self) -> str:
else:
return self.getClientAddress().host

def request_info(self) -> "RequestInfo":
h = self.getHeader(b"User-Agent")
user_agent = h.decode("ascii", "replace") if h else None
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())


class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
Expand Down Expand Up @@ -661,3 +666,9 @@ def request_factory(channel: HTTPChannel, queued: bool) -> Request:

def log(self, request: SynapseRequest) -> None:
pass


@attr.s(auto_attribs=True, frozen=True, slots=True)
class RequestInfo:
user_agent: Optional[str]
ip: str
3 changes: 3 additions & 0 deletions synapse/module_api/__init__.py
Expand Up @@ -80,6 +80,7 @@
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
Expand Down Expand Up @@ -302,6 +303,7 @@ def register_spam_checker_callbacks(
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.

Expand All @@ -319,6 +321,7 @@ def register_spam_checker_callbacks(
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
check_login_for_spam=check_login_for_spam,
)

def register_account_validity_callbacks(
Expand Down
80 changes: 80 additions & 0 deletions synapse/module_api/callbacks/spamchecker_callbacks.py
Expand Up @@ -196,6 +196,26 @@
]
],
]
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
[
str,
Optional[str],
Optional[str],
Collection[Tuple[Optional[str], str]],
Optional[str],
],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
]
],
]


def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
Expand Down Expand Up @@ -315,6 +335,7 @@ def __init__(self, hs: "synapse.server.HomeServer") -> None:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []

def register_callbacks(
self,
Expand All @@ -335,6 +356,7 @@ def register_callbacks(
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
Expand Down Expand Up @@ -378,6 +400,9 @@ def register_callbacks(
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)

if check_login_for_spam is not None:
self._check_login_for_spam_callbacks.append(check_login_for_spam)

@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
Expand Down Expand Up @@ -819,3 +844,58 @@ async def check_media_file_for_spam(
return synapse.api.errors.Codes.FORBIDDEN, {}

return self.NOT_SPAM

async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if we should allow the given registration request.

Args:
user_id: The request user ID
request_info: List of tuples of user agent and IP that
were used during the registration process.
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
"cas". If any. Note this does not include users registered
via a password provider.

Returns:
Enum for how the request should be handled
"""

for callback in self._check_login_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(
callback(
user_id,
device_id,
initial_display_name,
request_info,
auth_provider_id,
)
)
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is self.NOT_SPAM:
continue
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
else:
logger.warning(
"Module returned invalid value, rejecting login as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}

return self.NOT_SPAM
52 changes: 48 additions & 4 deletions synapse/rest/client/login.py
Expand Up @@ -50,7 +50,7 @@
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(self, hs: "HomeServer"):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._spam_checker = hs.get_module_api_callbacks().spam_checker

self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
Expand Down Expand Up @@ -197,6 +198,8 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
self._refresh_tokens_enabled and client_requested_refresh_token
)

request_info = request.request_info()

try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
Expand All @@ -216,6 +219,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif (
self.jwt_enabled
Expand All @@ -227,6 +231,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
Expand All @@ -235,6 +240,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
Expand All @@ -243,6 +249,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
Expand All @@ -265,6 +272,8 @@ async def _do_appservice_login(
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
Expand Down Expand Up @@ -300,10 +309,15 @@ async def _do_appservice_login(
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
request_info=request_info,
)

async def _do_other_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
) -> LoginResponse:
"""Handle non-token/saml/jwt logins

Expand Down Expand Up @@ -333,6 +347,7 @@ async def _do_other_login(
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
return result

Expand All @@ -347,6 +362,8 @@ async def _complete_login(
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
Expand All @@ -371,6 +388,7 @@ async def _complete_login(

This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
request_info: The user agent/IP address of the user.

Returns:
Dictionary of account information after successful login.
Expand Down Expand Up @@ -417,6 +435,22 @@ async def _complete_login(
)

initial_display_name = login_submission.get("initial_device_display_name")
spam_check = await self._spam_checker.check_login_for_spam(
user_id,
device_id=device_id,
initial_display_name=initial_display_name,
request_info=[(request_info.user_agent, request_info.ip)],
auth_provider_id=auth_provider_id,
)
if spam_check != self._spam_checker.NOT_SPAM:
logger.info("Blocking login due to spam checker")
raise SynapseError(
403,
msg="Login was blocked by the server",
errcode=spam_check[0],
additional_fields=spam_check[1],
)

(
device_id,
access_token,
Expand Down Expand Up @@ -451,7 +485,11 @@ async def _complete_login(
return result

async def _do_token_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
Expand All @@ -474,10 +512,15 @@ async def _do_token_login(
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
request_info=request_info,
)

async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
Expand All @@ -496,6 +539,7 @@ async def _do_jwt_login(
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)


Expand Down