diff --git a/changelog.d/7839.docker b/changelog.d/7839.docker new file mode 100644 index 000000000000..cdf3c9631c27 --- /dev/null +++ b/changelog.d/7839.docker @@ -0,0 +1 @@ +Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. diff --git a/changelog.d/7842.feature b/changelog.d/7842.feature new file mode 100644 index 000000000000..727deb01c9a7 --- /dev/null +++ b/changelog.d/7842.feature @@ -0,0 +1 @@ +Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/changelog.d/7849.misc b/changelog.d/7849.misc new file mode 100644 index 000000000000..e3296418c112 --- /dev/null +++ b/changelog.d/7849.misc @@ -0,0 +1 @@ +Consistently use `db_to_json` to convert from database values to JSON objects. diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature new file mode 100644 index 000000000000..2b6a9f0e71c5 --- /dev/null +++ b/changelog.d/7855.feature @@ -0,0 +1 @@ +Add experimental support for running multiple pusher workers. diff --git a/changelog.d/7858.misc b/changelog.d/7858.misc new file mode 100644 index 000000000000..8f0fc2de7429 --- /dev/null +++ b/changelog.d/7858.misc @@ -0,0 +1 @@ +The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. diff --git a/changelog.d/7859.bugfix b/changelog.d/7859.bugfix new file mode 100644 index 000000000000..19cff4b0616b --- /dev/null +++ b/changelog.d/7859.bugfix @@ -0,0 +1 @@ +Fix a bug which allowed empty rooms to be rejoined over federation. diff --git a/changelog.d/7860.misc b/changelog.d/7860.misc new file mode 100644 index 000000000000..fdd48b955cc5 --- /dev/null +++ b/changelog.d/7860.misc @@ -0,0 +1 @@ +Convert _base, profile, and _receipts handlers to async/await. diff --git a/changelog.d/7861.misc b/changelog.d/7861.misc new file mode 100644 index 000000000000..ada616c62ffa --- /dev/null +++ b/changelog.d/7861.misc @@ -0,0 +1 @@ +Optimise queueing of inbound replication commands. diff --git a/changelog.d/7866.bugfix b/changelog.d/7866.bugfix new file mode 100644 index 000000000000..6b5c3c4ecabe --- /dev/null +++ b/changelog.d/7866.bugfix @@ -0,0 +1 @@ +Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. diff --git a/changelog.d/7868.misc b/changelog.d/7868.misc new file mode 100644 index 000000000000..eadef5e4c206 --- /dev/null +++ b/changelog.d/7868.misc @@ -0,0 +1 @@ +Convert synapse.app and federation client to async/await. diff --git a/changelog.d/7869.feature b/changelog.d/7869.feature new file mode 100644 index 000000000000..1982049a52ee --- /dev/null +++ b/changelog.d/7869.feature @@ -0,0 +1 @@ +Add experimental support for moving typing off master. diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc new file mode 100644 index 000000000000..4d398a9f3af2 --- /dev/null +++ b/changelog.d/7871.misc @@ -0,0 +1 @@ +Convert device handler to async/await. diff --git a/changelog.d/7872.bugfix b/changelog.d/7872.bugfix new file mode 100644 index 000000000000..b21f8e1f147d --- /dev/null +++ b/changelog.d/7872.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where the tracing of async functions with opentracing was broken. diff --git a/changelog.d/7880.bugfix b/changelog.d/7880.bugfix new file mode 100644 index 000000000000..356add099689 --- /dev/null +++ b/changelog.d/7880.bugfix @@ -0,0 +1 @@ +Fix "TypeError in `synapse.notifier`" exceptions. diff --git a/changelog.d/7881.misc b/changelog.d/7881.misc new file mode 100644 index 000000000000..67991170990d --- /dev/null +++ b/changelog.d/7881.misc @@ -0,0 +1 @@ +Change "unknown room version" logging from 'error' to 'warning'. diff --git a/changelog.d/7882.misc b/changelog.d/7882.misc new file mode 100644 index 000000000000..90027493351a --- /dev/null +++ b/changelog.d/7882.misc @@ -0,0 +1 @@ +Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. diff --git a/changelog.d/7885.doc b/changelog.d/7885.doc new file mode 100644 index 000000000000..cbe9de408298 --- /dev/null +++ b/changelog.d/7885.doc @@ -0,0 +1 @@ +Provide instructions on using `register_new_matrix_user` via docker. diff --git a/changelog.d/7888.misc b/changelog.d/7888.misc new file mode 100644 index 000000000000..5328d2dcca84 --- /dev/null +++ b/changelog.d/7888.misc @@ -0,0 +1 @@ +Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. diff --git a/changelog.d/7889.doc b/changelog.d/7889.doc new file mode 100644 index 000000000000..d91f62fd390f --- /dev/null +++ b/changelog.d/7889.doc @@ -0,0 +1 @@ +Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. \ No newline at end of file diff --git a/changelog.d/7890.misc b/changelog.d/7890.misc new file mode 100644 index 000000000000..8c127084bc7e --- /dev/null +++ b/changelog.d/7890.misc @@ -0,0 +1 @@ +Fix typo in generated config file. Contributed by @ThiefMaster. diff --git a/changelog.d/7892.misc b/changelog.d/7892.misc new file mode 100644 index 000000000000..ef4cfa04fd62 --- /dev/null +++ b/changelog.d/7892.misc @@ -0,0 +1 @@ +Import ABC from `collections.abc` for Python 3.10 compatibility. diff --git a/changelog.d/7895.bugfix b/changelog.d/7895.bugfix new file mode 100644 index 000000000000..1ae7f8ca7c2e --- /dev/null +++ b/changelog.d/7895.bugfix @@ -0,0 +1 @@ +Fix deprecation warning due to invalid escape sequences. \ No newline at end of file diff --git a/changelog.d/7897.misc b/changelog.d/7897.misc new file mode 100644 index 000000000000..77772533fd94 --- /dev/null +++ b/changelog.d/7897.misc @@ -0,0 +1,2 @@ +Remove unused functions `time_function`, `trace_function`, `get_previous_frames` +and `get_previous_frame` from `synapse.logging.utils` module. \ No newline at end of file diff --git a/changelog.d/7912.misc b/changelog.d/7912.misc new file mode 100644 index 000000000000..d619590070a1 --- /dev/null +++ b/changelog.d/7912.misc @@ -0,0 +1 @@ +Convert `RoomListHandler` to async/await. diff --git a/changelog.d/7914.misc b/changelog.d/7914.misc new file mode 100644 index 000000000000..710553249cc0 --- /dev/null +++ b/changelog.d/7914.misc @@ -0,0 +1 @@ +Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. diff --git a/changelog.d/7919.misc b/changelog.d/7919.misc new file mode 100644 index 000000000000..addaa35183ca --- /dev/null +++ b/changelog.d/7919.misc @@ -0,0 +1 @@ +Use Element CSS and logo in notification emails when app name is Element. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 48da410d9462..77422f5e5de8 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -17,9 +17,6 @@ """ Starts a synapse client console. """ from __future__ import print_function -from twisted.internet import reactor, defer, threads -from http import TwistedHttpClient - import argparse import cmd import getpass @@ -28,12 +25,14 @@ import sys import time import urllib -import urlparse +from http import TwistedHttpClient -import nacl.signing import nacl.encoding +import nacl.signing +import urlparse +from signedjson.sign import SignatureVerifyException, verify_signed_json -from signedjson.sign import verify_signed_json, SignatureVerifyException +from twisted.internet import defer, reactor, threads CONFIG_JSON = "cmdclient_config.json" @@ -493,7 +492,7 @@ def do_list(self, line): "list messages from=END&to=START&limit=3" """ args = self._parse(line, ["type", "roomid", "qp"]) - if not "type" in args or not "roomid" in args: + if "type" not in args or "roomid" not in args: print("Must specify type and room ID.") return if args["type"] not in ["members", "messages"]: @@ -508,7 +507,7 @@ def do_list(self, line): try: key_value = key_value_str.split("=") qp[key_value[0]] = key_value[1] - except: + except Exception: print("Bad query param: %s" % key_value) return @@ -585,7 +584,7 @@ def do_raw(self, line): parsed_url = urlparse.urlparse(args["path"]) qp.update(urlparse.parse_qs(parsed_url.query)) args["path"] = parsed_url.path - except: + except Exception: pass reactor.callFromThread( @@ -772,10 +771,10 @@ def main(server_url, identity_server_url, username, token, config_path): syn_cmd.config = json.load(config) try: http_client.verbose = "on" == syn_cmd.config["verbose"] - except: + except Exception: pass print("Loaded config from %s" % config_path) - except: + except Exception: pass # Twisted-specific: Runs the command processor in Twisted's event loop diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 0e101d2be56b..e2534ee584ff 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -14,14 +14,14 @@ # limitations under the License. from __future__ import print_function -from twisted.web.client import Agent, readBody -from twisted.web.http_headers import Headers -from twisted.internet import defer, reactor - -from pprint import pformat import json import urllib +from pprint import pformat + +from twisted.internet import defer, reactor +from twisted.web.client import Agent, readBody +from twisted.web.http_headers import Headers class HttpClient(object): diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 3bbbcfa1b44e..a84ec4ecaefc 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -28,27 +28,24 @@ """ -from synapse.federation import ReplicationHandler - -from synapse.federation.units import Pdu - -from synapse.util import origin_from_ucid - -from synapse.app.homeserver import SynapseHomeServer - -# from synapse.logging.utils import log_function - -from twisted.internet import reactor, defer -from twisted.python import log - import argparse +import curses.wrapper import json import logging import os import re import cursesio -import curses.wrapper + +from twisted.internet import defer, reactor +from twisted.python import log + +from synapse.app.homeserver import SynapseHomeServer +from synapse.federation import ReplicationHandler +from synapse.federation.units import Pdu +from synapse.util import origin_from_ucid + +# from synapse.logging.utils import log_function logger = logging.getLogger("example") @@ -75,7 +72,7 @@ def on_line(self, line): """ try: - m = re.match("^join (\S+)$", line) + m = re.match(r"^join (\S+)$", line) if m: # The `sender` wants to join a room. (room_name,) = m.groups() @@ -84,7 +81,7 @@ def on_line(self, line): # self.print_line("OK.") return - m = re.match("^invite (\S+) (\S+)$", line) + m = re.match(r"^invite (\S+) (\S+)$", line) if m: # `sender` wants to invite someone to a room room_name, invitee = m.groups() @@ -93,7 +90,7 @@ def on_line(self, line): # self.print_line("OK.") return - m = re.match("^send (\S+) (.*)$", line) + m = re.match(r"^send (\S+) (.*)$", line) if m: # `sender` wants to message a room room_name, body = m.groups() @@ -102,7 +99,7 @@ def on_line(self, line): # self.print_line("OK.") return - m = re.match("^backfill (\S+)$", line) + m = re.match(r"^backfill (\S+)$", line) if m: # we want to backfill a room (room_name,) = m.groups() @@ -201,16 +198,6 @@ def on_receive_pdu(self, pdu): % (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) ) - # def on_state_change(self, pdu): - ##self.output.print_line("#%s (state) %s *** %s" % - ##(pdu.context, pdu.state_key, pdu.pdu_type) - ##) - - # if "joinee" in pdu.content: - # self._on_join(pdu.context, pdu.content["joinee"]) - # elif "invitee" in pdu.content: - # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) - def _on_message(self, pdu): """ We received a message """ @@ -314,7 +301,7 @@ def backfill(self, room_name, limit=5): return self.replication_layer.backfill(dest, room_name, limit) def _get_room_remote_servers(self, room_name): - return [i for i in self.joined_rooms.setdefault(room_name).servers] + return list(self.joined_rooms.setdefault(room_name).servers) def _get_or_create_room(self, room_name): return self.joined_rooms.setdefault(room_name, Room(room_name)) @@ -334,7 +321,7 @@ def main(stdscr): user = args.user server_name = origin_from_ucid(user) - ## Set up logging ## + # Set up logging root_logger = logging.getLogger() @@ -354,7 +341,7 @@ def main(stdscr): observer = log.PythonLoggingObserver() observer.start() - ## Set up synapse server + # Set up synapse server curses_stdio = cursesio.CursesStdIO(stdscr) input_output = InputOutput(curses_stdio, user) @@ -368,16 +355,16 @@ def main(stdscr): input_output.set_home_server(hs) - ## Add input_output logger + # Add input_output logger io_logger = IOLoggerHandler(input_output) io_logger.setFormatter(formatter) root_logger.addHandler(io_logger) - ## Start! ## + # Start! try: port = int(server_name.split(":")[1]) - except: + except Exception: port = 12345 app_hs.get_http_server().start_listening(port) diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py index 92736480ebab..de33fac1c70f 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py @@ -1,5 +1,13 @@ from __future__ import print_function +import argparse +import cgi +import datetime +import json + +import pydot +import urllib2 + # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,15 +23,6 @@ # limitations under the License. -import sqlite3 -import pydot -import cgi -import json -import datetime -import argparse -import urllib2 - - def make_name(pdu_id, origin): return "%s@%s" % (pdu_id, origin) @@ -33,7 +32,7 @@ def make_graph(pdus, room, filename_prefix): node_map = {} origins = set() - colors = set(("red", "green", "blue", "yellow", "purple")) + colors = {"red", "green", "blue", "yellow", "purple"} for pdu in pdus: origins.add(pdu.get("origin")) @@ -49,7 +48,7 @@ def make_graph(pdus, room, filename_prefix): try: c = colors.pop() color_map[o] = c - except: + except Exception: print("Run out of colours!") color_map[o] = "black" diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py index 4619f0e3c18e..0980231e4a01 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py @@ -13,12 +13,13 @@ # limitations under the License. -import sqlite3 -import pydot +import argparse import cgi -import json import datetime -import argparse +import json +import sqlite3 + +import pydot from synapse.events import FrozenEvent from synapse.util.frozenutils import unfreeze @@ -98,7 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit): for prev_id, _ in event.prev_events: try: end_node = node_map[prev_id] - except: + except Exception: end_node = pydot.Node(name=prev_id, label="<%s>" % (prev_id,)) node_map[prev_id] = end_node diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index 31546385208b..91db98e7efcb 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -1,5 +1,15 @@ from __future__ import print_function +import argparse +import cgi +import datetime + +import pydot +import simplejson as json + +from synapse.events import FrozenEvent +from synapse.util.frozenutils import unfreeze + # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,16 +25,6 @@ # limitations under the License. -import pydot -import cgi -import simplejson as json -import datetime -import argparse - -from synapse.events import FrozenEvent -from synapse.util.frozenutils import unfreeze - - def make_graph(file_name, room_id, file_prefix, limit): print("Reading lines") with open(file_name) as f: @@ -106,7 +106,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for prev_id, _ in event.prev_events: try: end_node = node_map[prev_id] - except: + except Exception: end_node = pydot.Node(name=prev_id, label="<%s>" % (prev_id,)) node_map[prev_id] = end_node diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index 67fb2cd1a7a5..69aa74bd34d0 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -12,15 +12,15 @@ """ from __future__ import print_function -import gevent -import grequests -from BeautifulSoup import BeautifulSoup import json -import urllib import subprocess import time -# ACCESS_TOKEN="" # +import gevent +import grequests +from BeautifulSoup import BeautifulSoup + +ACCESS_TOKEN = "" MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/" MYUSERNAME = "@davetest:matrix.org" diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py index f57e6e7d2599..372dbd9e4f32 100755 --- a/contrib/scripts/kick_users.py +++ b/contrib/scripts/kick_users.py @@ -1,10 +1,12 @@ #!/usr/bin/env python from __future__ import print_function -from argparse import ArgumentParser + import json -import requests import sys import urllib +from argparse import ArgumentParser + +import requests try: raw_input diff --git a/docker/Dockerfile b/docker/Dockerfile index 093e89af6c56..8b3a4246a5fe 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,35 +16,31 @@ ARG PYTHON_VERSION=3.7 ### ### Stage 0: builder ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder +FROM docker.io/python:${PYTHON_VERSION}-slim as builder # install the OS build deps -RUN apk add \ - build-base \ - libffi-dev \ - libjpeg-turbo-dev \ - libwebp-dev \ - libressl-dev \ - libxslt-dev \ - linux-headers \ - postgresql-dev \ - zlib-dev -# build things which have slow build steps, before we copy synapse, so that -# the layer can be cached. -# -# (we really just care about caching a wheel here, as the "pip install" below -# will install them again.) +RUN apt-get update && apt-get install -y \ + build-essential \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* +# Build dependencies that are not available as wheels, to speed up rebuilds RUN pip install --prefix="/install" --no-warn-script-location \ - cryptography \ - msgpack-python \ - pillow \ - pynacl + frozendict \ + jaeger-client \ + opentracing \ + prometheus-client \ + psycopg2 \ + pycparser \ + pyrsistent \ + pyyaml \ + simplejson \ + threadloop \ + thrift # now install synapse and all of the python deps to /install. - COPY synapse /synapse/synapse/ COPY scripts /synapse/scripts/ COPY MANIFEST.in README.rst setup.py synctl /synapse/ @@ -56,20 +52,13 @@ RUN pip install --prefix="/install" --no-warn-script-location \ ### Stage 1: runtime ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 +FROM docker.io/python:${PYTHON_VERSION}-slim -# xmlsec is required for saml support -RUN apk add --no-cache --virtual .runtime_deps \ - libffi \ - libjpeg-turbo \ - libwebp \ - libressl \ - libxslt \ - libpq \ - zlib \ - su-exec \ - tzdata \ - xmlsec +RUN apt-get update && apt-get install -y \ + libpq5 \ + xmlsec1 \ + gosu \ + && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py diff --git a/docker/README.md b/docker/README.md index 8c337149ca71..008a9ff70865 100644 --- a/docker/README.md +++ b/docker/README.md @@ -94,6 +94,21 @@ The following environment variables are supported in run mode: * `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`. * `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`. +## Generating an (admin) user + +After synapse is running, you may wish to create a user via `register_new_matrix_user`. + +This requires a `registration_shared_secret` to be set in your config file. Synapse +must be restarted to pick up this change. + +You can then call the script: + +``` +docker exec -it synapse register_new_matrix_user http://localhost:8008 -c /data/homeserver.yaml --help +``` + +Remember to remove the `registration_shared_secret` and restart if you no-longer need it. + ## TLS support The default configuration exposes a single HTTP port: http://localhost:8008. It diff --git a/docker/start.py b/docker/start.py index 2a25c9380e34..9f081341581b 100755 --- a/docker/start.py +++ b/docker/start.py @@ -120,7 +120,7 @@ def generate_config_from_template(config_dir, config_path, environ, ownership): if ownership is not None: subprocess.check_output(["chown", "-R", ownership, "/data"]) - args = ["su-exec", ownership] + args + args = ["gosu", ownership] + args subprocess.check_output(args) @@ -172,8 +172,8 @@ def run_generate_config(environ, ownership): # make sure that synapse has perms to write to the data dir. subprocess.check_output(["chown", ownership, data_dir]) - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) + args = ["gosu", ownership] + args + os.execv("/usr/sbin/gosu", args) else: os.execv("/usr/local/bin/python", args) @@ -189,7 +189,7 @@ def main(args, environ): ownership = "{}:{}".format(desired_uid, desired_gid) if ownership is None: - log("Will not perform chmod/su-exec as UserID already matches request") + log("Will not perform chmod/gosu as UserID already matches request") # In generate mode, generate a configuration and missing keys, then exit if mode == "generate": @@ -236,8 +236,8 @@ def main(args, environ): args = ["python", "-m", synapse_worker, "--config-path", config_path] if ownership is not None: - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) + args = ["gosu", ownership] + args + os.execv("/usr/sbin/gosu", args) else: os.execv("/usr/local/bin/python", args) diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 3f26adc16caa..15b83e98248b 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -319,11 +319,43 @@ Response: } ``` +# Room Members API + +The Room Members admin API allows server admins to get a list of all members of a room. + +The response includes the following fields: + +* `members` - A list of all the members that are present in the room, represented by their ids. +* `total` - Total number of members in the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms//members + +{} +``` + +Response: + +``` +{ + "members": [ + "@foo:matrix.org", + "@bar:matrix.org", + "@foobar:matrix.org + ], + "total": 3 +} +``` + # Delete Room API The Delete Room admin API allows server admins to remove rooms from server and block these rooms. -It is a combination and improvement of "[Shutdown room](shutdown_room.md)" +It is a combination and improvement of "[Shutdown room](shutdown_room.md)" and "[Purge room](purge_room.md)" API. Shuts down a room. Moves all local users and room aliases automatically to a diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 131990001ae9..7bfb96eff623 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -38,6 +38,11 @@ the reverse proxy and the homeserver. server { listen 443 ssl; listen [::]:443 ssl; + + # For the federation port + listen 8448 ssl default_server; + listen [::]:8448 ssl default_server; + server_name matrix.example.com; location /_matrix { @@ -48,17 +53,6 @@ server { client_max_body_size 10M; } } - -server { - listen 8448 ssl default_server; - listen [::]:8448 ssl default_server; - server_name example.com; - - location / { - proxy_pass http://localhost:8008; - proxy_set_header X-Forwarded-For $remote_addr; - } -} ``` **NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 5ed44e8a3ac1..e21864047ac4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -102,7 +102,9 @@ pid_file: DATADIR/homeserver.pid #gc_thresholds: [700, 10, 10] # Set the limit on the returned events in the timeline in the get -# and sync operations. The default value is -1, means no upper limit. +# and sync operations. The default value is 100. -1 means no upper limit. +# +# Uncomment the following to increase the limit to 5000. # #filter_timeline_limit: 5000 @@ -146,7 +148,7 @@ pid_file: DATADIR/homeserver.pid # names: a list of names of HTTP resources. See below for a list of # valid resource names. # -# compress: set to true to enable HTTP comression for this resource. +# compress: set to true to enable HTTP compression for this resource. # # additional_resources: Only valid for an 'http' listener. A map of # additional endpoints which should be loaded via dynamic modules. @@ -751,7 +753,7 @@ caches: #database: # name: psycopg2 # args: -# user: synapse +# user: synapse_user # password: secretpassword # database: synapse # host: localhost diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index e6f4bd1dcadf..d055cf32877d 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -24,7 +24,6 @@ DISTS = ( "debian:sid", "ubuntu:xenial", "ubuntu:bionic", - "ubuntu:eoan", "ubuntu:focal", ) diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 66b056885879..064799365832 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -11,7 +11,7 @@ if [ $# -ge 1 ] then files=$* else - files="synapse tests scripts-dev scripts" + files="synapse tests scripts-dev scripts contrib synctl" fi echo "Linting these locations: $files" diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 0ebffb04a50d..b21b8d573db6 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -49,6 +49,7 @@ from synapse.storage.data_stores.main.media_repository import ( from synapse.storage.data_stores.main.profile import ProfileStore from synapse.storage.data_stores.main.registration import ( RegistrationBackgroundUpdateStore, + find_max_generated_user_id_localpart, ) from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore @@ -624,8 +625,10 @@ class Porter(object): ) ) - # Step 5. Do final post-processing + # Step 5. Set up sequences + self.progress.set_state("Setting up sequence generators") await self._setup_state_group_id_seq() + await self._setup_user_id_seq() self.progress.done() except Exception as e: @@ -795,6 +798,13 @@ class Porter(object): return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) + def _setup_user_id_seq(self): + def r(txn): + next_id = find_max_generated_user_id_localpart(txn) + 1 + txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) + + return self.postgres_store.db.runInteraction("setup_user_id_seq", r) + ############################################## # The following is simply UI stuff diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e90695f026f6..c1b76d827b3c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -21,7 +21,7 @@ from typing_extensions import ContextManager -from twisted.internet import address, defer, reactor +from twisted.internet import address, reactor import synapse import synapse.events @@ -111,6 +111,7 @@ RoomSendEventRestServlet, RoomStateEventRestServlet, RoomStateRestServlet, + RoomTypingRestServlet, ) from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha import groups, sync, user_directory @@ -374,9 +375,8 @@ def _user_syncing(): return _user_syncing() - @defer.inlineCallbacks - def notify_from_replication(self, states, stream_id): - parties = yield get_interested_parties(self.store, states) + async def notify_from_replication(self, states, stream_id): + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -386,8 +386,7 @@ def notify_from_replication(self, states, stream_id): users=users_to_states.keys(), ) - @defer.inlineCallbacks - def process_replication_rows(self, token, rows): + async def process_replication_rows(self, token, rows): states = [ UserPresenceState( row.user_id, @@ -405,7 +404,7 @@ def process_replication_rows(self, token, rows): self.user_to_current_state[state.user_id] = state stream_id = token - yield self.notify_from_replication(states, stream_id) + await self.notify_from_replication(states, stream_id) def get_currently_syncing_users_for_replication(self) -> Iterable[str]: return [ @@ -451,37 +450,6 @@ async def bump_presence_active_time(self, user): await self._bump_active_client(user_id=user_id) -class GenericWorkerTyping(object): - def __init__(self, hs): - self._latest_room_serial = 0 - self._reset() - - def _reset(self): - """ - Reset the typing handler's data caches. - """ - # map room IDs to serial numbers - self._room_serials = {} - # map room IDs to sets of users currently typing - self._room_typing = {} - - def process_replication_rows(self, token, rows): - if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. - self._reset() - - # Set the latest serial token to whatever the server gave us. - self._latest_room_serial = token - - for row in rows: - self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = row.user_ids - - def get_current_token(self) -> int: - return self._latest_room_serial - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. @@ -558,6 +526,7 @@ def _listen_http(self, listener_config: ListenerConfig): KeyUploadServlet(self).register(resource) AccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource) + RoomTypingRestServlet(self).register(resource) sync.register_servlets(self, resource) events.register_servlets(self, resource) @@ -669,9 +638,6 @@ def build_replication_data_handler(self): def build_presence_handler(self): return GenericWorkerPresence(self) - def build_typing_handler(self): - return GenericWorkerTyping(self) - class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 09291d86add8..ec7401f91130 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -483,8 +483,7 @@ def stopService(self): _stats_process = [] -@defer.inlineCallbacks -def phone_stats_home(hs, stats, stats_process=_stats_process): +async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) uptime = int(now - hs.start_time) @@ -522,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): stats["python_version"] = "{}.{}.{}".format( version.major, version.minor, version.micro ) - stats["total_users"] = yield hs.get_datastore().count_all_users() + stats["total_users"] = await hs.get_datastore().count_all_users() - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() stats["total_nonbridged_users"] = total_nonbridged_users - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + daily_user_type_results = await hs.get_datastore().count_daily_user_type() for name, count in daily_user_type_results.items(): stats["daily_user_type_" + name] = count - room_count = yield hs.get_datastore().get_room_count() + room_count = await hs.get_datastore().get_room_count() stats["total_room_count"] = room_count - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + stats["daily_active_users"] = await hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - r30_results = yield hs.get_datastore().count_r30_users() + r30_results = await hs.get_datastore().count_r30_users() for name, count in r30_results.items(): stats["r30_users_" + name] = count - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size @@ -558,7 +557,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: - yield hs.get_proxied_http_client().put_json( + await hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: diff --git a/synapse/config/_base.py b/synapse/config/_base.py index f2830c609dd5..34a2370e679d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -19,10 +19,12 @@ import errno import os from collections import OrderedDict +from hashlib import sha256 from io import open as io_open from textwrap import dedent -from typing import Any, MutableMapping, Optional +from typing import Any, List, MutableMapping, Optional +import attr import yaml @@ -718,4 +720,36 @@ def find_config_files(search_paths): return config_files -__all__ = ["Config", "RootConfig"] +@attr.s +class ShardedWorkerHandlingConfig: + """Algorithm for choosing which instance is responsible for handling some + sharded work. + + For example, the federation senders use this to determine which instances + handles sending stuff to a given destination (which is used as the `key` + below). + """ + + instances = attr.ib(type=List[str]) + + def should_handle(self, instance_name: str, key: str) -> bool: + """Whether this instance is responsible for handling the given key. + """ + + # If multiple instances are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of instances and + # then checking whether this instance matches the instance at that + # index. + # + # (Technically this introduces some bias and is not entirely uniform, + # but since the hash is so large the bias is ridiculously small). + dest_hash = sha256(key.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 9e576060d4df..eb911e8f9f45 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -137,3 +137,8 @@ class Config: def read_config_files(config_files: List[str]): ... def find_config_files(search_paths: List[str]): ... + +class ShardedWorkerHandlingConfig: + instances: List[str] + def __init__(self, instances: List[str]) -> None: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... diff --git a/synapse/config/database.py b/synapse/config/database.py index 1064c2697b30..62bccd9ef52f 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -55,7 +55,7 @@ #database: # name: psycopg2 # args: -# user: synapse +# user: synapse_user # password: secretpassword # database: synapse # host: localhost diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 7782ab4c9d6b..82ff9664de54 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -13,42 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hashlib import sha256 -from typing import List, Optional +from typing import Optional -import attr from netaddr import IPSet -from ._base import Config, ConfigError - - -@attr.s -class ShardedFederationSendingConfig: - """Algorithm for choosing which federation sender instance is responsible - for which destionation host. - """ - - instances = attr.ib(type=List[str]) - - def should_send_to(self, instance_name: str, destination: str) -> bool: - """Whether this instance is responsible for sending transcations for - the given host. - """ - - # If multiple federation senders are not defined we always return true. - if not self.instances or len(self.instances) == 1: - return True - - # We shard by taking the hash, modulo it by the number of federation - # senders and then checking whether this instance matches the instance - # at that index. - # - # (Technically this introduces some bias and is not entirely uniform, but - # since the hash is so large the bias is ridiculously small). - dest_hash = sha256(destination.encode("utf8")).digest() - dest_int = int.from_bytes(dest_hash, byteorder="little") - remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name +from ._base import Config, ConfigError, ShardedWorkerHandlingConfig class FederationConfig(Config): @@ -61,7 +30,7 @@ def read_config(self, config, **kwargs): self.send_federation = config.get("send_federation", True) federation_sender_instances = config.get("federation_sender_instances") or [] - self.federation_shard_config = ShardedFederationSendingConfig( + self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances ) diff --git a/synapse/config/push.py b/synapse/config/push.py index 6f2b3a7faa35..a1f3752c8ab4 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ShardedWorkerHandlingConfig class PushConfig(Config): @@ -24,6 +24,9 @@ def read_config(self, config, **kwargs): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) + pusher_instances = config.get("pusher_instances") or [] + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) + # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/config/server.py b/synapse/config/server.py index 9f406e471efb..35687f427e08 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,7 +207,7 @@ def read_config(self, config, **kwargs): # errors when attempting to search for messages. self.enable_search = config.get("enable_search", True) - self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + self.filter_timeline_limit = config.get("filter_timeline_limit", 100) # Whether we should block invites sent to users on this server # (other than those sent by local server admins) @@ -699,7 +699,9 @@ def generate_config_section( #gc_thresholds: [700, 10, 10] # Set the limit on the returned events in the timeline in the get - # and sync operations. The default value is -1, means no upper limit. + # and sync operations. The default value is 100. -1 means no upper limit. + # + # Uncomment the following to increase the limit to 5000. # #filter_timeline_limit: 5000 @@ -743,7 +745,7 @@ def generate_config_section( # names: a list of names of HTTP resources. See below for a list of # valid resource names. # - # compress: set to true to enable HTTP comression for this resource. + # compress: set to true to enable HTTP compression for this resource. # # additional_resources: Only valid for an 'http' listener. A map of # additional endpoints which should be loaded via dynamic modules. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index dbc661630c1e..2574cd3aa170 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -34,9 +34,11 @@ class WriterLocations: Attributes: events: The instance that writes to the event and backfill streams. + events: The instance that writes to the typing stream. """ events = attr.ib(default="master", type=str) + typing = attr.ib(default="master", type=str) class WorkerConfig(Config): @@ -93,16 +95,15 @@ def read_config(self, config, **kwargs): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events also appears in + # Check that the configured writer for events and typing also appears in # `instance_map`. - if ( - self.writers.events != "master" - and self.writers.events not in self.instance_map - ): - raise ConfigError( - "Instance %r is configured to write events but does not appear in `instance_map` config." - % (self.writers.events,) - ) + for stream in ("events", "typing"): + instance = getattr(self.writers, stream) + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f6b507977f99..11f0d34ec8f7 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +import collections.abc import re from typing import Any, Mapping, Union @@ -424,7 +424,7 @@ def copy_power_levels_contents( Raises: TypeError if the input does not look like a valid power levels event content """ - if not isinstance(old_power_levels, collections.Mapping): + if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) power_levels = {} @@ -434,7 +434,7 @@ def copy_power_levels_contents( power_levels[k] = v continue - if isinstance(v, collections.Mapping): + if isinstance(v, collections.abc.Mapping): power_levels[k] = h = {} for k1, v1 in v.items(): # we should only have one level of nesting diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a37cc9cb4a9f..994e6c8d5a0d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -374,29 +374,26 @@ async def _check_sigs_and_hash_and_fetch( """ deferreds = self._check_sigs_and_hashes(room_version, pdus) - @defer.inlineCallbacks - def handle_check_result(pdu: EventBase, deferred: Deferred): + async def handle_check_result(pdu: EventBase, deferred: Deferred): try: - res = yield make_deferred_yieldable(deferred) + res = await make_deferred_yieldable(deferred) except SynapseError: res = None if not res: # Check local db. - res = yield self.store.get_event( + res = await self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: try: - res = yield defer.ensureDeferred( - self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) + res = await self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, ) except SynapseError: pass @@ -995,24 +992,25 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict): raise RuntimeError("Failed to send to any server.") - @defer.inlineCallbacks - def get_room_complexity(self, destination, room_id): + async def get_room_complexity( + self, destination: str, room_id: str + ) -> Optional[dict]: """ Fetch the complexity of a remote room from another server. Args: - destination (str): The remote server - room_id (str): The room ID to ask about. + destination: The remote server + room_id: The room ID to ask about. Returns: - Deferred[dict] or Deferred[None]: Dict contains the complexity - metric versions, while None means we could not fetch the complexity. + Dict contains the complexity metric versions, while None means we + could not fetch the complexity. """ try: - complexity = yield self.transport_layer.get_room_complexity( + complexity = await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - defer.returnValue(complexity) + return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us. @@ -1029,4 +1027,4 @@ def get_room_complexity(self, destination, room_id): # If we don't manage to find it, return None. It's not an error if a # server doesn't give it to us. - defer.returnValue(None) + return None diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8c53330c4999..23625ba995e4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -15,7 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Match, + Optional, + Tuple, + Union, +) from canonicaljson import json from prometheus_client import Counter, Histogram @@ -56,6 +67,9 @@ from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.server import HomeServer + # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. TRANSACTION_CONCURRENCY_LIMIT = 10 @@ -768,11 +782,30 @@ class FederationHandlerRegistry(object): query type for incoming federation traffic. """ - def __init__(self): - self.edu_handlers = {} - self.query_handlers = {} + def __init__(self, hs: "HomeServer"): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + self._instance_name = hs.get_instance_name() - def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]): + # These are safe to load in monolith mode, but will explode if we try + # and use them. However we have guards before we use them to ensure that + # we don't route to ourselves, and in monolith mode that will always be + # the case. + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + self.edu_handlers = ( + {} + ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] + self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + + # Map from type to instance name that we should route EDU handling to. + self._edu_type_to_instance = {} # type: Dict[str, str] + + def register_edu_handler( + self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + ): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. @@ -809,66 +842,56 @@ def register_query_handler( self.query_handlers[query_type] = handler + def register_instance_for_edu(self, edu_type: str, instance_name: str): + """Register that the EDU handler is on a different instance than master. + """ + self._edu_type_to_instance[edu_type] = instance_name + async def on_edu(self, edu_type: str, origin: str, content: dict): + if not self.config.use_presence and edu_type == "m.presence": + return + + # Check if we have a handler on this instance handler = self.edu_handlers.get(edu_type) - if not handler: - logger.warning("No handler registered for EDU type %s", edu_type) + if handler: + with start_active_span_from_edu(content, "handle_edu"): + try: + await handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception: + logger.exception("Failed to handle edu %r", edu_type) return - with start_active_span_from_edu(content, "handle_edu"): + # Check if we can route it somewhere else that isn't us + route_to = self._edu_type_to_instance.get(edu_type, "master") + if route_to != self._instance_name: try: - await handler(origin, content) + await self._send_edu( + instance_name=route_to, + edu_type=edu_type, + origin=origin, + content=content, + ) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: logger.exception("Failed to handle edu %r", edu_type) - - def on_query(self, query_type: str, args: dict) -> defer.Deferred: - handler = self.query_handlers.get(query_type) - if not handler: - logger.warning("No handler registered for query type %s", query_type) - raise NotFoundError("No handler for Query type '%s'" % (query_type,)) - - return handler(args) - - -class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): - """A FederationHandlerRegistry for worker processes. - - When receiving EDU or queries it will check if an appropriate handler has - been registered on the worker, if there isn't one then it calls off to the - master process. - """ - - def __init__(self, hs): - self.config = hs.config - self.http_client = hs.get_simple_http_client() - self.clock = hs.get_clock() - - self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) - self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) - - super(ReplicationFederationHandlerRegistry, self).__init__() - - async def on_edu(self, edu_type: str, origin: str, content: dict): - """Overrides FederationHandlerRegistry - """ - if not self.config.use_presence and edu_type == "m.presence": return - handler = self.edu_handlers.get(edu_type) - if handler: - return await super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content - ) - - return await self._send_edu(edu_type=edu_type, origin=origin, content=content) + # Oh well, let's just log and move on. + logger.warning("No handler registered for EDU type %s", edu_type) async def on_query(self, query_type: str, args: dict): - """Overrides FederationHandlerRegistry - """ handler = self.query_handlers.get(query_type) if handler: return await handler(args) - return await self._get_query_client(query_type=query_type, args=args) + # Check if we can route it somewhere else that isn't us + if self._instance_name == "master": + return await self._get_query_client(query_type=query_type, args=args) + + # Uh oh, no handler! Let's raise an exception so the request returns an + # error. + logger.warning("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4b63a0755fc9..b328a4df096c 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -197,7 +197,7 @@ async def handle_event(event: EventBase) -> None: destinations = { d for d in destinations - if self._federation_shard_config.should_send_to( + if self._federation_shard_config.should_handle( self._instance_name, d ) } @@ -335,7 +335,7 @@ def send_read_receipt(self, receipt: ReadReceipt): d for d in domains if d != self.server_name - and self._federation_shard_config.should_send_to(self._instance_name, d) + and self._federation_shard_config.should_handle(self._instance_name, d) ] if not domains: return @@ -441,7 +441,7 @@ def send_presence_to_destinations( for destination in destinations: if destination == self.server_name: continue - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): continue @@ -460,7 +460,7 @@ def _process_presence_inner(self, states: List[UserPresenceState]): if destination == self.server_name: continue - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): continue @@ -486,7 +486,7 @@ def build_and_send_edu( logger.info("Not sending EDU to ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return @@ -507,7 +507,7 @@ def send_edu(self, edu: Edu, key: Optional[Hashable]): edu: edu to send key: clobbering key for this edu """ - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, edu.destination ): return @@ -523,7 +523,7 @@ def send_device_messages(self, destination: str): logger.warning("Not sending device update to ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return @@ -541,7 +541,7 @@ def wake_destination(self, destination: str): logger.warning("Not waking up ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 6402136e8abf..343674178327 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -78,7 +78,7 @@ def __init__( self._federation_shard_config = hs.config.federation.federation_shard_config self._should_send_on_this_instance = True - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): # We don't raise an exception here to avoid taking out any other diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 1478ee03a5ad..5da69c2c4985 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -20,8 +20,6 @@ import re from typing import Optional, Tuple, Type -from twisted.internet.defer import maybeDeferred - import synapse from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions @@ -796,12 +794,8 @@ async def on_GET(self, origin, content, query): # zero is a special value which corresponds to no limit. limit = None - data = await maybeDeferred( - self.handler.get_local_public_room_list, - limit, - since_token, - network_tuple=network_tuple, - from_federation=True, + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True ) return 200, data diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a4944467ac1..ba2bf998008f 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - import synapse.state import synapse.storage import synapse.types @@ -66,8 +64,7 @@ def __init__(self, hs): self.event_builder_factory = hs.get_event_builder_factory() - @defer.inlineCallbacks - def ratelimit(self, requester, update=True, is_admin_redaction=False): + async def ratelimit(self, requester, update=True, is_admin_redaction=False): """Ratelimits requests. Args: @@ -99,7 +96,7 @@ def ratelimit(self, requester, update=True, is_admin_redaction=False): burst_count = self._rc_message.burst_count # Check if there is a per user override in the DB. - override = yield self.store.get_ratelimit_for_user(user_id) + override = await self.store.get_ratelimit_for_user(user_id) if override: # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 31346b56c366..db417d60deb4 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional - -from twisted.internet import defer +from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes @@ -57,21 +55,20 @@ def __init__(self, hs): self._auth_handler = hs.get_auth_handler() @trace - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: """ Retrieve the given user's devices Args: - user_id (str): + user_id: The user ID to query for devices. Returns: - defer.Deferred: list[dict[str, X]]: info on each device + info on each device """ set_tag("user_id", user_id) - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -81,24 +78,23 @@ def get_devices_by_user(self, user_id): return devices @trace - @defer.inlineCallbacks - def get_device(self, user_id, device_id): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """ Retrieve the given device Args: - user_id (str): - device_id (str): + user_id: The user to get the device from + device_id: The device to fetch. Returns: - defer.Deferred: dict[str, X]: info on the device + info on the device Raises: errors.NotFoundError: if the device was not found """ try: - device = yield self.store.get_device(user_id, device_id) + device = await self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) set_tag("device", device) @@ -106,10 +102,9 @@ def get_device(self, user_id, device_id): return device - @measure_func("device.get_user_ids_changed") @trace - @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_token): + @measure_func("device.get_user_ids_changed") + async def get_user_ids_changed(self, user_id, from_token): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. @@ -120,13 +115,13 @@ def get_user_ids_changed(self, user_id, from_token): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = yield self.store.get_room_events_max_id() + now_room_key = await self.store.get_room_events_max_id() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -135,14 +130,14 @@ def get_user_ids_changed(self, user_id, from_token): # Always tell the user about their own devices tracked_users.add(user_id) - changed = yield self.store.get_users_whose_devices_changed( + changed = await self.store.get_users_whose_devices_changed( from_token.device_list_key, tracked_users ) # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) - member_events = yield self.store.get_membership_changes_for_user( + member_events = await self.store.get_membership_changes_for_user( user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) @@ -152,7 +147,7 @@ def get_user_ids_changed(self, user_id, from_token): possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = yield self.store.get_current_state_ids(room_id) + current_state_ids = await self.store.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. @@ -166,7 +161,7 @@ def get_user_ids_changed(self, user_id, from_token): # Fetch the current state at the time. try: - event_ids = yield self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -192,7 +187,7 @@ def get_user_ids_changed(self, user_id, from_token): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -238,11 +233,10 @@ def get_user_ids_changed(self, user_id, from_token): return result - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") - self_signing_key = yield self.store.get_e2e_cross_signing_key( + async def on_federation_query_user_devices(self, user_id): + stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" ) @@ -271,8 +265,7 @@ def __init__(self, hs): hs.get_distributor().observe("user_left_room", self.user_left_room) - @defer.inlineCallbacks - def check_device_registered( + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): """ @@ -290,13 +283,13 @@ def check_device_registered( str: device id (generated if none was supplied) """ if device_id is not None: - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id # if the device id is not specified, we'll autogen one, but loop a few @@ -304,33 +297,29 @@ def check_device_registered( attempts = 0 while attempts < 5: device_id = stringutils.random_string(10).upper() - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @trace - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """ Delete the given device Args: - user_id (str): - device_id (str): - - Returns: - defer.Deferred: + user_id: The user to delete the device from. + device_id: The device to delete. """ try: - yield self.store.delete_device(user_id, device_id) + await self.store.delete_device(user_id, device_id) except errors.StoreError as e: if e.code == 404: # no match @@ -342,49 +331,40 @@ def delete_device(self, user_id, device_id): else: raise - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) @trace - @defer.inlineCallbacks - def delete_all_devices_for_user(self, user_id, except_device_id=None): + async def delete_all_devices_for_user( + self, user_id: str, except_device_id: Optional[str] = None + ) -> None: """Delete all of the user's devices Args: - user_id (str): - except_device_id (str|None): optional device id which should not - be deleted - - Returns: - defer.Deferred: + user_id: The user to remove all devices from + except_device_id: optional device id which should not be deleted """ - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) device_ids = list(device_map) if except_device_id is not None: device_ids = [d for d in device_ids if d != except_device_id] - yield self.delete_devices(user_id, device_ids) + await self.delete_devices(user_id, device_ids) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """ Delete several devices Args: - user_id (str): - device_ids (List[str]): The list of device IDs to delete - - Returns: - defer.Deferred: + user_id: The user to delete devices from. + device_ids: The list of device IDs to delete """ try: - yield self.store.delete_devices(user_id, device_ids) + await self.store.delete_devices(user_id, device_ids) except errors.StoreError as e: if e.code == 404: # no match @@ -397,28 +377,22 @@ def delete_devices(self, user_id, device_ids): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( + await self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_ids) + await self.notify_device_update(user_id, device_ids) - @defer.inlineCallbacks - def update_device(self, user_id, device_id, content): + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """ Update the given device Args: - user_id (str): - device_id (str): - content (dict): body of update request - - Returns: - defer.Deferred: + user_id: The user to update devices of. + device_id: The device to update. + content: body of update request """ # Reject a new displayname which is too long. @@ -431,10 +405,10 @@ def update_device(self, user_id, device_id, content): ) try: - yield self.store.update_device( + await self.store.update_device( user_id, device_id, new_display_name=new_display_name ) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() @@ -443,12 +417,15 @@ def update_device(self, user_id, device_id, content): @trace @measure_func("notify_device_update") - @defer.inlineCallbacks - def notify_device_update(self, user_id, device_ids): + async def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + if not device_ids: + # No changes to notify about, so this is a no-op. + return + + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -459,20 +436,24 @@ def notify_device_update(self, user_id, device_ids): set_tag("target_hosts", hosts) - position = yield self.store.add_device_change_to_streams( + position = await self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) + if not position: + # This should only happen if there are no updates, so we bail. + return + for device_id in device_ids: logger.debug( "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - yield self.notifier.on_new_event( + self.notifier.on_new_event( "device_list_key", position, users=[user_id], rooms=room_ids ) @@ -484,29 +465,29 @@ def notify_device_update(self, user_id, device_ids): self.federation_sender.send_device_messages(host) log_kv({"message": "sent device update to host", "host": host}) - @defer.inlineCallbacks - def notify_user_signature_update(self, from_user_id, user_ids): + async def notify_user_signature_update( + self, from_user_id: str, user_ids: List[str] + ) -> None: """Notify a user that they have made new signatures of other users. Args: - from_user_id (str): the user who made the signature - user_ids (list[str]): the users IDs that have new signatures + from_user_id: the user who made the signature + user_ids: the users IDs that have new signatures """ - position = yield self.store.add_user_signature_change_to_streams( + position = await self.store.add_user_signature_change_to_streams( from_user_id, user_ids ) self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - @defer.inlineCallbacks - def user_left_room(self, user, room_id): + async def user_left_room(self, user, room_id): user_id = user.to_string() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We no longer share rooms with this user, so we'll no longer # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) def _update_device_from_client_ips(device, client_ips): @@ -549,8 +530,7 @@ def __init__(self, hs, device_handler): ) @trace - @defer.inlineCallbacks - def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -583,7 +563,7 @@ def incoming_device_list_update(self, origin, edu_content): ) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -608,14 +588,13 @@ def incoming_device_list_update(self, origin, edu_content): (device_id, stream_id, prev_ids, edu_content) ) - yield self._handle_device_updates(user_id) + await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - @defer.inlineCallbacks - def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id): "Actually handle pending updates." - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -632,7 +611,7 @@ def _handle_device_updates(self, user_id): # Given a list of updates we check if we need to resync. This # happens if we've missed updates. - resync = yield self._need_to_do_resync(user_id, pending_updates) + resync = await self._need_to_do_resync(user_id, pending_updates) if logger.isEnabledFor(logging.INFO): logger.info( @@ -643,16 +622,16 @@ def _handle_device_updates(self, user_id): ) if resync: - yield self.user_device_resync(user_id) + await self.user_device_resync(user_id) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: - yield self.store.update_remote_device_list_cache_entry( + await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates] ) @@ -660,14 +639,13 @@ def _handle_device_updates(self, user_id): stream_id for _, stream_id, _, _ in pending_updates ) - @defer.inlineCallbacks - def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) + extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) logger.debug("Current extremity for %r: %r", user_id, extremity) @@ -692,8 +670,7 @@ def _need_to_do_resync(self, user_id, updates): return False @trace - @defer.inlineCallbacks - def _maybe_retry_device_resync(self): + async def _maybe_retry_device_resync(self): """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -705,12 +682,12 @@ def _maybe_retry_device_resync(self): # we don't send too many requests. self._resync_retry_in_progress = True # Get all of the users that need resyncing. - need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + need_resync = await self.store.get_user_ids_requiring_device_list_resync() # Iterate over the set of user IDs. for user_id in need_resync: try: # Try to resync the current user's devices list. - result = yield self.user_device_resync( + result = await self.user_device_resync( user_id=user_id, mark_failed_as_stale=False, ) @@ -734,16 +711,17 @@ def _maybe_retry_device_resync(self): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False - @defer.inlineCallbacks - def user_device_resync(self, user_id, mark_failed_as_stale=True): + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[dict]: """Fetches all devices for a user and updates the device cache with them. Args: - user_id (str): The user's id whose device_list will be updated. - mark_failed_as_stale (bool): Whether to mark the user's device list as stale + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale if the attempt to resync failed. Returns: - Deferred[dict]: a dict with device info as under the "devices" in the result of this + A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ @@ -752,12 +730,12 @@ def user_device_resync(self, user_id, mark_failed_as_stale=True): # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: - result = yield self.federation.query_user_devices(origin, user_id) + result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return except (RequestSendFailed, HttpResponseException) as e: @@ -768,7 +746,7 @@ def user_device_resync(self, user_id, mark_failed_as_stale=True): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list @@ -792,7 +770,7 @@ def user_device_resync(self, user_id, mark_failed_as_stale=True): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return log_kv({"result": result}) @@ -833,25 +811,24 @@ def user_device_resync(self, user_id, mark_failed_as_stale=True): stream_id, ) - yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) + await self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. - cross_signing_device_ids = yield self.process_cross_signing_key_update( + cross_signing_device_ids = await self.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + cross_signing_device_ids - yield self.device_handler.notify_device_update(user_id, device_ids) + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. self._seen_updates[user_id] = {stream_id} - defer.returnValue(result) + return result - @defer.inlineCallbacks - def process_cross_signing_key_update( + async def process_cross_signing_key_update( self, user_id: str, master_key: Optional[Dict[str, Any]], @@ -872,14 +849,14 @@ def process_cross_signing_key_update( device_ids = [] if master_key: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) if self_signing_key: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 52499c679d22..1178af692015 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,7 +19,7 @@ import itertools import logging -from collections import Container +from collections.abc import Container from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -44,6 +44,7 @@ FederationDeniedError, FederationError, HttpResponseException, + NotFoundError, RequestSendFailed, SynapseError, ) @@ -1442,10 +1443,20 @@ async def on_make_join_request( ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - event_content = {"membership": Membership.JOIN} - + # checking the room version will check that we've actually heard of the room + # (and return a 404 otherwise) room_version = await self.store.get_room_version_id(room_id) + # now check that we are *still* in the room + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + if not is_in_room: + logger.info( + "Got /make_join request for room %s we are no longer in", room_id, + ) + raise NotFoundError("Not an active room on this server") + + event_content = {"membership": Membership.JOIN} + builder = self.event_builder_factory.new( room_version, { diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da206e1ec112..c47764a4ce22 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -488,11 +488,15 @@ def create_event( try: if "displayname" not in content: - displayname = yield profile.get_displayname(target) + displayname = yield defer.ensureDeferred( + profile.get_displayname(target) + ) if displayname is not None: content["displayname"] = displayname if "avatar_url" not in content: - avatar_url = yield profile.get_avatar_url(target) + avatar_url = yield defer.ensureDeferred( + profile.get_avatar_url(target) + ) if avatar_url is not None: content["avatar_url"] = avatar_url except Exception as e: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dd8979e75040..acecb9c5dbfb 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -15,10 +15,8 @@ # limitations under the License. import logging -from typing import List - -from six.moves import range +from typing import List from signedjson.sign import sign_json from twisted.internet import defer, reactor @@ -145,16 +143,15 @@ def _replicate_host_profile_batch(self, host, batchnum): ) raise - @defer.inlineCallbacks - def get_profile(self, user_id): + async def get_profile(self, user_id): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -165,7 +162,7 @@ def get_profile(self, user_id): return {"displayname": displayname, "avatar_url": avatar_url} else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": user_id}, @@ -177,8 +174,7 @@ def get_profile(self, user_id): except HttpResponseException as e: raise e.to_synapse_error() - @defer.inlineCallbacks - def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id): """Get the profile information from our local cache. If the user is ours then the profile information will always be corect. Otherwise, it may be out of date/missing. @@ -186,10 +182,10 @@ def get_profile_from_cache(self, user_id): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -199,14 +195,13 @@ def get_profile_from_cache(self, user_id): return {"displayname": displayname, "avatar_url": avatar_url} else: - profile = yield self.store.get_from_remote_profile_cache(user_id) + profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - @defer.inlineCallbacks - def get_displayname(self, target_user): + async def get_displayname(self, target_user): if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) except StoreError as e: @@ -217,7 +212,7 @@ def get_displayname(self, target_user): return displayname else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "displayname"}, @@ -334,11 +329,10 @@ def set_active( # start a profile replication push run_in_background(self._replicate_profiles) - @defer.inlineCallbacks - def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): try: - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -348,7 +342,7 @@ def get_avatar_url(self, target_user): return avatar_url else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "avatar_url"}, @@ -455,8 +449,7 @@ def _validate_and_parse_media_id_from_avatar_url(self, mxc): raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc) return avatar_pieces[-1] - @defer.inlineCallbacks - def on_profile_query(self, args): + async def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -466,12 +459,12 @@ def on_profile_query(self, args): response = {} try: if just_field is None or just_field == "displayname": - response["displayname"] = yield self.store.get_profile_displayname( + response["displayname"] = await self.store.get_profile_displayname( user.localpart ) if just_field is None or just_field == "avatar_url": - response["avatar_url"] = yield self.store.get_profile_avatar_url( + response["avatar_url"] = await self.store.get_profile_avatar_url( user.localpart ) except StoreError as e: @@ -506,8 +499,7 @@ async def _update_join_states(self, requester, target_user): "Failed to update join event for room %s - %s", room_id, str(e) ) - @defer.inlineCallbacks - def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed(self, target_user, requester=None): """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users @@ -539,8 +531,8 @@ def check_profile_query_allowed(self, target_user, requester=None): return try: - requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) - target_user_rooms = yield self.store.get_rooms_for_user( + requester_rooms = await self.store.get_rooms_for_user(requester.to_string()) + target_user_rooms = await self.store.get_rooms_for_user( target_user.to_string() ) @@ -573,25 +565,24 @@ def _start_update_remote_profile_cache(self): "Update remote profile", self._update_remote_profile_cache ) - @defer.inlineCallbacks - def _update_remote_profile_cache(self): + async def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. """ - entries = yield self.store.get_remote_profile_cache_entries_that_expire( + entries = await self.store.get_remote_profile_cache_entries_that_expire( last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS ) for user_id, displayname, avatar_url in entries: - is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + is_subscribed = await self.store.is_subscribed_remote_profile_for_user( user_id ) if not is_subscribed: - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) continue try: - profile = yield self.federation.make_query( + profile = await self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", args={"user_id": user_id}, @@ -600,7 +591,7 @@ def _update_remote_profile_cache(self): except Exception: logger.exception("Failed to get avatar_url") - yield self.store.update_remote_profile_cache( + await self.store.update_remote_profile_cache( user_id, displayname, avatar_url ) continue @@ -609,4 +600,4 @@ def _update_remote_profile_cache(self): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) + await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 8bc100db42b8..f922d8a54545 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id from synapse.util.async_helpers import maybe_awaitable @@ -129,15 +127,14 @@ class ReceiptEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) - to_key = yield self.get_current_key() + to_key = self.get_current_key() if from_key == to_key: return [], to_key - events = yield self.store.get_linearized_receipts_for_rooms( + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) @@ -146,8 +143,7 @@ def get_new_events(self, from_key, room_ids, **kwargs): def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): to_key = int(config.from_key) if config.to_key: @@ -155,8 +151,8 @@ def get_pagination_rows(self, user, config, key): else: from_key = None - room_ids = yield self.store.get_rooms_for_user(user.to_string()) - events = yield self.store.get_linearized_receipts_for_rooms( + room_ids = await self.store.get_rooms_for_user(user.to_string()) + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f223630d437d..d00b9dc5374b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -28,7 +28,6 @@ ) from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester -from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -51,14 +50,7 @@ def __init__(self, hs): self.http_client = hs.get_simple_http_client() self.identity_handler = self.hs.get_handlers().identity_handler self.ratelimiter = hs.get_registration_ratelimiter() - - self._next_generated_user_id = None - self.macaroon_gen = hs.get_macaroon_generator() - - self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer" - ) self._server_notices_mxid = hs.config.server_notices_mxid self._show_in_user_directory = self.hs.config.show_users_in_user_directory @@ -239,7 +231,7 @@ async def register_user( if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self._generate_user_id() + localpart = await self.store.generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5e05be6181ad..5dd7b2839194 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -20,12 +20,10 @@ import msgpack from unpaddedbase64 import decode_base64, encode_base64 -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, HttpResponseException from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler @@ -47,7 +45,7 @@ def __init__(self, hs): hs, "remote_room_list", timeout_ms=30 * 1000 ) - def get_local_public_room_list( + async def get_local_public_room_list( self, limit=None, since_token=None, @@ -72,7 +70,7 @@ def get_local_public_room_list( API """ if not self.enable_room_list_search: - return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) + return {"chunk": [], "total_room_count_estimate": 0} logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", @@ -87,7 +85,7 @@ def get_local_public_room_list( # appservice specific lists. logger.info("Bypassing cache as search request.") - return self._get_public_room_list( + return await self._get_public_room_list( limit, since_token, search_filter, @@ -96,7 +94,7 @@ def get_local_public_room_list( ) key = (limit, since_token, network_tuple) - return self.response_cache.wrap( + return await self.response_cache.wrap( key, self._get_public_room_list, limit, @@ -105,8 +103,7 @@ def get_local_public_room_list( from_federation=from_federation, ) - @defer.inlineCallbacks - def _get_public_room_list( + async def _get_public_room_list( self, limit: Optional[int] = None, since_token: Optional[str] = None, @@ -145,7 +142,7 @@ def _get_public_room_list( # we request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None - results = yield self.store.get_largest_public_rooms( + results = await self.store.get_largest_public_rooms( network_tuple, search_filter, probing_limit, @@ -221,44 +218,44 @@ def build_room_entry(room): response["chunk"] = results - response["total_room_count_estimate"] = yield self.store.count_public_rooms( + response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, ignore_non_federatable=from_federation ) return response - @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry( + @cached(num_args=1, cache_context=True) + async def generate_room_entry( self, - room_id, - num_joined_users, + room_id: str, + num_joined_users: int, cache_context, - with_alias=True, - allow_private=False, - ): + with_alias: bool = True, + allow_private: bool = False, + ) -> Optional[dict]: """Returns the entry for a room Args: - room_id (str): The room's ID. - num_joined_users (int): Number of users in the room. + room_id: The room's ID. + num_joined_users: Number of users in the room. cache_context: Information for cached responses. - with_alias (bool): Whether to return the room's aliases in the result. - allow_private (bool): Whether invite-only rooms should be shown. + with_alias: Whether to return the room's aliases in the result. + allow_private: Whether invite-only rooms should be shown. Returns: - Deferred[dict|None]: Returns a room entry as a dictionary, or None if this + Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ result = {"room_id": room_id, "num_joined_members": num_joined_users} if with_alias: - aliases = yield self.store.get_aliases_for_room( + aliases = await self.store.get_aliases_for_room( room_id, on_invalidate=cache_context.invalidate ) if aliases: result["aliases"] = aliases - current_state_ids = yield self.store.get_current_state_ids( + current_state_ids = await self.store.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) @@ -266,7 +263,7 @@ def generate_room_entry( # We're not in the room, so may as well bail out here. return result - event_map = yield self.store.get_events( + event_map = await self.store.get_events( [ event_id for key, event_id in current_state_ids.items() @@ -336,8 +333,7 @@ def generate_room_entry( return result - @defer.inlineCallbacks - def get_remote_public_room_list( + async def get_remote_public_room_list( self, server_name, limit=None, @@ -356,7 +352,7 @@ def get_remote_public_room_list( # to a locally-filtered search if we must. try: - res = yield self._get_remote_list_cached( + res = await self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, @@ -381,7 +377,7 @@ def get_remote_public_room_list( limit = None since_token = None - res = yield self._get_remote_list_cached( + res = await self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, @@ -400,7 +396,7 @@ def get_remote_public_room_list( return res - def _get_remote_list_cached( + async def _get_remote_list_cached( self, server_name, limit=None, @@ -412,7 +408,7 @@ def _get_remote_list_cached( repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - return repl_layer.get_public_rooms( + return await repl_layer.get_public_rooms( server_name, limit=limit, since_token=since_token, @@ -428,7 +424,7 @@ def _get_remote_list_cached( include_all_networks, third_party_instance_id, ) - return self.remote_response_cache.wrap( + return await self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, server_name, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 846ddbdc6cef..a86ac0150e05 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,15 +15,19 @@ import logging from collections import namedtuple -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Set, Tuple from synapse.api.errors import AuthError, SynapseError -from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.streams import TypingStream from synapse.types import UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -39,48 +43,48 @@ FEDERATION_PING_INTERVAL = 40 * 1000 -class TypingHandler(object): - def __init__(self, hs): +class FollowerTypingHandler: + """A typing handler on a different process than the writer that is updated + via replication. + """ + + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.server_name = hs.config.server_name - self.auth = hs.get_auth() - self.is_mine_id = hs.is_mine_id - self.notifier = hs.get_notifier() - self.state = hs.get_state_handler() - - self.hs = hs - self.clock = hs.get_clock() - self.wheel_timer = WheelTimer(bucket_size=5000) + self.is_mine_id = hs.is_mine_id - self.federation = hs.get_federation_sender() + self.federation = None + if hs.should_send_federation(): + self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + if hs.config.worker.writers.typing != hs.get_instance_name(): + hs.get_federation_registry().register_instance_for_edu( + "m.typing", hs.config.worker.writers.typing, + ) - hs.get_distributor().observe("user_left_room", self.user_left_room) + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} - self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} - + self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 - self._reset() - - # caches which room_ids changed at which serials - self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial - ) self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): - """ - Reset the typing handler's data caches. + """Reset the typing handler's data caches. """ # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing self._room_typing = {} + self._member_last_federation_poke = {} + self.wheel_timer = WheelTimer(bucket_size=5000) + def _handle_timeouts(self): logger.debug("Checking for typing timeouts") @@ -89,30 +93,140 @@ def _handle_timeouts(self): members = set(self.wheel_timer.fetch(now)) for member in members: - if not self.is_typing(member): - # Nothing to do if they're no longer typing - continue - - until = self._member_typing_until.get(member, None) - if not until or until <= now: - logger.info("Timing out typing for: %s", member.user_id) - self._stopped_typing(member) - continue - - # Check if we need to resend a keep alive over federation for this - # user. - if self.hs.is_mine_id(member.user_id): - last_fed_poke = self._member_last_federation_poke.get(member, None) - if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background(self._push_remote, member=member, typing=True) - - # Add a paranoia timer to ensure that we always have a timer for - # each person typing. - self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) + self._handle_timeout_for_member(now, member) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + # Check if we need to resend a keep alive over federation for this + # user. + if self.federation and self.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: + run_as_background_process( + "typing._push_remote", self._push_remote, member=member, typing=True + ) + + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) + async def _push_remote(self, member, typing): + if not self.federation: + return + + try: + users = await self.store.get_users_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + ) + + for domain in {get_domain_from_id(u) for u in users}: + if domain != self.server_name: + logger.debug("sending typing update to %s", domain) + self.federation.build_and_send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") + + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + """Should be called whenever we receive updates for typing stream. + """ + + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. + self._latest_room_serial = token + + for row in rows: + self._room_serials[row.room_id] = token + + prev_typing = set(self._room_typing.get(row.room_id, [])) + now_typing = set(row.user_ids) + self._room_typing[row.room_id] = row.user_ids + + run_as_background_process( + "_handle_change_in_typing", + self._handle_change_in_typing, + row.room_id, + prev_typing, + now_typing, + ) + + async def _handle_change_in_typing( + self, room_id: str, prev_typing: Set[str], now_typing: Set[str] + ): + """Process a change in typing of a room from replication, sending EDUs + for any local users. + """ + for user_id in now_typing - prev_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), True) + + for user_id in prev_typing - now_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), False) + + def get_current_token(self): + return self._latest_room_serial + + +class TypingWriterHandler(FollowerTypingHandler): + def __init__(self, hs): + super().__init__(hs) + + assert hs.config.worker.writers.typing == hs.get_instance_name() + + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + + self.hs = hs + + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + + hs.get_distributor().observe("user_left_room", self.user_left_room) + + self._member_typing_until = {} # clock time we expect to stop + + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial + ) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + super()._handle_timeout_for_member(now, member) + + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + until = self._member_typing_until.get(member, None) + if not until or until <= now: + logger.info("Timing out typing for: %s", member.user_id) + self._stopped_typing(member) + return + async def started_typing(self, target_user, auth_user, room_id, timeout): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -179,35 +293,11 @@ def _stopped_typing(self, member): def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - run_in_background(self._push_remote, member, typing) - - self._push_update_local(member=member, typing=typing) - - async def _push_remote(self, member, typing): - try: - users = await self.store.get_users_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() - - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + run_as_background_process( + "typing._push_remote", self._push_remote, member, typing ) - for domain in {get_domain_from_id(u) for u in users}: - if domain != self.server_name: - logger.debug("sending typing update to %s", domain) - self.federation.build_and_send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) - except Exception: - logger.exception("Error pushing typing notif to remotes") + self._push_update_local(member=member, typing=typing) async def _recv_edu(self, origin, content): room_id = content["room_id"] @@ -304,8 +394,11 @@ async def get_all_typing_updates( return rows, current_id, limited - def get_current_token(self): - return self._latest_room_serial + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + # The writing process should never get updates from replication. + raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource(object): diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index c6c0e623c16e..21015175758c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -733,37 +733,54 @@ def decorator(func): _opname = opname if opname else func.__name__ - @wraps(func) - def _trace_inner(*args, **kwargs): - if opentracing is None: - return func(*args, **kwargs) + if inspect.iscoroutinefunction(func): - scope = start_active_span(_opname) - scope.__enter__() + @wraps(func) + async def _trace_inner(*args, **kwargs): + if opentracing is None: + return await func(*args, **kwargs) - try: - result = func(*args, **kwargs) - if isinstance(result, defer.Deferred): + with start_active_span(_opname) as scope: + try: + return await func(*args, **kwargs) + except Exception: + scope.span.set_tag(tags.ERROR, True) + raise - def call_back(result): - scope.__exit__(None, None, None) - return result + else: + # The other case here handles both sync functions and those + # decorated with inlineDeferred. + @wraps(func) + def _trace_inner(*args, **kwargs): + if opentracing is None: + return func(*args, **kwargs) - def err_back(result): - scope.span.set_tag(tags.ERROR, True) - scope.__exit__(None, None, None) - return result + scope = start_active_span(_opname) + scope.__enter__() + + try: + result = func(*args, **kwargs) + if isinstance(result, defer.Deferred): + + def call_back(result): + scope.__exit__(None, None, None) + return result - result.addCallbacks(call_back, err_back) + def err_back(result): + scope.span.set_tag(tags.ERROR, True) + scope.__exit__(None, None, None) + return result - else: - scope.__exit__(None, None, None) + result.addCallbacks(call_back, err_back) + + else: + scope.__exit__(None, None, None) - return result + return result - except Exception as e: - scope.__exit__(type(e), None, e.__traceback__) - raise + except Exception as e: + scope.__exit__(type(e), None, e.__traceback__) + raise return _trace_inner diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 99049bb5d8f3..fea774e2e524 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -14,9 +14,7 @@ # limitations under the License. -import inspect import logging -import time from functools import wraps from inspect import getcallargs @@ -74,127 +72,3 @@ def format(value): wrapped.__name__ = func_name return wrapped - - -def time_function(f): - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - global _TIME_FUNC_ID - id = _TIME_FUNC_ID - _TIME_FUNC_ID += 1 - - start = time.clock() - - try: - _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) - - r = f(*args, **kwargs) - finally: - end = time.clock() - _log_debug_as_f( - f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) - ) - - return r - - return wrapped - - -def trace_function(f): - func_name = f.__name__ - linenum = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - - s = frame.f_back - - to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" - % (pathname, linenum, func_name, args, kwargs) - ] - while s: - if True or s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_print.append( - "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print) - - record = logging.LogRecord( - name=name, - level=level, - pathname=pathname, - lineno=lineno, - msg=msg, - args=(), - exc_info=None, - ) - - logger.handle(record) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def get_previous_frames(): - - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - - s = frame.f_back.f_back - to_return = [] - while s: - if s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_return.append( - "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - return ", ".join(to_return) - - -def get_previous_frame(ignore=[]): - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - s = frame.f_back.f_back - - while s: - if s.f_globals["__name__"].startswith("synapse"): - if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - return "{{ %s:%d %s - Args: %s }}" % ( - filename, - lineno, - function, - args_string, - ) - - s = s.f_back - - return None diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index f6a54586815b..2456f12f469d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,13 +15,12 @@ # limitations under the License. import logging -from collections import defaultdict -from threading import Lock -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Dict, Union + +from prometheus_client import Gauge from twisted.internet import defer -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -29,9 +28,18 @@ from synapse.push.pusher import PusherFactory from synapse.util.async_helpers import concurrently_execute +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) +synapse_pushers = Gauge( + "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"] +) + + class PusherPool: """ The pusher pool. This is responsible for dispatching notifications of new events to @@ -47,36 +55,20 @@ class PusherPool: Pusher.on_new_receipts are not expected to return deferreds. """ - def __init__(self, _hs): - self.hs = _hs - self.pusher_factory = PusherFactory(_hs) - self._should_start_pushers = _hs.config.start_pushers + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.pusher_factory = PusherFactory(hs) + self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + # We shard the handling of push notifications by user ID. + self._pusher_shard_config = hs.config.push.pusher_shard_config + self._instance_name = hs.get_instance_name() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] - # a lock for the pushers dict, since `count_pushers` is called from an different - # and we otherwise get concurrent modification errors - self._pushers_lock = Lock() - - def count_pushers(): - results = defaultdict(int) # type: Dict[Tuple[str, str], int] - with self._pushers_lock: - for pushers in self.pushers.values(): - for pusher in pushers.values(): - k = (type(pusher).__name__, pusher.app_id) - results[k] += 1 - return results - - LaterGauge( - name="synapse_pushers", - desc="the number of active pushers", - labels=["kind", "app_id"], - caller=count_pushers, - ) - def start(self): """Starts the pushers off in a background process. """ @@ -104,6 +96,7 @@ def add_pusher( Returns: Deferred[EmailPusher|HttpPusher] """ + time_now_msec = self.clock.time_msec() # we try to create the pusher just to validate the config: it @@ -176,6 +169,9 @@ def remove_pushers_by_access_token(self, user_id, access_tokens): access_tokens (Iterable[int]): access token *ids* to remove pushers for """ + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): if p["access_token"] in tokens: @@ -237,6 +233,9 @@ def start_pusher_by_id(self, app_id, pushkey, user_id): if not self._should_start_pushers: return + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None @@ -275,6 +274,11 @@ def _start_pusher(self, pusherdict): Returns: Deferred[EmailPusher|HttpPusher] """ + if not self._pusher_shard_config.should_handle( + self._instance_name, pusherdict["user_name"] + ): + return + try: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: @@ -298,11 +302,12 @@ def _start_pusher(self, pusherdict): appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) - with self._pushers_lock: - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + + synapse_pushers.labels(type(p).__name__, p.app_id).inc() # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to @@ -330,9 +335,10 @@ def remove_pusher(self, app_id, pushkey, user_id): if appid_pushkey in byuser: logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - byuser[appid_pushkey].on_stop() - with self._pushers_lock: - del byuser[appid_pushkey] + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index bd394f6b0059..a8a16dbc711c 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -26,7 +26,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 80f5df60f902..f88e0a2e404e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -14,9 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory @@ -42,8 +54,8 @@ EventsStream, FederationStream, Stream, + TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -61,6 +73,12 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -96,6 +114,14 @@ def __init__(self, hs): continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -107,10 +133,6 @@ def __init__(self, hs): self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - # Map of stream name to batched updates. See RdataCommand for info on # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] @@ -122,10 +144,6 @@ def __init__(self, hs): # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming stream names that are coming from - # that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -133,6 +151,32 @@ def __init__(self, hs): lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_stream + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -143,6 +187,64 @@ def __init__(self, hs): if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + async def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + # if we're already processing this stream, stick the new command in the + # queue, and we're done. + if stream_name in self._processing_streams: + queue.append((cmd, conn)) + return + + # otherwise, process the new command. + + # arguably we should start off a new background process here, but nothing + # will be too upset if we don't return for ages, so let's save the overhead + # and use the existing logcontext. + + self._processing_streams.add(stream_name) + try: + # might as well skip the queue for this one, since it must be empty + assert not queue + await self._process_command(cmd, conn, stream_name) + + # now process any other commands that have built up while we were + # dealing with that one. + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. @@ -276,63 +378,71 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got RDATA for unknown stream: %s", stream_name) - return - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # Discard this data if this token is earlier than the current - # position. Note that streams can be reset (in which case you - # expect an earlier token), but that must be preceded by a - # POSITION command. - if cmd.token <= current_token: - logger.debug( - "Discarding RDATA from stream %s at position %s before previous position %s", - stream_name, - cmd.token, - current_token, - ) - else: - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + await self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -358,67 +468,65 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + await self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) + + # TODO: add some tests for this + + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + self._streams_by_connection.setdefault(conn, set()).add(stream_name) async def on_REMOTE_SERVER_UP( self, conn: AbstractConnection, cmd: RemoteServerUpCommand diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9076bbe9f134..7a42de3f7d24 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -294,11 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 1c2a4cce7f8e..16c63ff4eca8 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr diff --git a/synapse/res/templates/mail-Element.css b/synapse/res/templates/mail-Element.css new file mode 100644 index 000000000000..6a3e36eda124 --- /dev/null +++ b/synapse/res/templates/mail-Element.css @@ -0,0 +1,7 @@ +.header { + border-bottom: 4px solid #e4f7ed ! important; +} + +.notif_link a, .footer a { + color: #76CFA6 ! important; +} diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html index 6b94d8c367c9..d87311f659a5 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html @@ -22,6 +22,8 @@ [Riot] {% elif app_name == "Vector" %} [Vector] + {% elif app_name == "Element" %} + [Element] {% else %} [matrix] {% endif %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index 019506e5fbc7..a2dfeb9e9f78 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -22,6 +22,8 @@ [Riot] {% elif app_name == "Vector" %} [Vector] + {% elif app_name == "Element" %} + [Element] {% else %} [matrix] {% endif %} diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index dc373bc5a352..1c88c93f3836 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -38,6 +38,7 @@ DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, ) @@ -201,6 +202,7 @@ def register_servlets(hs, http_server): register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 544be4706034..b8c95d045a74 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -231,6 +231,31 @@ async def on_GET(self, request, room_id): return 200, ret +class RoomMembersRestServlet(RestServlet): + """ + Get members list of a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + members = await self.store.get_users_in_room(room_id) + ret = {"members": members, "total": len(members)} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index c5a84af04779..1a3398316d1d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -818,9 +818,18 @@ def __init__(self, hs): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + async def on_PUT(self, request, room_id, user_id): requester = await self.auth.get_user_by_req(request) + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index bc11b4dda4ab..b21538766df8 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -22,6 +22,7 @@ from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -51,7 +52,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): return patterns -def set_timeline_upper_limit(filter_json, filter_timeline_limit): +def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: + """ + Enforces a maximum limit of a timeline query. + + Params: + filter_json: The timeline query to modify. + filter_timeline_limit: The maximum limit to allow, passing -1 will + disable enforcing a maximum limit. + """ if filter_timeline_limit < 0: return # no upper limits timeline = filter_json.get("room", {}).get("timeline", {}) diff --git a/synapse/server.py b/synapse/server.py index f838a03d71dc..a34d8149ff29 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -44,7 +44,6 @@ from synapse.federation.federation_server import ( FederationHandlerRegistry, FederationServer, - ReplicationFederationHandlerRegistry, ) from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import FederationSender @@ -84,7 +83,7 @@ from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler -from synapse.handlers.typing import TypingHandler +from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient @@ -380,7 +379,10 @@ def build_presence_handler(self): return PresenceHandler(self) def build_typing_handler(self): - return TypingHandler(self) + if self.config.worker.writers.typing == self.get_instance_name(): + return TypingWriterHandler(self) + else: + return FollowerTypingHandler(self) def build_sync_handler(self): return SyncHandler(self) @@ -536,10 +538,7 @@ def build_room_member_handler(self): return RoomMemberMasterHandler(self) def build_federation_registry(self): - if self.config.worker_app: - return ReplicationFederationHandlerRegistry(self) - else: - return FederationHandlerRegistry() + return FederationHandlerRegistry(self) def build_server_notices_manager(self): if self.config.worker_app: diff --git a/synapse/server.pyi b/synapse/server.pyi index cd50c721b82a..90a673778f8e 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -148,3 +148,5 @@ class HomeServer(object): self, ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: pass + def should_send_federation(self) -> bool: + pass diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bfce541ca7ad..985a04286961 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -100,8 +100,8 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, so we - # consistenty get a Unicode-containing object out. + # Decode it to a Unicode string before feeding it to json.loads, since + # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 59f3394b0a0f..018826ef6947 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -249,7 +249,10 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: retcol="progress_json", ) - progress = json.loads(progress_json) + # Avoid a circular import. + from synapse.storage._base import db_to_json + + progress = db_to_json(progress_json) time_start = self._clock.time_msec() items_updated = await update_handler(progress, batch_size) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 4b4763c70172..932458f651eb 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -128,7 +128,7 @@ def __init__(self, database: Database, db_conn, hs): db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index b58f04d00dff..33cc372dfd7e 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,7 +22,7 @@ from twisted.internet import defer -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -77,7 +77,7 @@ def get_account_data_for_user_txn(txn): ) global_account_data = { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } rows = self.db.simple_select_list_txn( @@ -90,7 +90,7 @@ def get_account_data_for_user_txn(txn): by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) - room_data[row["account_data_type"]] = json.loads(row["content"]) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -113,7 +113,7 @@ def get_global_account_data_by_type_for_user(self, data_type, user_id): ) if result: - return json.loads(result) + return db_to_json(result) else: return None @@ -137,7 +137,7 @@ def get_account_data_for_room_txn(txn): ) return { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } return self.db.runInteraction( @@ -170,7 +170,7 @@ def get_account_data_for_room_and_type_txn(txn): allow_none=True, ) - return json.loads(content_json) if content_json else None + return db_to_json(content_json) if content_json else None return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn @@ -255,7 +255,7 @@ def get_updated_account_data_for_user_txn(txn): txn.execute(sql, (user_id, stream_id)) - global_account_data = {row[0]: json.loads(row[1]) for row in txn} + global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" @@ -267,7 +267,7 @@ def get_updated_account_data_for_user_txn(txn): account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) - room_account_data[row[1]] = json.loads(row[2]) + room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 7a1fe8cdd249..56659fed37d9 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -22,7 +22,7 @@ from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database @@ -303,7 +303,7 @@ def _get_oldest_unsent_txn(txn): if not entry: return None - event_ids = json.loads(entry["event_ids"]) + event_ids = db_to_json(entry["event_ids"]) events = yield self.get_events_as_list(event_ids) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index d313b9705f79..da297b31fbbe 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache @@ -65,7 +65,7 @@ def get_new_messages_for_device_txn(txn): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos @@ -173,7 +173,7 @@ def get_new_messages_for_remote_destination_txn(txn): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id @@ -424,9 +424,6 @@ def add_messages_txn(txn, now_ms, stream_id): def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" - txn.execute(sql, (stream_id, stream_id)) - local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 343cf9a2d5f2..45581a65004e 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -577,7 +577,7 @@ def get_users_whose_signatures_changed(self, user_id, from_key): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return {user for row in rows for user in json.loads(row[0])} + return {user for row in rows for user in db_to_json(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 23f4570c4b3e..615364f01837 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +from canonicaljson import json from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json class EndToEndRoomKeyStore(SQLBaseStore): @@ -148,7 +148,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): "forwarded_count": row["forwarded_count"], # is_verified must be returned to the client as a boolean "is_verified": bool(row["is_verified"]), - "session_data": json.loads(row["session_data"]), + "session_data": db_to_json(row["session_data"]), } return sessions @@ -222,7 +222,7 @@ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): "first_message_index": row[2], "forwarded_count": row[3], "is_verified": row[4], - "session_data": json.loads(row[5]), + "session_data": db_to_json(row[5]), } return ret @@ -319,7 +319,7 @@ def _get_e2e_room_keys_version_info_txn(txn): keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, retcols=("version", "algorithm", "auth_data", "etag"), ) - result["auth_data"] = json.loads(result["auth_data"]) + result["auth_data"] = db_to_json(result["auth_data"]) result["version"] = str(result["version"]) if result["etag"] is None: result["etag"] = 0 diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 6c3cff82e1e4..317c07a8297c 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -366,7 +366,7 @@ def _get_bare_e2e_cross_signing_keys_bulk_txn( for row in rows: user_id = row["user_id"] key_type = row["keytype"] - key = json.loads(row["keydata"]) + key = db_to_json(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index bc9f4f08eac4..504babaa7e18 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks @@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight): """Custom deserializer for actions. This allows us to "compress" common actions """ if actions: - return json.loads(actions) + return db_to_json(actions) if is_highlight: return DEFAULT_HIGHLIGHT_ACTION diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 66f01aad84ec..6f2e0d15cc0d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple import attr -from canonicaljson import json from prometheus_client import Counter from twisted.internet import defer @@ -32,7 +31,7 @@ from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function -from synapse.storage._base import make_in_list_sql_clause +from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.data_stores.main.search import SearchEntry from synapse.storage.database import Database, LoggingTransaction from synapse.storage.util.id_generators import StreamIdGenerator @@ -236,7 +235,7 @@ def _get_events_which_are_prevs_txn(txn, batch): ) txn.execute(sql + clause, args) - results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) + results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): yield self.db.runInteraction( @@ -297,7 +296,7 @@ def _get_prevs_before_rejected_txn(txn, batch): if prev_event_id in existing_prevs: continue - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) @@ -583,7 +582,7 @@ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): txn.execute(sql, (room_id, EventTypes.Create, "")) row = txn.fetchone() if row: - event_json = json.loads(row[0]) + event_json = db_to_json(row[0]) content = event_json.get("content", {}) creator = content.get("creator") room_version_id = content.get("room_version", RoomVersions.V1.identifier) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 62d28f44dc97..663c94b24fc8 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -15,12 +15,10 @@ import logging -from canonicaljson import json - from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ def reindex_txn(txn): for row in rows: try: event_id = row[1] - event_json = json.loads(row[2]) + event_json = db_to_json(row[2]) sender = event_json["sender"] content = event_json["content"] @@ -208,7 +206,7 @@ def reindex_search_txn(txn): for row in ev_rows: event_id = row["event_id"] - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) try: origin_server_ts = event_json["origin_server_ts"] except (KeyError, AttributeError): @@ -317,7 +315,7 @@ def _cleanup_extremities_bg_update_txn(txn): soft_failed = False if metadata: - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) @@ -358,7 +356,7 @@ def _cleanup_extremities_bg_update_txn(txn): graph[event_id] = {prev_event_id} - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) else: @@ -543,7 +541,7 @@ def _event_store_labels_txn(txn): last_row_event_id = "" for (event_id, event_json_raw) in results: try: - event_json = json.loads(event_json_raw) + event_json = db_to_json(event_json_raw) self.db.simple_insert_many_txn( txn=txn, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 01cad7d4faa2..e812c67078de 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -21,7 +21,6 @@ from collections import namedtuple from typing import List, Optional, Tuple -from canonicaljson import json from constantly import NamedConstant, Names from twisted.internet import defer @@ -40,7 +39,7 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id @@ -611,8 +610,8 @@ def _get_events_from_db(self, event_ids, allow_rejected=False): if not allow_rejected and rejected_reason: continue - d = json.loads(row["json"]) - internal_metadata = json.loads(row["internal_metadata"]) + d = db_to_json(row["json"]) + internal_metadata = db_to_json(row["internal_metadata"]) format_version = row["format_version"] if format_version is None: @@ -640,7 +639,7 @@ def _get_events_from_db(self, event_ids, allow_rejected=False): else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) if not room_version: - logger.error( + logger.warning( "Event %s in room %s has unknown room version %s", event_id, d["room_id"], diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 4fb9f9850c79..01ff561e1a61 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -197,7 +197,7 @@ def _get_rooms_for_summary_txn(txn): categories = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -221,7 +221,7 @@ def get_group_categories(self, group_id): return { row["category_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } @@ -235,7 +235,7 @@ def get_group_category(self, group_id, category_id): desc="get_group_category", ) - category["profile"] = json.loads(category["profile"]) + category["profile"] = db_to_json(category["profile"]) return category @@ -251,7 +251,7 @@ def get_group_roles(self, group_id): return { row["role_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } @@ -265,7 +265,7 @@ def get_group_role(self, group_id, role_id): desc="get_group_role", ) - role["profile"] = json.loads(role["profile"]) + role["profile"] = db_to_json(role["profile"]) return role @@ -333,7 +333,7 @@ def _get_users_for_summary_txn(txn): roles = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -462,7 +462,7 @@ def get_remote_attestation(self, group_id, user_id): now = int(self._clock.time_msec()) if row and now < row["valid_until_ms"]: - return json.loads(row["attestation_json"]) + return db_to_json(row["attestation_json"]) return None @@ -489,7 +489,7 @@ def _get_all_groups_for_user_txn(txn): "group_id": row[0], "type": row[1], "membership": row[2], - "content": json.loads(row[3]), + "content": db_to_json(row[3]), } for row in txn ] @@ -519,7 +519,7 @@ def _get_groups_changes_for_user_txn(txn): "group_id": group_id, "membership": membership, "type": gtype, - "content": json.loads(content_json), + "content": db_to_json(content_json), } for group_id, membership, gtype, content_json in txn ] @@ -567,7 +567,7 @@ def _get_all_groups_changes_txn(txn): """ txn.execute(sql, (last_id, current_id, limit)) updates = [ - (stream_id, (group_id, user_id, gtype, json.loads(content_json))) + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) for stream_id, group_id, user_id, gtype, content_json in txn ] diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index f6e78ca5903f..d181488db710 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -24,7 +24,7 @@ from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore @@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) + rule["conditions"] = db_to_json(rawrule["conditions"]) + rule["actions"] = db_to_json(rawrule["actions"]) rule["default"] = False ruleslist.append(rule) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index 546101624094..e18f1ca87c86 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -17,11 +17,11 @@ import logging from typing import Iterable, Iterator, List, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: for r in rows: dataJson = r["data"] try: - r["data"] = json.loads(dataJson) + r["data"] = db_to_json(dataJson) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 8f5505bd674f..1d723f2d347e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,7 +22,7 @@ from twisted.internet import defer -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.async_helpers import ObservableDeferred @@ -203,7 +203,7 @@ def f(txn): for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] - ] = json.loads(row["data"]) + ] = db_to_json(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @@ -260,7 +260,7 @@ def f(txn): event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) - receipt_type[row["user_id"]] = json.loads(row["data"]) + receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] @@ -329,7 +329,7 @@ def get_all_updated_receipts_txn(txn): """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn] + updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] limited = False upper_bound = current_id diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index efb1a4fb4ce9..e1b6cded6532 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -42,6 +44,10 @@ def __init__(self, database: Database, db_conn, hs): self.config = hs.config self.clock = hs.get_clock() + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + @cached() def get_user_by_id(self, user_id): return self.db.simple_select_one( @@ -561,39 +567,17 @@ def _count_users(txn): ret = yield self.db.runInteraction("count_real_users", _count_users) return ret - @defer.inlineCallbacks - def find_next_generated_user_id_localpart(self): - """ - Gets the localpart of the next generated user ID. + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user - Generated user IDs are integers, so we find the largest integer user ID - already taken and return that plus one. + Returns: a (hopefully) free localpart """ - - def _find_next_generated_user_id(txn): - # We bound between '@0' and '@a' to avoid pulling the entire table - # out. - txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - max_found = 0 - - for (user_id,) in txn: - match = regex.search(user_id) - if match: - max_found = max(int(match.group(1)), max_found) - - return max_found + 1 - - return ( - ( - yield self.db.runInteraction( - "find_next_generated_user_id", _find_next_generated_user_id - ) - ) + next_id = await self.db.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn ) + return str(next_id) + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid @@ -1653,3 +1637,26 @@ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): keyvalues={"user_id": user_id}, values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 47f98ba421d7..93b6380f13aa 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -28,7 +28,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID @@ -693,7 +693,7 @@ def _get_media_mxcs_in_room_txn(self, txn, room_id): next_token = None for stream_ordering, content_json in txn: next_token = stream_ordering - event_json = json.loads(content_json) + event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") thumbnail_url = content.get("info", {}).get("thumbnail_url") @@ -938,7 +938,7 @@ def _background_insert_retention_txn(txn): if not row["json"]: retention_policy = {} else: - ev = json.loads(row["json"]) + ev = db_to_json(row["json"]) retention_policy = ev["content"] self.db.simple_insert_txn( @@ -994,7 +994,7 @@ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): updates = [] for room_id, event_json in txn: - event_dict = json.loads(event_json) + event_dict = db_to_json(event_json) room_version_id = event_dict.get("content", {}).get( "room_version", RoomVersions.V1.identifier ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 44bab65eac27..29765890ee82 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -17,8 +17,6 @@ import logging from typing import Iterable, List, Set -from canonicaljson import json - from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -27,6 +25,7 @@ from synapse.storage._base import ( LoggingTransaction, SQLBaseStore, + db_to_json, make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore @@ -938,7 +937,7 @@ def add_membership_profile_txn(txn): event_id = row["event_id"] room_id = row["room_id"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py new file mode 100644 index 000000000000..2011f6bcebc2 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py @@ -0,0 +1,34 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adds a postgres SEQUENCE for generating guest user IDs. +""" + +from synapse.storage.data_stores.main.registration import ( + find_max_generated_user_id_localpart, +) +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + next_id = find_max_generated_user_id_localpart(cur) + 1 + cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index a8381dc5778d..d52228297c28 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -17,12 +17,10 @@ import re from collections import namedtuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -157,7 +155,7 @@ def reindex_search_txn(txn): stream_ordering = row["stream_ordering"] origin_server_ts = row["origin_server_ts"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 290317fd9457..bd7227773aee 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -21,6 +21,7 @@ from twisted.internet import defer +from synapse.storage._base import db_to_json from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.util.caches.descriptors import cached @@ -49,7 +50,7 @@ def tags_by_room(rows): tags_by_room = {} for row in rows: room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = json.loads(row["content"]) + room_tags[row["tag"]] = db_to_json(row["content"]) return tags_by_room return deferred @@ -180,7 +181,7 @@ def get_tags_for_room(self, user_id, room_id): retcols=("tag", "content"), desc="get_tags_for_room", ).addCallback( - lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} + lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index 4c044b1a1549..5f1b919748a6 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import Any, Dict, Optional, Union import attr +from canonicaljson import json from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.types import JsonDict from synapse.util import stringutils as stringutils @@ -118,7 +118,7 @@ async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: desc="get_ui_auth_session", ) - result["clientdict"] = json.loads(result["clientdict"]) + result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(session_id, **result) @@ -168,7 +168,7 @@ async def get_completed_ui_auth_stages( retcols=("stage_type", "result"), desc="get_completed_ui_auth_stages", ): - results[row["stage_type"]] = json.loads(row["result"]) + results[row["stage_type"]] = db_to_json(row["result"]) return results @@ -224,7 +224,7 @@ def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: A ) # Update it and add it back to the database. - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) serverdict[key] = value self.db.simple_update_one_txn( @@ -254,7 +254,7 @@ async def get_ui_auth_session_data( desc="get_ui_auth_session_data", ) - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) return serverdict.get(key, default) diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 5db9f2013568..128c09a2cffb 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -24,6 +24,8 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database from synapse.storage.state import StateFilter +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -92,6 +94,14 @@ def __init__(self, database: Database, db_conn, hs): "*stateGroupMembersCache*", 500000, ) + def get_max_state_group_txn(txn: Cursor): + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + return txn.fetchone()[0] + + self._state_group_seq_gen = build_sequence_generator( + self.database_engine, get_max_state_group_txn, "state_group_id_seq" + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -386,7 +396,7 @@ def _store_state_group_txn(txn): # AFAIK, this can never happen raise Exception("current_state_ids cannot be None") - state_group = self.database_engine.get_next_state_group_id(txn) + state_group = self._state_group_seq_gen.get_next_id_txn(txn) self.db.simple_insert_txn( txn, diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ab0bbe4bd364..908cbc79e322 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -91,12 +91,6 @@ def is_connection_closed(self, conn: ConnectionType) -> bool: def lock_table(self, txn, table: str) -> None: ... - @abc.abstractmethod - def get_next_state_group_id(self, txn) -> int: - """Returns an int that can be used as a new state_group ID - """ - ... - @property @abc.abstractmethod def server_version(self) -> str: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a31588080dd1..ff39281f8599 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -154,12 +154,6 @@ def is_connection_closed(self, conn): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - txn.execute("SELECT nextval('state_group_id_seq')") - return txn.fetchone()[0] - @property def server_version(self): """Returns a string giving the server version. For example: '8.1.5' diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 215a94944287..8a0f8c89d173 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -96,19 +96,6 @@ def is_connection_closed(self, conn): def lock_table(self, txn, table): return - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - @property def server_version(self): """Gets a string giving the server version. For example: '3.22.0' diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f89ce0bed2a8..787cebfbec75 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,6 +21,7 @@ from typing_extensions import Deque from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.util.sequence import PostgresSequenceGenerator class IdGenerator(object): @@ -247,7 +248,6 @@ def __init__( ): self._db = db self._instance_name = instance_name - self._sequence_name = sequence_name # We lock as some functions may be called from DB threads. self._lock = threading.Lock() @@ -260,6 +260,8 @@ def __init__( # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + self._sequence_gen = PostgresSequenceGenerator(sequence_name) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: @@ -283,9 +285,7 @@ def _load_current_ids( return current_positions def _load_next_id_txn(self, txn): - txn.execute("SELECT nextval(?)", (self._sequence_name,)) - (next_id,) = txn.fetchone() - return next_id + return self._sequence_gen.get_next_id_txn(txn) async def get_next(self): """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py new file mode 100644 index 000000000000..63dfea422032 --- /dev/null +++ b/synapse/storage/util/sequence.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import threading +from typing import Callable, Optional + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor + + +class SequenceGenerator(metaclass=abc.ABCMeta): + """A class which generates a unique sequence of integers""" + + @abc.abstractmethod + def get_next_id_txn(self, txn: Cursor) -> int: + """Gets the next ID in the sequence""" + ... + + +class PostgresSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses a postgres sequence""" + + def __init__(self, sequence_name: str): + self._sequence_name = sequence_name + + def get_next_id_txn(self, txn: Cursor) -> int: + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + return txn.fetchone()[0] + + +GetFirstCallbackType = Callable[[Cursor], int] + + +class LocalSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses local locking + + This only works reliably if there are no other worker processes generating IDs at + the same time. + """ + + def __init__(self, get_first_callback: GetFirstCallbackType): + """ + Args: + get_first_callback: a callback which is called on the first call to + get_next_id_txn; should return the curreent maximum id + """ + # the callback. this is cleared after it is called, so that it can be GCed. + self._callback = get_first_callback # type: Optional[GetFirstCallbackType] + + # The current max value, or None if we haven't looked in the DB yet. + self._current_max_id = None # type: Optional[int] + self._lock = threading.Lock() + + def get_next_id_txn(self, txn: Cursor) -> int: + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + self._current_max_id += 1 + return self._current_max_id + + +def build_sequence_generator( + database_engine: BaseDatabaseEngine, + get_first_callback: GetFirstCallbackType, + sequence_name: str, +) -> SequenceGenerator: + """Get the best impl of SequenceGenerator available + + This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on + sqlite. + + Args: + database_engine: the database engine we are connected to + get_first_callback: a callback which gets the next sequence ID. Used if + we're on sqlite. + sequence_name: the name of a postgres sequence to use. + """ + if isinstance(database_engine, PostgresEngine): + return PostgresSequenceGenerator(sequence_name) + else: + return LocalSequenceGenerator(get_first_callback) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index da20523b7092..22a857a30616 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging from twisted.internet import defer +from twisted.internet.defer import Deferred, fail, succeed +from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -79,6 +81,28 @@ def fire(self, name, *args, **kwargs): run_as_background_process(name, self.signals[name].fire, *args, **kwargs) +def maybeAwaitableDeferred(f, *args, **kw): + """ + Invoke a function that may or may not return a Deferred or an Awaitable. + + This is a modified version of twisted.internet.defer.maybeDeferred. + """ + try: + result = f(*args, **kw) + except Exception: + return fail(failure.Failure(captureVars=Deferred.debug)) + + if isinstance(result, Deferred): + return result + # Handle the additional case of an awaitable being returned. + elif inspect.isawaitable(result): + return defer.ensureDeferred(result) + elif isinstance(result, failure.Failure): + return fail(result) + else: + return succeed(result) + + class Signal(object): """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -122,7 +146,7 @@ def eb(failure): ), ) - return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 08c86e92b86e..2e2b40a4264b 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -17,7 +17,7 @@ import random import re import string -from collections import Iterable +from collections.abc import Iterable from synapse.api.errors import Codes, SynapseError diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 62b47f65747e..6aa322bf3ac8 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -142,10 +142,8 @@ def test_delete_device(self): self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - res = self.handler.get_device(user1, "abc") - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError ) # we'd like to check the access token was invalidated, but that's a @@ -180,10 +178,9 @@ def test_update_device_too_long_display_name(self): def test_update_unknown_device(self): update = {"display_name": "new_display"} - res = self.handler.update_device("user_id", "unknown_device_id", update) - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.update_device("user_id", "unknown_device_id", update), + synapse.api.errors.NotFoundError, ) def _record_users(self): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index cdd093ffa878..210ddcbb882f 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -334,10 +334,12 @@ def test_self_signing_key_doesnt_show_up_as_device(self): res = None try: - yield self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", + yield defer.ensureDeferred( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) ) except errors.SynapseError as e: res = e.code diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a1f4bde3476b..42a236aa58e6 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -70,7 +70,9 @@ def register_query_handler(query_type, handler): def test_get_my_name(self): yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) - displayname = yield self.handler.get_displayname(self.frank) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.frank) + ) self.assertEquals("Frank", displayname) @@ -138,7 +140,9 @@ def test_get_other_name(self): {"displayname": "Alice"} ) - displayname = yield self.handler.get_displayname(self.alice) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.alice) + ) self.assertEquals(displayname, "Alice") self.mock_federation.make_query.assert_called_with( @@ -152,8 +156,10 @@ def test_get_other_name(self): def test_incoming_fed_query(self): yield self.store.set_profile_displayname("caroline", "Caroline", 1) - response = yield self.query_handlers["profile"]( - {"user_id": "@caroline:test", "field": "displayname"} + response = yield defer.ensureDeferred( + self.query_handlers["profile"]( + {"user_id": "@caroline:test", "field": "displayname"} + ) ) self.assertEquals({"displayname": "Caroline"}, response) @@ -163,8 +169,7 @@ def test_get_my_avatar(self): yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png", 1 ) - - avatar_url = yield self.handler.get_avatar_url(self.frank) + avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) self.assertEquals("http://my.server/me.png", avatar_url) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py new file mode 100644 index 000000000000..2bdc6edbb14f --- /dev/null +++ b/tests/replication/test_pusher_shard.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + + +class PusherShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks pusher sharding works + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["start_pushers"] = False + return conf + + def _create_pusher_and_send_msg(self, localpart): + # Create a user that will get push notifications + user_id = self.register_user(localpart, "pass") + access_token = self.login(localpart, "pass") + + # Register a pusher + user_dict = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_dict["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "https://push.example.com/push"}, + ) + ) + + self.pump() + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + response = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + return event_id + + def test_send_push_single_worker(self): + """Test that registration works when using a pusher worker. + """ + http_client_mock = Mock(spec_set=["post_json_get_json"]) + http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + {"start_pushers": True}, + proxied_http_client=http_client_mock, + ) + + event_id = self._create_pusher_and_send_msg("user") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + def test_send_push_multiple_workers(self): + """Test that registration works when using sharded pusher workers. + """ + http_client_mock1 = Mock(spec_set=["post_json_get_json"]) + http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher1", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock1, + ) + + http_client_mock2 = Mock(spec_set=["post_json_get_json"]) + http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher2", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock2, + ) + + # We choose a user name that we know should go to pusher1. + event_id = self._create_pusher_and_send_msg("user2") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_called_once() + http_client_mock2.post_json_get_json.assert_not_called() + self.assertEqual( + http_client_mock1.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock1.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + http_client_mock1.post_json_get_json.reset_mock() + http_client_mock2.post_json_get_json.reset_mock() + + # Now we choose a user name that we know should go to pusher2. + event_id = self._create_pusher_and_send_msg("user4") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_not_called() + http_client_mock2.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock2.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock2.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a80537c4fcd5..946f06d151f7 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1136,6 +1136,52 @@ def test_single_room(self): self.assertEqual(room_id_1, channel.json_body["room_id"]) + def test_room_members(self): + """Test that room members can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + # Have another user join the room + user_2 = self.register_user("bar", "pass") + user_tok_2 = self.login("bar", "pass") + self.helper.join(room_id_1, user_2, tok=user_tok_2) + self.helper.join(room_id_2, user_2, tok=user_tok_2) + + # Have another user join the room + user_3 = self.register_user("foobar", "pass") + user_tok_3 = self.login("foobar", "pass") + self.helper.join(room_id_2, user_3, tok=user_tok_3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): diff --git a/tests/test_federation.py b/tests/test_federation.py index 89dcc58b9950..87a16d7d7aa2 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -173,7 +173,7 @@ def query_user_devices(destination, user_id): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -218,23 +218,26 @@ def test_cross_signing_keys_retry(self): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { + return_value=succeed( + { "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) # Resync the device list. diff --git a/tox.ini b/tox.ini index e5aef3c062cf..8a506a38189f 100644 --- a/tox.ini +++ b/tox.ini @@ -127,7 +127,7 @@ deps = black==19.10b0 commands = python -m black --check --diff . - /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" + /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}" {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort]