Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make get_user_by_id return a proper type
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Sep 14, 2023
1 parent af0d8ce commit 1fdae4e
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 114 deletions.
2 changes: 1 addition & 1 deletion synapse/api/auth/internal.py
Expand Up @@ -268,7 +268,7 @@ async def get_user_by_access_token(
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
if not stored_user.is_guest:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/api/auth/msc3861_delegated.py
Expand Up @@ -300,7 +300,7 @@ async def get_user_by_access_token(
user_id = UserID(username, self._hostname)

# First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
# If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/account.py
Expand Up @@ -102,7 +102,7 @@ async def _get_local_account_status(self, user_id: UserID) -> JsonDict:
"""
status = {"exists": False}

userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())

if userinfo is not None:
status = {
Expand Down
45 changes: 19 additions & 26 deletions synapse/handlers/admin.py
Expand Up @@ -18,7 +18,7 @@

from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,37 +57,30 @@ async def get_whois(self, user: UserID) -> JsonDict:

async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None

# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"locked",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
}

if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")

# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved

# Add additional user metadata
profile = await self._store.get_profileinfo(user)
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/message.py
Expand Up @@ -828,13 +828,13 @@ async def assert_accepted_privacy_policy(self, requester: Requester) -> None:

u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
if u["appservice_id"] is not None:
if u.appservice_id is not None:
# users registered by an appservice are exempt
return
if u["consent_version"] == self.config.consent.user_consent_version:
if u.consent_version == self.config.consent.user_consent_version:
return

consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
Expand Down
4 changes: 2 additions & 2 deletions synapse/module_api/__init__.py
Expand Up @@ -572,7 +572,7 @@ async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
Returns:
UserInfo object if a user was found, otherwise None
"""
return await self._store.get_userinfo_by_id(user_id)
return await self._store.get_user_by_id(user_id)

async def get_user_by_req(
self,
Expand Down Expand Up @@ -1878,7 +1878,7 @@ async def put_global(
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")

# Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None:
if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.")

await self._handler.add_account_data_for_user(user_id, data_type, new_data)
2 changes: 1 addition & 1 deletion synapse/rest/consent/consent_resource.py
Expand Up @@ -129,7 +129,7 @@ async def _async_render_GET(self, request: Request) -> None:
if u is None:
raise NotFoundError("Unknown user")

has_consented = u["consent_version"] == version
has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii")

try:
Expand Down
6 changes: 3 additions & 3 deletions synapse/server_notices/consent_server_notices.py
Expand Up @@ -79,15 +79,15 @@ async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
if u is None:
return

if u["is_guest"] and not self._send_to_guests:
if u.is_guest and not self._send_to_guests:
# don't send to guests
return

if u["consent_version"] == self._current_consent_version:
if u.consent_version == self._current_consent_version:
# user has already consented
return

if u["consent_server_notice_sent"] == self._current_consent_version:
if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user
return

Expand Down
69 changes: 21 additions & 48 deletions synapse/storage/databases/main/registration.py
Expand Up @@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import attr

Expand Down Expand Up @@ -192,8 +192,8 @@ def __init__(
)

@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""

def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
Expand All @@ -202,7 +202,7 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
Expand All @@ -224,50 +224,23 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
desc="get_user_by_id",
func=get_user_by_id_txn,
)

if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])

return row

async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.
Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
if row is None:
return None

return UserInfo(
appservice_id=user_data["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"],
consent_version=user_data["consent_version"],
creation_ts=user_data["creation_ts"],
is_admin=bool(user_data["admin"]),
is_deactivated=bool(user_data["deactivated"]),
is_guest=bool(user_data["is_guest"]),
is_shadow_banned=bool(user_data["shadow_banned"]),
user_id=UserID.from_string(user_data["name"]),
user_type=user_data["user_type"],
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)

async def is_trial_user(self, user_id: str) -> bool:
Expand All @@ -285,10 +258,10 @@ async def is_trial_user(self, user_id: str) -> bool:

now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days
info.appservice_id, self.config.server.mau_trial_days
)
trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial

@cached()
Expand Down
8 changes: 7 additions & 1 deletion synapse/types/__init__.py
Expand Up @@ -933,31 +933,37 @@ def get_verify_key_from_cross_signing_key(

@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id.
"""Holds information about a user. Result of get_user_by_id.
Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
"""

user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
approved: bool
locked: bool


class UserProfile(TypedDict):
Expand Down
12 changes: 9 additions & 3 deletions tests/api/test_auth.py
Expand Up @@ -188,8 +188,11 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})

class FakeUserInfo:
is_guest = False

self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
Expand Down Expand Up @@ -341,7 +344,10 @@ def test_get_user_from_macaroon(self) -> None:
)

def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
class FakeUserInfo:
is_guest = True

self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)

user_id = "@baldrick:matrix.org"
Expand Down

0 comments on commit 1fdae4e

Please sign in to comment.