Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Enforce validity period on server_keys for fed requests. (#5321)
Browse files Browse the repository at this point in the history
When handling incoming federation requests, make sure that we have an
up-to-date copy of the signing key.

We do not yet enforce the validity period for event signatures.
  • Loading branch information
richvdh committed Jun 3, 2019
1 parent fe2294e commit fec2dcb
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 88 deletions.
1 change: 1 addition & 0 deletions changelog.d/5321.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.
167 changes: 111 additions & 56 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
from collections import defaultdict

import six
from six import raise_from
Expand Down Expand Up @@ -70,6 +71,9 @@ class VerifyKeyRequest(object):
json_object(dict): The JSON object to verify.
minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care)
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
Expand All @@ -82,22 +86,25 @@ class VerifyKeyRequest(object):
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
deferred = attr.ib()
minimum_valid_until_ts = attr.ib()
deferred = attr.ib(default=attr.Factory(defer.Deferred))


class KeyLookupError(ValueError):
pass


class Keyring(object):
def __init__(self, hs):
def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()

self._key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
if key_fetchers is None:
key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
self._key_fetchers = key_fetchers

# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
Expand All @@ -106,20 +113,38 @@ def __init__(self, hs):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}

def verify_json_for_server(self, server_name, json_object):
def verify_json_for_server(self, server_name, json_object, validity_time):
"""Verify that a JSON object has been signed by a given server
Args:
server_name (str): name of the server which must have signed this object
json_object (dict): object to be checked
validity_time (int): timestamp at which we require the signing key to
be valid. (0 implies we don't care)
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = server_name, json_object, validity_time

return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server([(server_name, json_object)])[0]
self.verify_json_objects_for_server((req,))[0]
)

def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
server_and_json (iterable[Tuple[str, dict, int]):
Iterable of triplets of (server_name, json_object, validity_time)
validity_time is a timestamp at which the signing key must be valid.
Returns:
List<Deferred>: for each input pair, a deferred indicating success
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
Expand All @@ -128,12 +153,12 @@ def verify_json_objects_for_server(self, server_and_json):
verify_requests = []
handle = preserve_fn(_handle_key_deferred)

def process(server_name, json_object):
def process(server_name, json_object, validity_time):
"""Process an entry in the request list
Given a (server_name, json_object) pair from the request list,
adds a key request to verify_requests, and returns a deferred which will
complete or fail (in the sentinel context) when verification completes.
Given a (server_name, json_object, validity_time) triplet from the request
list, adds a key request to verify_requests, and returns a deferred which
will complete or fail (in the sentinel context) when verification completes.
"""
key_ids = signature_ids(json_object, server_name)

Expand All @@ -148,7 +173,7 @@ def process(server_name, json_object):

# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred()
server_name, key_ids, json_object, validity_time
)
verify_requests.append(verify_request)

Expand All @@ -160,8 +185,8 @@ def process(server_name, json_object):
return handle(verify_request)

results = [
process(server_name, json_object)
for server_name, json_object in server_and_json
process(server_name, json_object, validity_time)
for server_name, json_object, validity_time in server_and_json
]

if verify_requests:
Expand Down Expand Up @@ -298,8 +323,12 @@ def do_iterations():
verify_request.deferred.errback(
SynapseError(
401,
"No key for %s with id %s"
% (verify_request.server_name, verify_request.key_ids),
"No key for %s with ids in %s (min_validity %i)"
% (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
),
Codes.UNAUTHORIZED,
)
)
Expand All @@ -323,18 +352,28 @@ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be reomved from the list.
Any successfully-completed requests will be removed from the list.
"""
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
# dict[str, dict[str, int]]: keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict)

for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
keys_for_server = missing_keys[verify_request.server_name]

results = yield fetcher.get_keys(missing_keys.items())
for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
verify_request.minimum_valid_until_ts
)

results = yield fetcher.get_keys(missing_keys)

completed = list()
for verify_request in remaining_requests:
Expand All @@ -344,25 +383,34 @@ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
key = result_keys.get(key_id)
if key:
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, key.verify_key)
)
completed.append(verify_request)
break
fetch_key_result = result_keys.get(key_id)
if not fetch_key_result:
# we didn't get a result for this key
continue

if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue

with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, fetch_key_result.verify_key)
)
completed.append(verify_request)
break

remaining_requests.difference_update(completed)


class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Note that the iterables may be iterated more than once.
keys_to_fetch (dict[str, dict[str, int]]):
the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
Expand All @@ -378,13 +426,15 @@ def __init__(self, hs):
self.store = hs.get_datastore()

@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""

keys_to_fetch = (
(server_name, key_id)
for server_name, key_ids in server_name_and_key_ids
for key_id in key_ids
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)

res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
Expand Down Expand Up @@ -508,14 +558,14 @@ def __init__(self, hs):
self.perspective_servers = self.config.perspectives

@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""

@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys
keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
Expand Down Expand Up @@ -549,13 +599,15 @@ def get_key(perspective_name, perspective_keys):

@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
self, keys_to_fetch, perspective_name, perspective_keys
):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
keys_to_fetch (dict[str, dict[str, int]]):
the keys to be fetched. server_name -> key_id -> min_valid_ts
perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
Expand All @@ -569,22 +621,21 @@ def get_server_verify_key_v2_indirect(
"""
logger.info(
"Requesting keys %s from notary server %s",
server_names_and_key_ids,
keys_to_fetch.items(),
perspective_name,
)
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.

try:
query_response = yield self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
key_id: {u"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, key_ids in server_names_and_key_ids
for server_name, server_keys in keys_to_fetch.items()
}
},
long_retries=True,
Expand Down Expand Up @@ -694,15 +745,18 @@ def __init__(self, hs):
self.client = hs.get_http_client()

@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
# TODO make this more resilient
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
self.get_server_verify_key_v2_direct,
server_name,
server_keys.keys(),
)
for server_name, key_ids in server_name_and_key_ids
for server_name, server_keys in keys_to_fetch.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
Expand All @@ -721,6 +775,7 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, FetchKeyResult]

for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue

Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]

more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json)
(p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender
])

Expand Down Expand Up @@ -298,7 +298,7 @@ def sender_err(e, pdu_to_check):
]

more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id
])

Expand Down
4 changes: 3 additions & 1 deletion synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):

class Authenticator(object):
def __init__(self, hs):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
Expand All @@ -102,6 +103,7 @@ def __init__(self, hs):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
Expand Down Expand Up @@ -138,7 +140,7 @@ def authenticate_request(self, request, content):
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)

yield self.keyring.verify_json_for_server(origin, json_request)
yield self.keyring.verify_json_for_server(origin, json_request, now)

logger.info("Request from %s", origin)
request.authenticated_entity = origin
Expand Down
Loading

0 comments on commit fec2dcb

Please sign in to comment.