Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3983 to proxy /keys/claim queries to appservices.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 27, 2023
1 parent 6cd7f9f commit 17377a0
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 5 deletions.
1 change: 1 addition & 0 deletions changelog.d/15314.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
5 changes: 5 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
"msc3202_transaction_extensions", False
)

# MSC3983: Proxying OTK claim requests to exclusive ASes.
self.msc3983_appservice_otk_claims: bool = experimental.get(
"msc3983_appservice_otk_claims", False
)

# MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve
Expand Down
72 changes: 71 additions & 1 deletion synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)

from prometheus_client import Counter

Expand Down Expand Up @@ -829,3 +838,64 @@ async def _check_user_exists(self, user_id: str) -> bool:
if unknown_user:
return await self.query_user_exists(user_id)
return True

async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
"""Claim one time keys from application services.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
providing OTKs).
"""
services = self.store.get_app_services()

# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
missing = []
for user_id, device, algorithm in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
continue

# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
)
continue

# Query each service in parallel.
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
self.appservice_api.claim_client_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
service_query,
)
for service_id, service_query in query_by_appservice.items()
],
consumeErrors=True,
)
)

# Patch together the results.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
for success, result in results:
if success:
claimed_keys.append(result[0])
missing.extend(result[1])

return claimed_keys, missing
25 changes: 22 additions & 3 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple

Expand Down Expand Up @@ -53,6 +52,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

Expand Down Expand Up @@ -88,6 +88,10 @@ def __init__(self, hs: "HomeServer"):
max_count=10,
)

self._query_appservices_for_otks = (
hs.config.experimental.msc3983_appservice_otk_claims
)

@trace
@cancellable
async def query_devices(
Expand Down Expand Up @@ -548,7 +552,8 @@ async def claim_local_one_time_keys(
"""Claim one time keys for local users.
1. Attempt to claim OTKs from the database.
2. Attempt to fetch fallback keys from the database.
2. Ask application services if they provide OTKs.
3. Attempt to fetch fallback keys from the database.
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
Expand All @@ -559,10 +564,21 @@ async def claim_local_one_time_keys(

otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)

# If the application services have not provided any keys via the C-S
# API, query it directly.
if self._query_appservices_for_otks:
# Query the appservices for any OTKs.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []

# For any *still* remaining users, try fall-back keys.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)

return (otk_results, fallback_results)
return (otk_results, *appservice_results, fallback_results)

@trace
async def claim_one_time_keys(
Expand Down Expand Up @@ -593,6 +609,9 @@ async def claim_one_time_keys(
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}

# Remote failures.
failures: Dict[str, JsonDict] = {}

@trace
async def claim_client_keys(destination: str) -> None:
set_tag("destination", destination)
Expand Down
76 changes: 75 additions & 1 deletion tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@

from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config


class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock())
self.appservice_api = mock.Mock()
return self.setup_test_homeserver(
federation_client=mock.Mock(), application_service_api=self.appservice_api
)

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
Expand Down Expand Up @@ -941,3 +947,71 @@ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> Non

# The two requests to the local homeserver should be identical.
self.assertEqual(response_1, response_2)

@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
def test_query_appservice(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id_1 = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
device_id_2 = "abc"
otk = {"alg1:k2": "key2"}

# Inject an appservice interested in this user.
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
[appservice]
)

# Setup a response, but only for device 2.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
)

# we shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(res, [])

self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id_1,
{"fallback_keys": fallback_key},
)
)

# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])

# claiming an OTK when no OTKs are available should ask the appservice, then
# query the fallback keys.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{
"one_time_keys": {
local_user: {device_id_1: "alg1", device_id_2: "alg1"}
}
},
timeout=None,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
local_user: {device_id_1: fallback_key, device_id_2: otk}
},
},
)

0 comments on commit 17377a0

Please sign in to comment.