Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add support for claiming multiple OTKs at once. #15468

Merged
merged 9 commits into from Apr 27, 2023
1 change: 1 addition & 0 deletions changelog.d/15468.misc
@@ -0,0 +1 @@
Support claiming more than one OTK at a time.
29 changes: 20 additions & 9 deletions synapse/appservice/api.py
Expand Up @@ -442,8 +442,10 @@ async def push_bulk(
return False

async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from an application service.

Note that any error (including a timeout) is treated as the application
Expand All @@ -469,8 +471,10 @@ async def claim_client_keys(

# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm in query:
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
for user_id, device, algorithm, count in query:
body.setdefault(user_id, {}).setdefault(device, []).extend(
[algorithm] * count
)

uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
Expand All @@ -493,11 +497,18 @@ async def claim_client_keys(
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
missing = [
(user_id, device, algorithm)
for user_id, device, algorithm in query
if algorithm not in response.get(user_id, {}).get(device, [])
]
missing = []
for user_id, device, algorithm, count in query:
# The number of keys responded for this algorithm.
response_count = sum(
clokep marked this conversation as resolved.
Show resolved Hide resolved
key_id.startswith(f"{algorithm}:")
for key_id in response.get(user_id, {}).get(device, {})
)
count -= response_count
# If the appservice responds with fewer keys than requested, then
# consider the request unfulfilled.
if count > 0:
missing.append((user_id, device, algorithm, count))

return response, missing

Expand Down
47 changes: 46 additions & 1 deletion synapse/federation/federation_client.py
Expand Up @@ -235,7 +235,10 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: Optional[int]
self,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.

Expand All @@ -247,6 +250,48 @@ async def claim_client_keys(
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()

# Convert the query with counts into a stable and unstable query and check
# if attempting to claim more than 1 OTK.
content: Dict[str, Dict[str, str]] = {}
unstable_content: Dict[str, Dict[str, List[str]]] = {}
use_unstable = False
for user_id, one_time_keys in query.items():
for device_id, algorithms in one_time_keys.items():
if any(count > 1 for count in algorithms.values()):
use_unstable = True
if algorithms:
# Choose the first algorithm only for the stable query.
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
# Flatten the map of algorithm -> count to a list repeating
# each algorithm count times for the unstable query.
unstable_content.setdefault(user_id, {})[device_id] = list(
itertools.chain(
*(
itertools.repeat(algorithm, count)
for algorithm, count in algorithms.items()
)
)
)
MatMaul marked this conversation as resolved.
Show resolved Hide resolved

if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise

logger.debug(
"Couldn't claim client keys with the unstable API, falling back to the v1 API"
)
else:
logger.debug("Skipping unstable claim client keys API")

return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
Expand Down
7 changes: 1 addition & 6 deletions synapse/federation/federation_server.py
Expand Up @@ -1005,13 +1005,8 @@ async def on_query_user_devices(

@trace
async def on_claim_client_keys(
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))

log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
Expand Down
45 changes: 44 additions & 1 deletion synapse/federation/transport/client.py
Expand Up @@ -669,7 +669,50 @@ async def claim_client_keys(
path = _create_v1_path("/user/keys/claim")

return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)

async def claim_client_keys_unstable(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {"<algorithm>": <count>}
}
}
}

Response:
{
"device_keys": {
clokep marked this conversation as resolved.
Show resolved Hide resolved
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
}
}
}
}

Args:
destination: The server to query.
query_content: The user ids to query.
Returns:
A dict containing the one-time keys.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")

return await self.client.post_json(
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)

async def get_missing_events(
Expand Down
25 changes: 21 additions & 4 deletions synapse/federation/transport/server/federation.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
Expand Down Expand Up @@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm, which is hard-coded to 1.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
key_query.append((user_id, device_id, algorithm, 1))

response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
key_query, always_include_fallback_keys=False
)
return 200, response


class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
Identical to the stable endpoint (FederationClientKeysClaimServlet) except
it allows for querying for multiple OTKs at once and always includes fallback
keys in the response.
"""

PREFIX = FEDERATION_UNSTABLE_PREFIX
Expand All @@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithms in device_keys.items():
counts = Counter(algorithms)
for algorithm, count in counts.items():
key_query.append((user_id, device_id, algorithm, count))

response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
key_query, always_include_fallback_keys=True
)
return 200, response

Expand Down Expand Up @@ -805,6 +821,7 @@ async def on_POST(
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
Expand Down
14 changes: 8 additions & 6 deletions synapse/handlers/appservice.py
Expand Up @@ -841,8 +841,10 @@ async def _check_user_exists(self, user_id: str) -> bool:
return True

async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
self, query: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from application services.

Users which are exclusively owned by an application service are sent a
Expand All @@ -863,18 +865,18 @@ async def claim_e2e_one_time_keys(
services = self.store.get_app_services()

# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
missing = []
for user_id, device, algorithm in query:
for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
missing.append((user_id, device, algorithm, count))
continue

# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
(user_id, device, algorithm, count)
)
continue

Expand Down
31 changes: 21 additions & 10 deletions synapse/handlers/e2e_keys.py
Expand Up @@ -564,7 +564,7 @@ async def on_federation_query_client_keys(

async def claim_local_one_time_keys(
self,
local_query: List[Tuple[str, str, str]],
local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
Expand All @@ -581,6 +581,12 @@ async def claim_local_one_time_keys(
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""

# Cap the number of OTKs that can be claimed at once to avoid abuse.
local_query = [
(user_id, device_id, algorithm, min(count, 5))
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
for user_id, device_id, algorithm, count in local_query
]

otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)

# If the application services have not provided any keys via the C-S
Expand All @@ -607,7 +613,7 @@ async def claim_local_one_time_keys(
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
for user_id, device_id, algorithm, _count in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
Expand All @@ -630,13 +636,17 @@ async def claim_local_one_time_keys(
.get(device_id, {})
.keys()
)
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
fallback_query.append((user_id, device_id, algorithm, mark_as_used))

else:
# All fallback keys get marked as used.
fallback_query = [
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
for user_id, device_id, algorithm, count in not_found
]

# For each user that does not have a one-time keys available, see if
Expand All @@ -650,18 +660,19 @@ async def claim_local_one_time_keys(
@trace
async def claim_one_time_keys(
self,
query: Dict[str, Dict[str, Dict[str, str]]],
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}

for user_id, one_time_keys in query.get("one_time_keys", {}).items():
for user_id, one_time_keys in query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
for device_id, algorithms in one_time_keys.items():
for algorithm, count in algorithms.items():
local_query.append((user_id, device_id, algorithm, count))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
Expand Down Expand Up @@ -692,7 +703,7 @@ async def claim_client_keys(destination: str) -> None:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout
destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
Expand Down