Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Split fetching device keys and signatures into two transactions
Browse files Browse the repository at this point in the history
I think this is simpler (and moves stuff out of the db threads)
  • Loading branch information
richvdh committed Sep 3, 2020
1 parent 6f6f371 commit 7337c48
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 44 deletions.
1 change: 1 addition & 0 deletions changelog.d/8233.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
108 changes: 64 additions & 44 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])

# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
# cross-signing sigs on this device.
# dict from (signing user_id)->(signing device_id)->sig
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)


class EndToEndKeyWorkerStore(SQLBaseStore):
Expand Down Expand Up @@ -154,22 +156,57 @@ async def get_e2e_device_keys_and_signatures(

result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)

# get the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
(user_id, device_id)
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None
)

for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)

# add each cross-signing signature to the correct device in the result dict.
for row in cross_sigs_result:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]

target_device_result = result[target_user_id][target_device_id]
target_device_signatures = target_device_result.signatures

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature

log_kv(result)
return result

def _get_e2e_device_keys_and_signatures_txn(
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
signature_query_clauses = []
signature_query_params = []

if include_all_devices is False:
include_deleted_devices = False
Expand All @@ -180,20 +217,12 @@ def _get_e2e_device_keys_and_signatures_txn(
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)

if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)

signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)

query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)

sql = (
"SELECT user_id, device_id, "
Expand Down Expand Up @@ -221,41 +250,32 @@ def _get_e2e_device_keys_and_signatures_txn(
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None

# get signatures on the device
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)

txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn)

# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]

target_user_result = result.get(target_user_id)
if not target_user_result:
continue
return result

target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Dict]:
"""Get cross-signing signatures for a given list of devices
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
Returns signatures made by the owner of the devices. Each entry in the result
is a dict containing the fields from the database ('user_id', 'key_id',
'target_user_id', 'target_device_id', 'signature').
"""
signature_query_clauses = []
signature_query_params = []

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
for (user_id, device_id) in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signing_user_signatures[signing_key_id] = signature
signature_query_params.extend([user_id, device_id, user_id])

return result
signature_sql = "SELECT * FROM e2e_cross_signing_signatures WHERE %s" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)

txn.execute(signature_sql, signature_query_params)
return self.db_pool.cursor_to_dict(txn)

async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
Expand Down

0 comments on commit 7337c48

Please sign in to comment.