Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Modify StoreKeyFetcher to read from server_keys_json. #15417

Merged
merged 8 commits into from Apr 20, 2023
1 change: 1 addition & 0 deletions changelog.d/15417.bugfix
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
@@ -0,0 +1 @@
Fix a long-standing bug where cached key results which were directly fetched would not be properly re-used.
2 changes: 1 addition & 1 deletion synapse/crypto/keyring.py
Expand Up @@ -510,7 +510,7 @@ async def _fetch_keys(
for key_id in queue_value.key_ids
)

res = await self.store.get_server_signature_keys(key_ids_to_fetch)
res = await self.store.get_server_keys_json(key_ids_to_fetch)
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
Expand Down
80 changes: 76 additions & 4 deletions synapse/storage/databases/main/keys.py
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.

import itertools
import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple

from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
Expand Down Expand Up @@ -63,10 +65,12 @@ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""

# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
sql = """
SELECT server_name, key_id, verify_key, ts_valid_until_ms
FROM server_signature_keys WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)

txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

Expand Down Expand Up @@ -181,6 +185,74 @@ async def store_server_keys_json(
desc="store_server_keys_json",
)

# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate((((server_name, key_id),)))

@cached()
def _get_server_keys_json(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
)
async def get_server_keys_json(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a Collection rather than an Iterable? I thought we try to avoid passing iterables to DB queries because they might be exhausted when we come to retry them? (Or is this an Iterable versus Iterator thing?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? Don't we pass iterables in like everywhere?!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#11569 is what I had in mind.

I'm happy for this to land as-is (since it's no worse and should stop trusted key servers from spamming hosts). Though I would like to better understand if Iterables are still a problem that we should worry about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #11569 and #11564.

I think what we have is probably fine for now then?


Returns:
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}

def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""

# batch_iter always returns tuples so it's safe to do len(batch)
sql = """
SELECT server_name, key_id, key_json, ts_valid_until_ms
FROM server_keys_json WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)

txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
if ts_valid_until_ms is None:
# Old keys may be stored with a ts_valid_until_ms of null,
# in which case we treat this as if it was set to `0`, i.e.
# it won't match key requests that define a minimum
# `ts_valid_until_ms`.
ts_valid_until_ms = 0

# The entire signed JSON response is stored in server_keys_json,
# fetch out the bits needed.
key_json = json.loads(key_json_bytes)
key_base64 = key_json["verify_keys"][key_id]["key"]

keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(
key_id, decode_base64(key_base64)
),
valid_until_ts=ts_valid_until_ms,
)

def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys

return await self.db_pool.runInteraction("get_server_keys_json", _txn)

async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
Expand Down
50 changes: 23 additions & 27 deletions tests/crypto/test_keyring.py
Expand Up @@ -190,10 +190,23 @@ def test_verify_json_for_server(self) -> None:
kr = keyring.Keyring(self.hs)

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(
r = self.hs.get_datastores().main.store_server_keys_json(
"server9",
int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
get_key_id(key1),
from_server="test",
ts_now_ms=int(time.time() * 1000),
ts_expires_ms=1000,
# The entire response gets signed & stored, just include the bits we
# care about.
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
}
}
),
)
self.get_success(r)

Expand Down Expand Up @@ -280,10 +293,6 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))

kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
)

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(
"server9",
Expand All @@ -298,27 +307,12 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1)

# should fail immediately on an unsigned object
d = kr.verify_json_for_server("server9", {}, 0)
self.get_failure(d, SynapseError)

# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
d = kr.verify_json_for_server("server9", json1, 500)
self.get_failure(d, SynapseError)

# We expect the keyring tried to refetch the key once.
mock_fetcher.get_keys.assert_called_once_with(
"server9", [get_key_id(key1)], 500
)

# should succeed on a signed object with a 0 minimum_valid_until_ms
d = kr.verify_json_for_server(
"server9",
json1,
0,
d = self.hs.get_datastores().main.get_server_signature_keys(
[("server9", get_key_id(key1))]
)
self.get_success(d)
result = self.get_success(d)
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)

def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
Expand Down Expand Up @@ -464,7 +458,9 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict:
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote([lookup_triplet])
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
Expand Down