Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Enforce validity period on server_keys for fed requests. #5321

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5321.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.
159 changes: 104 additions & 55 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
from collections import defaultdict

import six
from six import raise_from
Expand Down Expand Up @@ -70,6 +71,9 @@ class VerifyKeyRequest(object):

json_object(dict): The JSON object to verify.

minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care)

deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
Expand All @@ -82,22 +86,25 @@ class VerifyKeyRequest(object):
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
deferred = attr.ib()
minimum_valid_until_ts = attr.ib()
deferred = attr.ib(factory=defer.Deferred)


class KeyLookupError(ValueError):
pass


class Keyring(object):
def __init__(self, hs):
def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()

self._key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
if key_fetchers is None:
key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
self._key_fetchers = key_fetchers

# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
Expand All @@ -106,20 +113,38 @@ def __init__(self, hs):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}

def verify_json_for_server(self, server_name, json_object):
def verify_json_for_server(self, server_name, json_object, validity_time):
"""Verify that a JSON object has been signed by a given server

Args:
server_name (str): name of the server which must have signed this object

json_object (dict): object to be checked

validity_time (int): timestamp at which we require the signing key to
be valid. (0 implies we don't care)

Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = server_name, json_object, validity_time

return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server([(server_name, json_object)])[0]
self.verify_json_objects_for_server((req,))[0]
)

def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.

Args:
server_and_json (list): List of pairs of (server_name, json_object)
server_and_json (iterable[Tuple[str, dict, int]):
Iterable of triplets of (server_name, json_object, validity_time)
validity_time is a timestamp at which the signing key must be valid.

Returns:
List<Deferred>: for each input pair, a deferred indicating success
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
Expand All @@ -128,12 +153,12 @@ def verify_json_objects_for_server(self, server_and_json):
verify_requests = []
handle = preserve_fn(_handle_key_deferred)

def process(server_name, json_object):
def process(server_name, json_object, validity_time):
"""Process an entry in the request list

Given a (server_name, json_object) pair from the request list,
adds a key request to verify_requests, and returns a deferred which will
complete or fail (in the sentinel context) when verification completes.
Given a (server_name, json_object, validity_time) triplet from the request
list, adds a key request to verify_requests, and returns a deferred which
will complete or fail (in the sentinel context) when verification completes.
"""
key_ids = signature_ids(json_object, server_name)

Expand All @@ -148,7 +173,7 @@ def process(server_name, json_object):

# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred()
server_name, key_ids, json_object, validity_time
)
verify_requests.append(verify_request)

Expand All @@ -160,8 +185,8 @@ def process(server_name, json_object):
return handle(verify_request)

results = [
process(server_name, json_object)
for server_name, json_object in server_and_json
process(server_name, json_object, validity_time)
for server_name, json_object, validity_time in server_and_json
]

if verify_requests:
Expand Down Expand Up @@ -298,8 +323,12 @@ def do_iterations():
verify_request.deferred.errback(
SynapseError(
401,
"No key for %s with id %s"
% (verify_request.server_name, verify_request.key_ids),
"No key for %s with ids in %s (min_validity %i)"
% (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
),
Codes.UNAUTHORIZED,
)
)
Expand All @@ -325,16 +354,20 @@ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be reomved from the list.
"""
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
# dict[str, dict[str, int]]: keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict)

for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
keys_for_server = missing_keys[verify_request.server_name]
for key_id in verify_request.key_ids:
current_min_ts = keys_for_server.get(key_id, -1)
if current_min_ts < verify_request.minimum_valid_until_ts:
keys_for_server[key_id] = verify_request.minimum_valid_until_ts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can haz comment pls? I'm struggling to follow the logic here. We're taking the maximum minimum_valid_untl_ts? Maybe this can be written as:

for key_id in verify_request.key_ids:
   current_min_ts = keys_for_server.get(key_id, -1)
   keys_for_server[key_id] = max(keys_for_server[key_id], current_min_ts)

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the remote server respond with the key with the latest valid_until_ts even if that is less than the requested if it can't find a later one? If not will that cause problems where the key may have been valid for some of the key requests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the remote server respond with the key with the latest valid_until_ts even if that is less than the requested if it can't find a later one? If not will that cause problems where the key may have been valid for some of the key requests?

Hum, apparently a notary server will not respond with such a key. And yes, it probably will. I'll try and get that changed on the notary server impl before landing this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can haz comment pls?

done

Maybe this can be written as:

I'm not entirely convinced it's clearer, but have tweaked it anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, apparently a notary server will not respond with such a key.

This is apparently incorrect, as tested by matrix-org/sytest#620.


results = yield fetcher.get_keys(missing_keys.items())
results = yield fetcher.get_keys(missing_keys)

completed = list()
for verify_request in remaining_requests:
Expand All @@ -344,25 +377,34 @@ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
key = result_keys.get(key_id)
if key:
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, key.verify_key)
)
completed.append(verify_request)
break
fetch_key_result = result_keys.get(key_id)
if not fetch_key_result:
# we didn't get a result for this key
continue

if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue

with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, fetch_key_result.verify_key)
)
completed.append(verify_request)
break

remaining_requests.difference_update(completed)


class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Note that the iterables may be iterated more than once.
keys_to_fetch (dict[str, dict[str, int]]):
the keys to be fetched. server_name -> key_id -> min_valid_ts

Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
Expand All @@ -378,13 +420,15 @@ def __init__(self, hs):
self.store = hs.get_datastore()

@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""

keys_to_fetch = (
(server_name, key_id)
for server_name, key_ids in server_name_and_key_ids
for key_id in key_ids
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)

res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
Expand Down Expand Up @@ -517,14 +561,14 @@ def __init__(self, hs):
self.perspective_servers = self.config.perspectives

@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""

@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys
keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
Expand Down Expand Up @@ -558,13 +602,15 @@ def get_key(perspective_name, perspective_keys):

@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
self, keys_to_fetch, perspective_name, perspective_keys
):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
keys_to_fetch (dict[str, dict[str, int]]):
the keys to be fetched. server_name -> key_id -> min_valid_ts

perspective_name (str): name of the notary server to query for the keys

perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server

Expand All @@ -578,22 +624,21 @@ def get_server_verify_key_v2_indirect(
"""
logger.info(
"Requesting keys %s from notary server %s",
server_names_and_key_ids,
keys_to_fetch.items(),
perspective_name,
)
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.

try:
query_response = yield self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
key_id: {u"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, key_ids in server_names_and_key_ids
for server_name, server_keys in keys_to_fetch.items()
}
},
long_retries=True,
Expand Down Expand Up @@ -703,15 +748,18 @@ def __init__(self, hs):
self.client = hs.get_http_client()

@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
# TODO make this more resilient
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
self.get_server_verify_key_v2_direct,
server_name,
server_keys.keys(),
)
for server_name, key_ids in server_name_and_key_ids
for server_name, server_keys in keys_to_fetch.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
Expand All @@ -730,6 +778,7 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, FetchKeyResult]

for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue

Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]

more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json)
(p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender
])

Expand Down Expand Up @@ -298,7 +298,7 @@ def sender_err(e, pdu_to_check):
]

more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id
])

Expand Down
4 changes: 3 additions & 1 deletion synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):

class Authenticator(object):
def __init__(self, hs):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
Expand All @@ -102,6 +103,7 @@ def __init__(self, hs):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
Expand Down Expand Up @@ -138,7 +140,7 @@ def authenticate_request(self, request, content):
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)

yield self.keyring.verify_json_for_server(origin, json_request)
yield self.keyring.verify_json_for_server(origin, json_request, now)

logger.info("Request from %s", origin)
request.authenticated_entity = origin
Expand Down
5 changes: 3 additions & 2 deletions synapse/groups/attestations.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ def verify_attestation(self, attestation, group_id, user_id, server_name=None):

# TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while.
if valid_until_ms < self.clock.time_msec():
now = self.clock.time_msec()
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")

yield self.keyring.verify_json_for_server(server_name, attestation)
yield self.keyring.verify_json_for_server(server_name, attestation, now)

def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default
Expand Down
Loading