Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve typing in user_directory files #10891

Merged
merged 5 commits into from Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10891.misc
@@ -0,0 +1 @@
Improve type hinting in the user directory code.
2 changes: 2 additions & 0 deletions mypy.ini
Expand Up @@ -85,9 +85,11 @@ files =
tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/handlers/test_sync.py,
tests/handlers/test_user_directory.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/storage/test_user_directory.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

Expand Down
124 changes: 89 additions & 35 deletions synapse/storage/databases/main/user_directory.py
Expand Up @@ -14,14 +14,28 @@

import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
cast,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached

logger = logging.getLogger(__name__)
Expand All @@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
Expand All @@ -57,10 +76,12 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)

async def _populate_user_directory_createtables(self, progress, batch_size):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:

# Get all the rooms that we want to process.
def _make_staging_area(txn):
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
Expand Down Expand Up @@ -110,16 +131,20 @@ def _make_staging_area(txn):
)
return 1

async def _populate_user_directory_cleanup(self, progress, batch_size):
async def _populate_user_directory_cleanup(
self,
progress: JsonDict,
batch_size: int,
) -> int:
"""
Update the user directory stream position, then clean up the old tables.
"""
position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
TEMP_TABLE + "_position", {}, "position"
)
await self.update_user_directory_stream_pos(position)

def _delete_staging_area(txn):
def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
Expand All @@ -133,18 +158,32 @@ def _delete_staging_area(txn):
)
return 1

async def _populate_user_directory_process_rooms(self, progress, batch_size):
async def _populate_user_directory_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Rescan the state of all rooms so we can track

- who's in a public room;
- which local users share a private room with other users (local
and remote); and
- who should be in the user_directory.

Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.

Returns:
number of events processed.
"""
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()

def _get_next_batch(txn):
def _get_next_batch(
txn: LoggingTransaction,
) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
Expand All @@ -155,15 +194,17 @@ def _get_next_batch(txn):
TEMP_TABLE + "_rooms",
)
txn.execute(sql)
rooms_to_work_on = txn.fetchall()
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())

if not rooms_to_work_on:
return None

# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
progress["remaining"] = txn.fetchone()[0]
result = txn.fetchone()
assert result is not None
progress["remaining"] = result[0]

return rooms_to_work_on

Expand Down Expand Up @@ -261,29 +302,33 @@ def _get_next_batch(txn):

return processed_event_count

async def _populate_user_directory_process_users(self, progress, batch_size):
async def _populate_user_directory_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Add all local users to the user directory.
"""

def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users",
str(batch_size),
)
txn.execute(sql)
users_to_work_on = txn.fetchall()
user_result = cast(List[Tuple[str]], txn.fetchall())

if not users_to_work_on:
if not user_result:
return None

users_to_work_on = [x[0] for x in users_to_work_on]
users_to_work_on = [x[0] for x in user_result]

# Get how many are left to process, so we can give status on how
# far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql)
progress["remaining"] = txn.fetchone()[0]
count_result = txn.fetchone()
assert count_result is not None
progress["remaining"] = count_result[0]

return users_to_work_on

Expand Down Expand Up @@ -324,7 +369,7 @@ def _get_next_batch(txn):

return len(users_to_work_on)

async def is_room_world_readable_or_publicly_joinable(self, room_id):
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""

# Create a state filter that only queries join and history state event
Expand Down Expand Up @@ -368,7 +413,7 @@ async def update_profile_in_user_dir(
if not isinstance(avatar_url, str):
avatar_url = None

def _update_profile_in_user_dir_txn(txn):
def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
Expand Down Expand Up @@ -435,7 +480,7 @@ async def add_users_who_share_private_room(
for user_id, other_user_id in user_id_tuples
],
value_names=(),
value_values=None,
value_values=(),
desc="add_users_who_share_room",
)

Expand All @@ -454,14 +499,14 @@ async def add_users_in_public_rooms(
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
value_values=None,
value_values=(),
desc="add_users_in_public_rooms",
)

async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory"""

def _delete_all_from_user_dir_txn(txn):
def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
Expand All @@ -473,7 +518,7 @@ def _delete_all_from_user_dir_txn(txn):
)

@cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
Expand All @@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs)

self._prefer_local_users_in_search = (
Expand All @@ -506,7 +556,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._server_name = hs.config.server.server_name

async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
Expand All @@ -532,7 +582,7 @@ def _remove_from_user_dir_txn(txn):
"remove_from_user_dir", _remove_from_user_dir_txn
)

async def get_users_in_dir_due_to_room(self, room_id):
async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
Expand Down Expand Up @@ -565,7 +615,7 @@ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
room_id
"""

def _remove_user_who_share_room_txn(txn):
def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
Expand All @@ -586,7 +636,7 @@ def _remove_user_who_share_room_txn(txn):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)

async def get_user_dir_rooms_user_is_in(self, user_id):
async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
"""
Returns the rooms that a user is in.

Expand Down Expand Up @@ -628,7 +678,9 @@ async def get_shared_rooms_for_users(
A set of room ID's that the users share.
"""

def _get_shared_rooms_for_users_txn(txn):
def _get_shared_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
"""
SELECT p1.room_id
Expand Down Expand Up @@ -669,7 +721,9 @@ async def get_user_directory_stream_pos(self) -> Optional[int]:
desc="get_user_directory_stream_pos",
)

async def search_user_dir(self, user_id, search_term, limit):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
"""Searches for users in directory

Returns:
Expand Down Expand Up @@ -705,7 +759,7 @@ async def search_user_dir(self, user_id, search_term, limit):
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
ordering_arguments = ()
ordering_arguments: Tuple[str, ...] = ()

if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
Expand Down Expand Up @@ -811,7 +865,7 @@ async def search_user_dir(self, user_id, search_term, limit):
return {"limited": limited, "results": results}


def _parse_query_sqlite(search_term):
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
Expand All @@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results)


def _parse_query_postgres(search_term):
def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
Expand Down
5 changes: 3 additions & 2 deletions tests/handlers/test_user_directory.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from unittest.mock import Mock

from twisted.internet import defer
Expand Down Expand Up @@ -285,7 +286,7 @@ def _compress_shared(self, shared):
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r

def get_users_in_public_rooms(self):
def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
r = self.get_success(
self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
Expand All @@ -296,7 +297,7 @@ def get_users_in_public_rooms(self):
retval.append((i["user_id"], i["room_id"]))
return retval

def get_users_who_share_private_rooms(self):
def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
return self.get_success(
self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
Expand Down