Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Validate input to POST /key/v2/query endpoint. (#16183)
Browse files Browse the repository at this point in the history
To avoid 500 internal server errors with garbage input.
  • Loading branch information
clokep committed Aug 25, 2023
1 parent fcf7a57 commit 8269942
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
1 change: 1 addition & 0 deletions changelog.d/16183.misc
@@ -0,0 +1 @@
Improve error reporting of invalid data passed to `/_matrix/key/v2/query`.
39 changes: 29 additions & 10 deletions synapse/rest/key/v2/remote_key_resource.py
Expand Up @@ -16,6 +16,7 @@
import re
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple

from pydantic import Extra, StrictInt, StrictStr
from signedjson.sign import sign_json

from twisted.web.server import Request
Expand All @@ -24,9 +25,10 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_integer,
parse_json_object_from_request,
)
from synapse.rest.models import RequestBodyModel
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
Expand All @@ -38,6 +40,13 @@
logger = logging.getLogger(__name__)


class _KeyQueryCriteriaDataModel(RequestBodyModel):
class Config:
extra = Extra.allow

minimum_valid_until_ts: Optional[StrictInt]


class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
Expand Down Expand Up @@ -96,6 +105,9 @@ class RemoteKey(RestServlet):

CATEGORY = "Federation requests"

class PostBody(RequestBodyModel):
server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]

def __init__(self, hs: "HomeServer"):
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main
Expand Down Expand Up @@ -137,24 +149,29 @@ async def on_GET(
)

minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server: {key_id: arguments}}
query = {
server: {
key_id: _KeyQueryCriteriaDataModel(
minimum_valid_until_ts=minimum_valid_until_ts
)
}
}
else:
query = {server: {}}

return 200, await self.query_keys(query, query_remote_on_cache_miss=True)

async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
content = parse_and_validate_json_object_from_request(request, self.PostBody)

query = content["server_keys"]
query = content.server_keys

return 200, await self.query_keys(query, query_remote_on_cache_miss=True)

async def query_keys(
self, query: JsonDict, query_remote_on_cache_miss: bool = False
self,
query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
query_remote_on_cache_miss: bool = False,
) -> JsonDict:
logger.info("Handling query for keys %r", query)

Expand Down Expand Up @@ -196,8 +213,10 @@ async def query_keys(
else:
ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
req_key = query.get(server_name, {}).get(
key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
)
req_valid_until = req_key.minimum_valid_until_ts
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
Expand Down

0 comments on commit 8269942

Please sign in to comment.