Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to media repository storage module #11311

Merged
merged 2 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11311.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ exclude = (?x)
|synapse/storage/databases/main/events_forward_extremities.py
|synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/media_repository.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/media/v1/preview_url_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
Expand Down Expand Up @@ -231,7 +231,7 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)

async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview

Args:
Expand Down Expand Up @@ -360,7 +360,7 @@ async def _do_preview(self, url: str, user: str, ts: int) -> bytes:

return jsonog.encode("utf8")

async def _download_url(self, url: str, user: str) -> MediaInfo:
async def _download_url(self, url: str, user: UserID) -> MediaInfo:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
Expand Down Expand Up @@ -450,7 +450,7 @@ async def _download_url(self, url: str, user: str) -> MediaInfo:
)

async def _precache_image_url(
self, user: str, media_info: MediaInfo, og: JsonDict
self, user: UserID, media_info: MediaInfo, og: JsonDict
) -> None:
"""
Pre-cache the image (if one exists) for posterity
Expand Down
141 changes: 84 additions & 57 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -46,7 +61,12 @@ class MediaSortOrder(Enum):


class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
Expand Down Expand Up @@ -102,13 +122,15 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._drop_media_index_without_method,
)

async def _drop_media_index_without_method(self, progress, batch_size):
async def _drop_media_index_without_method(
self, progress: JsonDict, batch_size: int
) -> int:
"""background update handler which removes the old constraints.

Note that this is only run on postgres.
"""

def f(txn):
def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
)
Expand All @@ -126,7 +148,12 @@ def f(txn):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""

def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname

Expand Down Expand Up @@ -174,7 +201,9 @@ async def get_local_media_by_user_paginate(
plus the total count of all the user's media
"""

def get_local_media_by_user_paginate_txn(txn):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:

# Set ordering
order_by_column = MediaSortOrder(order_by).value
Expand All @@ -184,14 +213,14 @@ def get_local_media_by_user_paginate_txn(txn):
else:
order = "ASC"

args = [user_id]
args: List[Union[str, int]] = [user_id]
sql = """
SELECT COUNT(*) as total_media
FROM local_media_repository
WHERE user_id = ?
"""
txn.execute(sql, args)
count = txn.fetchone()[0]
count = txn.fetchone()[0] # type: ignore[index]

sql = """
SELECT
Expand Down Expand Up @@ -268,7 +297,7 @@ async def get_local_media_before(
)
sql += sql_keep

def _get_local_media_before_txn(txn):
def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]

Expand All @@ -278,13 +307,13 @@ def _get_local_media_before_txn(txn):

async def store_local_media(
self,
media_id,
media_type,
time_now_ms,
upload_name,
media_length,
user_id,
url_cache=None,
media_id: str,
media_type: str,
time_now_ms: int,
upload_name: Optional[str],
media_length: int,
user_id: UserID,
url_cache: Optional[str] = None,
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
Expand Down Expand Up @@ -315,7 +344,7 @@ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
None if the URL isn't cached.
"""

def get_url_cache_txn(txn):
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# get the most recently cached result (relative to the given ts)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
Expand Down Expand Up @@ -359,7 +388,7 @@ def get_url_cache_txn(txn):

async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
) -> None:
await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
Expand Down Expand Up @@ -390,13 +419,13 @@ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]

async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
media_id: str,
thumbnail_width: int,
thumbnail_height: int,
thumbnail_type: str,
thumbnail_method: str,
thumbnail_length: int,
) -> None:
await self.db_pool.simple_upsert(
table="local_media_repository_thumbnails",
keyvalues={
Expand Down Expand Up @@ -430,14 +459,14 @@ async def get_cached_remote_media(

async def store_cached_remote_media(
self,
origin,
media_id,
media_type,
media_length,
time_now_ms,
upload_name,
filesystem_id,
):
origin: str,
media_id: str,
media_type: str,
media_length: int,
time_now_ms: int,
upload_name: Optional[str],
filesystem_id: str,
) -> None:
await self.db_pool.simple_insert(
"remote_media_cache",
{
Expand All @@ -458,7 +487,7 @@ async def update_cached_last_access_time(
local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]],
time_ms: int,
):
) -> None:
"""Updates the last access time of the given media

Args:
Expand All @@ -467,7 +496,7 @@ async def update_cached_last_access_time(
time_ms: Current time in milliseconds
"""

def update_cache_txn(txn):
def update_cache_txn(txn: LoggingTransaction) -> None:
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?"
Expand All @@ -488,7 +517,7 @@ def update_cache_txn(txn):

txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))

return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)

Expand Down Expand Up @@ -542,15 +571,15 @@ async def get_remote_media_thumbnail(

async def store_remote_media_thumbnail(
self,
origin,
media_id,
filesystem_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
origin: str,
media_id: str,
filesystem_id: str,
thumbnail_width: int,
thumbnail_height: int,
thumbnail_type: str,
thumbnail_method: str,
thumbnail_length: int,
) -> None:
await self.db_pool.simple_upsert(
table="remote_media_cache_thumbnails",
keyvalues={
Expand All @@ -566,7 +595,7 @@ async def store_remote_media_thumbnail(
desc="store_remote_media_thumbnail",
)

async def get_remote_media_before(self, before_ts):
async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
Expand Down Expand Up @@ -602,26 +631,24 @@ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
" LIMIT 500"
)

def _get_expired_url_cache_txn(txn):
def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]

return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)

async def delete_url_cache(self, media_ids):
async def delete_url_cache(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0:
return

sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"

def _delete_url_cache_txn(txn):
def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])

return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
)
await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)

async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
Expand All @@ -631,19 +658,19 @@ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
" LIMIT 500"
)

def _get_url_cache_media_before_txn(txn):
def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]

return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)

async def delete_url_cache_media(self, media_ids):
async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0:
return

def _delete_url_cache_media_txn(txn):
def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
sql = "DELETE FROM local_media_repository WHERE media_id = ?"

txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
Expand All @@ -652,6 +679,6 @@ def _delete_url_cache_media_txn(txn):

txn.execute_batch(sql, [(media_id,) for media_id in media_ids])

return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)