Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve typing in user_directory files #10891

Merged
merged 5 commits into from Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10891.misc
@@ -0,0 +1 @@
Improve type hinting in the user_directory code.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions mypy.ini
Expand Up @@ -85,9 +85,11 @@ files =
tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/handlers/test_sync.py,
tests/handlers/test_user_directory.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/storage/test_user_directory.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

Expand Down
124 changes: 89 additions & 35 deletions synapse/storage/databases/main/user_directory.py
Expand Up @@ -14,14 +14,28 @@

import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
cast,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached

logger = logging.getLogger(__name__)
Expand All @@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
Expand All @@ -57,10 +76,12 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)

async def _populate_user_directory_createtables(self, progress, batch_size):
async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:

# Get all the rooms that we want to process.
def _make_staging_area(txn):
def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
Expand Down Expand Up @@ -110,16 +131,20 @@ def _make_staging_area(txn):
)
return 1

async def _populate_user_directory_cleanup(self, progress, batch_size):
async def _populate_user_directory_cleanup(
self,
progress: JsonDict,
batch_size: int,
) -> int:
"""
Update the user directory stream position, then clean up the old tables.
"""
position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
TEMP_TABLE + "_position", {}, "position"
)
await self.update_user_directory_stream_pos(position)

def _delete_staging_area(txn):
def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
Expand All @@ -133,18 +158,32 @@ def _delete_staging_area(txn):
)
return 1

async def _populate_user_directory_process_rooms(self, progress, batch_size):
async def _populate_user_directory_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Rescan the state of all rooms so we can track

- who's in a public room;
- which local users share a private room with other users (local
and remote); and
- who should be in the user_directory.

Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.

Returns:
number of events processed.
"""
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()

def _get_next_batch(txn):
def _get_next_batch(
txn: LoggingTransaction,
) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
Expand All @@ -155,15 +194,17 @@ def _get_next_batch(txn):
TEMP_TABLE + "_rooms",
)
txn.execute(sql)
rooms_to_work_on = txn.fetchall()
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())

if not rooms_to_work_on:
return None

# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
progress["remaining"] = txn.fetchone()[0]
result = txn.fetchone()
assert result is not None
progress["remaining"] = result[0]

return rooms_to_work_on

Expand Down Expand Up @@ -261,29 +302,33 @@ def _get_next_batch(txn):

return processed_event_count

async def _populate_user_directory_process_users(self, progress, batch_size):
async def _populate_user_directory_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Add all local users to the user directory.
"""

def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users",
str(batch_size),
)
txn.execute(sql)
users_to_work_on = txn.fetchall()
user_result = cast(List[Tuple[str]], txn.fetchall())

if not users_to_work_on:
if not user_result:
return None

users_to_work_on = [x[0] for x in users_to_work_on]
users_to_work_on = [x[0] for x in user_result]

# Get how many are left to process, so we can give status on how
# far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql)
progress["remaining"] = txn.fetchone()[0]
count_result = txn.fetchone()
assert count_result is not None
progress["remaining"] = count_result[0]

return users_to_work_on

Expand Down Expand Up @@ -324,7 +369,7 @@ def _get_next_batch(txn):

return len(users_to_work_on)

async def is_room_world_readable_or_publicly_joinable(self, room_id):
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""

# Create a state filter that only queries join and history state event
Expand Down Expand Up @@ -368,7 +413,7 @@ async def update_profile_in_user_dir(
if not isinstance(avatar_url, str):
avatar_url = None

def _update_profile_in_user_dir_txn(txn):
def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
Expand Down Expand Up @@ -435,7 +480,7 @@ async def add_users_who_share_private_room(
for user_id, other_user_id in user_id_tuples
],
value_names=(),
value_values=None,
value_values=(),
desc="add_users_who_share_room",
)

Expand All @@ -454,14 +499,14 @@ async def add_users_in_public_rooms(
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
value_values=None,
value_values=(),
desc="add_users_in_public_rooms",
)

async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory"""

def _delete_all_from_user_dir_txn(txn):
def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
Expand All @@ -473,7 +518,7 @@ def _delete_all_from_user_dir_txn(txn):
)

@cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
Expand All @@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs)

self._prefer_local_users_in_search = (
Expand All @@ -506,7 +556,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._server_name = hs.config.server.server_name

async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
Expand All @@ -532,7 +582,7 @@ def _remove_from_user_dir_txn(txn):
"remove_from_user_dir", _remove_from_user_dir_txn
)

async def get_users_in_dir_due_to_room(self, room_id):
async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
Expand Down Expand Up @@ -565,7 +615,7 @@ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
room_id
"""

def _remove_user_who_share_room_txn(txn):
def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
Expand All @@ -586,7 +636,7 @@ def _remove_user_who_share_room_txn(txn):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)

async def get_user_dir_rooms_user_is_in(self, user_id):
async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
"""
Returns the rooms that a user is in.

Expand Down Expand Up @@ -628,7 +678,9 @@ async def get_shared_rooms_for_users(
A set of room ID's that the users share.
"""

def _get_shared_rooms_for_users_txn(txn):
def _get_shared_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
"""
SELECT p1.room_id
Expand Down Expand Up @@ -669,7 +721,9 @@ async def get_user_directory_stream_pos(self) -> Optional[int]:
desc="get_user_directory_stream_pos",
)

async def search_user_dir(self, user_id, search_term, limit):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
"""Searches for users in directory

Returns:
Expand Down Expand Up @@ -705,7 +759,7 @@ async def search_user_dir(self, user_id, search_term, limit):
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
ordering_arguments = ()
ordering_arguments: Tuple[str, ...] = ()

if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
Expand Down Expand Up @@ -811,7 +865,7 @@ async def search_user_dir(self, user_id, search_term, limit):
return {"limited": limited, "results": results}


def _parse_query_sqlite(search_term):
def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
Expand All @@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results)


def _parse_query_postgres(search_term):
def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
Expand Down
5 changes: 3 additions & 2 deletions tests/handlers/test_user_directory.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from unittest.mock import Mock

from twisted.internet import defer
Expand Down Expand Up @@ -285,7 +286,7 @@ def _compress_shared(self, shared):
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r

def get_users_in_public_rooms(self):
def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
r = self.get_success(
self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
Expand All @@ -296,7 +297,7 @@ def get_users_in_public_rooms(self):
retval.append((i["user_id"], i["room_id"]))
return retval

def get_users_who_share_private_rooms(self):
def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
return self.get_success(
self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
Expand Down